Skip to content

Commit

Permalink
update notebook for running in colab
Browse files Browse the repository at this point in the history
  • Loading branch information
elad cohen committed Oct 23, 2023
1 parent 35693e8 commit 3b0f66b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 60 deletions.
76 changes: 50 additions & 26 deletions tutorials/notebooks/example_keras_effdet_lite0.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
"execution_count": null,
"outputs": [],
"source": [
"!pip install tensorflow\n",
"!pip install model-compression-toolkit\n",
"!pip install torch\n",
"!pip install torchvision\n",
"!pip install timm\n",
"!pip install effdet\n",
"!git clone https://github.com/sony/model_optimization/tree/main/tutorials/resources"
"!pip install -q tensorflow\n",
"!pip install -q model-compression-toolkit\n",
"!pip install -q torch\n",
"!pip install -q torchvision\n",
"!pip install -q timm\n",
"!pip install -q effdet\n",
"!pip install -q sony-custom-layers\n",
"!git clone -b add_ported_effdet_keras_tutorial https://github.com/sony/model_optimization.git local_mct"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -72,14 +73,43 @@
"from effdet import create_dataset, create_loader, create_evaluator\n",
"from effdet.data import resolve_input_config\n",
"import model_compression_toolkit as mct\n",
"from resources.efficientdet import EfficientDetKeras, TorchWrapper\n",
"from resources.utils import load_state_dict"
"import sys\n",
"sys.path.insert(0,\"/content/local_mct\")\n",
"from tutorials.resources.efficientdet import EfficientDetKeras, TorchWrapper\n",
"from tutorials.resources.utils import load_state_dict"
],
"metadata": {
"collapsed": false
},
"id": "38e460c939d89482"
},
{
"cell_type": "markdown",
"source": [
"### Load COCO evaluation set"
],
"metadata": {
"collapsed": false
},
"id": "f75abdac7950c038"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!wget -nc http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n",
"!unzip -q -o annotations_trainval2017.zip -d /content/coco\n",
"!echo Done loading annotations\n",
"!wget -nc http://images.cocodataset.org/zips/val2017.zip\n",
"!unzip -q -o val2017.zip -d /content/coco\n",
"!echo Done loading val2017 images"
],
"metadata": {
"collapsed": false
},
"id": "1bf50c7706331ba8"
},
{
"cell_type": "markdown",
"source": [
Expand All @@ -96,7 +126,7 @@
"outputs": [],
"source": [
"def get_coco_dataloader(batch_size=16, split='val', config=None):\n",
" root = '/data/projects/swat/datasets_src/COCO'\n",
" root = '/content/coco'\n",
"\n",
" args = dict(interpolation='bilinear', mean=None, std=None, fill_color=None)\n",
" dataset = create_dataset('coco', root, split)\n",
Expand Down Expand Up @@ -164,19 +194,16 @@
"model_name = 'tf_efficientdet_lite0'\n",
"config = get_efficientdet_config(model_name)\n",
"\n",
"merged_outputs = True\n",
"use_custom_layer = True\n",
"pretrained_backbone = False\n",
"model = EfficientDetKeras(config,\n",
" pretrained_backbone=pretrained_backbone\n",
" ).get_model([*config.image_size] + [3],\n",
" merge_outputs=merged_outputs,\n",
" use_custom_layer=use_custom_layer)\n",
" pretrained_backbone=False\n",
" ).get_model([*config.image_size] + [3])\n",
"\n",
"state_dict = torch.hub.load_state_dict_from_url(config.url, progress=False,\n",
" map_location='cpu')\n",
"state_dict_numpy = {k: v.numpy() for k, v in state_dict.items()}\n",
"load_state_dict(model, state_dict_numpy)"
"load_state_dict(model, state_dict_numpy)\n",
"\n",
"model.save('/content/model.keras')"
],
"metadata": {
"collapsed": false
Expand All @@ -198,9 +225,7 @@
"execution_count": null,
"outputs": [],
"source": [
"wrapped_model = TorchWrapper(model,\n",
" merged_outputs=merged_outputs,\n",
" used_custom_layer=use_custom_layer)\n",
"wrapped_model = TorchWrapper(model)\n",
"\n",
"float_map = acc_eval(wrapped_model, batch_size=64, config=config)"
],
Expand All @@ -224,7 +249,7 @@
"execution_count": null,
"outputs": [],
"source": [
"loader, _ = get_coco_dataloader(split='train', config=config)\n",
"loader, _ = get_coco_dataloader(split='val', config=config)\n",
"\n",
"\n",
"def get_representative_dataset(n_iter):\n",
Expand All @@ -239,7 +264,8 @@
"\n",
"\n",
"quant_model, _ = mct.ptq.keras_post_training_quantization_experimental(model,\n",
" get_representative_dataset(20))"
" get_representative_dataset(20))\n",
"quant_model.save('/content/quant_model.keras')"
],
"metadata": {
"collapsed": false
Expand All @@ -261,9 +287,7 @@
"execution_count": null,
"outputs": [],
"source": [
"wrapped_model = TorchWrapper(quant_model,\n",
" merged_outputs=merged_outputs,\n",
" used_custom_layer=use_custom_layer)\n",
"wrapped_model = TorchWrapper(quant_model)\n",
"\n",
"quant_map = acc_eval(wrapped_model, batch_size=64, config=config)\n",
"\n",
Expand Down
54 changes: 20 additions & 34 deletions tutorials/resources/efficientdet/effdet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@


class TorchWrapper(torch.nn.Module):
def __init__(self, model, merged_outputs=False, used_custom_layer=False):
def __init__(self, model):
super(TorchWrapper, self).__init__()
self.model = model
self.merged_outputs = merged_outputs
self.used_custom_layer = used_custom_layer

@property
def config(self):
Expand All @@ -52,22 +50,11 @@ def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
device = x.device
keras_input = x.detach().cpu().numpy().transpose((0, 2, 3, 1))
outputs = self.model(keras_input)
if self.merged_outputs:
if self.used_custom_layer:
outs = [torch.Tensor(o.numpy()).to(device) for o in outputs]
outs[0] = outs[0][:, :, [1, 0, 3, 2]] # reorder (y, x, y2, x2) to (x, y, x2, y2)
outs[0] = outs[0] * img_info['img_scale'].view((-1, 1, 1)) # scale to original image size
return torch.cat([outs[0], outs[1].unsqueeze(2), outs[2].unsqueeze(2)+1], 2)
else:
class_out, box_out = outputs
class_out = torch.Tensor(class_out.numpy()).to(device)
box_out = torch.Tensor(box_out.numpy()).to(device)
return [class_out, box_out]
else:
class_out, box_out = outputs
class_out = [torch.Tensor(c.numpy().transpose((0, 3, 1, 2))).to(device) for c in class_out]
box_out = [torch.Tensor(b.numpy().transpose((0, 3, 1, 2))).to(device) for b in box_out]
return [class_out, box_out]

outs = [torch.Tensor(o.numpy()).to(device) for o in outputs]
outs[0] = outs[0][:, :, [1, 0, 3, 2]] # reorder (y, x, y2, x2) to (x, y, x2, y2)
outs[0] = outs[0] * img_info['img_scale'].view((-1, 1, 1)) # scale to original image size
return torch.cat([outs[0], outs[1].unsqueeze(2), outs[2].unsqueeze(2) + 1], 2)


def get_act_layer(act_type):
Expand Down Expand Up @@ -680,25 +667,24 @@ def toggle_head_bn_level_first(self):
self.class_net.toggle_bn_level_first()
self.box_net.toggle_bn_level_first()

def get_model(self, input_shape, merge_outputs=False, use_custom_layer=False):
def get_model(self, input_shape):
_input = tf.keras.layers.Input(shape=input_shape)
x = self.backbone(_input)
x = self.fpn(x)
x_class = self.class_net(x)
x_box = self.box_net(x)
outputs = [x_class, x_box]
if merge_outputs:
x_class = [tf.keras.layers.Reshape((-1, self.config.num_classes))(_x) for _x in x_class]
x_class = tf.keras.layers.Concatenate(axis=1)(x_class)
x_box = [tf.keras.layers.Reshape((-1, 4))(_x) for _x in x_box]
x_box = tf.keras.layers.Concatenate(axis=1)(x_box)
if use_custom_layer:
anchors = tf.constant(Anchors.from_config(self.config).boxes.detach().cpu().numpy())

ssd_pp = SSDPostProcess(anchors, [1, 1, 1, 1], [*self.config.image_size],
ScoreConverter.SIGMOID, score_threshold=0.001, iou_threshold=0.5,
max_detections=self.config.max_det_per_image)
outputs = ssd_pp((x_box, x_class))
else:
outputs = [x_class, x_box]

x_class = [tf.keras.layers.Reshape((-1, self.config.num_classes))(_x) for _x in x_class]
x_class = tf.keras.layers.Concatenate(axis=1)(x_class)
x_box = [tf.keras.layers.Reshape((-1, 4))(_x) for _x in x_box]
x_box = tf.keras.layers.Concatenate(axis=1)(x_box)

anchors = tf.constant(Anchors.from_config(self.config).boxes.detach().cpu().numpy())

ssd_pp = SSDPostProcess(anchors, [1, 1, 1, 1], [*self.config.image_size],
ScoreConverter.SIGMOID, score_threshold=0.001, iou_threshold=0.5,
max_detections=self.config.max_det_per_image)
outputs = ssd_pp((x_box, x_class))

return tf.keras.Model(inputs=_input, outputs=outputs)

0 comments on commit 3b0f66b

Please sign in to comment.