深度学习——基于卷积神经网络实现食物图像分类【4】(使用最优模型)

文件目录

    • 引言
    • 一、环境准备
    • 二、数据预处理
      • 训练集预处理说明:
      • 验证集预处理说明:
    • 三、自定义数据集类
    • 四、设备选择
    • 五、CNN模型构建
    • 六、模型加载与评估
      • 1. 加载预训练模型
      • 2. 准备测试数据
      • 3. 测试函数
      • 4. 计算准确率
    • 七、完整代码
    • 八、总结

引言

本文将详细介绍如何使用PyTorch框架构建一个完整的食物图像分类系统,包含数据预处理、模型构建、训练优化以及模型保存等关键环节。与上一篇博客介绍的版本相比,本版本增加了使用最优模型这一流程。

一、环境准备

首先,我们需要导入必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

这些库中:

  • torchtorchvision是PyTorch的核心库
  • DatasetDataLoader用于数据加载和处理
  • transforms提供图像预处理功能
  • PIL用于图像处理
  • numpy用于数值计算

二、数据预处理

数据预处理是深度学习项目中至关重要的一环。PyTorch提供了transforms模块来方便地进行图像预处理:

data_transforms = {'train': transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

训练集预处理说明:

  1. Resize([300,300]):将图像调整为300×300像素
  2. RandomRotation(45):随机旋转图像(-45°到45°之间)
  3. CenterCrop(256):从中心裁剪256×256的区域
  4. RandomHorizontalFlip(p=0.5):以50%概率水平翻转图像
  5. RandomVerticalFlip(p=0.5):以50%概率垂直翻转图像
  6. ColorJitter:随机调整亮度、对比度、饱和度和色调
  7. RandomGrayscale(p=0.1):以10%概率将图像转为灰度
  8. ToTensor():将PIL图像转为PyTorch张量
  9. Normalize:标准化处理(使用ImageNet的均值和标准差)

验证集预处理说明:

验证集的预处理相对简单,只包括调整大小、转为张量和标准化,因为验证阶段不需要数据增强。

三、自定义数据集类

PyTorch的Dataset类允许我们自定义数据加载方式。我们创建了一个food_dataset类:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

这个类的主要功能:

  1. __init__:初始化函数,读取包含图像路径和标签的文本文件
  2. __len__:返回数据集大小
  3. __getitem__:根据索引返回图像和对应的标签

四、设备选择

PyTorch支持在CPU、GPU(CUDA)和苹果M系列芯片(MPS)上运行。我们使用以下代码自动选择可用设备:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

五、CNN模型构建

我们构建了一个简单的CNN模型,包含三个卷积块和一个全连接层:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

模型结构说明:

  1. conv1:输入3通道,输出16通道,5×5卷积核,ReLU激活,2×2最大池化
  2. conv2:输入16通道,输出32通道,同上结构
  3. conv3:输入32通道,输出64通道,同上结构
  4. out:全连接层,将64×32×32的特征图映射到20个类别

六、模型加载与评估

1. 加载预训练模型

model = CNN().to(device)
model.load_state_dict(torch.load("best2025-04.pth"))
model.eval()

2. 准备测试数据

test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

3. 测试函数

result = []
labels = []def Test_true(dataloader, model):model.eval()with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)result.append(pred.argmax(1).item())labels.append(y.item())Test_true(test_dataloader, model)

4. 计算准确率

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels, result)
print(f"准确率:{accuracy:.2%}")

七、完整代码

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #对图片预处理的组合transforms.Resize([300,300]),   #对数据进行改变大小transforms.RandomRotation(45),  #随机旋转,-45到45之间随机选transforms.CenterCrop(256),     #从中心开始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转,p是指选择一个概率翻转,p=0.5表示百分之50transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),transforms.RandomGrayscale(p=0.1),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),#数据转换为tensortransforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,均值,标准差]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化,均值,标准差]),
}#Dataset是用来处理数据的
class food_dataset(Dataset):        # food_dataset是自己创建的类名称,可以改为你需要的名称def __init__(self,file_path,transform=None):    #类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的图片路径保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #图像的路径self.labels.append(label)   #标签,还不是tensor# 初始化:把图片目录加到selfdef __len__(self):  #类实例化对象后,可以使用len函数测量对象的个数return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #关键,可通过索引的形式获取每一个图片的数据及标签image = Image.open(self.imgs[idx])  #读取到图片数据,还不是tensor,BGRif self.transform:                  #将PIL图像数据转换为tensorimage = self.transform(image)   #图像处理为256*256,转换为tensorlabel = self.labels[idx]    #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也转换为tensorreturn image,label'''判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驱动软件的功能:pytorch能够去执行cuda的命令
# 神经网络的模型也需要传入到GPU,1个batch_size的数据集也需要传入到GPU,才可以进行训练''' 定义神经网络  类的继承这种方式'''
class CNN(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型,nn.mdouledef __init__(self): #输入大小:(3,256,256)super(CNN,self).__init__()  #初始化父类self.conv1 = nn.Sequential( #将多个层组合成一起,创建了一个容器,将多个网络组合在一起nn.Conv2d(              # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=3,      # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=16,     # 要得到多少个特征图,卷积核的个数kernel_size=5,      # 卷积核大小 3×3stride=1,           # 步长padding=2,          # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好),                      # 输出的特征图为(16,256,256)nn.ReLU(),  # Relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2×2操作),输出结果为(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #输出(32,128,128)nn.ReLU(),  #Relu层  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化层,输出结果为(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输出(64,64,64)nn.ReLU(),  # Relu层  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化层,输出结果为(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全连接层得到的结果def forward(self,x):   #前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,结果为:(batch_size,32 * 64 * 64)output = self.out(x)return output
# 提取模型的2种方法:
#   1、读取参数的方法
model = CNN().to(device) #初始化模型,w都是随机初始化的
model.load_state_dict(torch.load("best2025-04.pth"))
#   2、读取完整模型的方法,无需提前创建model
#   model = CNN().to(device)
#   model = torch.load('best.pt')#w,b,cnn
# 模型保存的对不对?
model.eval() #固定模型参数和数据,防止后面被修改
print(model)test_data = food_dataset(file_path='test.txt', transform = data_transforms['valid'])
test_dataloader = DataLoader(test_data,batch_size=1,shuffle=True)result = [] #保存的预测的结果
labels = [] #真实结果def Test_true(dataloader,model):model.eval()        #测试,w就不能再更新with torch.no_grad():   #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X) #预测之后的结果result.append(pred.argmax(1).item())labels.append(y.item())
Test_true(test_dataloader,model)
print('预测值:\t',result)
print('真实值:\t',labels)from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels,result)
print(f"准确率:{accuracy:.2%}")

八、总结

本文详细介绍了使用PyTorch实现图像分类任务的完整流程,包括:

  1. 数据预处理与增强
  2. 自定义数据集类
  3. CNN模型构建
  4. 模型加载与评估

关键点:

  • 数据增强可以提高模型的泛化能力
  • 自定义Dataset类可以灵活处理不同格式的数据
  • CNN是图像分类任务的经典模型结构
  • 模型评估需要使用eval()模式和torch.no_grad()上下文

通过这个示例,读者可以掌握PyTorch进行图像分类的基本方法,并可以根据自己的需求调整模型结构和数据处理方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.tpcf.cn/bicheng/85645.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++基础算法————并查集

C++并查集详解与实战指南 一、引言 并查集(Union-Find)是一种高效的数据结构,用于处理一些不相交集合的合并与查询问题。它在图论、社交网络、网络连通性等领域有广泛的应用。并查集的核心思想是通过一个数组来记录每个元素的父节点,从而将元素组织成若干棵树,每棵树代表…

系统性能优化的关键手段

系统性能的提升方向 服务器并发处理能力:通过优化内存管理策略、选择合适的连接模式(长连接或短连接)、改进 I/O 模型(如 epoll、IOCP)、以及采用高效的服务器并发策略(如多线程、事件驱动等)&a…

httpclient实现http连接池

HTTP连接池是一种优化网络通信性能的技术,通过复用已建立的TCP连接减少重复握手开销,提升资源利用率。以下是关键要点: 核心原理与优势 ‌连接复用机制‌ 维护活跃连接队列,避免每次请求重复TCP三次握手/SSL协商,降低…

广义焦点丢失:学习用于密集目标检测的合格和分布式边界盒之GFL论文阅读

摘要 一阶段检测器通常将目标检测形式化为密集的分类与定位(即边界框回归)问题。分类部分通常使用 Focal Loss 进行优化,而边界框位置则在狄拉克δ分布下进行学习。最近,一阶段检测器的发展趋势是引入独立的预测分支来估计定位质量,所预测的质量可以辅助分类,从而提升检…

Real-World Deep Local Motion Deblurring论文阅读

Real-World Deep Local Motion Deblurring 1. 研究目标与实际问题意义1.1 研究目标1.2 实际问题1.3 产业意义2. 创新方法:LBAG模型与关键技术2.1 整体架构设计2.2 关键技术细节2.2.1 真实模糊掩码生成(LBFMG)2.2.2 门控块(Gate Block)2.2.3 模糊感知补丁裁剪(BAPC)2.3 损…

【Docker基础】Docker镜像管理:docker commit详解

目录 引言 1 docker commit命令概述 1.1 什么是docker commit 1.2 使用场景 1.3 优缺点分析 2 docker commit命令详解 2.1 基本语法 2.2 常用参数选项 2.3 实际命令示例 2.4 提交流程 2.5 步骤描述 3 docker commit与Dockerfile构建对比 3.1 构建流程对比 3.2 对…

可调式稳压二极管

1.与普通稳压二极管的比较: 项目普通稳压二极管可调式稳压二极管(如 TL431)输出电压固定(如5.1V、3.3V)可调(2.5V ~ 36V,取决于外部分压)精度低(5%~10%)高&a…

Kafka使用Elasticsearch Service Sink Connector直接传输topic数据到Elasticsearch

链接:Elasticsearch Service Sink Connector for Confluent Platform | Confluent Documentation 链接:Apache Kafka 一、搭建测试环境 下载Elasticsearch Service Sink Connector https://file.zjwlyy.cn/confluentinc-kafka-connect-elasticsearch…

讯方“教学有方”平台获华为昇腾应用开发技术认证!

教学有方 华为昇腾应用开发技术认证 权威认证 彰显实力 近日,讯方技术自研的教育行业大模型平台——“教学有方”,成功获得华为昇腾应用开发技术认证。这一认证不仅是对 “教学有方” 平台技术实力的高度认可,更标志着讯方在智慧教育领域的…

保护你的Electron应用:深度解析asar文件与Virbox Protector的安全策略

在现代软件开发中,Electron框架因其跨平台特性而备受开发者青睐。然而,随着Electron应用的普及,如何保护应用中的核心资源文件——asar文件,成为了开发者必须面对的问题。今天,我们将深入探讨asar文件的特性&#xff0…

端口安全配置示例

组网需求 如图所示,用户PC1、PC2、PC3通过接入设备连接公司网络。为了提高用户接入的安全性,将接入设备Router的接口使能端口安全功能,并且设置接口学习MAC地址数的上限为接入用户数,这样其他外来人员使用自己带来的PC无法访问公…

零基础RT-thread第四节:电容按键

电容按键 其实只需要理解,手指按上去后充电时间变长,我们可以利用定时器输入捕获功能计算充电时间,超过无触摸时的充电时间一定的阈值就认为是有手指触摸。 基本原理就是这样,我们开始写代码: 其实,看过了…

SQL基础操作:从增删改查开始

好的!SQL(Structured Query Language)是用于管理关系型数据库的标准语言。让我们从最基础的增删改查(CRUD)​​ 操作开始学习,我会用简单易懂的方式讲解每个操作。 🛠 准备工作(建表…

vim 编辑模式/命令模式/视图模式常用命令

以下是一份 Vim 命令大全,涵盖 编辑模式(Insert Mode)、命令模式(Normal Mode) 和 视图模式(Visual Mode) 的常用操作,适合初学者和进阶用户使用。 🧾 Vim 模式简介 Vim…

每天看一个Fortran文件(10)

今天来看下MCV模式调用物理过程的相关代码。我想改进有关于海气边界层方面的内容,因此我寻找相关的代码,发现在physics目录下有一个sfc_ocean.f的文件。 可以看见这个文件是在好多好多年前更新的了,里面内容不多,总共146行。是计算…

python打卡day37

疏锦行 知识点回顾: 1. 过拟合的判断:测试集和训练集同步打印指标 2. 模型的保存和加载 a. 仅保存权重 b. 保存权重和模型 c. 保存全部信息checkpoint,还包含训练状态 3. 早停策略 作业:对信贷数据集训练后保存权重&#xf…

【Spark征服之路-2.9-Spark-Core编程(五)】

RDD行动算子: 行动算子就是会触发action的算子,触发action的含义就是真正的计算数据。 1. reduce ➢ 函数签名 def reduce(f: (T, T) > T): T ➢ 函数说明 聚集 RDD 中的所有元素,先聚合分区内数据,再聚合分区间数据 val…

【入门】【练17.3 】比大小

| 时间限制:C/C 1000MS,其他语言 2000MS 内存限制:C/C 64MB,其他语言 128MB 难度:中等 分数:100 OI排行榜得分:12(0.1分数2难度) 出题人:root | 描述 试编一个程序,输入…

CppCon 2017 学习:Free Your Functions!

“Free Your Functions!” 这句话在C设计中有很深的含义,意思是: “Free Your Functions!” 的理解 “解放你的函数”,鼓励程序员: 不要把所有的函数都绑在类的成员函数里,优先考虑写成自由函数(non-mem…

日常运维问题汇总-19

60. OVF3维护成本中心与订货原因之间的对应关系时,报错提示,SYST: 不期望的日期 00/00/0000。消息号 FGV004,如下图所示: OVF3往右边拉动,有一个需要填入的字段“有效期自”,此字段值必须在成本中心定义的有…