Tensorflow: how to save/restore a model?


After you train a model in Tensorflow:

How do you save the trained model?
How do you later restore this saved model?

New and shorter way: simple_save

Many good answer, for completeness I’ll add my 2 cents: simple_save. Also a standalone code example using the tf.data.Dataset API.

Python 3 ; Tensorflow 1.7


Standalone example

The following code generates random data for the sake of the demonstration.

We start by creating the placeholders. They will hold the data at runtime. From them, we create the Dataset and then its Iterator. We get the iterator’s generated tensor, called input_tensor which will serve as input to our model.
The model itself is built from input_tensor: a GRU-based bidirectional RNN followed by a dense classifier. Because why not.
The loss is a softmax_cross_entropy_with_logits, optimized with Adam. After 2 epochs (of 2 batches each), we save the “trained” model with tf.saved_model.simple_save. If you run the code as is, then the model will be saved in a folder called simple/ in your current working directory.
In a new graph, we then restore the saved model with tf.saved_model.loader.load. We grab the placeholders and logits with graph.get_tensor_by_name and the Iterator initializing operation with graph.get_operation_by_name.
Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!


This will print:

Tensorflow: how to save/restore a model?
Rate this post