Content Overview
- Limitations
- Executing Python side effects
- All outputs of a tf.function must be return values
- Recursive tf.functions are not supported
- Known Issues
- Depending on Python global and free variables
- Depending on Python objects
- Creating tf.Variables
- Further reading
Limitations
tf.function
has a few limitations by design that you should be aware of when converting a Python function to a tf.function
.
Executing Python side effects
Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a tf.function
, sometimes executing twice or not all. They only happen the first time you call a tf.function
with a set of inputs. Afterwards, the traced tf.Graph
is reexecuted, without executing the Python code.
The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like tf.data
, tf.print
, tf.summary
, tf.Variable.assign
, and tf.TensorArray
are the best way to ensure your code will be executed by the TensorFlow runtime with each call.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2
If you would like to execute Python code during each invocation of a tf.function
, tf. py_function
is an exit hatch. The drawbacks of tf.py_function
are that it’s not portable or particularly performant, cannot be saved with SavedModel
, and does not work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function
has to be wired into the graph, it casts all inputs/outputs to tensors.
@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
print('Executing eagerly.')
return x + y
@tf.function
def tf_wrapper(x, y):
print('Tracing.')
return py_plus(x, y)
The tf.function
will trace the first time:
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Tracing.
Executing eagerly.
3.0
But the tf.py_function
inside executes eagerly every time:
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Executing eagerly.
3.0
Changing Python global and free variables
Changing Python global and free variables counts as a Python side effect, so it only happens during tracing.
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
Sometimes unexpected behaviors are very hard to notice. In the example below, the counter
is intended to safeguard the increment of a variable. However because it is a python integer and not a TensorFlow object, it’s value is captured during the first trace. When the tf.function
is used, the assign_add
will be recorded unconditionally in the underlying graph. Therefore v
will increase by 1, every time the tf.function
is called. This issue is common among users that try to migrate their Graph-mode Tensorflow code to Tensorflow 2 using tf.function
decorators, when python side-effects (the counter
in the example) are used to determine what ops to run (assign_add
in the example). Usually, users realize this only after seeing suspicious numerical results, or significantly lower performance than expected (e.g. if the guarded operation is very costly).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1
2
3
A workaround to achieve the expected behavior is using tf.init_scope
to lift the operations outside of the function graph. This ensures that the variable increment is only done once during tracing time. It should be noted init_scope
has other side effects including cleared control flow and gradient tape. Sometimes the usage of init_scope
can become too complex to manage realistically.
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1
1
1
In summary, as a rule of thumb, you should avoid mutating python objects such as integers or containers like lists that live outside the tf.function
. Instead, use arguments and TF objects. For example, the section “Accumulating values in a loop” has one example of how list-like operations can be implemented.
You can, in some cases, capture and manipulate state if it is a tf.Variable
. This is how the weights of Keras models are updated with repeated calls to the same ConcreteFunction
.
Using Python iterators and generators
Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, they are examples of Python side effects and therefore only happen during tracing.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1
Just like how TensorFlow has a specialized tf.TensorArray
for list constructs, it has a specialized tf.data.Iterator
for iteration constructs. See the section on AutoGraph transformations for an overview. Also, the tf.data
API can help implement generator patterns:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3
All outputs of a tf.function must be return values
With the exception of tf.Variable
s, a tf.function must return all its outputs. Attempting to directly access any tensors from a function without going through return values causes “leaks”.
For example, the function below “leaks” the tensor a
through the Python global x
:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3
'SymbolicTensor' object has no attribute 'numpy'
This is true even if the leaked value is also returned:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 21, in <module>
captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 7, in <module>
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 4, in leaky_function
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/override_binary_operator.py", line 113, in binary_op_wrapper
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py", line 28, in _add_dispatch_factory
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1701, in _add_dispatch
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2682, in _create_op_internal
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1177, in from_node_def
The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139959630636096), which is out of scope.
Usually, leaks such as these occur when you use Python statements or data structures. In addition to leaking inaccessible tensors, such statements are also likely wrong because they count as Python side effects, and are not guaranteed to execute at every function call.
Common ways to leak local tensors also include mutating an external Python collection, or an object:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
Recursive tf.functions are not supported
Recursive tf.function
s are not supported and could cause infinite loops. For example,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception
<class 'Exception'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 9, in <module>
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn *
return recursive_fn(n - 1)
RecursionError: maximum recursion depth exceeded while calling a Python object
Even if a recursive tf.function
seems to work, the Python function will be traced multiple times and could have performance implications. For example,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
Known Issues
If your tf.function
is not evaluating correctly, the error may be explained by these known issues which are planned to be fixed in the future.
Depending on Python global and free variables
tf.function
creates a new ConcreteFunction
when called with a new value of a Python argument. However, it does not do that for the Python closure, globals, or nonlocals of that tf.function
. If their value changes in between calls to the tf.function
, the tf.function
will still use the values they had when it was traced. This is different from how regular Python functions work.
For that reason, you should follow a functional programming style that uses arguments instead of closing over outer names.
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)
Another way to update a global value is to make it a tf.Variable
and use the Variable.assign
method instead.
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)
Depending on Python objects
Passing custom Python objects as arguments to tf.function
is supported but has certain limitations.
For maximum feature coverage, consider transforming the objects into Extension types before passing them to tf.function
. You can also use Python primitives and tf.nest
-compatible structures.
However, as covered in the rules of tracing, when a custom TraceType
is not provided by the custom Python class, tf.function
is forced to use instance-based equality which means it will not create a new trace when you pass the same object with modified attributes.
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)
Using the same tf.function
to evaluate the modified instance of the model will be buggy since it still has the same instance-based TraceType as the original model.
For that reason, you’re recommended to write your tf.function
to avoid depending on mutable object attributes or implement the Tracing Protocol for the objects to inform tf.function
about such attributes.
If that is not possible, one workaround is to make new tf.function
s each time you modify your object to force retracing:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)
As retracing can be expensive, you can use tf.Variable
s as object attributes, which can be mutated (but not changed, careful!) for a similar effect without needing a retrace.
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)
Creating tf.Variables
tf.function
only supports singleton tf.Variable
s created once on the first call, and reused across subsequent function calls. The code snippet below would create a new tf.Variable
in every function call, which results in a ValueError
exception.
Example:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/3018268426.py", line 7, in <module>
f(1.0)
ValueError: in user code:
File "/tmpfs/tmp/ipykernel_167534/3018268426.py", line 3, in f *
v = tf.Variable(1.0)
ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
A common pattern used to work around this limitation is to start with a Python None value, then conditionally create the tf.Variable
if the value is None:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
Using with multiple Keras optimizers
You may encounter ValueError: tf.function only supports singleton tf.Variables created on the first call.
when using more than one Keras optimizer with a tf.function
. This error occurs because optimizers internally create tf.Variable
s when they apply gradients for the first time.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/950644149.py", line 18, in <module>
train_step(w, x, y, opt2)
ValueError: in user code:
File "/tmpfs/tmp/ipykernel_167534/950644149.py", line 9, in train_step *
optimizer.apply_gradients(zip(gradients, [w]))
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 291, in apply_gradients **
self.apply(grads, trainable_variables)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 330, in apply
self.build(trainable_variables)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 97, in build
self.add_variable_from_reference(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/optimizer.py", line 36, in add_variable_from_reference
return super().add_variable_from_reference(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 227, in add_variable_from_reference
return self.add_variable(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 201, in add_variable
variable = backend.Variable(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/common/variables.py", line 163, in __init__
self._initialize_with_initializer(initializer)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/core.py", line 40, in _initialize_with_initializer
self._value = tf.Variable(
ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
If you need to change a stateful object between calls, it’s simplest to define a tf.Module
subclass, and create instances to hold those objects:
class TrainStep(tf.Module):
def __init__(self, optimizer):
self.optimizer = optimizer
@tf.function
def __call__(self, w, x, y):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
self.optimizer.apply_gradients(zip(gradients, [w]))
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)
train_o1(w, x, y)
train_o2(w, x, y)
You could also do this manually by creating multiple instances of the @tf.function
wrapper, one for each optimizer:
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new tf.function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y, opt1)
else:
train_step_2(w, x, y, opt2)
Using with multiple Keras models
You may also encounter ValueError: tf.function only supports singleton tf.Variables created on the first call.
when passing different model instances to the same tf.function
.
This error occurs because Keras models (which do not have their input shape defined) and Keras layers create tf.Variable
s when they are first called. You may be attempting to initialize those variables inside a tf.function
, which has already been called. To avoid this error, try calling model.build(input_shape)
to initialize all the weights before training the model.
Further reading
To learn about how to export and load a tf.function
, see the SavedModel guide. To learn more about graph optimizations that are performed after tracing, see the Grappler guide. To learn how to optimize your data pipeline and profile your model, see the Profiler guide.
Originally published on the