From 39bbb08cb956b2c95ac7ab648133067cc40ad2bf Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 Jan 2025 11:45:38 -0800 Subject: [PATCH] add ability to hybridize attention with external module, for aiming to resolve state tracking issue by next week end --- README.md | 11 +++++++++++ setup.py | 2 +- tests/test_x_transformers.py | 19 +++++++++++++++++++ x_transformers/x_transformers.py | 32 +++++++++++++++++++++++--------- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6137949c..5aa3089e 100644 --- a/README.md +++ b/README.md @@ -2374,4 +2374,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) } ``` +```bibtex +@inproceedings{anonymous2024hymba, + title = {Hymba: A Hybrid-head Architecture for Small Language Models}, + author = {Anonymous}, + booktitle = {Submitted to The Thirteenth International Conference on Learning Representations}, + year = {2024}, + url = {https://openreview.net/forum?id=A1ztozypga}, + note = {under review} +} +``` + *solve intelligence... then use that to solve everything else.* - Demis Hassabis diff --git a/setup.py b/setup.py index 4e44950d..cc6455d8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.43.5', + version = '1.44.0', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index 739b2619..91c9caed 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -613,3 +613,22 @@ def test_hyper_connections(tanh): x = torch.randint(0, 20000, (2, 1024)) model(x) + +def test_hybrid(): + from torch.nn import GRU + + model = TransformerWrapper( + num_tokens = 20000, + max_seq_len = 1024, + attn_layers = Decoder( + dim = 128, + depth = 6, + heads = 8, + attn_dim_head = 64, + attn_hybrid_module = GRU(128, 64 * 8, batch_first = True) + ) + ) + + x = torch.randint(0, 20000, (2, 1024)) + + embed = model(x) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index c41fcaff..6f1df70e 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -2,6 +2,7 @@ from typing import Callable import math +from copy import deepcopy from random import random, randrange from packaging import version @@ -1136,6 +1137,7 @@ def __init__( sigmoid = False, selective = False, custom_attn_fn: Callable | None = None, + hybrid_module: Module | None = None, one_kv_head = False, kv_heads = None, shared_kv = False, @@ -1335,6 +1337,10 @@ def __init__( self.attn_on_attn = on_attn + # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676 + + self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None + # output dimension by default same as input, but can be overridden dim_out = default(dim_out, dim) @@ -1407,6 +1413,16 @@ def forward( value_residual_mix = self.to_value_residual_mix(q_input) v = v * value_residual_mix + value_residual * (1. - value_residual_mix) + # qk normalization + + if self.qk_norm: + qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) + q, k = map(qk_l2norm, (q, k)) + scale = self.qk_norm_scale + + q = q * self.qk_norm_q_scale + k = k * self.qk_norm_k_scale + # take care of caching if exists(cache): @@ -1427,14 +1443,6 @@ def forward( mem_len = mem.shape[-2] if exists(mem) else 0 cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :]) - if self.qk_norm: - qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) - q, k = map(qk_l2norm, (q, k)) - scale = self.qk_norm_scale - - q = q * self.qk_norm_q_scale - k = k * self.qk_norm_k_scale - if exists(rotary_pos_emb): freqs, xpos_scale = rotary_pos_emb q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) @@ -1581,6 +1589,12 @@ def forward( out = rearrange(out, 'b h n d -> b n (h d)') + # hybrid module + + if exists(self.hybrid_module): + hybrid_out, _ = self.hybrid_module(x) + out = 0.5 * (out + hybrid_out) + # alphafold2 styled gating of the values if exists(self.to_v_gate): @@ -2003,7 +2017,7 @@ def __init__( # determine whether can cache kv - self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ]) + self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)]) def forward( self,