This project demonstrates how to track a ball in a video using a custom YOLO object detection framework trained for ball detection and interpolation to handle areas where the tracking fails.
- Training: Train your YOLO model for optimal ball detection in your specific scenario using Roboflow.
- YOLO Object Detection: Use the trained YOLO model for real-time and accurate ball detection within video frames.
- Interpolation: Handles missing ball detections (e.g. due to occlusion) by interpolating the ball's position based on previous and subsequent detections.
- Visualization: Draws bounding boxes around detected balls and displays frame numbers on the output video for analysis.
Output_initial.mp4
Output without interpolation
This output is unable to track the ball during some of the instances like when the ball crosses the net. To solve this we created the interpolation function, that tracks the ball during this instances by interpolating the current position using the previous and next detected positions. Also the value we see during tracking around the bounding box is the confidence level (between 0 to 1) which shows how confident the model is about the detection. It can also be considered as the probability of prediction.
-
Install necessary libraries:
pip install ultralytics opencv-python pandas roboflow
-
Create a Roboflow Account: We need this to manage our computer vision datasets and training.
project = Roboflow(api_key="your_API_key").workspace("viren-dhanwani").project("tennis-ball-detection")
version = project.version(6)
dataset = version.download("yolov8")
-
Connect your Roboflow project: Enter your Roboflow API key by replacing
"your_API_key"
(available in your account settings) -
Download the dataset: The code downloads the dataset version 6 from the Roboflow project.
-
Train the YOLO model: Training the model using the downloaded dataset and necessary:
from ultralytics import YOLO
!yolo task=detect mode=train model=yolov8x.pt data=tennis-ball-detection-6/data.yaml epochs=150 imgsz=640
Place the model file in the models directory.
-
Prepare the input video: Place your input video file in the input_files directory.
-
Run the script: Execute the predict.py script.
-
Output: The processed video with bounding boxes and frame numbers will be saved in the saved_outputs directory.
The program reads the video and extracts individual frames using function read_video
. It creates a BallTracker
object with the YOLO model path for ball detection. Each frame is passed to the BallTracker
for object detection using detect_frames
. The detected ball's bounding box coordinates are stored for each frame. The function interpolate_ball_positions
handles frames where the ball might be missed by the custom YOLO model. Bounding boxes and frame numbers are visualized on each frame by draw_bboxes
. Finally, save_video
processes the video with visualizations and saves it as a new file.
Output_final.mp4
Final output video
Contributions are welcome! Feel free to fork this repository, make improvements, and submit pull requests.
This project is licensed under the MIT License.