-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathYolov8Training.py
77 lines (49 loc) · 2.16 KB
/
Yolov8Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import supervision as sv
import cv2
import shutil
import multiprocessing
def main():
# Folder
product_folder = 'Carne'
# Get current working directory
HOME = os.getcwd()
# Get the path of the data
DATA = os.path.join(HOME, 'Data')
IMAGES = os.path.join(DATA, product_folder, 'Images')
# Check if the path exists and image count
image_paths = sv.list_files_with_extensions(
directory=IMAGES,
extensions=["png", "jpg", "bmp"])
print('image count:', len(image_paths))
# Plot set of images
SAMPLE_SIZE = 16
SAMPLE_GRID_SIZE = (4, 4)
SAMPLE_PLOT_SIZE = (16, 16)
titles = [
image_path.stem
for image_path
in image_paths[:SAMPLE_SIZE]]
images = [
cv2.imread(str(image_path))
for image_path
in image_paths[:SAMPLE_SIZE]]
sv.plot_images_grid(images=images, titles=titles, grid_size=SAMPLE_GRID_SIZE, size=SAMPLE_PLOT_SIZE)
# Dataset - a Dataset is a set of auto-labeled data that can be used to train a Target Model. It is the output generated by a Base Model.
ANNOTATIONS_DIRECTORY_PATH = f"{DATA}/dataset/train/labels"
IMAGES_DIRECTORY_PATH = f"{DATA}/dataset/train/images"
DATA_YAML_PATH = os.path.join(DATA, 'dataset', 'data.yaml')
# Target Model - a Target Model is a supervised model that consumes a Dataset and outputs a distilled model that is ready for deployment.
# Target Models are usually small, fast, and fine-tuned to perform a specific task very well (but they don't generalize well beyond the information described in their Dataset).
# Examples of Target Models are YOLOv8 and DETR.
from autodistill_yolov8 import YOLOv8
target_model = YOLOv8("yolov8n.pt")
target_model.train(DATA_YAML_PATH, epochs=250)
# Target Model Evaluation
from IPython.display import Image
# Image(filename=f'{HOME}/runs/detect/train/confusion_matrix.png', width=600)
# Image(filename=f'{HOME}/runs/detect/train/results.png', width=600)
# Image(filename=f'{HOME}/runs/detect/train/val_batch0_pred.jpg', width=600)
if __name__ == '__main__':
multiprocessing.freeze_support()
main()