Skip to content

Commit

Permalink
add torch version to test
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent c6771fc commit 587b7d3
Showing 1 changed file with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest
from torch import nn

from packaging import version
import torch

class ScaledDotProductAttentionNet(nn.Module):
def __init__(self, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False):
Expand Down Expand Up @@ -57,10 +58,17 @@ def __init__(self, unit_test, batch_size: int, q_and_k_embd_size: int, v_embd_si
self.is_causal = is_causal

def create_feature_network(self, input_shape):
return ScaledDotProductAttentionNet(dropout_p=self.dropout_p,
attn_mask=self.attn_mask,
is_causal=self.is_causal,
scale=self.scale)

if version.parse(torch.__version__) >= version.parse("2.1"):
return ScaledDotProductAttentionNet(dropout_p=self.dropout_p,
attn_mask=self.attn_mask,
is_causal=self.is_causal,
scale=self.scale)

else: # older torch versions don't have scale argument
return ScaledDotProductAttentionNet(dropout_p=self.dropout_p,
attn_mask=self.attn_mask,
is_causal=self.is_causal)

def create_inputs_shape(self):
q_shape = [self.batch_size, self.target_seq_len, self.q_and_k_embd_size]
Expand Down

0 comments on commit 587b7d3

Please sign in to comment.