Skip to content

Commit

Permalink
test out flash attention in GPT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2022
1 parent 33fb78a commit 06b7775
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
7 changes: 5 additions & 2 deletions memory_efficient_attention_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F

from einops import rearrange
from memory_efficient_attention_pytorch import Attention
from memory_efficient_attention_pytorch import FlashAttention, Attention
from memory_efficient_attention_pytorch.reversible import ReversibleSequence

def exists(val):
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(
heads = 8,
ff_mult = 4,
ff_chunks = 1,
use_flash_attn = True,
**kwargs
):
super().__init__()
Expand All @@ -59,10 +60,12 @@ def __init__(
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)

attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memory-efficient-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.23',
version = '0.0.25',
license='MIT',
description = 'Memory Efficient Attention - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
6 changes: 4 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def test_flash_attn_gradients_equal():
k = torch.randn(1, 8, 1024, 512).requires_grad_()
v = torch.randn(1, 8, 1024, 512).requires_grad_()

o = attention(q, k, v, causal = False)
mask = torch.ones(1, 1024).bool()

o = attention(q, k, v, mask = mask, causal = True)
o.sum().backward()

dq_grad = q.grad.clone()
Expand All @@ -102,7 +104,7 @@ def test_flash_attn_gradients_equal():
k.grad.zero_()
v.grad.zero_()

flash_o = FlashAttentionFunction.apply(q, k, v, None, False, 64, 64)
flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
flash_o.sum().backward()

flash_dq_grad = q.grad.clone()
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 4096
GENERATE_LENGTH = 1024
SEQ_LEN = 4096

# helpers
Expand All @@ -43,10 +43,10 @@ def decode_tokens(tokens):
depth = 6,
heads = 8,
causal = True,
memory_efficient = True,
q_bucket_size = 256,
k_bucket_size = 256,
ff_chunks = 5
ff_chunks = 5,
use_flash_attn = True
)

model = AutoregressiveWrapper(model)
Expand Down

0 comments on commit 06b7775

Please sign in to comment.