用 PyTorch 实现神经网络的热力图可视化(Class Activation Mapping, CAM)
最近在做一个细粒度识别的项目,具体而言是为了做一个特定场景的车辆/车种检测,因为摄像头角度问题,大多时候只有车辆的一部分处于画面内,所以没有走检测的方式,而是尝试了一下通过各种数据增强(主要是裁剪)来指导网络对不同车种车辆的各个部位进行学习,从而指导车种分类以及有无车辆检测,最终达到监控车辆(车种)的目的。
但是因为数据集不大,尽管模型收敛很好,但心里还是没底,担心是过拟合。于是想到可以可视化一下网络的CAM,观察一下指导分类的高响应区域是否落在目标核心部位上。
Class Activation Mapping(CAM)是一个帮助可视化CNN的工具,通过它我们可以观察为了达到正确分类的目的,网络更侧重于哪块区域。比如,下面两幅图,一个是刷牙,一个是砍树,我们根据热力图可以看到高响应区域的确集中在我们认为最有助于作出判断的部位。
其计算方法如下图所示。对于一个CNN模型,对其最后一个featuremap做全局平均池化(GAP)计算各通道均值,然后通过FC层等映射到class score,找出argmax,计算最大的那一类的输出相对于最后一个featuremap的梯度,再把这个梯度可视化到原图上即可。直观来说,就是看一下网络抽取到的高层特征的哪部分对最终的classifier影响更大。
找到了一篇基于Keras的CAM实现,感谢:https://blog.****.net/Einstellung/article/details/82858974。但是我还是习惯用Pytorch一点,所以参考着改了一版Pytorch的实现。其中,有一个地方困扰了一下,因为Pytorch的自动求导机制,一般只会保存函数值对输入的导数值,而中间变量的导数值都没有保留,而此处我们需要计算输出层相对于最后一个feature map梯度,所以参考https://blog.****.net/qq_27061325/article/details/84728539解决了该问题。
基于Pytorch的CAM计算与绘制,具体代码如下:
#coding: utf-8
import os
from PIL import Image
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
def draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False):
'''
绘制 Class Activation Map
:param model: 加载好权重的Pytorch model
:param img_path: 测试图片路径
:param save_path: CAM结果保存路径
:param transform: 输入图像预处理方法
:param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
:return:
'''
# 图像加载&预处理
img = Image.open(img_path).convert('RGB')
if transform:
img = transform(img)
img = img.unsqueeze(0)
# 获取模型输出的feature/score
model.eval()
features = model.features(img)
output = model.classifier(features)
# 为了能读取到中间梯度定义的辅助函数
def extract(g):
global features_grad
features_grad = g
# 预测得分最高的那一类对应的输出score
pred = torch.argmax(output).item()
pred_class = output[:, pred]
features.register_hook(extract)
pred_class.backward() # 计算梯度
grads = features_grad # 获取梯度
pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))
# 此处batch size默认为1,所以去掉了第0维(batch size维)
pooled_grads = pooled_grads[0]
features = features[0]
# 512是最后一层feature的通道数
for i in range(512):
features[i, ...] *= pooled_grads[i, ...]
# 以下部分同Keras版实现
heatmap = features.detach().numpy()
heatmap = np.mean(heatmap, axis=0)
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# 可视化原始热力图
if visual_heatmap:
plt.matshow(heatmap)
plt.show()
img = cv2.imread(img_path) # 用cv2加载原始图像
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同
heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # 将热力图应用于原始图像
superimposed_img = heatmap * 0.4 + img # 这里的0.4是热力图强度因子
cv2.imwrite(save_path, superimposed_img) # 将图像保存到硬盘