资源经验分享Python——h5模型文件转pb模型文件

Python——h5模型文件转pb模型文件

2019-11-05 | |  80 |   0

原标题:Python——h5模型文件转pb模型文件

原文来自:CSDN      原文链接:https://blog.csdn.net/zichen_ziqi/article/details/102878198


                                         Python——h5模型文件转pb模型文件

本文主要介绍:

  • h5模型文件如何转存为pb模型,并加载进行测试;

  • bugfix:多个pb模型加载,除第一个pb模型成功加载,其余模型graph为空的bug。

 

第一部分——h5模型文件如何转存为pb模型,并加载进行测试

1、h5模型转pb模型的函数,以及加载pb模型的两种方式

#!/usr/bin/env python
# -*- coding:utf-8 -*-
 
"""
@Time     :2019/8/10
@Name     :GeekZW
@Contact  :1223242863@qq.com
@File     :freeze_util.py
@Software :Pycharm
"""
 
 
import tensorflow as tf
from tensorflow.python.framework import  graph_util
from tensorflow.python.platform import gfile
 
 
# h5模型转pb模型文件
def freeze_session(session, output, output_names=None):
    output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值进行固定
        sess=session,
        input_graph_def=session.graph_def,   # 等价于sess.graph_def
        output_node_names=output_names       # 若有多个输出节点,以逗号隔开
    )
 
    with tf.gfile.GFile(output, "wb") as file:  # 保存模型
        file.write(output_graph_def.SerializeTostring())  # 序列化输出
    print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点
 
 
# 加载pb模型文件 —— 方案1
def load_pb_model(pb_model_file_path):
    with gfile.FastGFile(pb_model_file_path, "rb") as file:
        new_graph = tf.GraphDef()
        new_graph.ParseFromString(file.read())
        tf.import_graph_def(new_graph, name='')
    return tf.get_default_graph()
 
 
# 加载pb模型文件 —— 方案2
def load_pb_model2(pb_model_file_path):
    graph = tf.Graph()
    with graph.as_default():
        with tf.gfile.GFile(pb_model_file_path, "rb") as file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file.read())
            tf.import_graph_def(graph_def, name='')
        sess = tf.Session(graph=graph)
    return sess

2、h5模型转pb模型的主函数

#!/usr/bin/env python
# -*- coding:utf-8 -*-
 
"""
@Time     :2019/8/10
@Name     :GeekZW
@Contact  :1223242863@qq.com
@File     :demo.py
@Software :Pycharm
"""
 
from tensorflow.python.keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from tensorflow.python.keras import backend
from tensorflow.python.platform import gfile
from freeze_util import freeze_session
from freeze_util import load_pb_model
 
 
# 输出pb模型文件的节点名
def get_pb_model_node_names(pb_model_file_path):
    with tf.Session() as sess:
        print("load graph")
        with gfile.FastGFile(pb_model_file_path, "rb") as file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file.read())
            tf.import_graph_def(graph_def, name='')
        for i, n in enumerate(graph_def.node):
            print("Name of the node - %s" % n.name)
 
 
def run(h5_model_file_path, output_pb_model_file_path):
    # 加载h5模型
    h5_model = load_model(h5_model_file_path)
 
    # 显示h5模型的输出节点名字
    output_names = [h5_model.output.name[:-2]]
    print(output_names)
 
    # 显示模型中所有层的节点
    print(h5_model.get_layer(index=0).output_shape)
    for layer in h5_model.layers:
        print(layer.output)
 
    # 将h5模型转存为pb模型
    freeze_session(backend.get_session(), output_pb_model_file_path, output_names)
    print("model saved!")
 
    # 加载pb模型,并打印所有的节点
    pb_model = load_pb_model(output_pb_model_file_path)
    for idx, layer in enumerate(pb_model._nodes_by_name):
        print(pb_model._nodes_by_name[layer]._outputs)
 
 
if __name__ == "__main__":
    
    # 输入路径
    input_path = os.path.abspath('..')
    weight_file_name = "xxx.h5"
    weight_file_path = osp.join(input_path, weight_file_name)
 
    # 输出路径
    output_graph_name = weight_file_name[:-3] + ".pb"
    output_pb_file_path = osp.join(input_path, output_graph_name)
 
    # 运行函数,将h5模型文件转存为pb模型文件
    run(weight_file_path, output_pb_file_path)
 
    # 结果:xxx.h5模型转存为xxx.pb模型成功

3、加载pb模型进行预测

#!/usr/bin/env python
# -*- coding:utf-8 -*-
 
"""
@Time     :2019/8/10
@Name     :GeekZW
@Contact  :1223242863@qq.com
@File     :predict.py
@Software :Pycharm
"""
 
import tensorflow as tf
from freeze_util import load_pb_model
import time
 
 
if __name__ == "__main__":
    test_content = "待测试内容"
    pb_model = load_pb_model("xxx.pb")  # 加载pb模型文件
 
    with tf.Session() as sess:
        pre = pb_model.get_tensor_by_name("输出节点名")  # 如"output/Softmax:0"
        x = pb_model.get_tensor_by_name("输入节点名")  # 如"input:0"
 
        # 利用pb模型进行预测
        start_time = time.time()
        predict_res = sess.run(pre, feed_dict={x: [test_content]})
        end_time = time.time()
        print("耗时: {0} ms !".format(round(1000 * (end_time - start_time), 3)))

说明:以上代码中的tensorflow的版本是1.14.0,如果不是该版本,无法进行加载,请更新版本,命令:pip install tensorflow == 1.14.0

 

第二部分——bugfix:多个pb模型加载,除第一个pb模型成功加载,其余模型graph为空的bug

主要问题:graph被覆盖,需要加载出第二个,第三个pb模型,需要对每个pb模型进行初始化。具体例子如下:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
 
"""
@Time     :2019/8/10
@Name     :GeekZW
@Contact  :1223242863@qq.com
@File     :predict.py
@Software :Pycharm
"""
 
import tensorflow as tf
from freeze_util import load_pb_model
import time
import numpy as np
 
 
def test_only1_pb_model():
    test_content = "待测试内容"
    pb_model = load_pb_model("xxx.pb")  # 加载pb模型文件
 
    with tf.Session() as sess:
        pre = pb_model.get_tensor_by_name("输出节点名")  # 如"output/Softmax:0"
        x = pb_model.get_tensor_by_name("输入节点名")  # 如"input:0"
 
        # 利用pb模型进行预测
        start_time = time.time()
        predict_res = sess.run(pre, feed_dict={x: [test_content]})
        end_time = time.time()
        print("耗时: {0} ms !".format(round(1000 * (end_time - start_time), 3)))
 
 
def test_multi_pb_model():
    test_content = "待测试内容"
    labels = ["class_A", "class_B", "class_C"]
 
    # 加载第一个pb模型进行预测
    g1 = tf.Graph()
    sess1 = tf.Session(graph=g1)
    with sess1.as_default():
        with g1.as_default():
            pb_model1 = load_pb_model("xxx.pb")  # 加载pb模型文件
 
    with sess1.as_default():
        with sess1.as_default():
            pre = pb_model1.get_tensor_by_name("输出节点名")  # 如"output/Softmax:0"
            x = pb_model1.get_tensor_by_name("输入节点名")    # 如"input:0"
 
            # 利用pb模型进行预测,并计算耗时
            start_time = time.time()
            predict_res = sess1.run(pre, feed_dict={x: [test_content]})
            end_time = time.time()
 
            y = np.argmax(np.array(predict_res[0]))
            label, score = labels[y], predict_res[y]
 
            print("{0}的预测结果为{1}, 平均耗时为{2} ms !".format(test_content, label, round(1000 * (end_time - start_time), 3)))
 
    # 加载第二个pb模型进行预测
    g2 = tf.Graph()
    sess2 = tf.Session(graph=g2)
    with sess2.as_default():
        with g2.as_default():
            pb_model2 = load_pb_model("yyy.pb")  # 加载pb模型文件
 
    with sess2.as_default():
        with sess2.as_default():
            pre = pb_model2.get_tensor_by_name("输出节点名")  # 如"output/Softmax:0"
            x = pb_model2.get_tensor_by_name("输入节点名")  # 如"input:0"
 
            # 利用pb模型进行预测,并计算耗时
            start_time = time.time()
            predict_res = sess2.run(pre, feed_dict={x: [test_content]})
            end_time = time.time()
 
            y = np.argmax(np.array(predict_res[0]))
            label, score = labels[y], predict_res[y]
 
            print("{0}的预测结果为{1}, 平均耗时为{2} ms !".format(test_content, label, round(1000 * (end_time - start_time), 3)))
 
 
if __name__ == "__main__":
    test_only1_pb_model()  # 单个pb模型进行预测
    test_multi_pb_model()  # 多个pb模型进行预测

说明:Python中的end_time - start_time得到的单位是:秒(s),不是毫秒(ms)

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:确定单峰函数的极小点所在的区间

下一篇:window10万能搭建tensorflowGPU,CPU双环境并配置pycharm环境

用户评价
全部评价

热门资源

  • Python 爬虫(二)...

    所谓爬虫就是模拟客户端发送网络请求,获取网络响...

  • TensorFlow从1到2...

    原文第四篇中,我们介绍了官方的入门案例MNIST,功...

  • TensorFlow从1到2...

    “回归”这个词,既是Regression算法的名称,也代表...

  • 机器学习中的熵、...

    熵 (entropy) 这一词最初来源于热力学。1948年,克...

  • TensorFlow2.0(10...

    前面的博客中我们说过,在加载数据和预处理数据时...