-
Notifications
You must be signed in to change notification settings - Fork 13
/
at.py
33 lines (28 loc) · 884 Bytes
/
at.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
AT with sum of absolute values with power p
code from: https://github.com/AberHu/Knowledge-Distillation-Zoo
'''
class AT(nn.Module):
'''
Paying More Attention to Attention: Improving the Performance of Convolutional
Neural Netkworks wia Attention Transfer
https://arxiv.org/pdf/1612.03928.pdf
'''
def __init__(self, p):
super(AT, self).__init__()
self.p = p
def forward(self, fm_s, fm_t):
loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))
return loss
def attention_map(self, fm, eps=1e-6):
am = torch.pow(torch.abs(fm), self.p)
am = torch.sum(am, dim=1, keepdim=True)
norm = torch.norm(am, dim=(2,3), keepdim=True)
am = torch.div(am, norm+eps)
return am