博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow
阅读量:4973 次
发布时间:2019-06-12

本文共 1530 字,大约阅读时间需要 5 分钟。

Mnist识别

首先载入tensorflow库,并创建一个新的interactiveSession,使用这个命令会将session注册为默认的session,之后的运算也会跑在这个session里,不同session之间的数据和运算应该是相互独立的,接下来创建一个placeholder,即输入数据的地方,placeholder的第一个参数是数据类型,第二个参数[None784]代表Tensorshape,也就是数据的尺寸,这里的None代表不限条数的输入,784代表每条输入是一个784维的向量。

   接下来给softmax 回归模型中的weightsbias创建variable对象,variable是存储模型参数的。不同于存储数据的Tensor一旦使用掉就会消失,variable在模型训练迭代中时持久化的(比如一直存放在显存中),他可以长期存在并且在每轮迭代中被更新。我们把weightsbias全部初始化为0,因为模型训练时会自动学习合适的值,所以对这个简单模型来说初始值不太重要,不过对于复杂的卷积网络,循环网络或者比较深的全连接网络,初始方法就比较重要。Wshape是【784,10】,784是特征的维数,而后面的10代表有10类,因为Labelone-hot编码后是10维的向量。

TensorFlow中的softmaxtf.nn下面的一个函数,tf.nn则包含了大量神经网络的组件,tf.matmulTensorFlow中的矩阵乘法函数,我们使用一行代码就定义了softmax regression,但他最厉害的地方不是定义公式,而是将forwardbackward的内容都自动实现,只要定义好loss,训练时就会自动求导并进行梯度下降。

为了训练模型,我们需要定义一个loss function来描述对问题的分类精度,loss越小代表模型的分类结果与真实值的偏差越小。

我们采用常见的随机梯度下降SGD,TensorFlow就会根据我们定义的整个流程图自动求导,并根据反向传播算法进行训练,在每一轮迭代时更新参数来减小loss,在后台TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降  ,而我们提供的就是一个封装好的优化器,只需要在每轮迭代时feed数据给他就好,我们直接调用tf.train.GradientDescentOptimizer,并设置学习速率为0.5,优化目标设定为cross-entropy,得到进行训练的操作train-step,当然TensorFlow还有许多其他的优化器,使用起来也非常方便,只需要修改函数名就可以

 

下一步使用TensorFlow的全局参数初始化器tf.gloabl_variables_initializer,并执行它的run方法

tf.gloabl_variables_initializer().run()

最后一步开始迭代执行训练操作train_step对这些样本进行训练。使用一小部分样本进行训练称为随机梯度下降,与每次使用全部样本的传统梯度下降相对应,如果每次训练都使用全部样本,计算量太大,有时也不容易跳出局部最优。对于大部分机器学习问题,我们都使用一部分数据进行随机梯度下降,这种做法绝大多数时候比全样本训练的收敛速度快的多

现在我们完成了训练,接下来对模型的准确率进行验证,下面代码中的tf.argmax(y,1)就是求各个预测的数字中概率最大的那一个,tf.argmax(y_,1)则是找样本的真实数字类别,最后返回计算分类是否正确的操作。

转载于:https://www.cnblogs.com/yangwenzhe/p/9933878.html

你可能感兴趣的文章
[django]form的content-type(mime)
查看>>
JQUERY —— 绑定事件
查看>>
在TabControl中的TabPage选项卡中添加Form窗体
查看>>
oracle中SET DEFINE意思
查看>>
个人作业-最长英语链
查看>>
JMeter-性能测试之报表设定的注意事项
查看>>
1066-堆排序
查看>>
仿面包旅行个人中心下拉顶部背景放大高斯模糊效果
查看>>
强大的css3
查看>>
[Luogu] 引水入城
查看>>
放张图片试试
查看>>
【WEB】高并发Web服务的演变-节约系统内存和CPU
查看>>
逻辑漏洞挖掘方式
查看>>
Servlet 编写过滤器
查看>>
Redis 数据类型
查看>>
Console-算法-回文数
查看>>
C#常用格式输出
查看>>
创建数据库表的SQL语句
查看>>
在Visual Studio 2010[VC++]中使用ffmpeg类库
查看>>
POJ 1488 TEX Quotes
查看>>