Skip to content

Commit

Permalink
Bug fix in the notebook's coco-evaluation functions. (#1106)
Browse files Browse the repository at this point in the history
  • Loading branch information
Idan-BenAmi authored Jun 16, 2024
1 parent f9a1f1b commit 1f62738
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
11 changes: 5 additions & 6 deletions tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import List, Dict, Tuple, Callable, Any
import random
from pycocotools import mask as mask_utils
import torch
from tqdm import tqdm

from ..models_pytorch.yolov8.yolov8_postprocess import scale_boxes, scale_coords
Expand Down Expand Up @@ -178,7 +177,7 @@ def format_results(self, outputs: List, img_ids: List, orig_img_dims: List, outp
'score': scores.tolist()[ind] if isinstance(scores.tolist(), list) else scores.tolist()
})

return detections
return detections

def load_and_preprocess_image(image_path: str, preprocess: Callable) -> np.ndarray:
"""
Expand Down Expand Up @@ -506,12 +505,13 @@ def evaluate_seg_model(annotation_file, results_file):
coco_eval.summarize()


def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_limit=800, output_file='results.json',iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55):
def evaluate_yolov8_segmentation(model, model_predict_func, data_dir, data_type='val2017', img_ids_limit=800, output_file='results.json',iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55):
"""
Evaluate YOLOv8 model for instance segmentation on COCO dataset.
Parameters:
- model: The YOLOv8 model to be evaluated.
- model_predict_func: A function to execute the model preidction
- data_dir: The directory containing the COCO dataset.
- data_type: The type of dataset to evaluate against (default is 'val2017').
- img_ids_limit: The maximum number of images to evaluate (default is 800).
Expand All @@ -535,11 +535,10 @@ def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_l

# Preprocess the image
input_img = load_and_preprocess_image(image_path, yolov8_preprocess_chw_transpose).astype('float32')
input_tensor = torch.from_numpy(input_img).unsqueeze(0) # Add batch dimension

# Run the model
with torch.no_grad():
output = model(input_tensor)
output = model_predict_func(model, input_img)

#run post processing (nms)
boxes, scores, classes, masks = postprocess_yolov8_inst_seg(outputs=output, conf=conf, iou_thres=iou_thresh, max_out_dets=max_dets)

Expand Down
22 changes: 22 additions & 0 deletions tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,28 @@ def keypoints_model_predict(model: Any, inputs: np.ndarray) -> List:

return postprocess_yolov8_keypoints(output_np)

def seg_model_predict(model: Any,
inputs: np.ndarray) -> List:
"""
Perform inference using the provided PyTorch model on the given inputs.
This function handles moving the inputs to the appropriate torch data type and format,
and returns the outputs.
Args:
model (Any): The PyTorch model used for inference.
inputs (np.ndarray): Input data to perform inference on.
Returns:
List: List containing tensors of predictions.
"""
input_tensor = torch.from_numpy(inputs).unsqueeze(0) # Add batch dimension

# Run the model
with torch.no_grad():
outputs = model(input_tensor)

return outputs

def yolov8_pytorch(model_yaml: str) -> (nn.Module, Dict):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,9 @@
},
"outputs": [],
"source": [
"from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import seg_model_predict\n",
"from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n",
"evaluate_yolov8_segmentation(model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
"evaluate_yolov8_segmentation(model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
]
},
{
Expand All @@ -442,7 +443,7 @@
"outputs": [],
"source": [
"from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n",
"evaluate_yolov8_segmentation(quant_model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
"evaluate_yolov8_segmentation(quant_model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
]
},
{
Expand All @@ -467,7 +468,7 @@
"outputs": [],
"source": [
"from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n",
"evaluate_yolov8_segmentation(gptq_quant_model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_g_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
"evaluate_yolov8_segmentation(gptq_quant_model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_g_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
]
},
{
Expand Down

0 comments on commit 1f62738

Please sign in to comment.