资源经验分享仙人掌的图片识别

仙人掌的图片识别

2019-09-25 | |  98 |   0

原标题:仙人掌的图片识别

原文来自:CSDN      原文链接:https://blog.csdn.net/awd5174119/article/details/101154527


仙人掌的图片识别

数据来自kaggle的Aerial Cactus Identification项目(https://www.kaggle.com/c/aerial-cactus-identification)

下载数据

# 终端运行
kaggle competitions download -c aerial-cactus-identification
unzip train.zip
unzip test.zip
mkdir test_input
cp -r test/ test_input

# ls
# train.zip test.zip train test train.csv sample_submission.csv123456789

train.csv是训练集的标签信息表

sample_submission.csv是kaggle最后上传结果的文件夹

导入tensorflow包

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import tensorflow as tf
from tensorflow import keras

import os
from shutil import copyfile, move
from tqdm import tqdm
import h5py1234567891011

查看tensorflow版本和gpu是否可用

print(tf.__version__)
print(tf.test.is_gpu_available())
# 2.0.0-beta0
# False1234

整理数据

读取训练集信息表

training_df = pd.read_csv("train.csv")1

将训练集数据分类,分为正负两个文件夹

src = "train/"
dst = "sorted_training/"

os.mkdir(dst)
os.mkdir(dst+"true")
os.mkdir(dst+"false")

with tqdm(total=len(list(training_df.iterrows()))) as pbar:
    for idx, row in training_df.iterrows():
        pbar.update(1)
        if row["has_cactus"] == 1:
            copyfile(src+row["id"], dst+"true/"+row["id"])
        else:
            copyfile(src+row["id"], dst+"false/"+row["id"])1234567891011121314

将训练集的十分之一分到验证集,来验证结果

src = "sorted_training/"
dst = "sorted_validation/"

os.mkdir(dst)
os.mkdir(dst+"true")
os.mkdir(dst+"false")

validation_df = training_df.sample(n=int(len(training_df)/10))

with tqdm(total=len(list(validation_df.iterrows()))) as pbar:
    for idx, row in validation_df.iterrows():
        pbar.update(1)
        if row["has_cactus"] == 1:
            move(src+"true/"+row["id"], dst+"true/"+row["id"])
        else:
            move(src+"false/"+row["id"], dst+"false/"+row["id"])12345678910111213141516

构建模型

导入模型包

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import InputLayer, Input
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout, Activation
from tensorflow.keras.layers import BatchNormalization, Reshape, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint123456

使用ImageDataGenerator导入数据要求把数据整理为

data/
	sorted_training/
		true/
			1.jpg
			2.jpg
			3.jpg
		false/
			1.jpg
			2.jpg
			3.jpg
	sorted_validation/
		true/
			1.jpg
			2.jpg
			3.jpg
		false/
			1.jpg
			2.jpg
			3.jpg
	test_input/
		test/
			1.jpg
			2.jpg
			3.jpg123456789101112131415161718192021222324

数据读取

batch_size = 64

train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    horizontal_flip=True,
    vertical_flip=True)

train_data_dir = "sorted_training"
train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    shuffle=True,
    target_size=(32, 32),
    batch_size=batch_size,
    class_mode='binary')


validation_datagen = ImageDataGenerator(rescale=1. / 255)
validation_data_dir = "sorted_validation"
validation_generator = validation_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(32, 32),
    batch_size=batch_size,
    class_mode='binary')

input_shape = (32,32,3)
num_classes = 21234567891011121314151617181920212223242526

模型构建(引自https://www.kaggle.com/frlemarchand/simple-cnn-using-keras/notebook)

dropout_dense_layer = 0.6

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('relu'))
model.add(Dropout(dropout_dense_layer))

model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(dropout_dense_layer))

model.add(Dense(1))
model.add(Activation('sigmoid'))12345678910111213141516171819202122232425262728293031323334353637383940

设定损失函数和优化器,并用准确度作为判读标准

model.compile(loss=keras.losses.binary_crossentropy,
              optimizer=keras.optimizers.Adam(lr=0.001),
              metrics=['accuracy'])123

25次epoch,loss没有降低,停止模型训练

callbacks = [EarlyStopping(monitor='val_loss', patience=25),
             ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]12

设定迭代100次

epochs = 100
history = model.fit_generator(train_generator,
          validation_data=validation_generator,
          epochs=epochs,
          verbose=1,
          shuffle=True,
          callbacks=callbacks)1234567

绘图loss和val_loss

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()123

绘图准确度和验证集准确度

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.show()123

模型训练完毕

测试模型

载入刚才获得的模型

model.load_weights("best_model.h5")1

然后载入测试集

test_folder = "test_input/"
test_datagen = ImageDataGenerator(
    rescale=1. / 255)

test_generator = test_datagen.flow_from_directory(
    directory=test_folder,
    target_size=(32,32),
    batch_size=1,
    class_mode='binary',
    shuffle=False
)1234567891011

预测测试集

pred=model.predict_generator(test_generator,verbose=1)
pred_binary = [0 if value<0.50 else 1 for value in pred]12

最后把预测的结果记录在csv文件里

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:MatConvNet训练自己的网络

下一篇:使用 matplotlib 画图的保存方法有两种

用户评价
全部评价

热门资源

  • Python 爬虫(二)...

    所谓爬虫就是模拟客户端发送网络请求,获取网络响...

  • TensorFlow从1到2...

    原文第四篇中,我们介绍了官方的入门案例MNIST,功...

  • TensorFlow从1到2...

    “回归”这个词,既是Regression算法的名称,也代表...

  • 机器学习中的熵、...

    熵 (entropy) 这一词最初来源于热力学。1948年,克...

  • TensorFlow2.0(10...

    前面的博客中我们说过,在加载数据和预处理数据时...