Python——h5模型文件转pb模型文件
原标题:Python——h5模型文件转pb模型文件
原文来自:CSDN 原文链接:https://blog.csdn.net/zichen_ziqi/article/details/102878198
本文主要介绍:
h5模型文件如何转存为pb模型,并加载进行测试;
bugfix:多个pb模型加载,除第一个pb模型成功加载,其余模型graph为空的bug。
第一部分——h5模型文件如何转存为pb模型,并加载进行测试
1、h5模型转pb模型的函数,以及加载pb模型的两种方式
#!/usr/bin/env python # -*- coding:utf-8 -*- """ @Time :2019/8/10 @Name :GeekZW @Contact :1223242863@qq.com @File :freeze_util.py @Software :Pycharm """ import tensorflow as tf from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile # h5模型转pb模型文件 def freeze_session(session, output, output_names=None): output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值进行固定 sess=session, input_graph_def=session.graph_def, # 等价于sess.graph_def output_node_names=output_names # 若有多个输出节点,以逗号隔开 ) with tf.gfile.GFile(output, "wb") as file: # 保存模型 file.write(output_graph_def.SerializeTostring()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点 # 加载pb模型文件 —— 方案1 def load_pb_model(pb_model_file_path): with gfile.FastGFile(pb_model_file_path, "rb") as file: new_graph = tf.GraphDef() new_graph.ParseFromString(file.read()) tf.import_graph_def(new_graph, name='') return tf.get_default_graph() # 加载pb模型文件 —— 方案2 def load_pb_model2(pb_model_file_path): graph = tf.Graph() with graph.as_default(): with tf.gfile.GFile(pb_model_file_path, "rb") as file: graph_def = tf.GraphDef() graph_def.ParseFromString(file.read()) tf.import_graph_def(graph_def, name='') sess = tf.Session(graph=graph) return sess
2、h5模型转pb模型的主函数
#!/usr/bin/env python # -*- coding:utf-8 -*- """ @Time :2019/8/10 @Name :GeekZW @Contact :1223242863@qq.com @File :demo.py @Software :Pycharm """ from tensorflow.python.keras.models import load_model import tensorflow as tf import os import os.path as osp from tensorflow.python.keras import backend from tensorflow.python.platform import gfile from freeze_util import freeze_session from freeze_util import load_pb_model # 输出pb模型文件的节点名 def get_pb_model_node_names(pb_model_file_path): with tf.Session() as sess: print("load graph") with gfile.FastGFile(pb_model_file_path, "rb") as file: graph_def = tf.GraphDef() graph_def.ParseFromString(file.read()) tf.import_graph_def(graph_def, name='') for i, n in enumerate(graph_def.node): print("Name of the node - %s" % n.name) def run(h5_model_file_path, output_pb_model_file_path): # 加载h5模型 h5_model = load_model(h5_model_file_path) # 显示h5模型的输出节点名字 output_names = [h5_model.output.name[:-2]] print(output_names) # 显示模型中所有层的节点 print(h5_model.get_layer(index=0).output_shape) for layer in h5_model.layers: print(layer.output) # 将h5模型转存为pb模型 freeze_session(backend.get_session(), output_pb_model_file_path, output_names) print("model saved!") # 加载pb模型,并打印所有的节点 pb_model = load_pb_model(output_pb_model_file_path) for idx, layer in enumerate(pb_model._nodes_by_name): print(pb_model._nodes_by_name[layer]._outputs) if __name__ == "__main__": # 输入路径 input_path = os.path.abspath('..') weight_file_name = "xxx.h5" weight_file_path = osp.join(input_path, weight_file_name) # 输出路径 output_graph_name = weight_file_name[:-3] + ".pb" output_pb_file_path = osp.join(input_path, output_graph_name) # 运行函数,将h5模型文件转存为pb模型文件 run(weight_file_path, output_pb_file_path) # 结果:xxx.h5模型转存为xxx.pb模型成功
3、加载pb模型进行预测
#!/usr/bin/env python # -*- coding:utf-8 -*- """ @Time :2019/8/10 @Name :GeekZW @Contact :1223242863@qq.com @File :predict.py @Software :Pycharm """ import tensorflow as tf from freeze_util import load_pb_model import time if __name__ == "__main__": test_content = "待测试内容" pb_model = load_pb_model("xxx.pb") # 加载pb模型文件 with tf.Session() as sess: pre = pb_model.get_tensor_by_name("输出节点名") # 如"output/Softmax:0" x = pb_model.get_tensor_by_name("输入节点名") # 如"input:0" # 利用pb模型进行预测 start_time = time.time() predict_res = sess.run(pre, feed_dict={x: [test_content]}) end_time = time.time() print("耗时: {0} ms !".format(round(1000 * (end_time - start_time), 3)))
说明:以上代码中的tensorflow的版本是1.14.0,如果不是该版本,无法进行加载,请更新版本,命令:pip install tensorflow == 1.14.0
第二部分——bugfix:多个pb模型加载,除第一个pb模型成功加载,其余模型graph为空的bug
主要问题:graph被覆盖,需要加载出第二个,第三个pb模型,需要对每个pb模型进行初始化。具体例子如下:
#!/usr/bin/env python # -*- coding:utf-8 -*- """ @Time :2019/8/10 @Name :GeekZW @Contact :1223242863@qq.com @File :predict.py @Software :Pycharm """ import tensorflow as tf from freeze_util import load_pb_model import time import numpy as np def test_only1_pb_model(): test_content = "待测试内容" pb_model = load_pb_model("xxx.pb") # 加载pb模型文件 with tf.Session() as sess: pre = pb_model.get_tensor_by_name("输出节点名") # 如"output/Softmax:0" x = pb_model.get_tensor_by_name("输入节点名") # 如"input:0" # 利用pb模型进行预测 start_time = time.time() predict_res = sess.run(pre, feed_dict={x: [test_content]}) end_time = time.time() print("耗时: {0} ms !".format(round(1000 * (end_time - start_time), 3))) def test_multi_pb_model(): test_content = "待测试内容" labels = ["class_A", "class_B", "class_C"] # 加载第一个pb模型进行预测 g1 = tf.Graph() sess1 = tf.Session(graph=g1) with sess1.as_default(): with g1.as_default(): pb_model1 = load_pb_model("xxx.pb") # 加载pb模型文件 with sess1.as_default(): with sess1.as_default(): pre = pb_model1.get_tensor_by_name("输出节点名") # 如"output/Softmax:0" x = pb_model1.get_tensor_by_name("输入节点名") # 如"input:0" # 利用pb模型进行预测,并计算耗时 start_time = time.time() predict_res = sess1.run(pre, feed_dict={x: [test_content]}) end_time = time.time() y = np.argmax(np.array(predict_res[0])) label, score = labels[y], predict_res[y] print("{0}的预测结果为{1}, 平均耗时为{2} ms !".format(test_content, label, round(1000 * (end_time - start_time), 3))) # 加载第二个pb模型进行预测 g2 = tf.Graph() sess2 = tf.Session(graph=g2) with sess2.as_default(): with g2.as_default(): pb_model2 = load_pb_model("yyy.pb") # 加载pb模型文件 with sess2.as_default(): with sess2.as_default(): pre = pb_model2.get_tensor_by_name("输出节点名") # 如"output/Softmax:0" x = pb_model2.get_tensor_by_name("输入节点名") # 如"input:0" # 利用pb模型进行预测,并计算耗时 start_time = time.time() predict_res = sess2.run(pre, feed_dict={x: [test_content]}) end_time = time.time() y = np.argmax(np.array(predict_res[0])) label, score = labels[y], predict_res[y] print("{0}的预测结果为{1}, 平均耗时为{2} ms !".format(test_content, label, round(1000 * (end_time - start_time), 3))) if __name__ == "__main__": test_only1_pb_model() # 单个pb模型进行预测 test_multi_pb_model() # 多个pb模型进行预测
说明:Python中的end_time - start_time得到的单位是:秒(s),不是毫秒(ms)
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com
热门资源
Python 爬虫(二)...
所谓爬虫就是模拟客户端发送网络请求,获取网络响...
TensorFlow从1到2...
原文第四篇中,我们介绍了官方的入门案例MNIST,功...
TensorFlow从1到2...
“回归”这个词,既是Regression算法的名称,也代表...
机器学习中的熵、...
熵 (entropy) 这一词最初来源于热力学。1948年,克...
TensorFlow2.0(10...
前面的博客中我们说过,在加载数据和预处理数据时...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com