基于深度神经网络的多分类函数
原标题:基于深度神经网络的多分类函数
原文来自:CSDN 原文链接:https://blog.csdn.net/s0302017/article/details/102996872
运行软件:MATLAB
一、基于深度神经网络的多分类函数
函数 MultiClass10 为基于深度神经网络的多分类函数。
此处神经网络包括三层:输入层、一个隐藏层、输出层,其中输入层为待分类的数据,输出层为分类结果。
函数 MultiClass10 的代码如下(MultiClass10.m) function [ W1, W2 ] = MultiClass10(W1, W2, X, D)
%功能:利用深度神经网络对二维数据进行多分类(10 类) % 输入参数:W1 为输入层到隐藏层的权,W2 为隐藏层到输出层的权 % X 为待分类的数据,D 为真实分类数据 % 输出参数:W1 为训练后的输入层到隐藏层的权,W2 为训练后的隐藏层到输出层的权 alpha = 0.9; [ W, H, N ] = size(X); for k = 1 : N x = reshape( X(:, :, k), W*H, 1 ); %x: W*H x 1 d = D(k, :)'; %Forword Propatation v1 = W1 * x; y1 = Sigmoid(v1); v = W2 * y1; y = ReLU(v); %Back Propatation e = d - y; delta = e; %last layer error e1 = W2' * delta; delta1 = y1 .* (1-y1) .* e1; %first layer error,gradient descend %revise the weight cofficients dW1 = alpha * delta1 * x'; W1 = W1 + dW1; dW2 = alpha * delta * y1'; W2 = W2 + dW2; end
二、分类训练与测试
以下代码为调用函数 MultiClass10 对二维数据数据进行多分类,类别数为 10。
该程序包括两部分,第一部分为训练,得到神经网络模型。第二部分为测试,利用测试得到的神经网络进行实际分类。
训练数据为 10 组 5x5 的二维数据(灰度图像),数据内容为点阵格式的数字 0-9 测试数据亦为 10 组 5x5 的二维数据(灰度图像),数据内容为点阵格式的数字 0-9 为方便使用本程序,二维数据均在程序内部给定,也可以读取外部的图像数据。
代码如下:
clear clc %X:train data %W,H,N 分别表示二维数据的宽、高和组数,可以根据实际情况调整该数据 W = 5; H = 5; N = 10; X = zeros(W, H, N); X(:, :, 1) = [ 0 0 1 0 0; 0 1 0 1 0; 1 0 0 0 1; 0 1 0 1 0; 0 0 1 0 0 ]; X(:, :, 2) = [ 0 1 1 0 0; 0 0 1 0 0; 0 0 1 0 0; 0 0 1 0 0; 0 1 1 1 0 ]; X(:, :, 3) = [ 1 1 1 1 0; 0 0 0 0 1; 0 1 1 1 0; 1 0 0 0 0; 1 1 1 1 1 ]; X(:, :, 4) = [ 1 1 1 1 0; 0 0 0 0 1; 0 1 1 1 0; 0 0 0 0 1; 1 1 1 1 0 ]; X(:, :, 5) = [ 0 0 0 1 0; 0 0 1 1 0; 0 1 0 1 0; 1 1 1 1 1; 0 0 0 1 0 ]; X(:, :, 6) = [ 1 1 1 1 1; 1 0 0 0 0; 1 1 1 1 0; 0 0 0 0 1; 1 1 1 1 0 ]; X(:, :, 7) = [ 0 0 0 1 0; 0 0 1 0 0; 0 1 0 1 0; 1 0 0 1 0; 0 1 1 0 0 ]; X(:, :, 8) = [ 1 1 1 1 1; 0 0 0 1 0; 0 0 1 0 0; 0 1 0 0 0; 1 0 0 0 0 ]; X(:, :, 9) = [ 0 1 1 1 0; 0 1 0 1 0; 0 0 1 0 0; 0 1 0 1 0; 0 1 1 1 0 ]; X(:, :, 10) = [ 0 1 1 1 0; 0 1 0 1 0; 0 0 1 1 0; 0 0 0 1 0; 0 1 1 1 0 ]; %D: real classification result D = [ 1 0 0 0 0 0 0 0 0 0; 0 1 0 0 0 0 0 0 0 0; 0 0 1 0 0 0 0 0 0 0; 0 0 0 1 0 0 0 0 0 0; 0 0 0 0 1 0 0 0 0 0; 0 0 0 0 0 1 0 0 0 0; 0 0 0 0 0 0 1 0 0 0; 0 0 0 0 0 0 0 1 0 0; 0 0 0 0 0 0 0 0 1 0; 0 0 0 0 0 0 0 0 0 1; ]; W1 = 2 * rand(50, W*H) - 1; W2 = 2 * rand(10, 50) - 1; %%%%%%%%%%%%%%%%%%%% train begin %%%%%%%%%%%%%%%%%%%%%%%% for k = 1 : 1000 [ W1, W2 ] = MultiClass10( W1, W2, X, D ); end %results of train for k = 1 : N x = reshape( X(:, :, k), 25 ,1 ); v1 = W1 * x; y1 = Sigmoid(v1); v = W2 * y1; y = ReLU(v) end %%%%%%%%%%%%%%%%%%%% train end %%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%% test begin %%%%%%%%%%%%%%%%%%%%%%%% Xte = zeros(W, H, N);; Xte(:, :, 1) = [ 0 0 1 0 0; 0 1 0 1 0; 0 0 0 0 1; 0 1 0 0 0; 0 0 1 0 0 ]; Xte(:, :, 2) = [ 0 1 1 0 0; 0 1 1 0 0; 0 1 0 1 0; 0 0 0 1 0; 0 1 1 1 0 ]; Xte(:, :, 3) = [ 1 1 1 1 0; 0 0 0 0 1; 0 1 1 1 0; 1 0 0 0 1; 1 1 1 1 1 ]; Xte(:, :, 4) = [ 1 1 1 1 0; 0 0 0 0 1; 0 1 1 1 0; 1 0 0 0 1; 1 1 1 1 0 ]; Xte(:, :, 5) = [ 0 1 1 1 0; 0 1 0 0 0; 0 1 1 1 0; 0 0 0 1 0; 0 1 1 1 0 ]; Xte(:, :, 6) = [ 0 1 1 1 1; 1 0 0 0 0; 0 1 1 1 0; 0 0 0 1 0; 1 1 1 1 0 ]; Xte(:, :, 7) = [ 0 0 0 1 0; 0 0 1 0 0; 0 1 0 1 0; 1 0 0 1 0; 0 1 0 0 0 ]; Xte(:, :, 8) = [ 1 1 1 1 1; 0 0 0 1 0; 0 0 1 0 0; 0 1 0 0 0; 0 0 0 0 0 ]; Xte(:, :, 9) = [ 0 1 0 1 0; 0 1 0 1 0; 0 0 1 0 0; 0 1 0 1 0; 0 1 0 1 0 ]; Xte(:, :, 10) = [ 0 1 0 1 0; 0 1 0 1 0; 0 0 1 1 0; 0 0 0 1 0; 0 1 0 1 0 ]; %output results of test real_y = []; for k = 1 : N x = reshape( Xte(:, :, k), 25 ,1 ); v1 = W1 * x; y1 = Sigmoid(v1); v = W2 * y1; y = ReLU(v); disp( [ 'real_y(',num2str(k), ') = '] ) disp( y ) %display classification results end %show test data figure(1); imshow(1-X(:, :, 1)); figure(2); imshow(1-X(:, :, 2)) figure(3); imshow(1-X(:, :, 3)) figure(4); imshow(1-X(:, :, 4)) figure(5); imshow(1-X(:, :, 5)) figure(6); imshow(1-X(:, :, 6)) figure(7); imshow(1-X(:, :, 7)) figure(8); imshow(1-X(:, :, 8)) figure(9); imshow(1-X(:, :, 9)) figure(10); imshow(1-X(:, :, 10)) %%%%%%%%%%%%%%%%%%%% test end %%%%%%%%%%%%%%%%%%%%%%%%
5.测试结果略
作者:YangYF
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com
热门资源
Python 爬虫(二)...
所谓爬虫就是模拟客户端发送网络请求,获取网络响...
TensorFlow从1到2...
原文第四篇中,我们介绍了官方的入门案例MNIST,功...
TensorFlow从1到2...
“回归”这个词,既是Regression算法的名称,也代表...
机器学习中的熵、...
熵 (entropy) 这一词最初来源于热力学。1948年,克...
TensorFlow2.0(10...
前面的博客中我们说过,在加载数据和预处理数据时...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com