diff --git a/tutorials/notebooks/example_keras_effdet_lite0.ipynb b/tutorials/notebooks/example_keras_effdet_lite0.ipynb index 776cec1be..44ff5cbad 100644 --- a/tutorials/notebooks/example_keras_effdet_lite0.ipynb +++ b/tutorials/notebooks/example_keras_effdet_lite0.ipynb @@ -37,13 +37,14 @@ "execution_count": null, "outputs": [], "source": [ - "!pip install tensorflow\n", - "!pip install model-compression-toolkit\n", - "!pip install torch\n", - "!pip install torchvision\n", - "!pip install timm\n", - "!pip install effdet\n", - "!git clone https://github.com/sony/model_optimization/tree/main/tutorials/resources" + "!pip install -q tensorflow\n", + "!pip install -q model-compression-toolkit\n", + "!pip install -q torch\n", + "!pip install -q torchvision\n", + "!pip install -q timm\n", + "!pip install -q effdet\n", + "!pip install -q sony-custom-layers\n", + "!git clone -b add_ported_effdet_keras_tutorial https://github.com/sony/model_optimization.git local_mct" ], "metadata": { "collapsed": false @@ -72,14 +73,43 @@ "from effdet import create_dataset, create_loader, create_evaluator\n", "from effdet.data import resolve_input_config\n", "import model_compression_toolkit as mct\n", - "from resources.efficientdet import EfficientDetKeras, TorchWrapper\n", - "from resources.utils import load_state_dict" + "import sys\n", + "sys.path.insert(0,\"/content/local_mct\")\n", + "from tutorials.resources.efficientdet import EfficientDetKeras, TorchWrapper\n", + "from tutorials.resources.utils import load_state_dict" ], "metadata": { "collapsed": false }, "id": "38e460c939d89482" }, + { + "cell_type": "markdown", + "source": [ + "### Load COCO evaluation set" + ], + "metadata": { + "collapsed": false + }, + "id": "f75abdac7950c038" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!wget -nc http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", + "!unzip -q -o annotations_trainval2017.zip -d /content/coco\n", + "!echo Done loading annotations\n", + "!wget -nc http://images.cocodataset.org/zips/val2017.zip\n", + "!unzip -q -o val2017.zip -d /content/coco\n", + "!echo Done loading val2017 images" + ], + "metadata": { + "collapsed": false + }, + "id": "1bf50c7706331ba8" + }, { "cell_type": "markdown", "source": [ @@ -96,7 +126,7 @@ "outputs": [], "source": [ "def get_coco_dataloader(batch_size=16, split='val', config=None):\n", - " root = '/data/projects/swat/datasets_src/COCO'\n", + " root = '/content/coco'\n", "\n", " args = dict(interpolation='bilinear', mean=None, std=None, fill_color=None)\n", " dataset = create_dataset('coco', root, split)\n", @@ -164,19 +194,16 @@ "model_name = 'tf_efficientdet_lite0'\n", "config = get_efficientdet_config(model_name)\n", "\n", - "merged_outputs = True\n", - "use_custom_layer = True\n", - "pretrained_backbone = False\n", "model = EfficientDetKeras(config,\n", - " pretrained_backbone=pretrained_backbone\n", - " ).get_model([*config.image_size] + [3],\n", - " merge_outputs=merged_outputs,\n", - " use_custom_layer=use_custom_layer)\n", + " pretrained_backbone=False\n", + " ).get_model([*config.image_size] + [3])\n", "\n", "state_dict = torch.hub.load_state_dict_from_url(config.url, progress=False,\n", " map_location='cpu')\n", "state_dict_numpy = {k: v.numpy() for k, v in state_dict.items()}\n", - "load_state_dict(model, state_dict_numpy)" + "load_state_dict(model, state_dict_numpy)\n", + "\n", + "model.save('/content/model.keras')" ], "metadata": { "collapsed": false @@ -198,9 +225,7 @@ "execution_count": null, "outputs": [], "source": [ - "wrapped_model = TorchWrapper(model,\n", - " merged_outputs=merged_outputs,\n", - " used_custom_layer=use_custom_layer)\n", + "wrapped_model = TorchWrapper(model)\n", "\n", "float_map = acc_eval(wrapped_model, batch_size=64, config=config)" ], @@ -224,7 +249,7 @@ "execution_count": null, "outputs": [], "source": [ - "loader, _ = get_coco_dataloader(split='train', config=config)\n", + "loader, _ = get_coco_dataloader(split='val', config=config)\n", "\n", "\n", "def get_representative_dataset(n_iter):\n", @@ -239,7 +264,8 @@ "\n", "\n", "quant_model, _ = mct.ptq.keras_post_training_quantization_experimental(model,\n", - " get_representative_dataset(20))" + " get_representative_dataset(20))\n", + "quant_model.save('/content/quant_model.keras')" ], "metadata": { "collapsed": false @@ -261,9 +287,7 @@ "execution_count": null, "outputs": [], "source": [ - "wrapped_model = TorchWrapper(quant_model,\n", - " merged_outputs=merged_outputs,\n", - " used_custom_layer=use_custom_layer)\n", + "wrapped_model = TorchWrapper(quant_model)\n", "\n", "quant_map = acc_eval(wrapped_model, batch_size=64, config=config)\n", "\n", diff --git a/tutorials/resources/efficientdet/effdet_keras.py b/tutorials/resources/efficientdet/effdet_keras.py index da6ef128d..e6c3f910a 100644 --- a/tutorials/resources/efficientdet/effdet_keras.py +++ b/tutorials/resources/efficientdet/effdet_keras.py @@ -38,11 +38,9 @@ class TorchWrapper(torch.nn.Module): - def __init__(self, model, merged_outputs=False, used_custom_layer=False): + def __init__(self, model): super(TorchWrapper, self).__init__() self.model = model - self.merged_outputs = merged_outputs - self.used_custom_layer = used_custom_layer @property def config(self): @@ -52,22 +50,11 @@ def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None): device = x.device keras_input = x.detach().cpu().numpy().transpose((0, 2, 3, 1)) outputs = self.model(keras_input) - if self.merged_outputs: - if self.used_custom_layer: - outs = [torch.Tensor(o.numpy()).to(device) for o in outputs] - outs[0] = outs[0][:, :, [1, 0, 3, 2]] # reorder (y, x, y2, x2) to (x, y, x2, y2) - outs[0] = outs[0] * img_info['img_scale'].view((-1, 1, 1)) # scale to original image size - return torch.cat([outs[0], outs[1].unsqueeze(2), outs[2].unsqueeze(2)+1], 2) - else: - class_out, box_out = outputs - class_out = torch.Tensor(class_out.numpy()).to(device) - box_out = torch.Tensor(box_out.numpy()).to(device) - return [class_out, box_out] - else: - class_out, box_out = outputs - class_out = [torch.Tensor(c.numpy().transpose((0, 3, 1, 2))).to(device) for c in class_out] - box_out = [torch.Tensor(b.numpy().transpose((0, 3, 1, 2))).to(device) for b in box_out] - return [class_out, box_out] + + outs = [torch.Tensor(o.numpy()).to(device) for o in outputs] + outs[0] = outs[0][:, :, [1, 0, 3, 2]] # reorder (y, x, y2, x2) to (x, y, x2, y2) + outs[0] = outs[0] * img_info['img_scale'].view((-1, 1, 1)) # scale to original image size + return torch.cat([outs[0], outs[1].unsqueeze(2), outs[2].unsqueeze(2) + 1], 2) def get_act_layer(act_type): @@ -680,25 +667,24 @@ def toggle_head_bn_level_first(self): self.class_net.toggle_bn_level_first() self.box_net.toggle_bn_level_first() - def get_model(self, input_shape, merge_outputs=False, use_custom_layer=False): + def get_model(self, input_shape): _input = tf.keras.layers.Input(shape=input_shape) x = self.backbone(_input) x = self.fpn(x) x_class = self.class_net(x) x_box = self.box_net(x) outputs = [x_class, x_box] - if merge_outputs: - x_class = [tf.keras.layers.Reshape((-1, self.config.num_classes))(_x) for _x in x_class] - x_class = tf.keras.layers.Concatenate(axis=1)(x_class) - x_box = [tf.keras.layers.Reshape((-1, 4))(_x) for _x in x_box] - x_box = tf.keras.layers.Concatenate(axis=1)(x_box) - if use_custom_layer: - anchors = tf.constant(Anchors.from_config(self.config).boxes.detach().cpu().numpy()) - - ssd_pp = SSDPostProcess(anchors, [1, 1, 1, 1], [*self.config.image_size], - ScoreConverter.SIGMOID, score_threshold=0.001, iou_threshold=0.5, - max_detections=self.config.max_det_per_image) - outputs = ssd_pp((x_box, x_class)) - else: - outputs = [x_class, x_box] + + x_class = [tf.keras.layers.Reshape((-1, self.config.num_classes))(_x) for _x in x_class] + x_class = tf.keras.layers.Concatenate(axis=1)(x_class) + x_box = [tf.keras.layers.Reshape((-1, 4))(_x) for _x in x_box] + x_box = tf.keras.layers.Concatenate(axis=1)(x_box) + + anchors = tf.constant(Anchors.from_config(self.config).boxes.detach().cpu().numpy()) + + ssd_pp = SSDPostProcess(anchors, [1, 1, 1, 1], [*self.config.image_size], + ScoreConverter.SIGMOID, score_threshold=0.001, iou_threshold=0.5, + max_detections=self.config.max_det_per_image) + outputs = ssd_pp((x_box, x_class)) + return tf.keras.Model(inputs=_input, outputs=outputs)