Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBert: reuse GemmaRotaryEmbedding via modular + Integration tests #35459

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
Expand Down Expand Up @@ -235,30 +236,62 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
super().__init__()
self.rope_kwargs = {"dim": dim, "base": base}
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = None
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


Expand Down Expand Up @@ -462,9 +495,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
)
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)

self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
Expand Down
37 changes: 7 additions & 30 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
logging,
)
from ...utils.import_utils import is_triton_available
from ..gemma.modeling_gemma import apply_rotary_pos_emb
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb


if is_flash_attn_2_available():
Expand Down Expand Up @@ -493,32 +493,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.Wo(self.drop(self.act(input) * gate))


class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
super().__init__(self, config=config, device=device)
self.rope_kwargs = {"dim": dim, "base": base}
self.config = None


def eager_attention_forward(
Expand Down Expand Up @@ -687,9 +666,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
)
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)

self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
Expand Down
134 changes: 130 additions & 4 deletions tests/models/modernbert/test_modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import unittest

import pytest
from packaging import version

from transformers import ModernBertConfig, is_torch_available
from transformers import AutoTokenizer, ModernBertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -362,6 +363,131 @@ def test_flash_attn_2_conversion(self):

@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):
"""
These still need to be written, once public models are available.
"""
@slow
def test_inference_masked_lm(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForMaskedLM.from_pretrained(
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 50368))
self.assertEqual(output.shape, expected_shape)

# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[3.8387, -0.2017, 12.2839], [3.6300, 0.6869, 14.7123], [-5.1137, -3.8122, 11.9874]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_no_head(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertModel.from_pretrained(
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 768))
self.assertEqual(output.shape, expected_shape)

# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[0.3151, -0.6417, -0.7027], [-0.7834, -1.5810, 0.4576], [1.0614, -0.7268, -0.0871]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_token_classification(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForTokenClassification.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForTokenClassification",
reference_compile=False,
attn_implementation="sdpa",
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-ModernBertForTokenClassification")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 2))
self.assertEqual(output.shape, expected_shape)

expected = torch.tensor(
[[[2.0159, 4.6569], [-0.9430, 3.1595], [-3.8770, 3.2653], [1.5752, 4.5167], [-1.6939, 1.2524]]]
)
self.assertTrue(torch.allclose(output, expected, atol=1e-4))

@slow
def test_inference_sequence_classification(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForSequenceClassification.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification",
reference_compile=False,
attn_implementation="sdpa",
)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification"
)

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 2))
self.assertEqual(output.shape, expected_shape)

expected = torch.tensor([[1.6466, 4.5662]])
self.assertTrue(torch.allclose(output, expected, atol=1e-4))

@slow
def test_export(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

bert_model = "answerdotai/ModernBERT-base"
device = "cpu"
attn_implementation = "sdpa"
max_length = 512

tokenizer = AutoTokenizer.from_pretrained(bert_model)
inputs = tokenizer(
"the man worked as a [MASK].",
return_tensors="pt",
padding="max_length",
max_length=max_length,
)

model = ModernBertForMaskedLM.from_pretrained(
bert_model,
device_map=device,
attn_implementation=attn_implementation,
)

logits = model(**inputs).logits
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask.split(), ["lawyer", "mechanic", "teacher", "doctor", "waiter"])

exported_program = torch.export.export(
model,
args=(inputs["input_ids"],),
kwargs={"attention_mask": inputs["attention_mask"]},
strict=True,
)

result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask, ep_predicted_mask)