Detects if tensorflow APIs such as tf.compat.v1.Session
or tf.distribute.experimental.ParameterServerStrategy
are used as they can introduce non-determinism.
1def tensorflow_avoid_using_nondeterministic_api_noncompliant():
2 import tensorflow as tf
3 data = tf.ones((1, 1))
4 # Noncompliant: Determinism of tf.compat.v1.Session
5 # can not be guaranteed in TF2.
6 tf.config.experimental.enable_op_determinism()
7 tf.compat.v1.Session(
8 target='', graph=None, config=None
9 )
10 layer = tf.keras.layers.Input(shape=[1])
11 model = tf.keras.models.Model(inputs=layer, outputs=layer)
12 model.compile(loss="categorical_crossentropy", metrics="AUC")
13 model.fit(x=data, y=data)
1def tensorflow_avoid_using_nondeterministic_api_compliant():
2 import tensorflow as tf
3 tf.random.set_seed(0)
4 # Compliant: uses deterministic API.
5 tf.config.experimental.enable_op_determinism()
6 data = tf.ones((1, 1))
7 layer = tf.keras.layers.Input(shape=[1])
8 model = tf.keras.models.Model(inputs=layer, outputs=layer)
9 model.compile(loss="categorical_crossentropy", metrics="AUC")
10 model.fit(x=data, y=data)