-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathAttentionBlock.py
117 lines (91 loc) · 3.69 KB
/
AttentionBlock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -*- coding: utf-8 -*-
from utils import *
class MLP(nn.Module) :
def __init__(self, D):
super(MLP, self).__init__()
self.fc1 = nn.Linear(D,D)
self.fc2 = nn.Linear(D,D)
self.relu = nn.ReLU()
def forward(self,x) :
return self.relu(self.fc2(self.fc1(x)))
class AttentionBlock(nn.Module) :
def __init__(self, t, D):
"""
=INPUT=
type : Type de Block (row/col)
: String
D : dimension embedding
"""
super(AttentionBlock, self).__init__()
self.D = D
self.Type = t
# head 1
self.Q1 = nn.Linear(D,D,bias=False)
self.K1 = nn.Linear(D,D,bias=False)
self.V1 = nn.Linear(D,D,bias=False)
# head 2
self.Q2 = nn.Linear(D,D,bias=False)
self.K2 = nn.Linear(D,D,bias=False)
self.V2 = nn.Linear(D,D,bias=False)
# head 3
self.Q3 = nn.Linear(D,D,bias=False)
self.K3 = nn.Linear(D,D,bias=False)
self.V3 = nn.Linear(D,D,bias=False)
# head 4
self.Q4 = nn.Linear(D,D,bias=False)
self.K4 = nn.Linear(D,D,bias=False)
self.V4 = nn.Linear(D,D,bias=False)
# linear
self.out = nn.Linear(4*D,D)
self.LN = nn.LayerNorm(D)
self.mlp = MLP(D)
def forward(self, input):
"""
=INPUT=
type : (BATCH *) M * N * D
=RETURN=
out : (BATCH *) M * N * D
"""
batch, row, col, _ = input.shape
out = torch.empty_like(input)
s = torch.nn.Softmax(-1)
if self.Type == 'row' :
for i in range(row) :
ln_input = self.LN(input[:,i,:,:])
A1 = s(torch.matmul(self.Q1(ln_input),
self.K1(ln_input).transpose(1,2))/math.sqrt(self.D))
A2 = s(torch.matmul(self.Q2(ln_input),
self.K2(ln_input).transpose(1,2))/math.sqrt(self.D))
A3 = s(torch.matmul(self.Q3(ln_input),
self.K3(ln_input).transpose(1,2))/math.sqrt(self.D))
A4 = s(torch.matmul(self.Q4(ln_input),
self.K4(ln_input).transpose(1,2))/math.sqrt(self.D))
# A : BATCH * W * W
SA1 = torch.matmul(A1,self.V1(ln_input))
SA2 = torch.matmul(A2,self.V2(ln_input))
SA3 = torch.matmul(A3,self.V3(ln_input))
SA4 = torch.matmul(A4,self.V4(ln_input))
MSA = self.out(torch.cat((SA1,SA2,SA3,SA4),2))
tmp = MSA + input[:,i,:,:]
out[:,i,:,:] = self.mlp(self.LN(tmp)) + tmp # W * D
# ColumnAttention
else :
for j in range(col) :
ln_input = self.LN(input[:,:,j,:])
A1 = s(torch.matmul(self.Q1(ln_input),
self.K1(ln_input).transpose(1,2))/math.sqrt(self.D))
A2 = s(torch.matmul(self.Q2(ln_input),
self.K2(ln_input).transpose(1,2))/math.sqrt(self.D))
A3 = s(torch.matmul(self.Q3(ln_input),
self.K3(ln_input).transpose(1,2))/math.sqrt(self.D))
A4 = s(torch.matmul(self.Q4(ln_input),
self.K4(ln_input).transpose(1,2))/math.sqrt(self.D))
# Ai : BATCH * H * H
SA1 = torch.matmul(A1,self.V1(ln_input))
SA2 = torch.matmul(A2,self.V2(ln_input))
SA3 = torch.matmul(A3,self.V3(ln_input))
SA4 = torch.matmul(A4,self.V4(ln_input))
MSA = self.out(torch.cat((SA1,SA2,SA3,SA4),2))
tmp = MSA + input[:,:,j,:]
out[:,:,j,:] = self.mlp(self.LN(tmp)) + tmp # H * D
return out