欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

基于 OpenCV4 的 SVM 算法用于手写数字识别

最编程 2024-03-31 10:10:12
...

SVM支持向量机是一种分类模型,其分类思想源于感知机,通常的说法是相当于有一个隐层的神经网络。
OpenCV4支持SVM算法的API调用,下面列出相关的代码实现。

训练的主要代码:
    // 读取大的手写数字图片,该图片的分辨率为2000x1000
    cv::Mat matBase = cv::imread("images.png");
    cv::cvtColor(matBase, matBase, cv::COLOR_BGR2GRAY);

    int nWidth = matBase.cols;
    int nHeight = matBase.rows;
    int nRows = nHeight / 20;
    int nCols = nWidth / 20;
    int nPicCount = nRows * nCols;

    // 提取训练图片,制作训练集和标签集
    std::vector< cv::Mat > vecTrain;
    cv::Mat matTrain = cv::Mat::zeros(nPicCount, 400, CV_8UC1);
    cv::Mat matLabel = cv::Mat::zeros(nPicCount, 1, CV_8UC1);
    for (int i = 0; i < nRows; ++i)
    {
        for (int j = 0; j < nCols; ++j)
        {
            cv::Mat matPic = matBase(cv::Rect(j * 20, i * 20, 20, 20));
            matPic.clone().reshape(1, 1).copyTo(matTrain(cv::Rect(0, i * nCols + j, 400, 1)));

            matLabel.at<uchar>(i * nCols + j, 0) = i / 5;
        }
    }

    matTrain.convertTo(matTrain, CV_32FC1);
    matLabel.convertTo(matLabel, CV_32SC1);

    // 创建训练集
    auto pTrainData = cv::ml::TrainData::create(matTrain, cv::ml::ROW_SAMPLE, matLabel);

    // 创建SVM模型,并执行训练
    auto pSVM = cv::ml::SVM::create();
    pSVM->setType(cv::ml::SVM::Types::C_SVC);
    pSVM->setKernel(cv::ml::SVM::KernelTypes::LINEAR);
    pSVM->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6));
    pSVM->train(pTrainData);
    pSVM->save("svm.xml");

该代码使用了一张包含5000个20x20手写字的图片用于训练,并将训练的结果保存为knn.yaml文件,手写字图片前文KNN算法相同,这里不再列出。

预测的主要代码:
    cv::Mat matPic2_1 = cv::imread("2_1.png");
    cv::Mat matPic7_1 = cv::imread("7_1.png");
    cv::cvtColor(matPic2_1, matPic2_1, cv::COLOR_BGR2GRAY);
    cv::cvtColor(matPic7_1, matPic7_1, cv::COLOR_BGR2GRAY);
    cv::resize(matPic2_1, matPic2_1, cv::Size(20, 20));
    cv::resize(matPic7_1, matPic7_1, cv::Size(20, 20));
    matPic2_1.convertTo(matPic2_1, CV_32F);
    matPic7_1.convertTo(matPic7_1, CV_32F);

    cv::Mat matPredict(2, 400, CV_32F);
    matPic2_1.reshape(1, 1).copyTo(matPredict(cv::Rect(0, 0, 400, 1)));
    matPic7_1.reshape(1, 1).copyTo(matPredict(cv::Rect(0, 1, 400, 1)));

    // 加载SVM模型
    auto pSVM = cv::ml::SVM::load("svm.xml");

    // 执行预测
    cv::Mat matResult;
    pSVM->predict(matPredict, matResult);

    printf("result: %.00f %.00f \n", matResult.at<float>(0, 0), matResult.at<float>(1, 0));

该代码使用了前文KNN算法中的两张手写照片,这里不再贴出。

运行预测程序,执行结果为:
result: 2 3
其中一个数字识别错误,应该是我写得不好,或者说训练的样本不够。