| import sys | |
| import PIL.Image as Image | |
| from ultralytics import YOLO | |
| import gradio as gr | |
| # Local imports | |
| from src.logger import logging | |
| from src.exception import CustomExceptionHandling | |
| def predict_pose( | |
| img: str, | |
| conf_threshold: float, | |
| iou_threshold: float, | |
| max_detections: int, | |
| model_name: str, | |
| ) -> Image.Image: | |
| """ | |
| Predicts objects in an image using a YOLO model with adjustable confidence and IOU thresholds. | |
| Args: | |
| - img (str or numpy.ndarray): The input image or path to the image file. | |
| - conf_threshold (float): The confidence threshold for object detection. | |
| - iou_threshold (float): The Intersection Over Union (IOU) threshold for non-max suppression. | |
| - max_detections (int): The maximum number of detections allowed. | |
| - model_name (str): The name or path of the YOLO model to be used for prediction. | |
| Returns: | |
| PIL.Image.Image: The image with predicted objects plotted on it. | |
| """ | |
| try: | |
| # Check if image is None | |
| if img is None: | |
| gr.Warning("Please provide an image.") | |
| # Load the YOLO model | |
| model = YOLO(model_name) | |
| # Predict objects in the image | |
| results = model.predict( | |
| source=img, | |
| conf=conf_threshold, | |
| iou=iou_threshold, | |
| max_det=max_detections, | |
| show_labels=True, | |
| show_conf=True, | |
| imgsz=640, | |
| half=True, | |
| device="cpu", | |
| ) | |
| # Plot the predicted objects on the image | |
| for r in results: | |
| im_array = r.plot() | |
| im = Image.fromarray(im_array[..., ::-1]) | |
| # Log the successful prediction | |
| logging.info("Pose estimated successfully.") | |
| # Return the image | |
| return im | |
| # Handle exceptions that may occur during the process | |
| except Exception as e: | |
| # Custom exception handling | |
| raise CustomExceptionHandling(e, sys) from e | |