国科大深度学习作业1-手写数字识别实验

背景介绍:单位实习,趁机摸鱼,由于电脑只安装了VSCode,所以算是从环境搭建写起。

目录

一、环境搭建

1. 安装Anaconda

2. 创建Python环境

3. 安装PyTorch

4. 安装其他必要库       

二、在 VSCode 中配置环境

1. 安装Python扩展

2. 选择正确的 Python 解释器

三、实验过程

1. 实验代码

2. 运行结果和输出​编辑

四、任务详细解析

1. 需求分析

2. 原理详解

        1)数据预处理

        2)CNN 网络结构分析

       3)CNN 关键设计原理

       4)特征提取原理 

        5)训练优化原理

3. 性能分析与优化

        1)模型复杂度分析

        2)性能优化策略

五、实验报告和讲解PPT


一、环境搭建

1. 安装Anaconda

        1. 访问 Anaconda官网 下载适合你操作系统的版本,这里建议选择 Mimiconda 安装。

        选择 Mimiconda 的原因:
        轻量级: 只包含conda、Python和少量必要包,下载快速

        灵活性: 可以根据需要安装特定的包,避免不必要的软件

        环境管理: 更好的虚拟环境管理功能

        适合深度学习: 对于我们的PyTorch项目来说完全够用

按步骤下载安装,勾选时全部选上,其余按照默认即可。

安装完毕后,打开 cmd 验证是否安装成功,有版本输出即为安装成功。

conda --version
python --version

        2. 安装完成后,打开Anaconda Prompt(Windows)或终端(Mac/Linux)

2. 创建Python环境

# 创建新的conda环境
conda create -n pytorch_env python=3.9
# 激活环境
conda activate pytorch_env

创建之后一定再次检查 python 版本,作业要求最好使用 3.9 版本

3. 安装PyTorch

访问 PyTorch官网 获取安装命令,或直接使用:

  • CPU版本:
pip install torch torchvision torchaudio
  • GPU版本(如果有NVIDIA显卡):此刻经历漫长的等待时间... ...     

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

4. 安装其他必要库       

pip install matplotlib numpy jupyter

二、在 VSCode 中配置环境

我安装环境时 已经在 VSCode 中进行了, 所以可以忽略第二步

1. 安装Python扩展

2. 选择正确的 Python 解释器

        在 VSCode 中按 Ctrl+Shift+P

   输入Python: Select Interpreter

        选择 pytorch_env 环境的 Python,例:

C:\Users\nnchen\AppData\Local\miniconda3\envs\pytorch_env\python.exe

三、实验过程

1. 实验代码

# 创建完整的手写数字识别代码 - 英文版图表
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True
)test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)print(f'训练集大小: {len(train_dataset)}')
print(f'测试集大小: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.pool1 = nn.MaxPool2d(2, 2)# 第二个卷积层self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool2 = nn.MaxPool2d(2, 2)# 第三个卷积层self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)# 全连接层self.fc1 = nn.Linear(128 * 7 * 7, 512)self.fc2 = nn.Linear(512, 128)self.fc3 = nn.Linear(128, 10)# Dropout层防止过拟合self.dropout = nn.Dropout(0.5)def forward(self, x):# 卷积层1 + ReLU + 池化x = self.pool1(F.relu(self.conv1(x)))# 卷积层2 + ReLU + 池化x = self.pool2(F.relu(self.conv2(x)))# 卷积层3 + ReLUx = F.relu(self.conv3(x))# 展平x = x.view(-1, 128 * 7 * 7)# 全连接层x = F.relu(self.fc1(x))x = self.dropout(x)x = F.relu(self.fc2(x))x = self.dropout(x)x = self.fc3(x)return x# 创建模型实例
model = CNN().to(device)
print(model)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练函数
def train_model(model, train_loader, criterion, optimizer, epochs=10):model.train()train_losses = []train_accuracies = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 梯度清零optimizer.zero_grad()# 前向传播output = model(data)loss = criterion(output, target)# 反向传播loss.backward()optimizer.step()# 统计running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()if batch_idx % 200 == 0:print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], 'f'Loss: {loss.item():.4f}')epoch_loss = running_loss / len(train_loader)epoch_acc = 100 * correct / totaltrain_losses.append(epoch_loss)train_accuracies.append(epoch_acc)print(f'Epoch [{epoch+1}/{epochs}] - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')return train_losses, train_accuracies# 测试函数
def test_model(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = 100 * correct / totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracy# 可视化函数 - 修改为英文标题
def visualize_predictions(model, test_loader, num_images=10):model.eval()images, labels = next(iter(test_loader))images, labels = images.to(device), labels.to(device)with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs, 1)# 移到CPU进行可视化images = images.cpu()labels = labels.cpu()predicted = predicted.cpu()fig, axes = plt.subplots(2, 5, figsize=(12, 6))for i in range(num_images):ax = axes[i//5, i%5]ax.imshow(images[i].squeeze(), cmap='gray')ax.set_title(f'True: {labels[i]}, Pred: {predicted[i]}')ax.axis('off')# 总标题plt.suptitle('MNIST Digit Recognition Results', fontsize=16)plt.tight_layout()plt.show()# 开始训练
print("开始训练模型...")
train_losses, train_accuracies = train_model(model, train_loader, criterion, optimizer, epochs=10)# 测试模型
print("\n测试模型...")
test_accuracy = test_model(model, test_loader)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)
plt.plot(train_losses, 'b-', linewidth=2)
plt.title('Training Loss')  # 改为英文
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)plt.subplot(1, 2, 2)
plt.plot(train_accuracies, 'g-', linewidth=2)
plt.title('Training Accuracy')  # 改为英文
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.grid(True, alpha=0.3)plt.tight_layout()
plt.show()# 可视化预测结果
print("\n可视化预测结果...")
visualize_predictions(model, test_loader)# 保存模型
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("\n模型已保存为 mnist_cnn_model.pth")print(f"\n实验完成!最终测试准确率: {test_accuracy:.2f}%")
if test_accuracy >= 98.0:print("✅ 成功达到98%以上的准确率要求!")
else:print("⚠️ 未达到98%准确率要求,可以尝试调整超参数或增加训练轮数")

2. 运行结果和输出

四、任务详细解析

1. 需求分析

  • 目标:构建一个能够识别手写数字(0-9)的深度学习模型

  • 数据集:MNIST手写数字数据集

  • 期望准确率:≥98%

  • 技术要求:使用卷积神经网络(CNN)实现

# 功能需求
✅ 数据加载与预处理
✅ CNN模型设计与实现
✅ 模型训练与验证
✅ 性能评估与可视化

解决方案

数据加载  →  数据预处理  →  模型设计  →  训练  → 测试  →  可视化
    ↓                        ↓                      ↓                ↓             ↓              ↓          
  MNIST         标准化处理       CNN架构     优化器     评估          图表      

2. 原理详解

        1)数据预处理

        标准化处理:

transforms.Normalize((0.1307,), (0.3081,))

        原理解析:

                均值: 0.1307 是MNIST数据集的全局均值
                标准差: 0.3081 是MNIST数据集的全局标准差
                公式:$x_{\text {normalized }}=\frac{x-\mu}{\sigma}$

        作用机制:

# 原始像素值范围:[0, 255] → [0, 1] (ToTensor)
# 标准化后范围:约[-2, 2],均值≈0,标准差≈1
        2)CNN 网络结构分析
输入: 1×28×28 (灰度图像)↓
Conv1: 1→32, 3×3, padding=1↓ (32×28×28)
MaxPool1: 2×2↓ (32×14×14)
Conv2: 32→64, 3×3, padding=1↓ (64×14×14)
MaxPool2: 2×2↓ (64×7×7)
Conv3: 64→128, 3×3, padding=1↓ (128×7×7)
Flatten: 128×7×7 = 6272↓
FC1: 6272→512↓
Dropout(0.5)↓
FC2: 512→128↓
Dropout(0.5)↓
FC3: 128→10 (输出层)
       3)CNN 关键设计原理

        卷积层设计:

self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
  • kernel_size=3:3×3卷积核,平衡感受野和计算效率
  • padding=1:保持特征图尺寸不变
  • 通道数递增:32→64→128,逐步提取复杂特征

        池化层作用:

self.pool1 = nn.MaxPool2d(2, 2)
  • 降采样:减少参数量和计算量
  • 平移不变性:增强模型鲁棒性
  • 感受野扩大:捕获更大范围特征

        Dropout 机制:

self.dropout = nn.Dropout(0.5)
  • 正则化:随机丢弃50%神经元
  • 防止过拟合:提高泛化能力
  • 集成学习效果:相当于训练多个子网络
       4)特征提取原理 

        层级特征学习:

# 第一层:边缘检测
Conv1 → 检测基本边缘、线条# 第二层:形状组合
Conv2 → 组合边缘形成简单形状# 第三层:复杂模式
Conv3 → 识别数字的复杂模式和结构

        感受野计算:

# 感受野公式:RF = (RF_prev - 1) * stride + kernel_size
Layer 1: RF = 3
Layer 2: RF = (3-1)*2 + 3 = 7
Layer 3: RF = (7-1)*2 + 3 = 15
        5)训练优化原理

        Adam 优化器

optimizer = optim.Adam(model.parameters(), lr=0.001)

        交叉熵损失:

criterion = nn.CrossEntropyLoss()

3. 性能分析与优化

        1)模型复杂度分析

        参数量计算:

Conv1: 1×32×3×3 + 32 = 320
Conv2: 32×64×3×3 + 64 = 18,496
Conv3: 64×128×3×3 + 128 = 73,856
FC1: 6272×512 + 512 = 3,211,776
FC2: 512×128 + 128 = 65,664
FC3: 128×10 + 10 = 1,290总参数量 ≈ 3.37M

        计算复杂度:

# FLOPs (浮点运算次数)
Conv层: O(H×W×C_in×C_out×K²)
FC层: O(N_in×N_out)总FLOPs ≈ 50M (前向传播)
        2)性能优化策略

        数据增强:

# 可添加的数据增强
transforms.Compose([transforms.RandomRotation(10),      # 随机旋转transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 平移transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])

        学习率调度:

# 学习率衰减
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        早停机制:

# 防止过拟合
if val_loss > best_val_loss:patience_counter += 1if patience_counter >= patience:break

五、实验报告和讲解PPT

想看?不给,嘿嘿嘿


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.tpcf.cn/diannao/89629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Spring Boot的绿园社区团购系统的设计与实现

第1章 摘 要 本设计与实现的基于Spring Boot的绿园社区团购系统,旨在为社区居民提供一套高效、便捷的团购购物解决方案。随着电子商务的发展和社区居民对便捷购物需求的增加,传统的团购模式已无法满足用户的个性化需求。本系统通过整合现代化技术&…

【51单片机四位数码管从0循环显示到99,每0.5秒增加一个数字,打击键计数】2022-6-11

缘由 #include "REG52.h" unsigned char code smgduan[]{0x3f,0x06,0x5b,0x4f,0x66,0x6d,0x7d,0x07,0x7f,0x6f,0x77,0x7c,0x39,0x5e,0x79,0x71,0,64,15,56}; //共阴0~F消隐减号 unsigned char Js0, miao0;//中断计时 秒 分 时 毫秒 unsigned int shu0; //bit Mb0;//…

如何通过python脚本向redis和mongoDB传点位数据

向MongoDB传数据 from pymongo import MongoClient #导入库对应的库localhost "172.16.0.203" #数据库IP地址 baseName "GreenNagoya" client MongoClient(localhost, 27017, username"admin", password"zdiai123") #数…

昆仑通泰触摸屏Modbus TCP服务器工程 || TCP客户端工程

目录 一、Modbus TCP服务端 1.设备地址 2.实操及数据 二、Modbus TCP客户端 1.结果及协议解析 一、Modbus TCP服务端 1.设备地址 --单元标识符 DI输入/4个离散输入 DO输出/单个线圈输出 输入寄存器 读输入寄存器操作,写输入寄存器操作 保持寄存器 …

PyTorch 安装使用教程

一、PyTorch 简介 PyTorch 是由 Facebook AI Research 团队开发的开源深度学习框架。它以动态图机制、灵活性强、易于调试而著称,广泛应用于自然语言处理、计算机视觉和学术研究。 二、安装 PyTorch 2.1 通过官网选择安装命令(推荐) 访问官…

开源功能开关(feature flags) 和管理平台之unleash

文章目录 背景Flagsmith 和 Unleash什么是unleash架构Unleash Edge 安装和使用Unleash SDKs开放API Tokens访问**Server-side SDK (CLIENT)****查询所有 Feature Toggles****查询特定 Toggle** API token typesClient tokensFrontend tokensPersonal access tokensService acco…

细胞建模“图灵测试”:解析学习虚拟细胞挑战赛

一、AI能否预测细胞的未来? 想象一下,有一天我们不必一管管地做实验,就能在计算机中模拟细胞对基因敲除、药物处理乃至微环境变化的反应。这不再是科幻,而是“虚拟细胞”(Virtual Cell)研究的宏大目标。然…

centos9安装docker Dify

CentOS | Docker Docs yum -y install gcc gcc-c yum-utils Docker 官方的 YUM 软件仓库配置文件到系统,设置存储库 yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 也可以从阿里云下(我选择上面的) yum-config-manager --add-re…

基于Jenkins和Kubernetes构建DevOps自动化运维管理平台

目录 引言 基础概念 DevOps概述 Jenkins简介 Kubernetes简介 Jenkins与Kubernetes的关系 Jenkins与Kubernetes的集成 集成架构 安装和配置 安装Jenkins 安装Kubernetes插件 配置Kubernetes连接 配置Jenkins Agent Jenkins Pipeline与Kubernetes集成 Pipeline定义…

MySQL 8.0 OCP 1Z0-908 题目解析(18)

题目69 Choose three. A MySQL server is monitored using MySQL Enterprise Monitor’s agentless installation. Which three features are available with this installation method? □ A) MySQL Replication monitoring □ B) security-related advisor warnings □ …

【mongodb】安装和使用mongod

文章目录 前言一、如何安装?二、使用步骤1. 开启mongod服务2. 客户端连接数据库3. 数据库指令 总结 前言 Mongodb的安装可以直接安装系统默认的版本,也可以安装官网维护的版本,相对而言更推荐安装官网维护的版本,版本也相当更新。…

云效DevOps vs Gitee vs 自建GitLab的技术选型

针对「云效DevOps vs Gitee vs 自建GitLab」的技术选型,我们从核心需求、成本、运维、扩展性四个维度进行深度对比,并给出场景化决策建议: 一、核心能力对比表 能力维度云效DevOpsGitee自建GitLab(社区版/企业版)代码…

CentOS 7 安装RabbitMQ详细教程

前言:在分布式系统架构中,消息队列作为数据流转的 “高速公路”,是微服务架构不可或缺的核心组件。RabbitMQ 凭借其稳定的性能、灵活的路由机制和强大的生态支持,成为企业级消息中间件的首选之一。不过,当我们聚焦 Cen…

Python爬虫用途和介绍

目录 什么是Python爬虫 Python爬虫用途 Python爬虫可以获得那些数据 Python爬虫的用途 反爬是什么 常见的反爬措施 Python爬虫技术模块总结 获取网站的原始响应数据 获取到响应数据对响应数据进行过滤 对收集好的数据进行存储 抵御反爬机制 Python爬虫框架 Python…

uni-app开发app保持登录状态

在 uni-app 中实现用户登录一次后在 token 过期前一直免登录的功能,可以通过以下几个关键步骤实现:本地持久化存储 Token、使用请求与响应拦截器自动处理 Token 刷新、以及在 App.vue 中结合 pages.json 设置登录状态跳转逻辑。 ✅ 一、pages.json 配置说…

21、MQ常见问题梳理

目录 ⼀ 、MQ如何保证消息不丢失 1 、哪些环节可能会丢消息 2 、⽣产者发送消息如何保证不丢失 2.1、⽣产者发送消息确认机制 2.2、Rocket MQ的事务消息机制 2.3 、Broker写⼊数据如何保证不丢失 2.3.1** ⾸先需要理解操作系统是如何把消息写⼊到磁盘的**。 2.3.2然后来…

MySQL数据库--SQL DDL语句

SQL--DDL语句 1,DDL-数据库操作2,DDL-表操作-查询3,DDL-表操作-创建4,DDL-表操作-数据类型4.1,DDL-表操作-数值类型4.2,DDL-表操作-字符串类型4.3,DDL-表操作-日期时间类型4.4,实例 …

Spring Cloud 服务追踪实战:使用 Zipkin 构建分布式链路追踪

Spring Cloud 服务追踪实战:使用 Zipkin 构建分布式链路追踪 在分布式微服务架构中,一个用户请求往往需要经过多个服务协作完成,如果出现性能瓶颈或异常,排查会非常困难。此时,分布式链路追踪(Distributed…

Linux云计算基础篇(6)

一、IO重定向和管道 stdin:standard input 标准输入 stdout:standard output 标准输出 stderr: standard error 标准错误输出 举例 find /etc/ -name passwd > find.out 将正确的输出重定向在这个find.ou…

Python将COCO格式分割标签绘制到对应的图片上

Python将COCO格式分割标签绘制到对应的图片上 前言前提条件相关介绍COCO 格式简介(实例分割)📁 主要目录结构:📄 JSON 标注文件结构示例:✅ 特点: 实验环境Python将COCO格式分割标签绘制到对应的…