[Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model
原标题:[Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model
原文来自:博客园 原文链接:https://www.cnblogs.com/zlian2016/p/11158403.html
在 parameters.py 中,定义了各类参数。
# training data directory TRAINING_DATA_DIR = './data/' # checkpoint directory CHECKPOINT_DIR = './training_checkpoints/' # training details BATCH_SIZE = 16 BUFFER_SIZE = 128 EPOCHS = 15
在 numpy_dataset.py 中,创建了 5000 组训练数据集,模拟 y = x^3 + 1,并二进制格式写入文件。
from parameters import TRAINING_DATA_DIR import numpy as np import matplotlib.pyplot as plt import os # create training data X = np.linspace(-1, 1, 5000) np.random.shuffle(X) y = X ** 3 + 1 + np.random.normal(0, 0.01, (5000,)) # plot training data plt.scatter(X, y) plt.show() # save data if not os.path.exists(TRAINING_DATA_DIR): os.makedirs(TRAINING_DATA_DIR) X.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_X.bin')) y.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_y.bin'))
在 subclassed_model.py 中,通过对 tf.keras.models.Model 进行子类化,设计了两个自定义模型。
import tensorflow as tf tf.enable_eager_execution() # model definition class Encoder(tf.keras.models.Model): def __init__(self): super(Encoder, self).__init__() self.fc1 = tf.keras.layers.Dense(units=16, activation='relu') self.fc2 = tf.keras.layers.Dense(units=8, activation='relu') def call(self, inputs): r = self.fc1(inputs) return self.fc2(r) class Decoder(tf.keras.models.Model): def __init__(self): super(Decoder, self).__init__() self.fc = tf.keras.layers.Dense(units=1, activation=None) def call(self, inputs): return self.fc(inputs)
在 loss_function.py 中,定义了损失函数。
import tensorflow as tf2 tf.enable_eager_execution()3 4 5 def loss(real, pred):6 return tf.losses.mean_squared_error(labels=real, predictions=pred)
在 training.py 中,使用在 numpy_dataset.py 中创建的数据集训练模型,之后使用 model.save_weights() 保存 Keras Subclassed Model 模型,并创建验证集验证模型。
from parameters import TRAINING_DATA_DIR, CHECKPOINT_DIR, BATCH_SIZE, BUFFER_SIZE, EPOCHS from subclassed_model import * from loss_function import loss import os import numpy as np import matplotlib.pyplot as plt # load training data training_X = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_X.bin'), dtype=np.float64) training_y = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_y.bin'), dtype=np.float64) # plot training data plt.scatter(training_X, training_y) plt.show() # training dataset training_dataset = tf.data.Dataset.from_tensor_slices((training_X, training_y)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE) # model instance encoder = Encoder() decoder = Decoder() # optimizer optimizer = tf.train.AdamOptimizer() # checkpoint checkpoint_prefix_encoder = os.path.join(CHECKPOINT_DIR, 'encoder/', 'ckpt') checkpoint_prefix_decoder = os.path.join(CHECKPOINT_DIR, 'decoder/', 'ckpt') if not os.path.exists(os.path.dirname(checkpoint_prefix_encoder)): os.makedirs(os.path.dirname(checkpoint_prefix_encoder)) if not os.path.exists(os.path.dirname(checkpoint_prefix_decoder)): os.makedirs(os.path.dirname(checkpoint_prefix_decoder)) # training step for epoch in range(EPOCHS): epoch_loss = 0 for (batch, (tx, ty)) in enumerate(training_dataset): x = tf.cast(tx, tf.float32) y = tf.cast(ty, tf.float32) x = tf.expand_dims(x, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) y = tf.expand_dims(y, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) with tf.GradientTape() as tape: y_ = encoder(x) # tf.Tensor([...], shape=(BATCH_SIZE, 8), dtype=float32) prediction = decoder(y_) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) batch_loss = loss(real=y, pred=prediction) variables = encoder.variables + decoder.variables grads = tape.gradient(batch_loss, variables) optimizer.apply_gradients(zip(grads, variables), global_step=tf.train.get_or_create_global_step()) epoch_loss += batch_loss if (batch + 1) % 100 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch + 1, batch_loss.numpy())) print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss / len(training_X))) if (epoch + 1) % 5 == 0: encoder.save_weights(checkpoint_prefix_encoder) decoder.save_weights(checkpoint_prefix_decoder) # create evaluation data X = np.linspace(-1, 1, 3000) np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE) ey = [] for (batch, ex) in enumerate(evaluation_X): x = tf.cast(ex, tf.float32) x = tf.expand_dims(x, axis=1) prediction = decoder(encoder(x)) for i in range(len(prediction.numpy())): ey.append(prediction.numpy()[i]) plt.scatter(X, ey) plt.show() # evaluate eval_x = [[0.5]] tensor_x = tf.convert_to_tensor(eval_x) print(decoder(encoder(tensor_x)))
验证集评价结果如下图所示。
使用测试样例 eval_x 进行测试,测试结果如下。
tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)
在 evaluate.py 中,使用 model.load_weights() 恢复 Keras Subclassed Model 模型,并在验证集上进行验证,验证结果如下图所示。
from parameters import CHECKPOINT_DIR, BATCH_SIZE from subclassed_model import * import os import numpy as np import matplotlib.pyplot as plt # load model enc = Encoder() dec = Decoder() enc.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'encoder/'))) dec.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'decoder/'))) # create evaluation data X = np.linspace(-1, 1, 3000) np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE) ey = [] for (batch, ex) in enumerate(evaluation_X): x = tf.cast(ex, tf.float32) x = tf.expand_dims(x, axis=1) prediction = dec(enc(x)) for i in range(len(prediction.numpy())): ey.append(prediction.numpy()[i]) plt.scatter(X, ey) plt.show() # evaluate eval_x = [[0.5]] tensor_x = tf.convert_to_tensor(eval_x) print(dec(enc(tensor_x))) # model summary enc.summary() dec.summary()
使用测试样例 eval_x 进行测试,测试结果如下。
tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)
恢复模型的测试结果,与训练后模型的测试结果一致,且无需 build 模型。
版权声明:本文为博主原创文章,欢迎转载,转载请注明作者及原文出处!
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com
热门资源
Python 爬虫(二)...
所谓爬虫就是模拟客户端发送网络请求,获取网络响...
TensorFlow从1到2...
原文第四篇中,我们介绍了官方的入门案例MNIST,功...
TensorFlow从1到2...
“回归”这个词,既是Regression算法的名称,也代表...
机器学习中的熵、...
熵 (entropy) 这一词最初来源于热力学。1948年,克...
TensorFlow2.0(10...
前面的博客中我们说过,在加载数据和预处理数据时...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com