-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
428 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .evaluator import Evaluator | ||
from .evaluator import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
class Evaluator: | ||
class Metrics: | ||
"""For a collection of predictions and target values, return set of performance metrics.,""" | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
class Framework: | ||
def create_loader(self, iterator): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class Representation: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .splitter import Splitter | ||
from .splitter import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
class MySplitter: | ||
def __init__(self, seed=None) -> None: | ||
self.rng = np.random.rng(seed) | ||
|
||
def fit(self, dataset): | ||
n = len(dataset) | ||
train, test_val = train_test_split( | ||
np.arange(n), test_size=0.2, random_state=self.rng.random() | ||
) | ||
test, val = train_test_split( | ||
test_val, test_size=0.5, random_state=self.rng.random() | ||
) | ||
self.lookup = { | ||
**{index: "train" for index in train}, | ||
**{index: "test" for index in test}, | ||
**{index: "val" for index in val}, | ||
} | ||
|
||
def assign(self, index, protein): | ||
return self.lookup[index] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
class Splitter: | ||
class Split: | ||
""" | ||
Abstract class for selecting train/val/test indices given a dataset. | ||
""" | ||
|
||
pass | ||
@property | ||
def hash(self): | ||
return self.__class__.__name__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,125 @@ | ||
from proteinshake.frontend.dataset import Dataset | ||
from proteinshake.frontend.splitters import Splitter | ||
from typing import Union | ||
import numpy as np | ||
import os | ||
from pathlib import Path | ||
from functools import partial | ||
from proteinshake.frontend.splitters import Split | ||
from proteinshake.frontend.targets import Target | ||
from proteinshake.frontend.evaluators import Evaluator | ||
from proteinshake.frontend.transforms import Transform | ||
from proteinshake.frontend.evaluators import Metrics | ||
from proteinshake.frontend.transforms import Transform, Compose | ||
from proteinshake.util import amino_acid_alphabet, sharded, save_shards, load, warn | ||
|
||
|
||
class Task: | ||
dataset: str = "" | ||
split: Split = None | ||
target: Target = None | ||
metrics: Metrics = None | ||
augmentation: Transform = None | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
splitter: Splitter, | ||
target: Target, | ||
evaluator: Evaluator, | ||
transform: Transform, | ||
root: Union[str, None] = None, # default path in ~/.proteinshake | ||
shard_size: int = 1024, | ||
split_kwargs: dict = {}, | ||
target_kwargs: dict = {}, | ||
metrics_kwargs: dict = {}, | ||
augmentation_kwargs: dict = {}, | ||
) -> None: | ||
# compute splits. `splitter` returns a dictionary of name:index pairs. | ||
self.index = splitter(dataset) | ||
# partition the dataset. the dataset will optimize data loading. | ||
dataset.partition(self.index) | ||
# fit the transforms | ||
transform.fit(dataset) | ||
# create X,y,dataloader for each item in the split. | ||
for name, index in self.index.items(): | ||
# get the partition of the split, apply transforms, and save to disk. | ||
X,y = dataset.split(name).apply(target, transform) | ||
# create a dataloader for the framework | ||
loader = dataset.create_dataloader(X, y) | ||
# add attributes to the task object | ||
setattr(self, f"X_{name}", X) | ||
setattr(self, f"y_{name}", y) | ||
setattr(self, f"{name}_dataloader", loader) | ||
setattr(self, f"{name}_index", index) | ||
# evaluator is a callable: `task.evaluate(y_true, y_pred)` | ||
self.evaluate = evaluator | ||
# create root | ||
if root is None: | ||
if not os.environ.get("PROTEINSHAKE_ROOT", None) is None: | ||
root = os.environ["PROTEINSHAKE_ROOT"] | ||
else: | ||
root = "~/.proteinshake" | ||
root = Path(root) / self.__class__.__name__ | ||
os.makedirs(root, exist_ok=True) | ||
self.root = root | ||
self.shard_size = shard_size | ||
|
||
# assign task modules | ||
self.split = self.split(**split_kwargs) | ||
self.target = self.target(**target_kwargs) | ||
self.metrics = self.metrics(**metrics_kwargs) | ||
self.augmentation = self.augmentation(**augmentation_kwargs) | ||
|
||
@property | ||
def proteins(self): | ||
# return dataset iterator | ||
rng = np.random.default_rng(42) | ||
return ( | ||
{ | ||
"ID": f"protein_{i}", | ||
"coords": rng.integers(0, 100, size=(300, 3)), | ||
"sequence": "".join( | ||
rng.choice(list(amino_acid_alphabet)) for _ in range(300) | ||
), | ||
"label": rng.random() * 100, | ||
"split": rng.choice(["train", "test", "val"]), | ||
} | ||
for i in range(100) | ||
) | ||
|
||
def transform(self, *transforms) -> None: | ||
Xy = self.target(self.proteins) | ||
partitions = self.split(Xy) # returns dict of generators[(X,...),y] | ||
self.transform = Compose(*[self.augmentation, *transforms]) | ||
# cache from here | ||
self.transform.fit(partitions["train"]) | ||
for name, Xy in partitions.items(): | ||
Xy = sharded(Xy, shard_size=self.shard_size) | ||
data_transformed = ( | ||
self.transform.deterministic_transform(shard) for shard in Xy | ||
) | ||
save_shards( | ||
data_transformed, | ||
self.root / self.split.hash / self.transform.hash / "shards", | ||
) | ||
setattr(self, f"{name}_loader", partial(self.loader, split=name)) | ||
return self | ||
|
||
def loader( | ||
self, | ||
split=None, | ||
batch_size=None, | ||
shuffle: bool = False, | ||
random_seed: Union[int, None] = None, | ||
**kwargs, | ||
): | ||
rng = np.random.default_rng(random_seed) | ||
path = self.root / self.split.hash / self.transform.hash / "shards" | ||
shard_index = load(path / "index.npy") | ||
if self.shard_size % batch_size != 0 and batch_size % self.shard_size != 0: | ||
warn( | ||
"batch_size is not a multiple of shard_size. This causes inefficient data loading." | ||
) | ||
|
||
def generator(): | ||
if shuffle: | ||
rng.shuffle(shard_index) | ||
shards = (load(path / f"{i}.pkl") for i in shard_index) | ||
while current_shard := next(shards): | ||
current_X, current_y = current_shard | ||
X_batch, y_batch = [], [] | ||
while len(X_batch) < batch_size: | ||
b = batch_size - len(X_batch) | ||
X_piece, current_X = current_X[:b], current_X[b:] | ||
y_piece, current_y = current_y[:b], current_y[b:] | ||
X_batch = X_batch + list(X_piece) | ||
y_batch = y_batch + list(y_piece) | ||
if len(current_X) == 0: | ||
try: | ||
current_shard = next(shards) | ||
current_X, current_y = current_shard | ||
except StopIteration: | ||
break | ||
yield self.transform.stochastic_transform( | ||
(np.asarray(X_batch), np.asarray(y_batch)) | ||
) | ||
|
||
return self.transform.create_loader(generator, **kwargs) | ||
|
||
def evaluate(self, y_true, y_pred): | ||
y_true = self.transform.inverse_transform(y_true) | ||
y_pred = self.transform.inverse_transform(y_pred) | ||
return self.metrics(y_true, y_pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,17 @@ | ||
import torch | ||
from ..transform import FrameworkTransform | ||
import numpy as np | ||
from torch.utils.data import DataLoader, IterableDataset | ||
from ..transform import Transform | ||
from proteinshake.frontend.framework import Framework | ||
|
||
|
||
class TorchFrameworkTransform(FrameworkTransform): | ||
def transform(self, representation): | ||
return torch.tensor(representation) | ||
class TorchFrameworkTransform(Framework, Transform): | ||
def transform(self, X): | ||
return X | ||
|
||
def create_loader(self, iterator, **kwargs): | ||
class Dataset(IterableDataset): | ||
def __iter__(self): | ||
return iterator() | ||
|
||
return DataLoader(Dataset(), **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from proteinshake.frontend.transforms import Transform | ||
|
||
|
||
class MyLabelTransform(Transform): | ||
def fit(self, dataset): | ||
labels = [p["label"] for p in dataset.split("train").proteins] | ||
self.min, self.max = min(labels), max(labels) | ||
|
||
def transform(self, X, y, index): | ||
y_transformed = (y - self.min) / (self.max - self.min) | ||
return X, y_transformed, index | ||
|
||
def inverse_transform(self, y): | ||
return y * (self.max - self.min) + self.min |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from ..transform import RepresentationTransform | ||
from ..transform import Transform | ||
from proteinshake.frontend.representation import Representation | ||
import numpy as np | ||
|
||
|
||
class PointRepresentationTransform(RepresentationTransform): | ||
def transform(self, protein): | ||
return protein["coords"] | ||
class PointRepresentationTransform(Representation, Transform): | ||
def transform(self, X): | ||
return np.asarray([protein["coords"] for protein in X]) |
Oops, something went wrong.