edo1z blog

プログラミングなどに関するブログです

TensorFlow - パラメータを保存して使う

学習したパラメータを保存するのは、下記のようにやります。

Saverを呼び出して、

saver = tf.train.Saver()

saveメソッドで保存します。

saver.save(sess, '.\ckpt\model.ckpt')

Windowsだけかもしれませんが、save()の第二引数に保存するファイル名を入れますが、パスを明記しないとエラーでます。もしかしたら、上記の場合ckptディレクトリをつくっとかないとエラーでるかもです。サイトによっては絶対パスを設定しているのもありましたが、上記のように相対パスでもWindows10ですがエラーになりませんでした。

Saverを呼び出す前にVariableを書いておかないと保存できないっぽいです。

下記で、保存したモデル(パラメタ)があるかチェックできます。

ckpt = tf.train.get_checkpoint_state('.\ckpt')

下記のように、ckptあるかチェックして、あったら、ckptのmodel_checkpoint_pathをrestoreメソッドの第二引数に入れて、restoreメソッドを実行することで、保存したパラメタを読み込めます。

if ckpt:
  server.restore(sess, ckpt.model_checkpoint_path)