《Frontiers in Plant Science》:Enhancing multiclass plant disease classification using GAN-boosted vision transformer with XAI insights
编辑推荐:
本综述系统阐述了集成视觉Transformer(ViT)、生成式人工智能(GenAI)与可解释人工智能(XAI)的创新模型GRG-ViT,在解决水稻叶片病害多分类任务中的显著优势。文章重点分析了该模型通过条件生成对抗网络(C-GAN)平衡数据集、利用混合ReLU-GELU激活机制提升特征表征,并借助梯度加权类激活映射(Grad-CAM)和注意力图谱实现模型决策透明化,最终在七类病害识别中达到96%的准确率,为精准农业中的植物病害智能诊断提供了可靠方案。
引言
农业作为印度经济的支柱产业,水稻是其最重要的主食作物。然而,水稻生产受到多种植物病害的严重影响。深度学习和机器学习已成为解决计算机视觉问题的强大工具。本文旨在识别关键病害,并利用先进的深度学习模型解决这些突出问题。研究提出了一种名为GRG-ViT的新型多类水稻叶片病害识别模型,该模型整合了视觉Transformer(ViT)、生成式人工智能(GenAI)和可解释人工智能(XAI)技术,以取得更好结果。基于ViT的框架旨在捕捉叶片图像中的长程空间依赖性,从而增强模型识别细微病害模式的能力。由于数据集存在明显的类别不平衡,本模型采用基于GenAI的合成数据生成方法来创建平衡的训练样本,进而提高模型的鲁棒性。该模型还提出了一种基于混合整流线性单元(ReLU)和高斯误差线性单元(GELU)的激活机制,以获得有效的特征表示。获得的实验结果表明,所提出的GRG-ViT模型总体准确率接近96%,优于传统方法。XAI方法(如Grad-CAM)的融入通过强调影响模型决策的区域,提供了可解释性和透明度。这项研究展示了ViT、GenAI和XAI在精准农业水稻病害检测中产生可靠和高性能结果的融合力量。
相关研究工作
植物病害是导致作物减产的重要原因,给农民带来巨大损失。为帮助农民,许多研究人员利用人工智能学习方法,如机器学习、深度学习和迁移学习,开发了新的预测和分类模型。许多研究者已经提出了使用上述算法的精准农业分类技术。一些关键文献在本文的这一部分进行讨论。例如,有研究提出了用于各种玉米叶片病害分类的新型作物叶片GAN(CLGAN),旨在以最少的参数提高准确率并优化损失函数。另一项工作提出了一个基于物联网和深度学习的模型,用于苹果叶片的天气预报、田间监测和病害分类,使用门控循环单元(GRU)进行天气预报,ResNet-50进行病害预测,并通过传感器自动化以支持精准农业。还有研究为了实现对番茄植物叶片病害的精确分类和检测,在YOLOv7网络上融入了检测机制SimAM和DAiAM。图像使用SIFT技术进行分割以提取关键特征,并采用最大池化来减少信息损失。该目标检测模型预测了七种类型的叶片病害。Roopali Dogra等人提出了一个基于深度学习的模型,利用结合迁移学习的CNN-VGG19精确检测一种特定的水稻叶片病害——褐斑病,准确率达到93%,但该工作仅考虑了一种病害。随着Transformer在图像分类方面的进步,许多研究人员开始朝这个方向努力。José Maurício等人回顾了几篇近期论文,以确定在图像分类问题上Vision Transformer和卷积神经网络哪种性能更好。由于其长程依赖性和适应不同输入大小及噪声图像的能力,Vision Transformer优于CNN。在医学成像领域,最近也引入了几种基于AI的算法。其中一项工作使用Vision Transformer对皮肤癌图像进行常规分类。Vision Transformer的自注意力机制有助于提取图像的重要特征,同时排除产生噪声的特征,这反过来有助于癌细胞的早期预测。有研究者提出了一种在PlantVillage数据集上进行植物病害检测的分层方法。他们使用Vision Transformer进行训练和特征提取,并使用ResNet-9深度学习模型进行分类。这些模型产生了与其他预训练模型相当的结果。有研究使用深度谱生成对抗神经网络(DSGAN2)进行水稻植物叶片病害检测。通过引入GAN,研究人员增加了图像数据集的大小,从而提高了模型在植物病害检测中的性能。然而,该方法需要在其他作物上进一步测试以分析其是否可扩展。A.K. Singh等人开发了LeafyGAN,这是一个深度学习模型,结合了用于分割的Pix2PixGAN和用于图像转换的CycleGAN。通过实施这两种方法,研究人员成功生成了合成图像以平衡图像数据集。这些图像随后被输入轻量级MobileViT,并在两个不同的数据集PlantVillage和PlantDoc上训练以进行图像分类。该模型在PlantVillage数据集上表现良好,但在PlantDoc数据集上仅达到75%的准确率。研究人员提出了一个集成的VARMAx–CNN–GAN模型用于番茄叶片病害检测和管理。这是一个深度学习模型,集成了CNN、GAN和带外生回归量的向量自回归移动平均过程(VARMAx)。CNN用于特征提取,GAN用于生成合成图像,VARMAx组件用于改进病害分类。Amreen Abbas等人结合了条件生成对抗网络(C-GAN)和预训练的DenseNet121模型。使用条件GAN生成图像,DenseNet121用于病害分类。他们的研究旨在通过数据增强来增加有限图像数据集的大小,这反过来有助于他们的模型在多类数据集中对染病番茄叶片进行分类。然而,用于分类的模型是预训练的。为了解决图结构数据中节点分类任务的类别不平衡问题,Bojia Liu等人引入了类分布感知条件生成对抗网络(CDCGAN)。该模型旨在基于C-GAN的少数类增强模块和提取节点嵌入的类分布感知模块,生成多样化和可区分的少数类节点。该模型允许不同的GNN编码器在测试期间具有更大的泛化能力,但可能不适用于动态图场景。XAI是AI的一个子集,强调跨不同领域的增长和重要性。XAI突出重要特征,并为研究人员提供多学科方法。在开发XAI系统时应使用伦理的、以人为中心的和整体的方法。有研究提出了一种使用Vision Transformer和GRU进行阿尔茨海默病检测和分类的混合模型。在该研究中,他们融入了XAI方法以增强模型决策的可解释性。使用了LIME、SHAP和注意力图谱XAI技术来提供AI推理的透明视图。XAI-FruitNet是一个与平均和最大池化技术相结合的水果分类模型。这改善了特征辨别能力,并融入了可解释AI,通过Grad-CAM增强模型透明度。Grad-CAM解释了图像中贡献最大的部分,这有助于做出分类决策。有研究在预训练的Xception模型中使用迁移学习来分类和预测马铃薯叶片病害。这些结果使用可解释AI技术之一Grad-CAM进行解释,该技术通过可视化广泛强调叶片的核心区域,从而解决了现有研究中的关键差距。Natasha Nigar等人比较了四种深度学习模型——CNN、MobileNetV2、EfficientNetB0和ResNet-50,发现EfficientNetB0在预测植物叶片病害方面优于其他三种。他们的研究中包含了基于XAI的LIME技术用于解释所提出的模型。LIME被用来提供模型预测的视觉解释。有研究提出了针对PlantVillage数据集的Vision Transformer模型,该工作通过数据增强平衡不同类别的数据,达到了接近98%的准确率。另一项基于Transformer的工作提出了基于多尺度特征融合的方法,与其他最先进的基于CNN的模型相比,具有更好的泛化结果。有研究提出了一种基于集成的定制化EfficientNet模型,用于检测玉米、马铃薯和番茄等植物的病害。该模型以最低的误分类率达到了接近99%的准确率。Haridasan, A.等人使用CNN和SVM检测水稻数据集的五种不同水稻作物叶片病害,获得了91%的准确率。Deng, R.等人提出了一个集成模型,包含DenseNet-121、SE-ResNet-50和ResNeSt-50,预测具有六个病害类别的水稻数据集,即稻叶瘟、稻曲病、穗颈瘟、纹枯病、细菌性条斑病和褐斑病,达到91%的准确率。Elmitwally, N. S.等人选择了细菌性叶枯病、褐斑病和叶黑粉病叶片病害类别,并使用AlexNet进行训练预测,准确率达到99%,但仅选择了三个类别。Upadhyay, S. K.和Kumar, A.使用了Kaggle水稻叶片病害数据集,包含三个病害叶片类别和一个健康类别,并使用基于深度学习的CNN模型预测叶片病害,准确率达到99.7%,但所选类别数仅为四个。Gaurav Shrivastava和Harish Patidar提出了SVM与ANN相结合的方法,预测Kaggle水稻数据集的三个类别,准确率为91%。Rajpoot, V.等人提出了一个基于VGG-16的迁移学习Faster R-CNN模型,用于预测细菌性叶枯病、褐斑病和叶黑粉病病害叶片数据集,准确率为97.3%;然而,仅选择了三个病害叶片类别。Bhakta, I.等人使用细菌性叶枯病水稻数据集(这是一个二元分类问题),结合CNN获得了95%的准确率。T. Daniya和S. Vigneshwari提出了一种基于骑手亨利气体溶解度优化(RHGSO)的深度神经模糊网络(DNFN)模型,用于预测三个类别——BLB、稻瘟病和褐斑病病害叶片数据集,并获得93%的准确率。Paddy Doctor数据集在许多使用各种基于深度学习模型的研究中被使用。Villegas-Cubas等人部署了InceptionV3模型来分类和预测九个病害叶片图像类别和一个健康类别,准确率达到88%。Quan T. H.和Hoa N. T.提出了RiceNet分类模型,用于对10,407张图像进行分类,并获得93.8%的准确率。对于相同的数据集,Tasnim F.等人提出了混合关联规则挖掘(ARM)与逻辑回归相结合的方法,实现了92.8%的准确率。Garg等人对Paddy Doctor数据集实施了EfficientNet模型,准确率为91%。Klair等人从相同数据集中选择了几个类别,并实施了不同的模型,如ConvNet、ResNet和EfficientNet,准确率分别为87%、91%和94%。从Plant Doctor数据集中,仅选择了一个病害类别(白茎螟)和一个健康叶片类别,并提出了一个ViT模型用于二元分类的植物叶片病害预测,准确率为96%。基于所参考的文献,我们制定了研究问题和模型以解决已识别的差距,下一节将详细介绍所提出的模型和方法。
所用方法
本节提供了所用数据、所采用方法和结果分析的详细信息。整体系统流程图描述了拟议工作所涉及的逐步过程。第一步,数据准备,从数据收集和数据预处理开始。该步骤继续使用C-GAN生成合成图像,然后进行归一化和增强。第二步涉及模型构建和实施具有自注意力机制和多层感知器(MLP)分类器的Vision Transformer模型。为了训练和测试模型,数据集按80:20的比例分割。所提出的模型通过改变超参数配置来实施,以分析和评估它们对模型性能的影响。第三步是将所提出模型的性能与预训练的CNN模型进行比较。最后,借助Grad-CAM和注意力图谱可视化等技术来解释所提出模型的性能,以评估模型的可信度。
数据集描述
在这项工作中,我们考虑了来自IEEE DataPort上可用的Paddy Doctor数据集的水稻叶片图像。该数据集包含12类患病水稻叶片和一类健康水稻叶片。在从印度泰米尔纳德邦蒂鲁内尔维利地区周围收集的30,000多张图像经过清理和手动注释后,总共有16,225张图像。在这12个类别中,我们考虑了六种对作物总产量有重大影响的疾病。这六个类别是黑茎螟、白茎螟、黄茎螟、褐斑病、铁甲虫和BLB,以及健康叶片类别,以实现有效分类。在平衡数据集之前,为每个类别选择的图像数量已制成表格。所有六个病害类别和健康类别的样本测试图像如图所示。
提出的架构
本工作中实施了一种混合条件GAN与Vision Transformer的方法,用于改进多类水稻作物数据集的植物病害分类。为了进行数据准备,我们使用了七类水稻叶片图像,并通过探索性数据分析进行分析。从探索性数据分析中,我们发现当数据集不平衡时,Vision Transformer模型的性能相对较低。为了平衡该数据集,使用条件GAN算法来增加类别的数量。条件GAN模型使用U-Net生成器和PatchGAN判别器进行定制以生成合成图像。然后将这些作为输入提供给Vision Transformer多类分类模型,用于增强整个增强后的数据集。这些增强后的图像被分割成块,展平为一维线性投影,并附加位置嵌入。一旦在预处理阶段完成增强和块嵌入,它们将作为输入提供给Transformer编码器。Transformer编码器包含用于稳定训练过程的层归一化、用于捕获和分层图像不同块之间依赖关系的多头注意力层,以及一个MLP,通过引入非线性函数进一步处理图像。MHA和MLP的输出通过跳跃连接相加,并反馈给前一层输出。最后,Transformer编码器的输出被馈送到分类头,以分类不同类别的患病植物叶片和健康叶片。在这项工作中,实施了两种不同的模型配置以增强模型性能。第一个模型RG-ViT,使用八个Transformer编码器层和ReLU激活函数实施。激活函数用于深度学习模型中,通过其非线性函数学习复杂模式。在RG-ViT模型中,我们仅在编码器层和最终分类器层使用ReLU激活函数。为了进一步增强模型的性能,第二个模型GRG-ViT配置了12层Transformer编码器。与模型1的第二个区别是使用的激活函数类型。在该模型中,包含了高斯误差线性单元(GELU)和ReLU激活函数:ReLU激活函数在每个Transformer编码器内部使用,而GELU激活函数在分类头中使用,因为它比ReLU更能捕获全局依赖性。Vision Transformer模型的输出使用最先进的XAI技术(如Grad-CAM和注意力映射)进行说明。这些可视化方法捕获了Vision Transformer在进行分类决策时关注的核心区域。注意力映射可视化用于可视化和提取多头注意力模式。Grad-CAM映射通过生成类激活映射来突出影响类别预测的特定区域。这两种技术被纳入本工作,用于定性和定量分析,特别是为了理解模型预测能力、识别偏差以及增强模型设计和性能。所提出的GRG-ViT模型的架构如图所示。
用于数据平衡的条件GAN架构
在所提出的模型中,实施了一种条件GAN(一种GenAI方法)来生成与原始图像相似的合成图像。这里使用的数据集是不平衡的,一个类别有2,100张图像,而另一个类别只有506张图像。为了提高所提出的Vision Transformer模型的性能,需要平衡不平衡的类别。条件GAN是GAN的扩展版本,对两个对抗模型都应用了条件。条件是辅助信息(y),通过包含它来影响生成器(G)和判别器(D)。这可以表示为在噪声向量z上,对于真实图像x和条件y(这里y是类别标签)的目标函数,使用最小-最大函数确定。图描绘了条件GAN的架构表示,输入图像x与标签y和噪声向量z连接后馈送到U-Net生成器以生成合成图像。这些生成的图像与输入图像和条件一起作为输入提供给PatchGAN判别器,以对输入图像和合成图像进行分类。然后使用分类输出更新判别器和生成器,以训练条件GAN模型并提高其性能。
生成器和判别器
所提出的模型中条件GAN使用的生成器基于U-Net架构,该架构专门执行图像生成任务,同时保留空间信息。U-Net生成器通过实施收缩路径(编码器)和扩展路径(解码器)来生成图像。编码器和解码器的输出使用跳跃连接进行连接以生成最终图像。编码器通过实施七个卷积层块、批量归一化和LeakyReLU激活函数来对输入图像进行下采样。通过这样做,它减少了输入图像的空间维度,将分辨率从256 × 256降低到2 × 2,同时通道特征数量增加。对于每个相应的编码器层,解码器从2 × 2到256 × 256逐步上采样编码后的特征。然后使用跳跃连接在每个块级别将它们与编码器的下采样特征图连接起来。解码器的最终块实施tanh激活函数,以生成归一化在[?1, +1]之间的输出值,用于编码器和解码器的l层。U-Net生成器的输出G(y)表示为公式2。其中ue是生成器编码器,ud是生成器解码器,f是LeakyReLU激活函数,W是权重,b是偏差。使用PatchGAN判别器,因为它对图像块而不是整个图像进行分类。该方法捕获高频细节,并通过平均所有块响应产生最终输出。判别器D获得两个输入:要么是输入图像(x),要么是由U-Net生成器G(y, z)生成的图像,与条件(标签y)连接。使用四个卷积层块、批量归一化和LeakyReLU激活函数来处理连接后的输入以进行下采样。最终块使用sigmoid激活函数以二维矩阵形式产生输出,其中每个元素代表一个块。判别器中l层对U-Net生成器输出进行下采样的卷积输出表示为公式3。其中*表示卷积运算,uD(l)是判别器在层l的输出特征图,W是与最终层相关的滤波器权重,f是LeakyReLU激活函数,σ是sigmoid激活函数。如公式1所述,条件GAN生成器生成合成图像,由判别器评估。条件GAN中判别器的损失函数应最大化以分类真实图像和合成图像。生成器的损失函数应最小化,以便生成与输入图像x匹配的图像。计算两个网络的损失后,计算梯度以更新模型参数。条件GAN训练完成后,为少数类生成合成图像并添加到训练集中,以将不平衡数据集转换为平衡数据集。在这项工作中,合成图像被添加到细菌性叶枯病、黑茎螟和黄茎螟类别中。
用于植物叶片检测的Vision Transformer模型
Vision Transformer是深度学习模型中的最新发展,专为计算机视觉任务而设计。它使用Transformer作为其骨干架构,具有一种称为自注意力机制的独特能力。该模型可以识别图像块之间的依赖关系和关系,无论它们的距离如何。Vision Transformer的详细过程在架构图中进行了说明。七类植物病害图像被转换为块,嵌入到块嵌入中,然后使用线性投影展平。添加了位置嵌入的块嵌入通过堆叠的Transformer编码器作为输入提供,产生一组精炼的分类(CLS)标记。CLS标记输出代表整个图像的摘要,用于检测、分割或分类。编码器输出被馈送到MLP,在那里它通过引入非线性变换来增强CLS标记的表示,以提取更具表现力的特征。最后,Vision Transformer的分类器预测多类数据集中不同类别的患病叶片和健康叶片。
块提取和位置嵌入
收集的图像被调整为72 × 72,然后进行预处理,并展平为2D块,因为Transformer模型只能接收一维标记序列作为输入。例如,在这项工作中,输入图像x的分辨率取为72 × 72,具有三个通道,像素P大小为6 × 6。因此,在展平后我们将获得N = (72 × 72)/(6 × 6) = 144个图像块。添加位置编码以保留原始空间信息。然后这些块被映射到较低维度的可训练线性投影以生成块嵌入。块嵌入然后与位置信息相加以保留每个图像的原始位置,产生位置嵌入。这些位置嵌入作为输入提供给Transformer编码器,分类标记添加到块嵌入中以用于不同类别的最终表示。具有(H, W)分辨率和C个通道数的RGB图像x被表示为x ∈ RH×W×C,被转换为N个不重叠的图像块,具有(P, P)分辨率,xp∈ RN×(P2•C),其中N = (HW/P2)。这些块被映射到线性投影D以获得E ∈ RN×(P2•C)×D块嵌入pi。一旦知道这些块嵌入,就添加位置嵌入Epos以了解它们的空间信息Epos∈ R(N+1)×D。在预先添加分类标记xclass后,计算位置嵌入后获得序列嵌入向量t,如公式4所示。这个t作为输入提供给Transformer编码器。
Transformer编码器和注意力机制
ViT架构中的Transformer编码器由L层组成,每层都有一个交替的多头自注意力模块(MSA)和前馈MLP模块。Transformer编码器中的每一层都有一个归一化层,通过残差连接将归一化后的输入提供给其他两个模块。提供给Transformer每个i层的MSA模块的输入是ti,作为输入提供给多层感知器以获得ti+1= MLP(MSA(LN(ti))) + ti,如公式5所示。在每个多头自注意力模块内采用多个自注意力机制。自注意力机制能够使模型学习和理解图像每个块之间的关系和依赖关系。这是通过根据最重要信息的重要性分配分数来执行的。自注意力机制包括三个关键参数——查询(Q)、键(K)和值(V)——应用于图像的每个单独块。一个块的查询关注图像的所有其他块,以分析哪个块相对于其表示更相关和更重要。键有助于确定每个块如何与相应的查询匹配,值用于计算块的实际信息或特征。对于每个头,单个模型维度可以线性投影h次到查询、键和值的不同投影dq, dk, dv,以计算多头自注意力,如公式6所示。对于不同的投影值,查询、键、值和输出参数的相应投影矩阵可以公式化为WmQ∈ Rdmodel×dk, WmK∈ Rdmodel×dk, WmV∈ Rdmodel×dv和 WmO∈ Rhdmodel×dk。因此,多头注意力(Q, K, V)= 连接(头1, …, 头m)WO,其中头m= 注意力(QWmQ, KWmK, VWmV)且注意力(Q, K, V)= softmax(QKT/√dk)V。
多层感知器和分类
为了处理自注意力机制的输出并补充MSA层,在Transformer编码器中嵌入了一个MLP。这作为最终分类器的输入,该分类器是一个多层感知器。它捕获图像的关键模式,并通过将向量转换为更高维度来增强表示。这种转换是通过引入非线性来学习复杂关系来实现的,使用非线性激活函数GELU实施。公式5的MLP输出ti+ 1作为分类器的输入,经过层归一化以生成最终预测向量v。此外,SoftMax层为分类器的所有不同水稻叶片类别产生类别概率。
可解释AI技术
XAI是AI的一部分,有助于理解从不同模型获得的结果。基于AI的模型通常很复杂且越来越难以解释,尤其是当它们在关键时刻做出决策时,而AI模型可以做出改变生活的决策。本工作中使用XAI来更深入地了解所提出的模型如何基于注意力机制、重要特征和类别的激活函数做出决策。XAI使人类能够分析和改进AI系统性能,因为模型变得透明。在这项工作中,使用了两种重要的XAI技术来解释所提出的混合Vision Transformer模型如何对多类水稻作物数据集中的患病叶片和健康叶片进行分类和预测。这两种技术是Grad-CAM和注意力图谱,均采用定制架构实施。对Vision Transformer模型进行了特定调整,以提取基于梯度的可解释性特征和用于可视化的注意力机制。
注意力图谱
注意力图谱用于可视化输入图像的核心区域,这些区域是Vision Transformer模型通过注意力机制