DAY 50 预训练模型+CBAM模块

@浙大疏锦行https://blog.csdn.net/weixin_45655710

知识点回顾:

  1. resnet结构解析
  2. CBAM放置位置的思考
  3. 针对预训练模型的训练策略
    1. 差异化学习率
    2. 三阶段微调

作业:

  1. 好好理解下resnet18的模型结构
  2. 尝试对vgg16+cbam进行微调策略
ResNet-18 结构核心思想

可以将ResNet-18想象成一个高效的“图像信息处理流水线”,它分为三个核心部分

  1. “开胃菜” - 输入预处理 (Stem)

    • 组成:一个大的7x7卷积层 (conv1) + 一个最大池化层 (maxpool)。

    • 作用:对输入的原始大尺寸图像(如224x224)进行一次快速、大刀阔斧的特征提取和尺寸压缩。它迅速将图像尺寸减小到56x56,为后续更精细的处理做好准备,像是一道开胃菜,快速打开味蕾。

  2. “主菜” - 四组残差块 (Layer1, 2, 3, 4)

    • 组成:这是ResNet的心脏,由四组Sequential模块构成,每组里面包含2个BasicBlock(残差块)。

    • 作用:这是真正进行深度特征提取的地方。其最精妙的设计在于:

      • 层级递进:从layer1layer4,特征图的空间尺寸逐级减半(56→28→14→7),而通道数逐级翻倍(64→128→256→512)。这实现了“牺牲空间细节,换取更高层语义信息”的经典策略。

      • 残差连接:每个BasicBlock内部的“跳跃连接”(out += identity)是其灵魂。它允许信息和梯度“抄近道”,直接从块的输入流向输出,完美解决了深度网络中因信息丢失导致的“网络退化”和梯度消失问题。

  3. “甜点” - 分类头 (Head)

    • 组成:一个全局平均池化层 (avgpool) + 一个全连接层 (fc)。

    • 作用

      • avgpool:将layer4输出的512x7x7的复杂特征图,暴力压缩成一个512维的特征向量,浓缩了整张图最高级的语义信息。

      • fc:扮演最终“裁判”的角色,将这个512维的特征向量映射到最终的类别得分上(例如,ImageNet的1000类)。

总结来说,ResNet-18的优雅之处在于其清晰的模块化设计和革命性的残差连接,它通过“尺寸减半,通道加倍”的策略逐层加深语义理解,并利用“跳跃连接”保证了信息流的畅通,从而能够构建出既深又易于训练的强大网络。

对VGG16 + CBAM 进行微调

VGG16以其结构统一、简单(全是3x3卷积和2x2池化)而著称,但缺点是参数量巨大。我们将为其集成CBAM,并应用类似的分阶段微调策略。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time
from tqdm import tqdm# --- 模块定义 (CBAM 和数据加载器,与之前一致) ---
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False), nn.ReLU(),nn.Linear(in_channels // ratio, in_channels, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.shapeavg_out = self.fc(self.avg_pool(x).view(b, c))max_out = self.fc(self.max_pool(x).view(b, c))attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)return x * attentionclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)pool_out = torch.cat([avg_out, max_out], dim=1)attention = self.conv(pool_out)return x * self.sigmoid(attention)class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):return self.spatial_attn(self.channel_attn(x))def get_cifar10_loaders(batch_size=64, resize_to=224): # VGG需要224x224输入print(f"--- 正在准备数据 (图像将缩放至 {resize_to}x{resize_to}) ---")transform = transforms.Compose([transforms.Resize(resize_to),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)print("✅ 数据加载器准备完成。")return train_loader, test_loader# --- 新增:VGG16 + CBAM 模型定义 ---
class VGG16_CBAM(nn.Module):def __init__(self, num_classes=10, pretrained=True):super().__init__()# 加载预训练的VGG16的特征提取部分vgg_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None).features# 我们将VGG的特征提取层按池化层分割,并在每个块后插入CBAMself.features = nn.ModuleList()self.cbam_modules = nn.ModuleList()current_channels = 3vgg_block = []for layer in vgg_features:vgg_block.append(layer)if isinstance(layer, nn.Conv2d):current_channels = layer.out_channelsif isinstance(layer, nn.MaxPool2d):self.features.append(nn.Sequential(*vgg_block))self.cbam_modules.append(CBAM(current_channels))vgg_block = [] # 开始新的块# VGG的分类器部分self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),nn.Linear(4096, num_classes),)def forward(self, x):for feature_block, cbam_module in zip(self.features, self.cbam_modules):x = feature_block(x)x = cbam_module(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# --- 训练和评估框架 (复用) ---
def run_experiment(model_name, model, device, train_loader, test_loader, epochs):print(f"\n{'='*25} 开始实验: {model_name} {'='*25}")model.to(device)total_params = sum(p.numel() for p in model.parameters())print(f"模型总参数量: {total_params / 1e6:.2f}M")criterion = nn.CrossEntropyLoss()# 差异化学习率:为不同的部分设置不同的学习率optimizer = optim.Adam([{'params': model.features.parameters(), 'lr': 1e-5}, # 特征提取层使用极低学习率{'params': model.cbam_modules.parameters(), 'lr': 1e-4}, # CBAM模块使用中等学习率{'params': model.classifier.parameters(), 'lr': 1e-3} # 分类头使用较高学习率])for epoch in range(1, epochs + 1):model.train()loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}] Training", leave=False)for data, target in loop:data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()loop.set_postfix(loss=loss.item())loop.close()model.eval()test_loss, correct = 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item() * data.size(0)pred = output.argmax(dim=1)correct += pred.eq(target).sum().item()avg_test_loss = test_loss / len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f"Epoch {epoch} 完成 | 测试集损失: {avg_test_loss:.4f} | 测试集准确率: {accuracy:.2f}%")# --- 主执行流程 ---
if __name__ == "__main__":DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")EPOCHS = 10 # 仅作演示,VGG需要更多轮次BATCH_SIZE = 32 # VGG参数量大,减小batch size防止显存溢出train_loader, test_loader = get_cifar10_loaders(batch_size=BATCH_SIZE)vgg_cbam_model = VGG16_CBAM()run_experiment("VGG16+CBAM", vgg_cbam_model, DEVICE, train_loader, test_loader, EPOCHS)
VGG16+CBAM 微调策略解析
  1. 模型修改 (VGG16_CBAM)

    • 拆分与重组:VGG16的预训练模型中,特征提取部分model.features是一个包含所有卷积和池化层的nn.Sequential。我们不能直接在中间插入CBAM。因此,我们遍历了vgg_features中的所有层,以MaxPool2d为界,将它们拆分成了5个卷积块。

    • 插入CBAM:在每个卷积块之后,我们都插入了一个对应通道数的CBAM模块。

    • 保留分类头:原始的model.classifier(全连接层)被保留,只修改最后一层以适应CIFAR-10的10个类别。

  2. 数据预处理适配

    • VGG16在ImageNet上预训练时,接收的是224x224的图像。为了最大化利用预训练权重,我们在get_cifar10_loaders函数中,通过transforms.Resize(224)将CIFAR-10的32x32图像放大224x224

  3. 训练策略:差异化学习率

    • 由于VGG16的参数量巨大(超过1.3亿),如果全局使用相同的学习率进行微调,很容易破坏已经学得很好的预训练权重。

    • 我们采用了一种更精细的差异化学习率 (Differential Learning Rates) 策略:

      • 特征提取层 (model.features):这些是“资深专家”,权重已经很好了,我们给一个极低的学习率1e-5),让它们只做微小的调整。

      • CBAM模块 (model.cbam_modules):这些是新加入的“顾问”,需要学习,但不能太激进,给一个中等学习率1e-4)。

      • 分类头 (model.classifier):这是完全为新任务定制的“新员工”,需要从头快速学习,给一个较高的学习率1e-3)。

    • 这种策略通过optim.Adam接收一个参数组列表来实现,是微调大型模型时非常有效且常用的高级技巧。

  4. Batch Size调整
    批次大小调整

    • VGG16的参数量和中间激活值都非常大,对显存的消耗远超ResNet18。因此,我们将BATCH_SIZE减小到32,以防止显存溢出(OOM)错误。

通过这个实验,不仅能实践如何将注意力模块集成到一个全新的经典网络(VGG16)中,还能学习到微调大型模型时更高级、更精细的训练策略,如差异化学习率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.tpcf.cn/news/913252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker连接mysql

查看在运行的容器:docker ps -s 进入容器:docker exec -it 容器号或名 /bin/bash,如:docker exec -it c04c438ff177 /bin/bash 或docker exec -it mysql /bin/bash。 3. 登录mysql:mysql -uroot -p123456

javaweb第182节Linux概述~ 虚拟机连接不上FinalShell

问题描述 虚拟机无法连接到finalshell 报错 session.connect:java.net.socketexception:connection reset 或者 connection is closed by foreign host 解决 我经过一系列的排查,花费了一天的时间后,发现,只是因为,我将连接…

高压电缆护层安全的智能防线:TLKS-PLGD 监控设备深度解析

在现代电力系统庞大复杂的网络中,高压电缆护层是守护电力传输的 "隐形铠甲",其安全直接影响电网稳定。传统监测手段响应慢、精度低,难以满足安全运维需求。TLKS-PLGD 高压电缆护层环流监控设备应运而生,提供智能化解决方…

Element-Plus Cascader 级联选择器获取节点名称和value值方法

html 部分 <template><el-cascaderref"selectAeraRef":options"areas":disabled"disabled":props"optionProps"v-model"selectedOptions"filterablechange"handleChange"><template #default"…

STM32中实现shell控制台(命令解析实现)

文章目录一、核心设计思想二、命令系统实现详解&#xff08;含完整注释&#xff09;1. 示例命令函数实现2. 初始化命令系统3. 命令注册函数4. 命令查找函数5. 命令执行函数三、命令结构体&#xff08;cmd\_t&#xff09;四、运行效果示例五、小结在嵌入式系统的命令行控制台&am…

基于matlab的二连杆机械臂PD控制的仿真

基于matlab的二连杆机械臂PD控制的仿真。。。 chap3_5input.m , 1206 d2plant1.m , 1364 hs_err_pid2808.log , 15398 hs_err_pid4008.log , 15494 lx_plot.m , 885 PD_Control.mdl , 35066 tiaojie.m , 737 chap2_1ctrl.asv , 988 chap2_1ctrl.m , 905

TCP、HTTP/1.1 和HTTP/2 协议

TCP、HTTP/1.1 和 HTTP/2 是互联网通信中的核心协议&#xff0c;它们在网络分层中处于不同层级&#xff0c;各有特点且逐步演进。以下是它们的详细对比和关键特性&#xff1a;1. TCP&#xff08;传输控制协议&#xff09; 层级&#xff1a;传输层&#xff08;OSI第4层&#xff…

Java+Vue开发的进销存ERP系统,集采购、销售、库存管理,助力企业数字化运营

前言&#xff1a;在当今竞争激烈的商业环境中&#xff0c;企业对于高效管理商品流通、采购、销售、库存以及财务结算等核心业务流程的需求日益迫切。进销存ERP系统作为一种集成化的企业管理解决方案&#xff0c;能够整合企业资源&#xff0c;实现信息的实时共享与协同运作&…

【趣谈】Android多用户导致的UserID、UID、shareUserId、UserHandle术语混乱讨论

【趣谈】Android多用户导致的UserID、UID、shareUserId、UserHandle术语混乱讨论 备注一、概述二、概念对比1.UID2.shareUserId3.UserHandle4.UserID 三、结论 备注 2025/07/02 星期三 在与Android打交道时总遇到UserID、UID、shareUserId、UserHandle这些术语&#xff0c;但是…

P1424 小鱼的航程(改进版)

题目描述有一只小鱼&#xff0c;它平日每天游泳 250 公里&#xff0c;周末休息&#xff08;实行双休日)&#xff0c;假设从周 x 开始算起&#xff0c;过了 n 天以后&#xff0c;小鱼一共累计游泳了多少公里呢&#xff1f;输入格式输入两个正整数 x,n&#xff0c;表示从周 x 算起…

<二>Sping-AI alibaba 入门-记忆聊天及持久化

请看文档&#xff0c;流程不再赘述&#xff1a;官网及其示例 简易聊天 环境变量 引入Spring AI Alibaba 记忆对话还需要我们有数据库进行存储&#xff0c;mysql&#xff1a;mysql-connector-java <?xml version"1.0" encoding"UTF-8"?> <pr…

【机器学习深度学习】模型参数量、微调效率和硬件资源的平衡点

目录 一、核心矛盾是什么&#xff1f; 二、微调本质&#xff1a;不是全调&#xff0c;是“挑着调” 三、如何平衡&#xff1f; 3.1 核心策略 3.2 参数量 vs 微调难度 四、主流轻量微调方案盘点 4.1 冻结部分参数 4.2 LoRA&#xff08;低秩微调&#xff09; 4.3 量化训…

【V13.0 - 战略篇】从“完播率”到“价值网络”:训练能预测商业潜力的AI矩阵

在上一篇 《超越“平均分”&#xff1a;用多目标预测捕捉观众的“心跳曲线”》 中&#xff0c;我们成功地让AI学会了预测观众留存曲线&#xff0c;它的诊断能力已经深入到了视频的“过程”层面&#xff0c;能精确地指出观众是在哪个瞬间失去耐心。 我的AI现在像一个顶级的‘心…

java微服务(Springboot篇)——————IDEA搭建第一个Springboot入门项目

在正文开始之前我们先来解决一些概念性的问题 &#x1f355;&#x1f355;&#x1f355; 问题1&#xff1a;Spring&#xff0c;Spring MVC&#xff0c;Spring Boot和Spring Cloud之间的区别与联系&#xff1f; &#x1f36c;&#x1f36c;&#x1f36c;&#xff08;1&#xff0…

服务器间接口安全问题的全面分析

一、服务器接口安全核心威胁 文章目录**一、服务器接口安全核心威胁**![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6f54698b9a22439892f0c213bc0fd1f4.png)**二、六大安全方案深度对比****1. IP白名单机制****2. 双向TLS认证(mTLS)****3. JWT签名认证****4. OAuth…

vs code关闭函数形参提示

问题&#xff1a;函数内出现灰色的形参提示 需求/矛盾&#xff1a; 这个提示对老牛来说可能是一种干扰&#xff0c;比如不好对齐控制一行代码的长度&#xff0c;或者容易看走眼&#xff0c;造成眼花缭乱的体验。 关闭方法&#xff1a; 进入设置&#xff0c;输入inlay Hints&…

ESXi 8.0安装

使用群晖&#xff0c;突然nvme固态坏了 新nvme固态&#xff0c;先在PC上格式化下&#xff0c;不然可能N100可能不认 启动&#xff0c;等待很长时间 回车 F11 输入密码&#xff0c;字母小写字母大写数字 拔掉U盘&#xff0c;回车重启 网络配置 按F2&#xff0c; 输入密码&…

【git学习】第2课:查看历史与版本回退

好的&#xff0c;我们进入 第2课&#xff1a;版本查看与回退机制&#xff0c;本课你将学会如何查看提交历史、对比更改&#xff0c;并掌握多种回退版本的方法。&#x1f4d8; 第2课&#xff1a;查看历史与版本回退&#x1f3af; 本课目标熟练查看 Git 提交记录掌握差异查看、版…

摄像头AI智能识别工程车技术及应用前景展望

摄像头AI自动识别工程车是智能交通系统和工程安全管理领域的一项重要技术。它通过图像识别技术和深度学习算法&#xff0c;实现对工程车的自动检测和识别&#xff0c;从而提高了施工现场的安全性和管理效率。以下是对该技术及其应用的详细介绍&#xff1a;一、技术实现数据收集…

Windows服务器安全配置:组策略与权限管理最佳实践

Windows服务器是企业常用的服务器操作系统&#xff0c;但其开放性和复杂性也使其成为攻击者的目标。通过正确配置组策略和权限管理&#xff0c;可以有效提高安全性&#xff0c;防止未经授权的访问和恶意软件的入侵。以下是详细的安全配置指南和最佳实践。 1. 为什么组策略和权限…