MATLAB应用之深度学习网络到底在“看”哪里?

你有没有想过,你经常使用的深度学习网络在看图像的什么部分进行分类?

例如下图:

如果深度学习网络将此图像分类为“圆号”,你认为图片的哪个部分对分类最重要?

我们使用预训练好的 ResNet-50 网络进行此实验。

He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian. "Deep Residual Learning for Image Recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016
获取 MATLAB 中 ResNet-50 网络的方法是启动 Add-On Explorer并搜索 resnet。

net = resnet50;

我们需要注意 ResNet-50 需要输入特定尺寸的图像。网络的初始层提供了这一信息:

sz = net.Layers(1).InputSize(1:2)
sz =

   224   224

所需的图像尺寸可以直接传递给 imresize 函数。

rgb = imread(url);
rgb = imresize(rgb,sz);
imshow(rgb)

在网络中调用 classify ,查看图片可能的分类:

classify(net,rgb)
ans = 

  categorical

       French horn 

ResNet-50 认为这是圆号。

Birju 在一篇关于卷积神经网络可视化技术的论文中,了解到遮挡敏感性的概念。如果阻挡或遮挡图像的一部分,将如何影响网络的预测得分?遮挡不同的部分又将如何影响结果?

Birju 做了如下尝试:

rgb2 = rgb;
rgb2((1:71)+77,(1:71)+108,:) = 128;
imshow(rgb2)


classify(net,rgb2)
ans = 

  categorical

     notebook 

Hmm...估计网络“认为”灰色方块看起来像笔记本。被遮挡的区域对于图像分类来说应该很重要。再试试不同的遮挡位置:

rgb3 = rgb;rgb3((1:71)+15,(1:71)+80,:) = 128;imshow(rgb3)

classify(net,rgb3)
ans = 

  categorical

       French horn 

好吧,脑袋并不重要。

Birju 编写了一些 MATLAB 代码来系统地量化不同图像区域对分类结果的相对重要性。他使用 MATLAB 构建了大量图像,并对遮挡不同区域的图像进行批处理。对于遮挡的不同位置,记录预期类(本例为“法国号”)的概率得分。

我们制作一批带有 71x71 遮挡区域的图像。首先计算所有遮挡模块的顶点,用 (X1,Y1) 和 (X2,Y2) 表示。
mask_size = [71 71];[H,W,~] = size(rgb);X = 1:W;Y = 1:H;[X1, Y1] = meshgrid(X, Y);X1 = X1(:) - (mask_size(2)-1)/2;Y1 = Y1(:) - (mask_size(1)-1)/2;X2 = X1 + mask_size(2) - 1;Y2 = Y1 + mask_size(1) - 1;

注意不要让遮挡区域的顶点偏离图像边界。

X1 = max(1, X1);
Y1 = max(1, Y1);

X2 = min(W, X2);
Y2 = min(H, Y2);

批处理:

batch = repmat(rgb,[1 1 1 size(X1,1)]);

for i = 1:size(X1,1)
   c = X1(i):X2(i);
   r = Y1(i):Y2(i);
   batch(r,c,:,i) = 128; % gray mask.
end

注意:这一批包含 50,000 多张图像。你需要大量的 RAM 才能同时创建和处理如此大量的图像。

这里有一些遮挡的图像:

现在,我们将使用 predict(而不是 classify)来获取每个图像在每个类别中的预测分数。MiniBatchSize 参数是用来限制 GPU 内存的使用,意味着 predict 函数将一次发送 64 个图像到 GPU 进行处理。

s = predict(net, batch, 'MiniBatchSize',64);
size(s)
ans =

       50176        1000

我们获得了很多的概率得分!其中 51,529 个图像,共有 1,000 个类别。矩阵 s 具有每个类别和每个图像的预测分数。

我们重点关注预测原始图像类别的预测分数:

scores = predict(net,rgb);
[~,horn_idx] = max(scores);

这里是每一个圆号类别中的图像预测分数:

s_horn = s(:,horn_idx);

将圆号类别的分数转换为图像显示:

S_horn = reshape(s_horn,H,W);
imshow(-S_horn,[])
colormap(gca,'parula')

最亮的区域表示遮挡对概率得分影响最大的遮挡区间。

下面我们找到了最影响圆号概率得分的遮挡位置:

[min_score,min_idx] = min(s_horn);
rgb_min_score = batch(:,:,:,min_idx);
imshow(rgb_min_score)

结果可见,识别圆号的关键在于螺旋形管身和阀键,而不是号嘴。

面包多代码

https://mbd.pub/o/GeBENHAGEN

此外,知乎付费咨询:哥廷根数学学派

擅长现代信号处理(改进小波分析系列,改进变分模态分解,改进经验小波变换,改进辛几何模态分解等等),改进机器学习,改进深度学习,机械故障诊断,改进时间序列分析(金融信号,心电信号,振动信号等)

展开阅读全文

页面更新:2024-04-15

标签:深度   圆号   网络   遮挡   概率   得分   分数   图像   类别   区域

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top