资源经验分享机器学习之线性回归使用Python和tensorflow实现

机器学习之线性回归使用Python和tensorflow实现

2019-12-13 | |  108 |   0

原标题:机器学习之线性回归使用Python和tensorflow实现

原文来自:博客园      原文链接:http://www.xinhuanet.com/2019-09/09/c_1124976366.htm


导入依赖包

import tensorflow as tf
import numpy as np
import matplotlib.pylab as pltfrom pylab 
import mplmpl.rcParams['font.sans-serif'] = ['SimHei']

生成直线数据并加入噪音画图显示

train_x = np.linspace(-1, 1, 100)  # 生成 -1 到 1之间 分成100份
# print(train_x)noise = np.random.randn(*train_x.shape) * 0.3
train_y = 2 * train_x + noise  # 给每一个点加上噪音
# print(noise)plt.plot(train_x, train_y, "go", label="我的初始数据")
plt.legend()
plt.show()

定义模型的输入和输出

x = tf.placeholder("float")
y = tf.placeholder("float")
# 定义并初始化模型的权重偏置
w = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 定义模型的前向传播过程y_predict = tf.multiply(w, x) + b

定义模型的损失函数,反向传播

cost = tf.reduce_mean(tf.square(y - y_predict))
learning_rate = 0.01  # 定义学习率
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  # 定义优化器
init = tf.global_variables_initializer()

定义超参数

train_epochs = 100
display_epoch = 2

训练模型

with tf.Session() as sess:
    sess.run(init)
    plotdata = {"epochs": [], "cost": []}  # 保存训练得到的参数
    for epoch in range(train_epochs):        for X, Y in zip(train_x, train_y):
            sess.run(optimizer, feed_dict={x: X, y: Y})        if epoch % display_epoch == 0:
            print("训练的epoch为:%d cost为:%f %f" % (epoch+1, sess.run(cost, feed_dict={x: train_x, y: train_y}), sess.run(b)))
            plotdata["epochs"].append(epoch+1)
            plotdata["cost"].append(sess.run(cost, feed_dict={x: train_x, y: train_y}))

    print("训练完成")
    print("模型训练的结果为: ", "w", sess.run(w), "b:", sess.run(b), "cost:",
          sess.run(cost, feed_dict={x: train_x, y: train_y}))

画图显示

    plt.plot(train_x, train_y, "go", label="我的初始数据")    
    plt.plot(train_x, sess.run(w) * train_x + sess.run(b), label='Fitted line')    
    plt.legend()
    plt.show()

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:强化学习算法Policy Gradient

下一篇:人工智能基础知识复习:机器学习

用户评价
全部评价

热门资源

  • Python 爬虫(二)...

    所谓爬虫就是模拟客户端发送网络请求,获取网络响...

  • TensorFlow从1到2...

    原文第四篇中,我们介绍了官方的入门案例MNIST,功...

  • TensorFlow从1到2...

    “回归”这个词,既是Regression算法的名称,也代表...

  • 机器学习中的熵、...

    熵 (entropy) 这一词最初来源于热力学。1948年,克...

  • TensorFlow2.0(10...

    前面的博客中我们说过,在加载数据和预处理数据时...