Introduction To Variables - TensorFlow Core
Introduction To Variables - TensorFlow Core
Run in
Google (https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/variab
Colab
A TensorFlow variable is the recommended way to represent shared, persistent state your
program manipulates. This guide covers how to create, update, and manage instances of
tf.Variable (https://www.tensorflow.org/api_docs/python/tf/Variable) in TensorFlow.
Setup
This notebook discusses variable placement. If you want to see on what device your variables
are placed, uncomment this line.
import tensorflow as tf
A variable looks and acts like a tensor, and, in fact, is a data structure backed by a tf.Tensor
(https://www.tensorflow.org/api_docs/python/tf/Tensor). Like tensors, they have a dtype and a
shape, and can be exported to NumPy.
Shape: (2, 2)
DType: <dtype: 'float32'>
As NumPy: [[1. 2.]
[3. 4.]]
Most tensor operations work on variables as expected, although variables cannot be reshaped.
As noted above, variables are backed by tensors. You can reassign the tensor using
tf.Variable.assign (https://www.tensorflow.org/api_docs/python/tf/Variable#assign). Calling
assign does not (usually) allocate a new tensor; instead, the existing tensor's memory is
reused.
a = tf.Variable([2.0, 3.0])
# This will keep the same dtype, float32
a.assign([1, 2])
# Not allowed as it resizes the variable:
try:
a.assign([1.0, 2.0, 3.0])
except Exception as e:
print(f"{type(e).__name__}: {e}")
If you use a variable like a tensor in operations, you will usually operate on the backing tensor.
Creating new variables from existing variables duplicates the backing tensors. Two variables
will not share the same memory.
a = tf.Variable([2.0, 3.0])
# Create b based on the value of a
b = tf.Variable(a)
a.assign([5, 6])
[5. 6.]
[2. 3.]
[7. 9.]
[0. 0.]
Variables can also be named which can help you track and debug them. You can give two
variables the same name.
# Create a and b; they will have the same name but will be backed by
# different tensors.
a = tf.Variable(my_tensor, name="Mark")
# A new variable with the same name, but different value
# Note that the scalar add is broadcast
b = tf.Variable(my_tensor + 1, name="Mark")
Variable names are preserved when saving and loading models. By default, variables in models
will acquire unique variable names automatically, so you don't need to assign them yourself
unless you want to.
Although variables are important for differentiation, some variables will not need to be
differentiated. You can turn off gradients for a variable by setting trainable to false at
creation. An example of a variable that would not need gradients is a training step counter.
However, you can override this. In this snippet, place a float tensor and a variable on the CPU,
even if a GPU is available. By turning on device placement logging (see Setup
(#scrollTo=xZoJJ4vdvTrD)), you can see where the variable is placed.
If you run this notebook on different backends with and without a GPU you will see different
logging. Note that logging device placement must be turned on at the start of the session.
with tf.device('CPU:0'):
print(c)
tf.Tensor(
[[22. 28.]
[49. 64.]], shape=(2, 2), dtype=float32)
It's possible to set the location of a variable or tensor on one device and do the computation on
another device. This will introduce delay, as data needs to be copied between the devices.
You might do this, however, if you had multiple GPU workers but only want one copy of the
variables.
with tf.device('CPU:0'):
a = tf.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.Variable([[1.0, 2.0, 3.0]])
with tf.device('GPU:0'):
# Element-wise multiply
k = a * b
print(k)
tf.Tensor(
[[ 1. 4. 9.]
[ 4. 10. 18.]], shape=(2, 3), dtype=float32)
Next steps
To understand how variables are typically used, see our guide on automatic differentiation
(/guide/autodiff).
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License
(https://creativecommons.org/licenses/by/4.0/), and code samples are licensed under the Apache 2.0 License
(https://www.apache.org/licenses/LICENSE-2.0). For details, see the Google Developers Site Policies
(https://developers.google.com/site-policies). Java is a registered trademark of Oracle and/or its affiliates.