-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revise package structure to ease usability
This is a user-experience-focused package. Without a cogent API, it would seem hypocritical and insincere to make another crap tool that promises the world.
- Loading branch information
1 parent
34dcc05
commit e0b12bb
Showing
12 changed files
with
110 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1 @@ | ||
from .morph import morph | ||
from .nn.morph_net import * | ||
from .nn.morph import once # facility tate "morph.once" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class EasyMnist(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = nn.Linear(784, 1000) | ||
self.linear2 = nn.Linear(1000, 30) | ||
self.linear3 = nn.Linear(30, 10) | ||
|
||
def forward(self, x_batch: torch.Tensor): | ||
"""Simple ReLU-based activations through all layers of the DNN. | ||
Simple and effectively deep neural network. No frills. | ||
""" | ||
_input = x_batch.view(-1, 784) # shape for our linear1 | ||
out1 = F.relu(self.linear1(x_batch)) | ||
out2 = F.relu(self.linear2(out1)) | ||
out3 = F.relu(self.linear3(out2)) | ||
|
||
return out3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .sparsify import * | ||
from .sparse import * | ||
from .widen import widen |
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ._morph_net import Morph | ||
from .morph import once |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from morph.layers.sparse import percent_waste | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader | ||
|
||
class Morph(nn.Module): | ||
"""An encapsulation of the benefits of MorphNet, namely: | ||
1. automatically shrinking and widening, to produce a new architecture w.r.t. layer widths | ||
2. Training of the network, to match (or beat) model performance | ||
3. | ||
""" | ||
|
||
@classmethod | ||
def shrink_out(cls, child_layer): | ||
new_out = int(child_layer.out_features * percent_waste(child_layer)) | ||
return nn.Linear(child_layer.in_features, new_out) | ||
|
||
def __init__(self, net: nn.Module, epochs: int, dataloader: DataLoader): | ||
super().__init__() | ||
self.layers = nn.ModuleList([ | ||
Morph.shrink_out(c) for c in net.children() | ||
]) | ||
|
||
def run_training(self): | ||
"""Performs the managed training for this instance""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch.nn as nn | ||
|
||
def once(net: nn.Module, experimental_support=False) -> nn.Module: | ||
"""Runs an experimental implementation of the MorphNet algorithm on `net` | ||
producing a new network: | ||
1. Shrink the layers o | ||
Returns: either `net` if `experimental_support == False` or a MorphNet of | ||
the supplied `net`. | ||
""" | ||
# TODO: run the algorithm | ||
return net |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters