Skip to content

Commit

Permalink
flexibly handle hybrid module outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 5, 2025
1 parent b81646f commit c51ecd3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.44.0',
version = '1.44.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 17 additions & 2 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def test_hyper_connections(tanh):
def test_hybrid():
from torch.nn import GRU

model = TransformerWrapper(
dec = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
Expand All @@ -631,4 +631,19 @@ def test_hybrid():

x = torch.randint(0, 20000, (2, 1024))

embed = model(x)
embed = dec(x)

enc = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 128,
depth = 6,
heads = 8,
attn_dim_head = 64,
attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
)
)

mask = torch.randint(0, 2, (2, 1024)).bool()
embed = enc(x, mask = mask)
22 changes: 20 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from packaging import version

import torch
from torch.amp import autocast
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.utils._pytree import tree_flatten
from torch.nn import Module, ModuleList, ModuleDict
from torch.amp import autocast

from functools import partial, wraps
from collections import namedtuple
Expand Down Expand Up @@ -1138,6 +1139,7 @@ def __init__(
selective = False,
custom_attn_fn: Callable | None = None,
hybrid_module: Module | None = None,
hybrid_mask_kwarg: str | None = None,
one_kv_head = False,
kv_heads = None,
shared_kv = False,
Expand Down Expand Up @@ -1341,6 +1343,8 @@ def __init__(

self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None

self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths

# output dimension by default same as input, but can be overridden

dim_out = default(dim_out, dim)
Expand Down Expand Up @@ -1592,7 +1596,21 @@ def forward(
# hybrid module

if exists(self.hybrid_module):
hybrid_out, _ = self.hybrid_module(x)

# hybrid input

hybrid_forward_kwargs = dict()

if not self.causal and exists(self.hybrid_mask_kwarg):
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}

# hybrid forward

hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)

# handle hybrid out

(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
out = 0.5 * (out + hybrid_out)

# alphafold2 styled gating of the values
Expand Down

0 comments on commit c51ecd3

Please sign in to comment.