文章目录
猫狗大战背景介绍
代码示例
step1 对模型的修改
step2 数据的输入
step3 模型的重新训练与存储
step4 模型的复用
猫狗大战背景介绍猫狗大战数据集来源于Kaggle上的一个竞赛:Dogs vs. Cats,猫狗大战的数据集下载地址,其中数据集有12500只猫和12500只狗
http://www.kaggle.com/c/dogs-vs-cats
使用Finetuning对VGGNet进行调整,从而针对猫狗大战的训练集进行训练,创建工程文件,所有素材如下
代码示例首先是对模型的修改(VGG16_model.py文件),在这里原先的输出结果是对1000个不同的类别进行判定,而在此是对2个图像,也就是猫和狗的判断,因此首先第一步就是修改输出层的全连接数据
这里是最后一层的输出通道被设置成2,而对于其他部分,定义创建卷积层和全连接层的方法则无需做出太大改动。
对于修改后的模型,需要对其进行重新训练,而首要条件就是数据输入,在这里笔者使用数据的输入流方式。代码如下
这里定义的get_file函数对输入文件的文件夹进行分类,通过以不同的文件夹作为分类标准将图片分为2类,使用2个列表文件分别用来存储图片地址和对应的标记地址,同时我们需要按照程序的要求,将train文件夹中的图片,分成cat和dog 文件夹,如图所示:
get_batch函数是通过对列表地址的读取而循环载入具有参数batch_size大小而定的图片,并读取相应的图片标签作为数据标签一同进行训练,完整定义如下:
Finetuning最重要的一个步骤就是模型的重新训练与存储。首先对于模型的值的输出,在类中已经做了定义,因此只需要将定义的模型类初始化后输出赋予一个特定的变量即可
这里同时定义了损失函数已经最小化方法,完整代码如下:
在训练函数中使用了Tensorflow的队列方式进行数据输入,而对于权重的重新载入也使用的是前面文章类似的方式,最终数据进行200次迭代,存储模型在model文件夹中。
Copyright © 2024 妖气游戏网 www.17u1u.com All Rights Reserved