Skip to content

Commit

Permalink
fix cosine sim flash attention as well
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2022
1 parent 06b7775 commit 35559a0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)

q = q * scale

if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
Expand All @@ -63,7 +61,7 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
Expand Down Expand Up @@ -129,14 +127,13 @@ def backward(ctx, do):
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

qc_scaled = qc * scale
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)

exp_attn_weights = torch.exp(attn_weights)
exp_attn_weights = torch.exp(attn_weights - scale)

if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memory-efficient-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.25',
version = '0.0.26',
license='MIT',
description = 'Memory Efficient Attention - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 35559a0

Please sign in to comment.