Skip to content

Commit

Permalink
yolov8n pose tutorial - Fix wrong representative dataloader for GPTQ (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Idan-BenAmi authored Sep 2, 2024
1 parent 950ae8f commit 598caeb
Showing 1 changed file with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
" annotation_file=REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n",
" preprocess=yolov8_preprocess_chw_transpose)\n",
"\n",
"representative_dataset_gen = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)\n",
"representative_dataloader = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)\n",
"\n",
"# Define representative dataset generator\n",
"def get_representative_dataset(n_iter: int, dataset_loader: Iterator[Tuple]):\n",
Expand All @@ -195,7 +195,7 @@
"\n",
"# Get representative dataset generator\n",
"representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n",
" dataset_loader=representative_dataset_gen)\n",
" dataset_loader=representative_dataloader)\n",
"\n",
"# Set IMX500-v1 TPC\n",
"tpc = mct.get_target_platform_capabilities(fw_name=\"pytorch\",\n",
Expand Down Expand Up @@ -269,9 +269,10 @@
"execution_count": null,
"outputs": [],
"source": [
"!wget -nc http://images.cocodataset.org/zips/train2017.zip\n",
"!unzip -q -o train2017.zip -d ./coco\n",
"!echo Done loading train2017 images\n",
"if not os.path.isdir('coco/train2017'):\n",
" !wget -nc http://images.cocodataset.org/zips/train2017.zip\n",
" !unzip -q -o train2017.zip -d ./coco\n",
" !echo Done loading train2017 images\n",
"\n",
"GPTQ_REPRESENTATIVE_DATASET_FOLDER = './coco/train2017/'\n",
"GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE = './coco/annotations/person_keypoints_train2017.json'\n",
Expand All @@ -283,7 +284,11 @@
" annotation_file=GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n",
" preprocess=yolov8_preprocess_chw_transpose)\n",
"\n",
"representative_dataset_gen = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)"
"representative_dataloader = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)\n",
"\n",
"# Get representative dataset generator\n",
"representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n",
" dataset_loader=representative_dataloader)"
],
"metadata": {
"collapsed": false,
Expand Down

0 comments on commit 598caeb

Please sign in to comment.