-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #134 from m-novikov/tf
Initial test for tensoflow model
- Loading branch information
Showing
20 changed files
with
300 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
* text=auto eol=lf | ||
*.zip binary | ||
*.npy binary | ||
*.py eol=lf diff=python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
name: DummyTFModel | ||
description: A dummy tensorflow model for testing | ||
authors: | ||
- ilastik team | ||
cite: | ||
- text: "Ilastik" | ||
doi: https://doi.org | ||
documentation: dummy.md | ||
tags: [tensorflow] | ||
license: MIT | ||
|
||
format_version: 0.1.0 | ||
language: python | ||
framework: tensorflow | ||
|
||
source: dummy.py::TensorflowModelWrapper | ||
|
||
test_input: null # ../test_input.npy | ||
test_output: null # ../test_output.npy | ||
|
||
# TODO double check inputs/outputs | ||
inputs: | ||
- name: input | ||
axes: cyx | ||
data_type: float32 | ||
data_range: [-inf, inf] | ||
shape: [1, 128, 128] | ||
outputs: | ||
- name: output | ||
axes: bcyx | ||
data_type: float32 | ||
data_range: [0, 1] | ||
shape: | ||
reference_input: input # FIXME(m-novikov) ignoring for now | ||
scale: [1, 1, 1] | ||
offset: [0, 0, 0] | ||
#halo: [0, 0, 32, 32] # Should be moved to outputs | ||
|
||
prediction: | ||
weights: | ||
source: ./model | ||
hash: {md5: TODO} | ||
dependencies: conda:./environment.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
class TensorflowModelWrapper: | ||
def __init__(self): | ||
self._model = None | ||
|
||
def set_model(self, model): | ||
self._model = model | ||
|
||
def forward(self, input_): | ||
return self._model.predict(input_) | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self._model.predict(*args, **kwargs) |
Binary file not shown.
Binary file added
BIN
+596 Bytes
tests/data/dummy_tensorflow/model/variables/variables.data-00000-of-00001
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import List | ||
|
||
from pybio.spec import nodes | ||
|
||
from ._base import ModelAdapter | ||
|
||
__all__ = ["ModelAdapter", "create_model_adapter"] | ||
|
||
|
||
def create_model_adapter(*, pybio_model: nodes.Model, devices=List[str]): | ||
spec = pybio_model.spec | ||
if spec.framework == "pytorch": | ||
from ._exemplum import Exemplum | ||
|
||
return Exemplum(pybio_model=pybio_model, devices=devices) | ||
elif spec.framework == "tensorflow": | ||
from ._tensorflow_model_adapter import TensorflowModelAdapter | ||
|
||
return TensorflowModelAdapter(pybio_model=pybio_model, devices=devices) | ||
else: | ||
raise NotImplementedError(f"Unknown framework: {spec.framework}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import abc | ||
from typing import Callable | ||
|
||
|
||
class ModelAdapter(abc.ABC): | ||
@abc.abstractmethod | ||
def forward(self, input_tensor): | ||
... | ||
|
||
@property | ||
@abc.abstractmethod | ||
def max_num_iterations(self) -> int: | ||
... | ||
|
||
@property | ||
@abc.abstractmethod | ||
def iteration_count(self) -> int: | ||
... | ||
|
||
@abc.abstractmethod | ||
def set_break_callback(self, thunk: Callable[[], bool]) -> None: | ||
... | ||
|
||
@abc.abstractmethod | ||
def set_max_num_iterations(self, val: int) -> None: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.