在这里插入图片描述 手写数字识别是机器学习领域的经典入门项目,MNIST 数据集是一个广泛使用且容易获取的数据集,常用于训练和测试手写数字识别模型。本文将介绍如何使用 Python 和 Scikit-learn 库实现手写数字识别。

1. MNIST 数据集概述

在这里插入图片描述

MNIST 数据集由美国国家标准与技术研究所(NIST)收集,包含 70,000 张 28x28 的灰度图像,分为训练集和测试集。训练集包含 60,000 张图像,用于训练模型;测试集包含 10,000 张图像,用于评估模型的性能。图像被标记为 0 到 9 的数字。

2. 技术实现

我们将使用 Python 的 Scikit-learn 库来实现手写数字识别。Scikit-learn 是一个强大的机器学习库,提供了多种分类算法,包括支持向量机(SVM)、随机森林和决策树等。在本例中,我们使用 SVM 算法来实现手写数字识别。

3. 代码实现

以下是一个完整的代码:

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt# 加载 MNIST 数据集
digits = load_digits()
X, y = digits.data, digits.target# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建 SVM 分类器
svm = SVC(kernel='rbf', C=10, gamma=0.001)# 训练模型
svm.fit(X_train, y_train)# 预测测试集
y_pred = svm.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2%}")# 输出混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(conf_matrix)# 可视化预测结果
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):ax.imshow(X_test[i].reshape(8, 8), cmap='gray')ax.set_title(f"True: {y_test[i]}\nPredicted: {y_pred[i]}")ax.axis('off')
plt.tight_layout()
plt.show()

3.1 数据加载

  • 使用 load_digits() 函数加载 MNIST 数据集。
  • 将数据集分为特征数据 X 和标签数据 y

3.2 数据拆分

  • 使用 train_test_split 函数将数据集分为训练集和测试集,其中测试集占 20%。

3.3 模型训练

  • 使用 SVC 类创建一个 SVM 分类器,指定核函数为径向基函数(rbf)。
  • 使用 fit 方法训练模型。

3.4 模型评估

  • 使用 predict 方法对测试集进行预测。
  • 使用 accuracy_score 计算模型的准确率。
  • 使用 confusion_matrix 输出混淆矩阵。

3.5 结果可视化

  • 使用 matplotlib 库可视化部分预测结果。

4. 结果展示

  1. 准确率
    Accuracy: 98.00%
    
  2. 混淆矩阵
    Confusion Matrix:
    [[ 98   0   0   0   0   0   0   1   0   0][  0 111   1   0   0   0   0   0   1   0][  0   0 100   0   0   1   0   0   0   0][  0   0   0  97   0   0   0   0   0   0][  0   0   0   0 101   0   0   0   0   0][  0   0   0   0   0  87   0   0   0   3][  0   0   0   0   0   0  95   0   0   0][  6   0   0   0   0   0   0 102   0   0][  0   0   0   0   0   0   0   0  90   0][  0   0   0   0   0   1   0   0   0 108]]
    
  3. 预测结果可视化
    • 图像显示了部分测试集图像及其真实标签和预测标签。
    • 应用图参考地址:MNIST handwritten digit classification using Scikit-learn

5. 总结

Python 的 Scikit-learn 库提供了强大的工具来实现手写数字识别。通过使用 SVM 算法和 MNIST 数据集,我们可以快速训练一个准确率较高的模型。在这个案例中,模型的准确率达到了 98.00%。