![机器学习系统:设计和实现](https://wfqqreader-1252317822.image.myqcloud.com/cover/853/52842853/b_52842853.jpg)
上QQ阅读APP看书,第一时间看更新
2.2.5 训练及保存模型
MindSpore提供了回调(Callback)机制,可以在训练过程中执行自定义逻辑。代码2.6使用框架提供的ModelCheckpoint函数,ModelCheckpoint函数可以保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。
代码2.6 定义模型保存
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10673.jpg?sign=1739251080-PZmbm3v1kBAnOZrKHttAd4RXkzEqhhGm-0-f64fe11302116f7e1013165b0ce04d77)
通过MindSpore提供的model.train接口可以方便地进行网络的训练,同时使用Loss-Monitor可以监控训练过程中损失(loss)值的变化,如代码2.7所示。
代码2.7 定义模型训练
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10674.jpg?sign=1739251080-x2X4o3RSnUCICvfZ64GE8CGBBXD5Ei7L-0-0a9460c3dcb36031897d465a4439b4c6)
其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到设备(Device)上,可以加快训练速度,dataset_sink_mode为真(True),表示数据下沉,否则为非下沉。
有了数据集、模型、损失函数、优化器后就可以进行训练了。代码2.8把train_epoch设置为1,对数据集进行1次迭代训练。在train_net方法中,加载了之前下载的训练数据集,mnist_path是MNIST数据集路径。
代码2.8 训练模型
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10675.jpg?sign=1739251080-gcyCP8wDCDVTAql4LZAnXiKyIPiXrHdz-0-5e44e45a929763c601e7decd406f2548)