Using `tf.function` for performance in Tensorflow 2
By Dmitry Kabanov
Tensorflow 2 uses so called Eager mode by default. In this mode, it is easy to define tensors interactively, for example, in ipython and see their values. However, in Eager mode the execution is slow, which becomes noticable during model training.
Tensorflow 2 offers another mode of execution called Graph mode. In this mode, first the computational graph is created and then used to compute loss function and its gradient. This mode is more performance efficient.
It is possible to convert a function automatically to the Graph mode by decorating it with the tf.function decorator. This gives a simple way to speed up a function.
Let’s compare the execution time when using each of these modes. I will use Tensorflow 2.2 in this post:
import tensorflow as tf
print("Tensorflow version: ", tf.__version__)
Tensorflow version: 2.2.0
As an example, consider computation of the quantity \[ \left(\sin x \right)^2 + \left( \cos y \right)^2 \] with \(x\) and \(y\) being two-dimensional arrays (matrices). We define two functions (one without and another with the decorator) and create the arrays:
def f_slow(x, y):
return tf.square(tf.sin(x)) + tf.square(tf.cos(y))
@tf.function
def f_fast(x, y):
return tf.square(tf.sin(x)) + tf.square(tf.cos(y))
x = tf.random.normal((4000, 1000))
y = tf.random.normal((4000, 1000))
Now let’s measure the difference in the performance:
import timeit
t_slow = timeit.timeit(
stmt='f_slow(x, y)',
setup='from __main__ import f_slow, x, y',
number=100
)
print(t_slow)
t_fast = timeit.timeit(
stmt='f_fast(x, y)',
setup='from __main__ import f_fast, x, y',
number=100
)
print(t_fast)
1.02903594000054
0.6089697299994441
You can see that the runtime of the decorated function is only 60% of that
of the pure Python function.
In general, tf.function
decorator should be used whenever you have some
computation-heavy function.
Good example is one training step for a neural network in which you need to
evaluate loss function, compute its gradient and then update the neural-network
parameters.