Skip to content

Commit

Permalink
Merge pull request #1 from shamspias/feat/export/options
Browse files Browse the repository at this point in the history
Feat export option for roboflow and CVAT
  • Loading branch information
shamspias authored Sep 4, 2024
2 parents 1e57fc8 + 7415851 commit 8373d1e
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 71 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ cython_debug/

*.pt

outputs/*
outputs/*
object_class/*
78 changes: 10 additions & 68 deletions app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,30 @@ class VideoFrameExtractor:
frame_rate (float): Desired frame rate to extract images.
output_dir (str): Base directory to store extracted images and annotations.
model_path (str): Path to the YOLO model for object detection.
class_config_path (str): Path to the class configuration file.
output_format (object): Format handler for saving annotations.
"""

def __init__(self, video_path, frame_rate, output_dir, model_path, class_config_path):
def __init__(self, video_path, frame_rate, output_dir, model_path, class_config_path, output_format):
self.video_path = video_path
self.frame_rate = frame_rate
self.output_dir = os.path.join(output_dir, 'train')
self.image_dir = os.path.join(self.output_dir, 'images')
self.label_dir = os.path.join(self.output_dir, 'labels')
self.output_dir = output_dir
self.yolo_model = YOLO(os.path.join('models', model_path))
# self.supported_classes = ['person', 'car', 'truck', 'tank']

# Load classes from YAML
self.output_format = output_format
self.supported_classes = self.load_classes(class_config_path)

# Ensure necessary directories exist
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.label_dir, exist_ok=True)

# Create metadata for training
self._create_data_yaml()

def load_classes(self, config_path):
"""
Loads object classes from a YAML configuration file.
Parameters:
config_path (str): Path to the class configuration file.
Returns:
list: A list of class names.
"""
with open(config_path, 'r') as file:
class_data = yaml.safe_load(file)
return [cls['name'] for cls in class_data['classes']]

def _create_data_yaml(self):
"""
Creates a YAML file to store metadata about the training dataset.
"""
data = {
'train': os.path.abspath(self.image_dir),
'nc': len(self.supported_classes),
'names': self.supported_classes
}
with open(os.path.join(self.output_dir, '..', 'data.yaml'), 'w') as file:
yaml.dump(data, file)

def extract_frames(self, model_confidence):
"""
Extracts frames from the video file at the specified frame rate and saves them in the image directory.
model_confidence (float): Model confidence that help to annotated
Extracts frames from the video file at the specified frame rate, annotates them using the YOLO model,
and saves using the specified format.
"""
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
Expand All @@ -81,47 +53,17 @@ def extract_frames(self, model_confidence):

if frame_count % frame_interval == 0:
frame_filename = f"{self._get_video_basename()}_image{frame_count}.jpg"
frame_path = os.path.join(self.image_dir, frame_filename)
frame_path = os.path.join(self.output_dir, 'images', frame_filename)
cv2.imwrite(frame_path, frame)
self._annotate_frame(frame, frame_path, frame_filename, model_confidence)
results = self.yolo_model.predict(frame, conf=model_confidence)
self.output_format.save_annotations(frame, frame_path, frame_filename, results, self.supported_classes)

frame_count += 1

cap.release()

def _annotate_frame(self, frame, frame_path, frame_filename, model_conf):
"""
Annotates the frame using the YOLO model and saves the annotation to a file.
Parameters:
frame (np.array): The frame to be annotated.
frame_path (str): Path where the frame image is saved.
frame_filename (str): Filename of the frame image.
"""
results = self.yolo_model.predict(frame, conf=model_conf)
annotation_filename = frame_filename.replace('.jpg', '.txt')
annotation_path = os.path.join(self.label_dir, annotation_filename)
img_height, img_width = frame.shape[:2]

with open(annotation_path, 'w') as f:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
if self.supported_classes[class_id] in self.supported_classes:
confidence = box.conf[0]
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

def _get_video_basename(self):
"""
Extracts the basename of the video file without its extension.
Returns:
str: The basename of the video file.
"""
return os.path.splitext(os.path.basename(self.video_path))[0]
19 changes: 17 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import streamlit as st
from config import Config
from extractor import VideoFrameExtractor
from formats.roboflow_format import RoboflowFormat
from formats.cvat_format import CVATFormat

# Import other formats if available

config = Config()

Expand All @@ -22,6 +26,10 @@
frame_rate = st.number_input("Frame rate", value=config.default_frame_rate)
model_confidence = st.number_input("Model Confidence", value=0.1)

# Allow users to choose the output format
format_options = {'Roboflow': RoboflowFormat, 'CVAT': CVATFormat} # Add more formats to this dictionary
format_selection = st.selectbox("Choose output format:", list(format_options.keys()))

if st.button('Extract Frames'):
if uploaded_file is not None:
# Create temp directory if it does not exist
Expand All @@ -44,11 +52,18 @@
specific_output_dir = os.path.join(output_dir, unique_filename)
os.makedirs(specific_output_dir, exist_ok=True)

# Extract frames using the VideoFrameExtractor
# Instantiate the selected output format
output_format_instance = format_options[format_selection](specific_output_dir)

# Extract frames using the VideoFrameExtractor with the chosen format
try:
extractor = VideoFrameExtractor(video_path, frame_rate, specific_output_dir, model_selection,
class_config_path)
class_config_path, output_format_instance)
extractor.extract_frames(model_confidence)

if format_selection == "CVAT": # If CVAT export then it will save as zip format
output_format_instance.zip_and_cleanup()

st.success('Extraction Completed!')
# Delete the temporary video file after successful extraction
os.remove(video_path)
Expand Down
Empty file added formats/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions formats/base_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class BaseFormat:
def __init__(self, output_dir):
self.output_dir = output_dir
self.image_dir = None
self.label_dir = None

def ensure_directories(self):
"""
Ensures that necessary directories are created.
"""
raise NotImplementedError("Subclasses should implement this method.")

def annotate_frame(self, frame, frame_path, frame_filename, model_conf, supported_classes):
"""
Annotates the frame using the model output and saves the annotation in a format-specific manner.
Parameters:
frame (np.array): The frame to be annotated.
frame_path (str): Path where the frame image is saved.
frame_filename (str): Filename of the frame image.
model_conf (float): Model confidence threshold for annotations.
supported_classes (list): List of supported class names.
"""
raise NotImplementedError("Subclasses should implement this method.")

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
""" Method to save annotations; implemented in subclasses. """
raise NotImplementedError("Subclasses should implement this method.")
97 changes: 97 additions & 0 deletions formats/cvat_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import cv2
import zipfile
from formats.base_format import BaseFormat


class CVATFormat(BaseFormat):
"""
Class to handle the CVAT format for image annotations.
Attributes:
output_dir (str): Base directory for all output.
"""

def __init__(self, output_dir):
super().__init__(output_dir)
self.data_dir = os.path.join(output_dir, 'data')
self.image_dir = os.path.join(self.data_dir, 'obj_train_data')
os.makedirs(self.image_dir, exist_ok=True)

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
"""
Saves annotations and images in CVAT-compatible format directly in obj_train_data.
"""
frame_filename_png = frame_filename.replace('.jpg', '.png')
image_path = os.path.join(self.image_dir, frame_filename_png)
cv2.imwrite(image_path, frame) # Save the frame image

annotation_filename = frame_filename_png.replace('.png', '.txt')
annotation_path = os.path.join(self.image_dir, annotation_filename)

with open(annotation_path, 'w') as file:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
if box.xyxy.dim() == 2 and box.xyxy.shape[0] == 1:
class_id = int(box.cls[0])
xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
x_center = ((xmin + xmax) / 2) / frame.shape[1]
y_center = ((ymin + ymax) / 2) / frame.shape[0]
width = (xmax - xmin) / frame.shape[1]
height = (ymax - ymin) / frame.shape[0]
file.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

# After saving all annotations, update metadata files
self.create_metadata_files(supported_classes)

def create_metadata_files(self, supported_classes):
"""
Creates necessary metadata files for CVAT training setup.
"""
obj_names_path = os.path.join(self.data_dir, 'obj.names')
obj_data_path = os.path.join(self.data_dir, 'obj.data')
train_txt_path = os.path.join(self.data_dir, 'train.txt')

# Create obj.names file
with open(obj_names_path, 'w') as f:
for cls in supported_classes:
f.write(f"{cls}\n")

# Create obj.data file
with open(obj_data_path, 'w') as f:
f.write("classes = {}\n".format(len(supported_classes)))
f.write("train = data/train.txt\n")
f.write("names = data/obj.names\n")
f.write("backup = backup/\n")

# Create train.txt file listing all image files
with open(train_txt_path, 'w') as f:
for image_file in os.listdir(self.image_dir):
if image_file.endswith('.png'):
f.write(f"data/obj_train_data/{image_file}\n")

def ensure_directories(self):
"""Ensures all directories are created and ready for use."""
super().ensure_directories() # Ensures base directories are created

def zip_and_cleanup(self):
# Create a zip file and add all the data in the data directory to it.
zip_path = os.path.join(self.output_dir, 'cvat_data.zip')
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(self.data_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, self.data_dir))
for dir in dirs:
dir_path = os.path.join(root, dir)
zipf.write(dir_path, os.path.relpath(dir_path, self.data_dir))

# Clean up the directory by removing all files first, then empty directories.
for root, dirs, files in os.walk(self.data_dir, topdown=False):
for file in files:
os.remove(os.path.join(root, file))
for dir in dirs:
os.rmdir(os.path.join(root, dir))

# Finally, remove the base data directory now that it should be empty.
os.rmdir(self.data_dir)
49 changes: 49 additions & 0 deletions formats/roboflow_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from formats.base_format import BaseFormat
import os
import yaml


class RoboflowFormat(BaseFormat):
def __init__(self, output_dir):
super().__init__(output_dir)
self.image_dir = os.path.join(output_dir, 'images')
self.label_dir = os.path.join(output_dir, 'labels')
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.label_dir, exist_ok=True)

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
"""
Saves the annotations in the Roboflow specified format.
"""
annotation_filename = frame_filename.replace('.jpg', '.txt')
annotation_path = os.path.join(self.label_dir, annotation_filename)
img_height, img_width = frame.shape[:2]

with open(annotation_path, 'w') as f:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
if supported_classes[class_id] in supported_classes:
confidence = box.conf[0]
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

# Generate metadata file if needed
self.create_data_yaml(supported_classes)

def create_data_yaml(self, supported_classes):
"""
Creates a YAML file to store metadata about the training dataset.
"""
data = {
'train': os.path.abspath(self.image_dir),
'nc': len(supported_classes),
'names': supported_classes
}
with open(os.path.join(self.output_dir, 'data.yaml'), 'w') as file:
yaml.dump(data, file)

0 comments on commit 8373d1e

Please sign in to comment.