MATLAB环境下基于深度学习的新型冠状病毒肺炎 (COVID-19) 检测

作者:来自代顿大学研究院(UDRI))的 Barath Narayanan 博士

新型冠状病毒肺炎 (COVID-19) 是 2019 年发现的一种新型人类疾病,无任何先例可循。冠状病毒是一个庞大的病毒家族,可导致患者出现轻重不一的病症,轻至普通感冒,重至急性呼吸道综合症,如中东呼吸综合症 (MERS-COV) 或严重急性呼吸道综合症 (SARS-COV)。该肺炎已造成全球大流行,目前世界各地均有大量人口发生感染和接受治疗。仅美国一地,COVID-19 疫情或可导致 1.6 亿到 2.14 亿人感染 。一些国家和地区已宣布进入紧急状态,隔离人数多达数百万。

检测和诊断工具可以为医生提供宝贵的第二诊疗意见,协助他们完成筛查。与此同时,此类机制还有助于快速向医生呈现检测结果。来自代顿大学研究院(UDRI))的 Barath Narayanan 博士将为我们介绍他的团队如何应用基于深度学习的技术,使用 MATLAB 根据胸片检测 COVID-19

背景

本文使用的 COVID-19 数据集由蒙特利尔大学博士后研究员 Joseph Cohen 博士管理,下载 ZIP 文件后,将文件解压到名为 “Covid 19” 的文件夹中,所得的每个子文件夹对应 “dataset” 中的一个类。标签 “covid” 表示在患者体内检出 COVID-19,“normal” 则表示未检出。数据均等分布在两个类中(各 25 张影像),所以此处不存在类不均衡问题。

加载数据集

首先,使用 imageDatastore 加载数据库。该函数用于加载影像及其标签以执行分析,具有较高的计算效率。

%Clear workspace
clear; close all; clc;
 
%ImagesDatapath–Please modify your path accordingly
datapath='dataset';
 
%ImageDatastore
imds=imageDatastore(datapath,...
    'IncludeSubfolders',true,...
'LabelSource','foldernames');%Determine the split uptotal_split=countEachLabel(imds)

影像可视化

可视化影像,了解各个类之间的影像差异。另外,这也有助于我们确定采用何种分类方法区分这两个类。根据影像,我们可以选择适当的预处理技术来帮助我们完成分类。根据类内相似性及类间差异性,我们可以确定研究所需的 CNN 架构类型。

%Number of Images
num_images=length(imds.Labels);
%Visualize random imagesperm=randperm(num_images,6);
figure;for idx=1:length(perm)    
    subplot(2,3,idx);
    imshow(imread(imds.Files{perm(idx)}));
    title(sprintf('%s',imds.Labels(perm(idx))))
    
end

K 折验证

上文已述,此数据集提供的影像数量有限,因此,我们将数据集拆分为 10 折进行分析,也就是使用数据集中的各组影像分别训练 10 个不同算法。此验证方法相比常用的留出验证法提供更为准确的性能预估。

本文采用 ResNet-50 架构,因为该架构经证实对各类医学成像应用均十分有效 [1,2]。

%Number of folds
num_folds=10;
%Loopfor each foldfor fold_idx=1:num_folds
    
    fprintf    
    fprintf('Processing %d among %d folds 
',fold_idx,num_folds);
    
  %TestIndicesfor current fold
    test_idx=fold_idx:num_folds:num_images;

    %Test cases for current fold
    imdsTest = subset(imds,test_idx);
    
   %Train indices for current fold
    train_idx=setdiff(1:length(imds.Files),test_idx);
    
   %Train cases for current fold
    imdsTrain = subset(imds,train_idx);
 
   %ResNetArchitecture
    net=resnet50;
    lgraph = layerGraph(net);
    clear net;
    
    %Number of categories
    numClasses = numel(categories(imdsTrain.Labels));
    
    %NewLearnableLayer
    newLearnableLayer = fullyConnectedLayer(numClasses,...
        'Name','new_fc',...
        'WeightLearnRateFactor',10,...
        'BiasLearnRateFactor',10);
    
    %Replacing the last layers withnew layers
    lgraph = replaceLayer(lgraph,'fc1000',newLearnableLayer);
    newsoftmaxLayer = softmaxLayer('Name','new_softmax');
    lgraph = replaceLayer(lgraph,'fc1000_softmax',newsoftmaxLayer);
    newClassLayer = classificationLayer('Name','new_classoutput');
    lgraph = replaceLayer(lgraph,'ClassificationLayer_fc1000',newClassLayer);
    
    
  %PreprocessingTechnique
    imdsTrain.ReadFcn=@(filename)preprocess_Xray(filename);
    imdsTest.ReadFcn=@(filename)preprocess_Xray(filename);
    
   %TrainingOptions, we choose a small mini-batch size due to limited images
    options = trainingOptions('adam',...
        'MaxEpochs',30,'MiniBatchSize',8,...
        'Shuffle','every-epoch',...
        'InitialLearnRate',1e-4,...
        'Verbose',false,...
        'Plots','training-progress');
    
   %DataAugumentation
    augmenter = imageDataAugmenter(...
        'RandRotation',[-55],'RandXReflection',1,...
        'RandYReflection',1,'RandXShear',[-0.050.05],'RandYShear',[-0.050.05]);
    
    %Resizing all training images to [224224]forResNet architecture
    auimds = augmentedImageDatastore([224224],imdsTrain,'DataAugmentation',augmenter);
    
   %Training
    netTransfer = trainNetwork(auimds,lgraph,options);
    
  %Resizing all testing images to [224224]forResNet architecture
    augtestimds = augmentedImageDatastore([224224],imdsTest);
   
   %Testingand their corresponding LabelsandPosteriorfor each Case
    [predicted_labels(test_idx),posterior(test_idx,:)]= classify(netTransfer,augtestimds);
    
    %Save the IndependentResNetArchitectures obtained for each Fold
    save(sprintf('ResNet50_%d_among_%d_folds',fold_idx,num_folds),'netTransfer','test_idx','train_idx');
    
   %Clearing unnecessary variables
    clearvars -except fold_idx num_folds num_images predicted_labels posterior imds netTransfer;
    
end

性能研究

我们通过混淆矩阵衡量算法性能,该指标同时也反映了查准率和查全率方面的性能。我们认为总体准确度是一个有效指标,因为本研究中使用的测试数据集为均匀分布(每个类别的图像数均等)。

混淆矩阵

%ActualLabels
actual_labels=imds.Labels;
%ConfusionMatrixfigure;
plotconfusion(actual_labels,predicted_labels')
title('ConfusionMatrix:ResNet');

ROC 曲线

ROC 协助医生根据误报率和检测率选择工作点。

test_labels=double(nominal(imds.Labels));
% ROC Curve-Our target classis the first classinthis scenario [fp_rate,tp_rate,T,AUC]=perfcurve(test_labels,posterior(:,1),1);figure;
plot(fp_rate,tp_rate,'b-');
grid on;
xlabel('False Positive Rate');
ylabel('Detection Rate');

%Area under the ROC curve value
AUC

AUC = 0.9776

类激活映射

将不同 COVID-19 病例经过这些网络处理后得到的类激活映射 (CAM) 结果可视化,这有助于医生了解算法决策背后的依据。以下是不同病例的相应结果:



基于其他公开数据集进行测试

为了进一步研究和分析算法性能,我们需要确定从不含 COVID-19 标签的其他公开数据集检测出 COVID-19 的概率。在此,我们使用 [2] 提供的病例,病例由放射科医生标记为正常、细菌性肺炎或病毒性肺炎。前文已述,每个网络分别使用 COVID-19 数据集中的一组不同影像进行训练。只要影像的冠状病毒后验概率大于 0.5,即视为假阳性 (FP)。结果清楚地表明,我们的算法具有较高的特异度和敏感度。在单核 GPU 上,每个测试用例的用时约 13 毫秒


结论

本文介绍了一种基于深度学习的简单分类方法,可用于 COVID-19 的计算机辅助诊断。基于 ResNet 的分类算法表现相对出色,总体准确度和 AUC 较高。迁移学习方法的良好性能再次印证,基于 CNN 的分类模型适于执行特征提取。使用几组新的带标签影像即可轻松重新训练算法,从而进一步增强性能。将上述结果与其他现有架构相结合,可以从 AUC 和总体准确度两方面提高性能。如果能就计算量(内存和时间)和性能两方面综合研究这些算法,将有助于相关专家有所侧重地选择算法。计算机辅助诊断不仅是医生进行 COVID-19 筛查的好帮手,还有助于提供宝贵的第二诊疗意见。

参考文献

[1] Narayanan, B. N., De Silva, M. S., Hardie, R. C., Kueterman, N. K., & Ali, R. (2019)."Understanding Deep Neural Network Predictions for Medical Imaging Applications". arXiv preprint arXiv:1912.09621.

[2] Narayanan, B. N., Davuluru, V. S. P., & Hardie, R. C. (2020, March)."Two-stage deep learning architecture for pneumonia detection and its diagnosis in chest radiographs".In Medical Imaging 2020:Imaging Informatics for Healthcare, Research, and Applications (Vol. 11318, p. 113180G).International Society for Optics and Photonics.

面包多代码

https://mbd.pub/o/GeBENHAGEN

此外,知乎付费咨询:哥廷根数学学派

擅长现代信号处理(改进小波分析系列,改进变分模态分解,改进经验小波变换,改进辛几何模态分解等等),改进机器学习,改进深度学习,机械故障诊断,改进时间序列分析(金融信号,心电信号,振动信号等)

展开阅读全文

页面更新:2024-05-27

标签:准确度   肺炎   综合症   病例   算法   架构   深度   影像   性能   医生   标签   环境   冠状病毒   数据

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top