Python 项目练习 - 4. 手写数字识别
最编程
2024-03-27 07:04:39
...
import numpy as np
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
# 1. 数据准备
# 加载MNIST数据集
digits = datasets.load_digits()
# 2. 数据预处理
# 将图像数据展平为一维数组
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, shuffle=False)
# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 3. 模型选择
# 使用多层感知机(MLP)作为分类器
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='sgd', verbose=10, random_state=1,
learning_rate_init=.1)
# 4. 模型训练
mlp.fit(X_train, y_train)
# 5. 模型评估
# 预测测试集结果
predictions = mlp.predict(X_test)
# 计算准确率
print("Classification report for classifier %s:\n%s\n"
% (mlp, metrics.classification_report(y_test, predictions)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predictions))
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
# 1. 数据准备
# 加载MNIST数据集
digits = datasets.load_digits()
# 2. 数据预处理
# 将图像数据展平为一维数组
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, shuffle=False)
# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 3. 模型选择
# 使用多层感知机(MLP)作为分类器
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='sgd', verbose=10, random_state=1,
learning_rate_init=.1)
# 4. 模型训练
mlp.fit(X_train, y_train)
# 5. 模型评估
# 预测测试集结果
predictions = mlp.predict(X_test)
# 计算准确率
print("Classification report for classifier %s:\n%s\n"
% (mlp, metrics.classification_report(y_test, predictions)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predictions))