From 37c412bb594303db334588b6f3921b7239f18494 Mon Sep 17 00:00:00 2001 From: aaz <41302741+alibabaz@users.noreply.github.com> Date: Fri, 17 Dec 2021 14:16:00 -0500 Subject: [PATCH 1/4] init commit --- src/exabiome/nn/models/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/exabiome/nn/models/resnet.py b/src/exabiome/nn/models/resnet.py index eaa7ce6..9855e49 100644 --- a/src/exabiome/nn/models/resnet.py +++ b/src/exabiome/nn/models/resnet.py @@ -314,6 +314,7 @@ def _forward_impl(self, x): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) + if self.bottleneck is not None: From 79000722e49dcb980bb470be4f1a6610342a2953 Mon Sep 17 00:00:00 2001 From: aaz <41302741+alibabaz@users.noreply.github.com> Date: Fri, 17 Dec 2021 14:18:51 -0500 Subject: [PATCH 2/4] added attention to resnet --- src/exabiome/nn/models/resnet.py | 20 +++++++++++++++----- src/exabiome/nn/train.py | 1 + 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/exabiome/nn/models/resnet.py b/src/exabiome/nn/models/resnet.py index 9855e49..618903d 100644 --- a/src/exabiome/nn/models/resnet.py +++ b/src/exabiome/nn/models/resnet.py @@ -145,6 +145,8 @@ def __init__(self, hparams): hparams.simple_clf = False if not hasattr(hparams, 'dropout_clf'): hparams.dropout_clf = False + if not hasattr(hparams, 'attention'): + hparams.attention = False super(ResNet, self).__init__(hparams) @@ -183,13 +185,17 @@ def __init__(self, hparams): dilate=replace_stride_with_dilation[2]) n_output_channels = 512 * block.expansion + + self.avgpool = nn.AdaptiveAvgPool1d(1) + if hparams.bottleneck: self.bottleneck = FeatureReduction(n_output_channels, 64 * block.expansion) n_output_channels = 64 * block.expansion else: self.bottleneck = None - - self.avgpool = nn.AdaptiveAvgPool1d(1) + + if hparams.attention: + self.attention = nn.MultiheadAttention(n_output_channels, 16) if hparams.tgt_tax_lvl == 'all': self.fc = HierarchicalClassifier(n_output_channels, hparams.n_taxa_all) @@ -315,15 +321,19 @@ def _forward_impl(self, x): x = self.layer3(x) x = self.layer4(x) - - if self.bottleneck is not None: x = self.bottleneck(x) x = self.avgpool(x) + + if self.attention is not False: + x = x.permute(2, 0, 1) + x, _ = self.attention(x, x, x) + x = x.permute(1, 2, 0) + x = torch.flatten(x, 1) x = self.fc(x) - + return x def forward(self, x): diff --git a/src/exabiome/nn/train.py b/src/exabiome/nn/train.py index 70a6742..ecc03fd 100644 --- a/src/exabiome/nn/train.py +++ b/src/exabiome/nn/train.py @@ -60,6 +60,7 @@ def get_conf_args(): 'classify': dict(action='store_true', help='run a classification problem', default=False), 'manifold': dict(action='store_true', help='run a manifold learning problem', default=False), 'bottleneck': dict(action='store_true', help='add bottleneck layer at the end of ResNet features', default=True), + 'attention' : dict(help='add an attention layer at end of ResNet features', default=False), 'tgt_tax_lvl': dict(choices=DeepIndexFile.taxonomic_levels, metavar='LEVEL', default='species', help='the taxonomic level to predict. choices are phylum, class, order, family, genus, species'), 'simple_clf': dict(action='store_true', help='Use a single FC layer as the classifier for ResNets', default=False), From 83c55beff43c2031213642e861b4823a45164dcb Mon Sep 17 00:00:00 2001 From: aaz <41302741+alibabaz@users.noreply.github.com> Date: Fri, 17 Dec 2021 14:29:47 -0500 Subject: [PATCH 3/4] added default value for self.attention --- src/exabiome/nn/models/resnet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/exabiome/nn/models/resnet.py b/src/exabiome/nn/models/resnet.py index 618903d..10efb4c 100644 --- a/src/exabiome/nn/models/resnet.py +++ b/src/exabiome/nn/models/resnet.py @@ -196,6 +196,8 @@ def __init__(self, hparams): if hparams.attention: self.attention = nn.MultiheadAttention(n_output_channels, 16) + else: + self.attention = None if hparams.tgt_tax_lvl == 'all': self.fc = HierarchicalClassifier(n_output_channels, hparams.n_taxa_all) From d245f8a04c063b61e7bfad4f582d1665c19bf715 Mon Sep 17 00:00:00 2001 From: aaz <41302741+alibabaz@users.noreply.github.com> Date: Fri, 17 Dec 2021 14:38:18 -0500 Subject: [PATCH 4/4] added bottleneck to default if using attention --- src/exabiome/nn/models/resnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/exabiome/nn/models/resnet.py b/src/exabiome/nn/models/resnet.py index 10efb4c..e9fe30b 100644 --- a/src/exabiome/nn/models/resnet.py +++ b/src/exabiome/nn/models/resnet.py @@ -185,14 +185,16 @@ def __init__(self, hparams): dilate=replace_stride_with_dilation[2]) n_output_channels = 512 * block.expansion - - self.avgpool = nn.AdaptiveAvgPool1d(1) + if hparams.attention: + hparams.bottleneck = True #just to make sure bottleneck is on if using attention if hparams.bottleneck: self.bottleneck = FeatureReduction(n_output_channels, 64 * block.expansion) n_output_channels = 64 * block.expansion else: self.bottleneck = None + + self.avgpool = nn.AdaptiveAvgPool1d(1) if hparams.attention: self.attention = nn.MultiheadAttention(n_output_channels, 16)