欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

Python 项目练习 - 4. 手写数字识别

最编程 2024-03-27 07:04:39
...
import numpy as np  
from sklearn import datasets, svm, metrics  
from sklearn.model_selection import train_test_split  
from sklearn.preprocessing import StandardScaler  
from sklearn.neural_network import MLPClassifier  
  
# 1. 数据准备  
# 加载MNIST数据集  
digits = datasets.load_digits()  
  
# 2. 数据预处理  
# 将图像数据展平为一维数组  
n_samples = len(digits.images)  
data = digits.images.reshape((n_samples, -1))  
  
# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, shuffle=False)  
  
# 数据标准化  
scaler = StandardScaler()  
X_train = scaler.fit_transform(X_train)  
X_test = scaler.transform(X_test)  
  
# 3. 模型选择  
# 使用多层感知机(MLP)作为分类器  
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,  
                    solver='sgd', verbose=10, random_state=1,  
                    learning_rate_init=.1)  
  
# 4. 模型训练  
mlp.fit(X_train, y_train)  
  
# 5. 模型评估  
# 预测测试集结果  
predictions = mlp.predict(X_test)  
  
# 计算准确率  
print("Classification report for classifier %s:\n%s\n"  
      % (mlp, metrics.classification_report(y_test, predictions)))  
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predictions))