Skip to content

Commit

Permalink
frontend draft
Browse files Browse the repository at this point in the history
  • Loading branch information
timkucera committed Jan 7, 2024
1 parent d65469a commit e22bafb
Show file tree
Hide file tree
Showing 16 changed files with 428 additions and 127 deletions.
51 changes: 45 additions & 6 deletions proteinshake/frontend/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np
import random
from typing import Generic
from proteinshake.util import amino_acid_alphabet
from proteinshake.util import amino_acid_alphabet, save
from proteinshake.transforms import Compose
from functools import partial


class Dataset:
def __init__(
self,
path: str = "",
version: str = "latest",
shard_size: int = None,
shuffle: bool = False,
random_seed: int = 42,
Expand All @@ -35,10 +36,48 @@ def __init__(
def proteins(self):
return iter(self.dummy_proteins)

def apply(self, *transforms) -> Generic:
transforms = Compose(transforms)
save(transforms.transform_deterministic(self.proteins), self.root, shard_size=self.shard_size)
self.transforms = transforms.transform_nondeterministic
def split(self, splitter):
# fitting?
# computes splits and saves them as a dict of indices
# rearranges protein shards for optimal data loading
# creates ?_loader properties
# one rng per loader
splitter.fit(self)
for name, index in splitter.assign(self):
save(partition, shard_size=self.shard_size)
setattr(self, f"{name}_index", index)
setattr(self, f"{name}_loader", partial(self.loader, name))

def apply(self, *transforms) -> None:
# prepares transforms
self.transform = Compose(transforms)
self.transform.fit()
for partition in self.partitions:
save(
self.transform.deterministic_transform(partition),
self.root,
shard_size=self.shard_size,
)

def loader(self, split=None, batch_size=None):
# check if batch_size multiple of shard_size
# creates generator to load data from disk (optionally shuffled) and to apply stochastic transforms
# creates framework dataloader from self.transform.create_dataloader
# uses the index returned from transforms to reshape the data into tuples
def __iter__():
# create shard order from rng
def generator():
try:
protein = next(self.current_shard)
except StopIteration:
# create item order from rng
self.current_shard = self.proteins
protein = next(self.current_shard)
return self.transform.stochastic_transform(protein) # reshape here

return generator

return self.transform.create_dataloader(__iter__, batch_size=batch_size)

def partition(self, index: dict[np.ndarray]):
"""
Expand Down
2 changes: 1 addition & 1 deletion proteinshake/frontend/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .evaluator import Evaluator
from .evaluator import *
2 changes: 1 addition & 1 deletion proteinshake/frontend/evaluators/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class Evaluator:
class Metrics:
"""For a collection of predictions and target values, return set of performance metrics.,"""

pass
3 changes: 3 additions & 0 deletions proteinshake/frontend/framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Framework:
def create_loader(self, iterator):
pass
2 changes: 2 additions & 0 deletions proteinshake/frontend/representation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class Representation:
pass
2 changes: 1 addition & 1 deletion proteinshake/frontend/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .splitter import Splitter
from .splitter import *
20 changes: 20 additions & 0 deletions proteinshake/frontend/splitters/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class MySplitter:
def __init__(self, seed=None) -> None:
self.rng = np.random.rng(seed)

def fit(self, dataset):
n = len(dataset)
train, test_val = train_test_split(
np.arange(n), test_size=0.2, random_state=self.rng.random()
)
test, val = train_test_split(
test_val, test_size=0.5, random_state=self.rng.random()
)
self.lookup = {
**{index: "train" for index in train},
**{index: "test" for index in test},
**{index: "val" for index in val},
}

def assign(self, index, protein):
return self.lookup[index]
6 changes: 4 additions & 2 deletions proteinshake/frontend/splitters/splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
class Splitter:
class Split:
"""
Abstract class for selecting train/val/test indices given a dataset.
"""

pass
@property
def hash(self):
return self.__class__.__name__
146 changes: 118 additions & 28 deletions proteinshake/frontend/task.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,125 @@
from proteinshake.frontend.dataset import Dataset
from proteinshake.frontend.splitters import Splitter
from typing import Union
import numpy as np
import os
from pathlib import Path
from functools import partial
from proteinshake.frontend.splitters import Split
from proteinshake.frontend.targets import Target
from proteinshake.frontend.evaluators import Evaluator
from proteinshake.frontend.transforms import Transform
from proteinshake.frontend.evaluators import Metrics
from proteinshake.frontend.transforms import Transform, Compose
from proteinshake.util import amino_acid_alphabet, sharded, save_shards, load, warn


class Task:
dataset: str = ""
split: Split = None
target: Target = None
metrics: Metrics = None
augmentation: Transform = None

def __init__(
self,
dataset: Dataset,
splitter: Splitter,
target: Target,
evaluator: Evaluator,
transform: Transform,
root: Union[str, None] = None, # default path in ~/.proteinshake
shard_size: int = 1024,
split_kwargs: dict = {},
target_kwargs: dict = {},
metrics_kwargs: dict = {},
augmentation_kwargs: dict = {},
) -> None:
# compute splits. `splitter` returns a dictionary of name:index pairs.
self.index = splitter(dataset)
# partition the dataset. the dataset will optimize data loading.
dataset.partition(self.index)
# fit the transforms
transform.fit(dataset)
# create X,y,dataloader for each item in the split.
for name, index in self.index.items():
# get the partition of the split, apply transforms, and save to disk.
X,y = dataset.split(name).apply(target, transform)
# create a dataloader for the framework
loader = dataset.create_dataloader(X, y)
# add attributes to the task object
setattr(self, f"X_{name}", X)
setattr(self, f"y_{name}", y)
setattr(self, f"{name}_dataloader", loader)
setattr(self, f"{name}_index", index)
# evaluator is a callable: `task.evaluate(y_true, y_pred)`
self.evaluate = evaluator
# create root
if root is None:
if not os.environ.get("PROTEINSHAKE_ROOT", None) is None:
root = os.environ["PROTEINSHAKE_ROOT"]
else:
root = "~/.proteinshake"
root = Path(root) / self.__class__.__name__
os.makedirs(root, exist_ok=True)
self.root = root
self.shard_size = shard_size

# assign task modules
self.split = self.split(**split_kwargs)
self.target = self.target(**target_kwargs)
self.metrics = self.metrics(**metrics_kwargs)
self.augmentation = self.augmentation(**augmentation_kwargs)

@property
def proteins(self):
# return dataset iterator
rng = np.random.default_rng(42)
return (
{
"ID": f"protein_{i}",
"coords": rng.integers(0, 100, size=(300, 3)),
"sequence": "".join(
rng.choice(list(amino_acid_alphabet)) for _ in range(300)
),
"label": rng.random() * 100,
"split": rng.choice(["train", "test", "val"]),
}
for i in range(100)
)

def transform(self, *transforms) -> None:
Xy = self.target(self.proteins)
partitions = self.split(Xy) # returns dict of generators[(X,...),y]
self.transform = Compose(*[self.augmentation, *transforms])
# cache from here
self.transform.fit(partitions["train"])
for name, Xy in partitions.items():
Xy = sharded(Xy, shard_size=self.shard_size)
data_transformed = (
self.transform.deterministic_transform(shard) for shard in Xy
)
save_shards(
data_transformed,
self.root / self.split.hash / self.transform.hash / "shards",
)
setattr(self, f"{name}_loader", partial(self.loader, split=name))
return self

def loader(
self,
split=None,
batch_size=None,
shuffle: bool = False,
random_seed: Union[int, None] = None,
**kwargs,
):
rng = np.random.default_rng(random_seed)
path = self.root / self.split.hash / self.transform.hash / "shards"
shard_index = load(path / "index.npy")
if self.shard_size % batch_size != 0 and batch_size % self.shard_size != 0:
warn(
"batch_size is not a multiple of shard_size. This causes inefficient data loading."
)

def generator():
if shuffle:
rng.shuffle(shard_index)
shards = (load(path / f"{i}.pkl") for i in shard_index)
while current_shard := next(shards):
current_X, current_y = current_shard
X_batch, y_batch = [], []
while len(X_batch) < batch_size:
b = batch_size - len(X_batch)
X_piece, current_X = current_X[:b], current_X[b:]
y_piece, current_y = current_y[:b], current_y[b:]
X_batch = X_batch + list(X_piece)
y_batch = y_batch + list(y_piece)
if len(current_X) == 0:
try:
current_shard = next(shards)
current_X, current_y = current_shard
except StopIteration:
break
yield self.transform.stochastic_transform(
(np.asarray(X_batch), np.asarray(y_batch))
)

return self.transform.create_loader(generator, **kwargs)

def evaluate(self, y_true, y_pred):
y_true = self.transform.inverse_transform(y_true)
y_pred = self.transform.inverse_transform(y_pred)
return self.metrics(y_true, y_pred)
18 changes: 14 additions & 4 deletions proteinshake/frontend/transforms/framework/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import torch
from ..transform import FrameworkTransform
import numpy as np
from torch.utils.data import DataLoader, IterableDataset
from ..transform import Transform
from proteinshake.frontend.framework import Framework


class TorchFrameworkTransform(FrameworkTransform):
def transform(self, representation):
return torch.tensor(representation)
class TorchFrameworkTransform(Framework, Transform):
def transform(self, X):
return X

def create_loader(self, iterator, **kwargs):
class Dataset(IterableDataset):
def __iter__(self):
return iterator()

return DataLoader(Dataset(), **kwargs)
14 changes: 14 additions & 0 deletions proteinshake/frontend/transforms/label/MinMaxScaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from proteinshake.frontend.transforms import Transform


class MyLabelTransform(Transform):
def fit(self, dataset):
labels = [p["label"] for p in dataset.split("train").proteins]
self.min, self.max = min(labels), max(labels)

def transform(self, X, y, index):
y_transformed = (y - self.min) / (self.max - self.min)
return X, y_transformed, index

def inverse_transform(self, y):
return y * (self.max - self.min) + self.min
4 changes: 3 additions & 1 deletion proteinshake/frontend/transforms/post_framework/note.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ Example:
from torch_geometric.transforms import AddSelfLoops
ds = EnzymeDataset(...).to_graph(...).pyg(..., post_transform=AddSelfLoops)

"""
"""

Note: maybe we provide a wrapper here to cast native framework transforms to shake transforms (to deal with shake-specific things like batching, deterministic/stochastic, etc)
10 changes: 6 additions & 4 deletions proteinshake/frontend/transforms/representation/point.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ..transform import RepresentationTransform
from ..transform import Transform
from proteinshake.frontend.representation import Representation
import numpy as np


class PointRepresentationTransform(RepresentationTransform):
def transform(self, protein):
return protein["coords"]
class PointRepresentationTransform(Representation, Transform):
def transform(self, X):
return np.asarray([protein["coords"] for protein in X])
Loading

0 comments on commit e22bafb

Please sign in to comment.