本节使用Linear Regression展示shark进阶用法
数据集使用的红酒质量,feature是若干红酒属性数值,结果是质量打分0-10
#include <iostream> #include <shark/Data/Csv.h> #include <shark/Algorithms/Trainers/NormalizeComponentsUnitInterval.h> #include <shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h> #include <shark/Data/CVDatasetTools.h> #include <shark/Algorithms/Trainers/LinearRegression.h> #include <shark/ObjectiveFunctions/Loss/SquaredLoss.h> using namespace shark; using std::cout; using std::endl; int main(int argc, const char * argv[]) { try { //1.读取数据 RegressionDataset data; //类型为LabeledData<RealVector,RealVector>,专门用于预测数值使用 importCSV(data, "winequality-red.csv", LAST_COLUMN); //2.normalize data //normalize的思路跟训练类似,都是定义model和trainer然后train,得到的模型再作用于data Normalizer<> normalizer; //shark里大部分模板默认类型为RealVector,所以可以不写 //NormalizeComponentsUnitInterval<> normalizingTrainer; //normalize every input dimension to the range [0,1],所以称为unit interval,即区间是统一的 NormalizeComponentsUnitVariance<RealVector> normalizingTrainer(true); //adjust the variance of each component to one, and can optionally remove the mean,所以称为unit variance,即variance是统一的 normalizingTrainer.train(normalizer, data.inputs()); //使用data的feature部分进行训练,至于内部实现就不清楚了 data = transformInputs(data, normalizer); //得到normalizer模型后使用transformInputs作用于原始数据,返回处理过后的数据。此处也可以自己定义函数对象传入 //exportCSV(data, "normalize_2.csv", LAST_COLUMN); //3.使用k-fold-cross-validation进行训练 double dAcumResult = 0.0; CVFolds<RegressionDataset> folds = createCVSameSize(data, 10); //k = 10,每个fold规模一样。如果要使用stratified sampling用createCVSameSizeBalanced for (unsigned i = 0; i != folds.size(); ++i) { // access the fold RegressionDataset training = folds.training(i); RegressionDataset validation = folds.validation(i); // train LinearModel<> model; LinearRegression trainer; trainer.train(model, training); //evaluate Data<RealVector> prediction = model(validation.inputs()); SquaredLoss<> loss; double dResult = loss(validation.labels(), prediction); cout << "squared loss: " << dResult << endl; dAcumResult += dResult; } cout << "mean loss: " << dAcumResult / folds.size() << endl; //补充说明:貌似normalize没有什么作用,至少对这个数据集不normalize、UnitInterval、UnitVariance最后得到的loss都一样 //推测可能是trainer内部做了什么调整? } catch (const std::exception &e) { cout << e.what() << endl; } return 0; }