Skip to content

Commit

Permalink
add ability to hybridize attention with external module, for aiming t…
Browse files Browse the repository at this point in the history
…o resolve state tracking issue by next week end
  • Loading branch information
lucidrains committed Jan 3, 2025
1 parent f944dd7 commit 39bbb08
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2374,4 +2374,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@inproceedings{anonymous2024hymba,
title = {Hymba: A Hybrid-head Architecture for Small Language Models},
author = {Anonymous},
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
year = {2024},
url = {https://openreview.net/forum?id=A1ztozypga},
note = {under review}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.43.5',
version = '1.44.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 19 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,22 @@ def test_hyper_connections(tanh):
x = torch.randint(0, 20000, (2, 1024))

model(x)

def test_hybrid():
from torch.nn import GRU

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
attn_dim_head = 64,
attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
)
)

x = torch.randint(0, 20000, (2, 1024))

embed = model(x)
32 changes: 23 additions & 9 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable

import math
from copy import deepcopy
from random import random, randrange
from packaging import version

Expand Down Expand Up @@ -1136,6 +1137,7 @@ def __init__(
sigmoid = False,
selective = False,
custom_attn_fn: Callable | None = None,
hybrid_module: Module | None = None,
one_kv_head = False,
kv_heads = None,
shared_kv = False,
Expand Down Expand Up @@ -1335,6 +1337,10 @@ def __init__(

self.attn_on_attn = on_attn

# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676

self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None

# output dimension by default same as input, but can be overridden

dim_out = default(dim_out, dim)
Expand Down Expand Up @@ -1407,6 +1413,16 @@ def forward(
value_residual_mix = self.to_value_residual_mix(q_input)
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)

# qk normalization

if self.qk_norm:
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
q, k = map(qk_l2norm, (q, k))
scale = self.qk_norm_scale

q = q * self.qk_norm_q_scale
k = k * self.qk_norm_k_scale

# take care of caching

if exists(cache):
Expand All @@ -1427,14 +1443,6 @@ def forward(
mem_len = mem.shape[-2] if exists(mem) else 0
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])

if self.qk_norm:
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
q, k = map(qk_l2norm, (q, k))
scale = self.qk_norm_scale

q = q * self.qk_norm_q_scale
k = k * self.qk_norm_k_scale

if exists(rotary_pos_emb):
freqs, xpos_scale = rotary_pos_emb
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
Expand Down Expand Up @@ -1581,6 +1589,12 @@ def forward(

out = rearrange(out, 'b h n d -> b n (h d)')

# hybrid module

if exists(self.hybrid_module):
hybrid_out, _ = self.hybrid_module(x)
out = 0.5 * (out + hybrid_out)

# alphafold2 styled gating of the values

if exists(self.to_v_gate):
Expand Down Expand Up @@ -2003,7 +2017,7 @@ def __init__(

# determine whether can cache kv

self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])

def forward(
self,
Expand Down

1 comment on commit 39bbb08

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on 39bbb08 Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most likely will go for deltanet (linear attention + delta update rule) w/ negative eigenvalues as the first bandaid https://arxiv.org/abs/2411.12537

Please sign in to comment.