传统神经网络训练方法介绍

本节演示如何使用shark的传统神经网络,以及如何保存和读取model
数据集使用经典的MNIST.csv,文件过大就不在这里提供下载了,网上都能找到

#include <iostream>
#include <fstream>
#include <shark/Data/Csv.h>
#include <shark/Models/LinearModel.h>
#include <shark/Models/ConcatenatedModel.h>
#include <shark/ObjectiveFunctions/ErrorFunction.h>
#include <shark/ObjectiveFunctions/Loss/CrossEntropy.h>
#include <shark/ObjectiveFunctions/Loss/ZeroOneLoss.h>
#include <shark/Algorithms/GradientDescent/Adam.h>
using namespace shark;
using std::cout;
using std::endl;

const unsigned nEpoch = 500;
const unsigned nBatchSize = 100;

//演示如何使用传统神经网络,以及如何保存和读取model
int main(int argc, const char * argv[])
{
    try
    {
        //1.读取数据
        ClassificationDataset dataTrain;
        importCSV(dataTrain, "MNIST.csv", FIRST_COLUMN, ',', '#', nBatchSize);       //可以在导入数据时指定Dataset里一个batch含多少元素,默认为256
        //提取数据信息
        auto nClassNum = numberOfClasses(dataTrain);
        cout << "Class number: " << nClassNum << endl;
        auto nInputDim = inputDimension(dataTrain);
        cout << "Input dimension: " << nInputDim << endl;
        auto vecClassSize = classSizes(dataTrain);      // number of occurrences of every class label
        cout << "Class distribution:" << endl;
        for (unsigned i = 0; i != 10; ++i)
            cout << i << ": " << vecClassSize[i] << endl;
        auto dataTest = splitAtElement(dataTrain, 0.7 * dataTrain.numberOfBatches());       //切分训练集和测试集也可使用batch划分
        /*关于Dataset的进阶说明:
         Dataset内部使用batch而不是element来组织内容,这样的好处是batch确保内存连续分配可以加快计算速度
         如果想要获取Dataset的某个元素,则需要遍历batch找到该元素,因此复杂度是O(n),这点要特别注意。有一个DataView类可以让查找为O(1)但是需要O(n)的空间
         Dataset可以进行分割,有两种方式:1.splice (aka splitAtBatch),这个是成员函数。2.splitAtElement,这个是全局函数
         分割或拷贝并不会引发deep copy,这样节省时间和内存。如果要deep copy则调用成员函数makeIndependent
        */
        //2.搭建神经网络拓扑结构
        LinearModel<RealVector, RectifierNeuron> layer1(nInputDim, 200, true);      //使用LinearModel来表示神经网络中的一层layer。第一个模板参数是input类型,第二个是activation function(这里用的是ReLu)
        LinearModel<RealVector, RectifierNeuron> layer2(200, 100, true);        //构造参数第一个是input dimension,第二个是output dimension,第三个指定是否与offset(即是否有常量x0 = 1)
        LinearModel<RealVector> output(100, nClassNum, true);       //最后一层不用指定activation,output应该与标签种类相等。注意层与层的input -> output的维度应该相接
        auto network = layer1 >> layer2 >> output;      //使用>>将若干层拼接得到最终的神经网络
        initRandomNormal(network,0.001);        //初始化神经网络,即将parameters随机赋值
        //3.训练(神经网络没有trainer所以需要手动处理)
        CrossEntropy<unsigned int, RealVector> lossFunc;        //定义loss function,用于衡量预测与实际差多少。这里的CrossEntropy常用于神经网络,可以计算两种概率分布之间的差异
        ErrorFunction<> errorFunc(dataTrain, &network, &lossFunc, true);//建立objective function,将data+model+loss合起来。最后一个参数指定是否使用mini-batch
        //注:这里ErrorFunction是一种objective function,语义上来说当error最小时也就达到了最优解
        errorFunc.init();           //目标函数需要初始化
        Adam<> optimizer;       //准备好目标函数后需要一个optimizer来实际解这个问题。此处Adaptive Moment Estimation Algorithm是一种常用的optimizer
        optimizer.setEta(0.001);           //set learning rate
        optimizer.init(errorFunc);      //optimizer也要初始化,并指定目标函数
        cout << "training network" << endl;
        for (unsigned i = 0; i != nEpoch; ++i)
        {
            //Mini-Batch Gradient Descent是梯度下降的变种,思路是在整个数据集里选一部分称为batch,然后用这部分求解,再更新模型参数
            //优点是快速,是神经网络的首选方法。在shark里一个minibatch是从Dataset的batch里随机选出的。这也意味着建立Dataset时指定的batch大小就是minibatch的大小
            optimizer.step(errorFunc);          //迭代一次
            cout<<i<<" "<<optimizer.solution().value<<endl;     //输出当前训练成果
        }
        network.setParameterVector(optimizer.solution().point);     //copy solution parameters into model
        //4.评估
        ZeroOneLoss<unsigned int,RealVector> loss01;
        Data<RealVector> predictionTrain = network(dataTrain.inputs());
        cout << "classification error,train: " << loss01.eval(dataTrain.labels(), predictionTrain) << endl;
        Data<RealVector> predictionTest = network(dataTest.inputs());
        cout << "classification error,test: " << loss01.eval(dataTest.labels(), predictionTest) << endl;
        //5.保存和读取模型(shark直接使用了boost::serialization库来保存和读取,该库可以读写任意对象)
        //save
        //std::ofstream ofs("ann.model");
        //boost::archive::polymorphic_text_oarchive oa(ofs);
        //network.write(oa);
        //ofs.close();
        //load(读取时对象结构必须跟保存时一样)
        //std::ifstream ifs("ann.model");
        //boost::archive::polymorphic_text_iarchive ia(ifs);
        //network.read(ia);
        //ifs.close();
    }
    catch (const std::exception &e)
    {
        cout << e.what() << endl;
    }
    return 0;
}