| |
| import tensorflow as tf |
| from baselines.common.tf_util import ( |
| function, |
| initialize, |
| single_threaded_session |
| ) |
|
|
|
|
| def test_function(): |
| with tf.Graph().as_default(): |
| x = tf.compat.v1.placeholder(tf.int32, (), name="x") |
| y = tf.compat.v1.placeholder(tf.int32, (), name="y") |
| z = 3 * x + 2 * y |
| lin = function([x, y], z, givens={y: 0}) |
|
|
| with single_threaded_session(): |
| initialize() |
|
|
| assert lin(2) == 6 |
| assert lin(x=3) == 9 |
| assert lin(2, 2) == 10 |
| assert lin(x=2, y=3) == 12 |
|
|
|
|
| def test_multikwargs(): |
| with tf.Graph().as_default(): |
| x = tf.compat.v1.placeholder(tf.int32, (), name="x") |
| with tf.compat.v1.variable_scope("other"): |
| x2 = tf.compat.v1.placeholder(tf.int32, (), name="x") |
| z = 3 * x + 2 * x2 |
|
|
| lin = function([x, x2], z, givens={x2: 0}) |
| with single_threaded_session(): |
| initialize() |
| assert lin(2) == 6 |
| assert lin(2, 2) == 10 |
|
|
|
|
| if __name__ == '__main__': |
| test_function() |
| test_multikwargs() |
|
|