Saving state for tf.function-decorated functions
By Dmitry Kabanov
When you decorate a function with `tf.function` decorator, sometimes you need to keep state between invocations of this function.
The problem here is that the changes to the state will not be visible in the decorated function if the state is saved in the Python variables.
To illustrate the problem, Tensorflow 2.2 is used:
import tensorflow as tf
print(tf.__version__)
2.2.0
To see the problem, let’s consider the following code. Let’s assume that we need to scale a given Tensor `x` and we do it using `tf.function`-decorated function `scale` for performance reasons. Besides, the scaling factor is saved as an object attribute. So, the code looks like this:
# %% Initial version of the calculator.
class Calc:
def __init__(self, alpha=1.0):
self.alpha = alpha
@tf.function
def scale(self, x):
return self.alpha * x
Now when we use Calc.scale
function while changing the value of the scaling
factor, we can see that the changes are ignored by Tensorflow:
# %% Data to work with.
x = tf.Variable([1.0, 2.0, 3.0], dtype=tf.float32)
# %% Run and see the bug.
c = Calc(alpha=1.0)
print(c.scale(x))
c.alpha = 2.0
print(c.scale(x))
2020-07-31 12:13:47.432493: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-07-31 12:13:47.452824: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fcbd0434600 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-07-31 12:13:47.452839: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
As you can see above, the value of alpha
has been changed from 1 to 2 between
two invocations, but the output of the scale
function did not changed!
The reason is that that the type of the variable alpha
is Python’s float
,
and it is not monitored by Tensorflow.
One possible fix to this is to use object-oriented
style of programming and save alpha
internally as a Tensorflow variable, while
allowing the code user to deal only with Python’s float
for simplicity.
This can be done using properties, and then the result is what expected:
# %% Fix problem with state mutation using object attributes.
class CalcVersionTwo:
def __init__(self, alpha=1.0):
self._alpha = tf.Variable(alpha, dtype=tf.float32)
@property
def alpha(self):
return self._alpha
@alpha.setter
def alpha(self, value):
self._alpha.assign(value)
@tf.function
def scale(self, x):
return self.alpha * x
# %% Run and see that the issue is fixed.
c_v2 = CalcVersionTwo(alpha=1.0)
print(c_v2.scale(x))
c_v2.alpha = 2.0
print(c_v2.scale(x))
tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([2. 4. 6.], shape=(3,), dtype=float32)
As you can see from the above example, when alpha
changes, the output of the
scale
function changes as well.
Another solution is to use function style of programming and pass the value of
alpha
as a function argument:
# %% Fix problem using functional style of programming.
class CalcVersionThree:
@tf.function
def scale(self, x, alpha):
return alpha * x
# %% See that the issue is fixed with using functional style of programming.
c_v3 = CalcVersionThree()
print(c_v3.scale(x, 1.0))
print(c_v3.scale(x, 2.0))
tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([2. 4. 6.], shape=(3,), dtype=float32)
As you can, again, vector x
is properly scaled when alpha
changes its values.
Of course, in this particular example, the whole existence of the class
CalcVersionThree
is a bit questionable as there is no state that it keeps,
however, I keep it as a class to make it less different from other code examples
here.
Comparing object-oriented and functional solutions, one can notice the usual
differences between them.
Object-oriented solution gives better API for the user, especially when the user
needs to scale multiple vectors using the same value of alpha
.
However, it is more difficult to understand the data flow.
The functional solution gives more direct data flow but forces the user to
explicitly pass alpha
every time when the function scale
is called.
To read more about the constraints of using tf.function
-decorated functions,
you can read Tensorflow’s Wiki.