diff --git a/memory_efficient_attention_pytorch/cosine_sim_flash_attention.py b/memory_efficient_attention_pytorch/cosine_sim_flash_attention.py index 9213577..6fdefc7 100644 --- a/memory_efficient_attention_pytorch/cosine_sim_flash_attention.py +++ b/memory_efficient_attention_pytorch/cosine_sim_flash_attention.py @@ -37,8 +37,6 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size): o = torch.zeros_like(q) all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device) - q = q * scale - if not exists(mask): mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) else: @@ -63,7 +61,7 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size): for k_ind, (kc, vc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale if exists(row_mask): attn_weights.masked_fill_(~row_mask, max_neg_value) @@ -129,14 +127,13 @@ def backward(ctx, do): for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size - qc_scaled = qc * scale - attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights) + exp_attn_weights = torch.exp(attn_weights - scale) if exists(row_mask): exp_attn_weights.masked_fill_(~row_mask, 0.) diff --git a/setup.py b/setup.py index c8a17e8..bf766cc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'memory-efficient-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.0.25', + version = '0.0.26', license='MIT', description = 'Memory Efficient Attention - Pytorch', long_description_content_type = 'text/markdown',