菜码编程

  • 首页
  • 隐私政策
Caima Coding
专注于AI项目实战分享
  1. 首页
  2. 人工智能
  3. 正文

基于注意力机制的ResNet18网络架构的眼疾识别

2024年 8月 11日 1090点热度 0人点赞 0条评论
内容目录

1 介绍

眼疾是一种常见的眼部疾病,若不及时发现和治疗,会对视力造成严重影响。而通过机器学习技术,我们可以建立一个眼疾识别系统,帮助医生快速准确地诊断眼部疾病,提高诊断效率和准确性。 本项目旨在通过对眼底图像进行分类,实现眼疾的自动识别。数据集使用iChallenge-PM和眼病分类数据集,本文取上述两个数据集中的部分数据并已整理成*224224大小可直接使用。本文提出了基于注意力机制的ResNet18网络的眼疾识别算法**。主要使用了ResNet18和RenNet18_NAM两种卷积神经模型对患者眼底视网膜图像进行眼底疾病识别,对2种模型的损失函数值、模型参数量和准确率进行对比实验分析。

2 加载数据集

unzip -o -q -d dataset data/data220613/dataset.zip

2.1 分割数据集

from preproces_data import split_data
split_data(0.8)

2.2 加载数据到自定义的dataset

from dataset import MyDataset

train_dataset = MyDataset(csv_filepath='train.csv')
test_dataset = MyDataset(csv_filepath='test.csv')

3 模型构建

本文使用ResNet18和ResNet18-NAM两个模型进行实验

ResNet18-NAM是基于归一化的注意力机制的ResNet18模型,模型构建参考了【AI达人特训营】ResNet50-NAM:一种新的注意力计算方式复现
NAM是一种轻量级的高效的注意力机制,采用了CBAM的模块集成方式,重新设计了通道注意力和空间注意力子模块,这样,NAM可以嵌入到每个网络block的最后。对于残差网络,可以嵌入到残差结构的最后。对于通道注意力子模块,使用了Batch Normalization中的缩放因子,如式子(1),缩放因子反映出各个通道的变化的大小,也表示了该通道的重要性。为什么这么说呢,可以这样理解,缩放因子即BN中的方差,方差越大表示该通道变化的越厉害,那么该通道中包含的信息会越丰富,重要性也越大,而那些变化不大的通道,信息单一,重要性小。


其中 $\mu_B$ 和 $\sigma_B$ 为均值,$B$ 为标准差,$\gamma$ 和 $\beta$ 是可训练的仿射变换参数(尺度和位移)参考Batch Normalization.

通道注意力子模块如图(1)和式(2)所示:


其中$Mc$表示最后得到的输出特征,$\gamma$是每个通道的缩放因子,因此,每个通道的权值可以通过 $W\gamma =\gammai/\sum{j=0}\gamma_j$ 得到。我们也使用一个缩放因子 $BN$ 来计算注意力权重,称为像素归一化。像素注意力如图(2)和式(3)所示:


为了抑制不重要的特征,作者在损失函数中加入了一个正则化项,如式(4)所示。

import paddle
from train_and_test import train
from model import resnet18
from dataset import MyDataset
import warnings
warnings.filterwarnings("ignore")
net = resnet18(num_classes=6)
paddle.summary(net,(64,3,224,224))

4 模型训练

from train_and_test import train, test

save_path='./google/'

batch_size=32

train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size)

eval_loader = paddle.io.DataLoader(test_dataset, batch_size=batch_size)

optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())

train(
    model=net,
    opt=optim, 
    train_loader=train_loader, 
    valid_loader=eval_loader, 
    epoch_num=100, 
    save_path=save_path, 
    save_freq=20
)

output


图1 训练过程中的准确率



图2 训练过程中的损失函数

5 模型评估

from train_and_test import test
from model import resnet18;
net=resnet18(num_classes=6)
save_path='./resnet18-nam/'

test(
    model_path=save_path+'model/final.pdparams',
    net=net,
    test_dataloader=paddle.io.DataLoader(MyDataset(csv_filepath='test.csv'),
                                         batch_size=32),
    save_path=save_path
)

output

acc-> 0.9528
precision--> ([0.9221, 0.9828, 0.9032, 0.9649, 0.9636, 1.0], 0.9561000000000001)
recall--> ([0.9342, 0.9344, 0.9333, 0.9821, 0.9636, 0.9808], 0.9547333333333334)



图3 混淆矩阵

我们的网站是菜码编程。
如果你对我们的项目感兴趣可以扫码关注我们的公众号,我们会持续更新深度学习相关项目。
公众号二维码

本作品采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可
标签: ResNet18 人工智能 注意力机制 眼疾识别 计算机视觉 通道注意力机制
最后更新:2024年 9月 4日

Marlone

这个人很懒,什么都没留下

打赏 点赞
< 上一篇
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复
文章目录
  • 1 介绍
  • 2 加载数据集
    • 2.1 分割数据集
    • 2.2 加载数据到自定义的dataset
  • 3 模型构建
  • 4 模型训练
  • 5 模型评估

COPYRIGHT © 2024 菜码编程. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang

豫ICP备2024080801号