手写数字识别是机器学习领域的经典入门项目,MNIST 数据集是一个广泛使用且容易获取的数据集,常用于训练和测试手写数字识别模型。本文将介绍如何使用 Python 和 Scikit-learn 库实现手写数字识别。
1. MNIST 数据集概述
MNIST 数据集由美国国家标准与技术研究所(NIST)收集,包含 70,000 张 28x28 的灰度图像,分为训练集和测试集。训练集包含 60,000 张图像,用于训练模型;测试集包含 10,000 张图像,用于评估模型的性能。图像被标记为 0 到 9 的数字。
2. 技术实现
我们将使用 Python 的 Scikit-learn 库来实现手写数字识别。Scikit-learn 是一个强大的机器学习库,提供了多种分类算法,包括支持向量机(SVM)、随机森林和决策树等。在本例中,我们使用 SVM 算法来实现手写数字识别。
3. 代码实现
以下是一个完整的代码:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt# 加载 MNIST 数据集
digits = load_digits()
X, y = digits.data, digits.target# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建 SVM 分类器
svm = SVC(kernel='rbf', C=10, gamma=0.001)# 训练模型
svm.fit(X_train, y_train)# 预测测试集
y_pred = svm.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2%}")# 输出混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(conf_matrix)# 可视化预测结果
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):ax.imshow(X_test[i].reshape(8, 8), cmap='gray')ax.set_title(f"True: {y_test[i]}\nPredicted: {y_pred[i]}")ax.axis('off')
plt.tight_layout()
plt.show()
3.1 数据加载
- 使用
load_digits()
函数加载 MNIST 数据集。 - 将数据集分为特征数据
X
和标签数据y
。
3.2 数据拆分
- 使用
train_test_split
函数将数据集分为训练集和测试集,其中测试集占 20%。
3.3 模型训练
- 使用
SVC
类创建一个 SVM 分类器,指定核函数为径向基函数(rbf)。 - 使用
fit
方法训练模型。
3.4 模型评估
- 使用
predict
方法对测试集进行预测。 - 使用
accuracy_score
计算模型的准确率。 - 使用
confusion_matrix
输出混淆矩阵。
3.5 结果可视化
- 使用
matplotlib
库可视化部分预测结果。
4. 结果展示
- 准确率:
Accuracy: 98.00%
- 混淆矩阵:
Confusion Matrix: [[ 98 0 0 0 0 0 0 1 0 0][ 0 111 1 0 0 0 0 0 1 0][ 0 0 100 0 0 1 0 0 0 0][ 0 0 0 97 0 0 0 0 0 0][ 0 0 0 0 101 0 0 0 0 0][ 0 0 0 0 0 87 0 0 0 3][ 0 0 0 0 0 0 95 0 0 0][ 6 0 0 0 0 0 0 102 0 0][ 0 0 0 0 0 0 0 0 90 0][ 0 0 0 0 0 1 0 0 0 108]]
- 预测结果可视化:
- 图像显示了部分测试集图像及其真实标签和预测标签。
- 应用图参考地址:MNIST handwritten digit classification using Scikit-learn
5. 总结
Python 的 Scikit-learn 库提供了强大的工具来实现手写数字识别。通过使用 SVM 算法和 MNIST 数据集,我们可以快速训练一个准确率较高的模型。在这个案例中,模型的准确率达到了 98.00%。