Skip to content

Commit

Permalink
Merge pull request #134 from m-novikov/tf
Browse files Browse the repository at this point in the history
Initial test for tensoflow model
  • Loading branch information
m-novikov authored Oct 21, 2020
2 parents 614426b + c2f6f73 commit 7dbf87f
Show file tree
Hide file tree
Showing 20 changed files with 300 additions and 69 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* text=auto eol=lf
*.zip binary
*.npy binary
*.py eol=lf diff=python
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ sample_model:
unet2d:
cd tests/data/unet2d && zip -r $(ROOT_DIR)/unet2d.tmodel ./*

dummy_tf:
cd tests/data/dummy_tensorflow && zip -r $(ROOT_DIR)/dummy_tf.tmodel ./*

protos:
python -m grpc_tools.protoc -I./proto --python_out=tiktorch/proto/ --grpc_python_out=tiktorch/proto/ ./proto/*.proto
sed -i -r 's/import (.+_pb2.*)/from . import \1/g' tiktorch/proto/*_pb2*.py
Expand All @@ -27,4 +30,4 @@ remove_devenv:
conda env remove --yes --name $(TIKTORCH_ENV_NAME)


.PHONY: protos version sample_model devenv remove_devenv
.PHONY: protos version sample_model devenv remove_devenv dummy_tf
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TEST_DATA = "data"
TEST_PYBIO_ZIPFOLDER = "unet2d"
TEST_PYBIO_DUMMY = "dummy"
TEST_PYBIO_TENSORFLOW_DUMMY = "dummy_tensorflow"

NNModel = namedtuple("NNModel", ["model", "state"])

Expand Down Expand Up @@ -114,6 +115,34 @@ def pybio_dummy_model_bytes(data_path):
return data


def archive(directory):
result = io.BytesIO()

with ZipFile(result, mode="w") as zip_model:

def _archive(path_to_archive):
for path in path_to_archive.iterdir():
if str(path.name).startswith("__"):
continue

if path.is_dir():
_archive(path)

else:
with path.open(mode="rb") as f:
zip_model.writestr(str(path).replace(str(directory), ""), f.read())

_archive(directory)

return result


@pytest.fixture
def pybio_dummy_tensorflow_model_bytes(data_path):
pybio_net_dir = Path(data_path) / TEST_PYBIO_TENSORFLOW_DUMMY
return archive(pybio_net_dir)


@pytest.fixture
def cache_path(tmp_path):
return Path(getenv("PYBIO_CACHE_PATH", tmp_path))
43 changes: 43 additions & 0 deletions tests/data/dummy_tensorflow/Dummy.model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: DummyTFModel
description: A dummy tensorflow model for testing
authors:
- ilastik team
cite:
- text: "Ilastik"
doi: https://doi.org
documentation: dummy.md
tags: [tensorflow]
license: MIT

format_version: 0.1.0
language: python
framework: tensorflow

source: dummy.py::TensorflowModelWrapper

test_input: null # ../test_input.npy
test_output: null # ../test_output.npy

# TODO double check inputs/outputs
inputs:
- name: input
axes: cyx
data_type: float32
data_range: [-inf, inf]
shape: [1, 128, 128]
outputs:
- name: output
axes: bcyx
data_type: float32
data_range: [0, 1]
shape:
reference_input: input # FIXME(m-novikov) ignoring for now
scale: [1, 1, 1]
offset: [0, 0, 0]
#halo: [0, 0, 32, 32] # Should be moved to outputs

prediction:
weights:
source: ./model
hash: {md5: TODO}
dependencies: conda:./environment.yaml
12 changes: 12 additions & 0 deletions tests/data/dummy_tensorflow/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class TensorflowModelWrapper:
def __init__(self):
self._model = None

def set_model(self, model):
self._model = model

def forward(self, input_):
return self._model.predict(input_)

def __call__(self, *args, **kwargs):
return self._model.predict(*args, **kwargs)
Binary file added tests/data/dummy_tensorflow/model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.
8 changes: 3 additions & 5 deletions tests/test_server/test_exemplum.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from pathlib import Path

import numpy
import torch
from pybio.spec import load_spec_and_kwargs

from tiktorch.server.exemplum import Exemplum
from tiktorch.server.model_adapter._exemplum import Exemplum


def test_exemplum(data_path, cache_path):
spec_path = data_path / "unet2d/UNet2DNucleiBroad.model.yaml"
assert spec_path.exists(), spec_path.absolute()
pybio_model = load_spec_and_kwargs(str(spec_path), cache_path=cache_path)

exemplum = Exemplum(pybio_model=pybio_model, _devices=[torch.device("cpu")])
exemplum = Exemplum(pybio_model=pybio_model, devices=[torch.device("cpu")])
test_ipt = numpy.load(pybio_model.spec.test_input) # test input with batch dim
out = exemplum.forward(test_ipt[0]) # todo: exemplum.forward should get batch with batch dim
out = exemplum.forward(test_ipt[0].astype(numpy.float32)) # todo: exemplum.forward should get batch with batch dim
# assert isinstance(out_seq, (list, tuple)) # todo: forward should return a list
# assert len(out_seq) == 1
# out = out_seq
Expand Down
10 changes: 8 additions & 2 deletions tests/test_server/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from tiktorch.server.exemplum import Exemplum
from tiktorch.server.model_adapter import ModelAdapter
from tiktorch.server.reader import eval_model_zip, guess_model_path


Expand All @@ -20,4 +20,10 @@ def test_guess_model_path_without_model_file(paths):
def test_eval_model_zip(pybio_model_bytes, cache_path):
with ZipFile(pybio_model_bytes) as zf:
exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path)
assert isinstance(exemplum, Exemplum)
assert isinstance(exemplum, ModelAdapter)


def test_eval_tensorflow_model_zip(pybio_dummy_tensorflow_model_bytes, cache_path):
with ZipFile(pybio_dummy_tensorflow_model_bytes) as zf:
exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path)
assert isinstance(exemplum, ModelAdapter)
2 changes: 1 addition & 1 deletion tests/test_server/test_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_break_callback(self, cb):
self._break_cb = cb

def forward(self, input_tensor):
return torch.Tensor([42])
return np.array([42])

def set_max_num_iterations(self, val):
self.max_num_iterations = val
Expand Down
21 changes: 21 additions & 0 deletions tiktorch/server/model_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

from pybio.spec import nodes

from ._base import ModelAdapter

__all__ = ["ModelAdapter", "create_model_adapter"]


def create_model_adapter(*, pybio_model: nodes.Model, devices=List[str]):
spec = pybio_model.spec
if spec.framework == "pytorch":
from ._exemplum import Exemplum

return Exemplum(pybio_model=pybio_model, devices=devices)
elif spec.framework == "tensorflow":
from ._tensorflow_model_adapter import TensorflowModelAdapter

return TensorflowModelAdapter(pybio_model=pybio_model, devices=devices)
else:
raise NotImplementedError(f"Unknown framework: {spec.framework}")
26 changes: 26 additions & 0 deletions tiktorch/server/model_adapter/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import abc
from typing import Callable


class ModelAdapter(abc.ABC):
@abc.abstractmethod
def forward(self, input_tensor):
...

@property
@abc.abstractmethod
def max_num_iterations(self) -> int:
...

@property
@abc.abstractmethod
def iteration_count(self) -> int:
...

@abc.abstractmethod
def set_break_callback(self, thunk: Callable[[], bool]) -> None:
...

@abc.abstractmethod
def set_max_num_iterations(self, val: int) -> None:
...
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
from pybio.spec import nodes
from pybio.spec.utils import get_instance

logger = logging.getLogger(__name__)
# @dataclass
# class ValidationOutput(IterationOutput):
# pass
from ._base import ModelAdapter
from ._utils import has_batch_dim

# @dataclass
# class TrainingOutput(IterationOutput):
# pass
logger = logging.getLogger(__name__)


def _noop(tensor):
Expand All @@ -28,29 +24,15 @@ def _add_batch_dim(tensor):
return tensor.reshape((1,) + tensor.shape)


def _check_batch_dim(axes: str) -> bool:
try:
index = axes.index("b")
except ValueError:
return False
else:
if index != 0:
raise ValueError("Batch dimension is only supported in first position")
return True


class Exemplum:
class Exemplum(ModelAdapter):
def __init__(
self,
*,
pybio_model: nodes.Model,
batch_size: int = 1,
num_iterations_per_update: int = 2,
_devices=Sequence[torch.device],
devices=Sequence[str],
):
self.max_num_iterations = 0
self.iteration_count = 0
self.devices = _devices
self._max_num_iterations = 0
self._iteration_count = 0
spec = pybio_model.spec
self.name = spec.name

Expand All @@ -65,7 +47,7 @@ def __init__(
self._internal_input_axes = _input.axes
self._internal_output_axes = _output.axes

if _check_batch_dim(self._internal_input_axes):
if has_batch_dim(self._internal_input_axes):
self.input_axes = self._internal_input_axes[1:]
self._input_batch_dimension_transform = _add_batch_dim
_input_shape = _input.shape[1:]
Expand All @@ -78,7 +60,7 @@ def __init__(

_halo = _output.halo or [0 for _ in _output.axes]

if _check_batch_dim(self._internal_output_axes):
if has_batch_dim(self._internal_output_axes):
self.output_axes = self._internal_output_axes[1:]
self._output_batch_dimension_transform = _remove_batch_dim
_halo = _halo[1:]
Expand All @@ -89,29 +71,34 @@ def __init__(
self.halo = list(zip(self.output_axes, _halo))

self.model = get_instance(pybio_model)
self.model.to(self.devices[0])
if spec.framework == "pytorch":
self.devices = [torch.device(d) for d in devices]
self.model.to(self.devices[0])
assert isinstance(self.model, torch.nn.Module)
if spec.prediction.weights is not None:
state = torch.load(spec.prediction.weights.source, map_location=self.devices[0])
self.model.load_state_dict(state)
# elif spec.framework == "tensorflow":
# import tensorflow as tf
# self.devices = []
# tf_model = tf.keras.models.load_model(spec.prediction.weights.source)
# self.model.set_model(tf_model)
else:
raise NotImplementedError

self._prediction_preprocess = make_concatenated_apply([get_instance(tf) for tf in spec.prediction.preprocess])
self._prediction_postprocess = make_concatenated_apply([get_instance(tf) for tf in spec.prediction.postprocess])
# inference_engine = ignite.engine.Engine(self._inference_step_function)
# .add_event_handler(Events.STARTED, self.prepare_engine)
# .add_event_handler(Events.COMPLETED, self.log_compute_time)

# def _validation_step_function(self) -> ValidationOutput:
# return ValidationOutput()
#
#
# def _training_step_function(self) -> TrainingOutput:
# return TrainingOutput()
@property
def max_num_iterations(self) -> int:
return self._max_num_iterations

@property
def iteration_count(self) -> int:
return self._iteration_count

def forward(self, batch) -> List[Any]:
batch = torch.from_numpy(batch)
with torch.no_grad():
batch = self._input_batch_dimension_transform(batch)
batch = self._prediction_preprocess(batch)
Expand All @@ -120,10 +107,14 @@ def forward(self, batch) -> List[Any]:
batch = self._prediction_postprocess(batch)
batch = self._output_batch_dimension_transform(batch)
assert all([bs > 0 for bs in batch[0].shape]), batch[0].shape
return batch[0]
result = batch[0]
if isinstance(result, torch.Tensor):
return result.detach().cpu().numpy()
else:
return result

def set_max_num_iterations(self, max_num_iterations: int) -> None:
self.max_num_iterations = max_num_iterations
self._max_num_iterations = max_num_iterations

def set_break_callback(self, cb):
return NotImplementedError
Expand Down
Loading

0 comments on commit 7dbf87f

Please sign in to comment.