Source code for heart_library.estimators.object_detection.pytorch

# MIT License
#
# Copyright (C) HEART Authors 2024
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""This module implements a JATIC compatible ART Object Detector."""

import sys
import uuid
from collections.abc import Sequence
from typing import Any, Optional, Union

import numpy as np
import torch
from art.estimators.object_detection import (
    PyTorchDetectionTransformer,
    PyTorchFasterRCNN,
    PyTorchObjectDetector,
    PyTorchYolo,
)
from maite.protocols import ArrayLike
from numpy.typing import NDArray

from heart_library.utils import process_inputs_for_art

COCO_YOLO_LABELS = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]
COCO_FASTER_RCNN_LABELS = COCO_DETR_LABELS = [
    "background",
    "person",
    "bicycle",
    "car",
    "motorbike",
    "aeroplane",
    "bus",
    "train",
    "truck",
    "boat",
    "trafficlight",
    "firehydrant",
    "streetsign",
    "stopsign",
    "parkingmeter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "hat",
    "backpack",
    "umbrella",
    "shoe",
    "eyeglasses",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sportsball",
    "kite",
    "baseballbat",
    "baseballglove",
    "skateboard",
    "surfboard",
    "tennisracket",
    "bottle",
    "plate",
    "wineglass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hotdog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "sofa",
    "pottedplant",
    "bed",
    "mirror",
    "diningtable",
    "window",
    "desk",
    "toilet",
    "door",
    "tvmonitor",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cellphone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "blender",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddybear",
    "hairdrier",
    "toothbrush",
    "hairbrush",
]
SUPPORTED_DETECTORS: dict[str, str] = {
    "yolov3u": "YOLO3 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov3.md",
    "yolov5n": "YOLO5 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov5.md",
    "yolov6n": "YOLO6 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov6.md",
    "yolov8n": "YOLO8 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov8.md",
    "yolov9t": "YOLO9 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov9.md",
    "yolov10n": "YOLO10 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov10.md",
    "yolo11n": "YOLO11 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov11.md",
    "yolo12n": "YOLO12 model. Ref: https://github.com/ultralytics/ultralytics/blob/main/docs/en/models/yolov12.md",
    "fasterrcnn_resnet50_fpn": "Faster R-CNN model. Ref: \
https://pytorch.org/vision/master/models/generated/torchvision.models\
.detection.fasterrcnn_resnet50_fpn.html#\
torchvision.models.detection.fasterrcnn_resnet50_fpn",
    "fasterrcnn_resnet50_fpn_v2": "Faster R-CNN model. Ref: \
https://pytorch.org/vision/master/models/generated/torchvision.models.\
detection.fasterrcnn_resnet50_fpn_v2.html#\
torchvision.models.detection.fasterrcnn_resnet50_fpn_v2",
    "fasterrcnn_mobilenet_v3_large_fpn": "Faster R-CNN model. Ref: \
https://pytorch.org/vision/master/models/generated/torchvision.models.\
detection.fasterrcnn_mobilenet_v3_large_fpn.html#\
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn",
    "fasterrcnn_mobilenet_v3_large_320_fpn": "Faster R-CNN model. Ref: \
https://pytorch.org/vision/master/models/generated/torchvision.models\
.detection.fasterrcnn_mobilenet_v3_large_320_fpn.html#\
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn",
    "detr_resnet50": "Detection Transformer. Ref: https://github.com/facebookresearch/detr",
    "detr_resnet101": "Detection Transformer. Ref: https://github.com/facebookresearch/detr",
    "detr_resnet50_dc5": "Detection Transformer. Ref: https://github.com/facebookresearch/detr",
    "detr_resnet101_dc5": "Detection Transformer. Ref: https://github.com/facebookresearch/detr",
}


[docs] class JaticPyTorchObjectDetectionOutput: """Object Detection Output""" def __init__(self, detection: dict[str, NDArray[np.float32]]) -> None: """JaticPyTorchObjectDetectionOutput initialization. Args: detection (dict[str, NDArray[np.float32]]): Detection data.""" self._boxes = detection["boxes"] self._labels = detection["labels"] self._scores = detection["scores"] @property def boxes(self) -> NDArray[np.float32]: """Return detection bounding boxes Returns: NDArray[np.float32]: The boxes in [y1, x1, y2, x2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. """ return self._boxes @boxes.setter def boxes(self, value: NDArray[np.float32]) -> None: """Update detection bounding boxes Args: value (NDArray[np.float32]): The boxes in [y1, x1, y2, x2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. """ self._boxes = value @property def labels(self) -> NDArray[np.float32]: """Return detection labels Returns: NDArray[np.float32]: The labels for each image. """ return self._labels @labels.setter def labels(self, value: NDArray[np.float32]) -> None: """Update detection labels Args: value (NDArray[np.float32]): The labels for each image. """ self._labels = value @property def scores(self) -> NDArray[np.float32]: """Return detection scores Returns: NDArray[np.float32]: The scores or each prediction. """ return self._scores @scores.setter def scores(self, value: NDArray[np.float32]) -> None: """Update detection scores Args: value (NDArray[np.float32]): The scores or each prediction. """ self._scores = value
[docs] class JaticPyTorchObjectDetector(PyTorchObjectDetector): """JATIC compatible extension of ART core PyTorchObjectDetector Args: PyTorchObjectDetector (PyTorchObjectDetector): ART core PyTorchObjectDetector. Examples -------- We can create a JaticPyTorchObjectDetector and pass in sample data for detection: >>> from heart_library.estimators.object_detection import JaticPyTorchObjectDetector >>> import torch >>> import numpy >>> from datasets import load_dataset >>> from torchvision.transforms import transforms Define the JaticPyTorchObjectDetector, in this case passing in a resnet model: >>> MEAN = [0.485, 0.456, 0.406] >>> STD = [0.229, 0.224, 0.225] >>> preprocessing = (MEAN, STD) >>> detector = JaticPyTorchObjectDetector( ... model_type="detr_resnet50_dc5", ... input_shape=(3, 800, 800), ... clip_values=(0, 1), ... attack_losses=( ... "loss_ce", ... "loss_bbox", ... "loss_giou", ... ), ... device_type="cpu", ... optimizer=torch.nn.CrossEntropyLoss(), ... preprocessing=preprocessing, ... ) Prepare images for detection: >>> data = load_dataset("guydada/quickstart-coco", split="train[20:25]") >>> preprocess = transforms.Compose([transforms.Resize(800), transforms.CenterCrop(800), transforms.ToTensor()]) >>> data = data.map(lambda x: {"image": preprocess(x["image"]), "label": None}) Execute object detection and return JaticPyTorchObjectDetectionOutput: >>> detections = detector(data) """ metadata: dict[str, Any] def __init__( self, model: Union["torch.nn.Module", str] = "", model_type: str = "", metadata_id: Optional[str] = None, **kwargs: Any, # noqa ANN401 ) -> None: """JaticPyTorchObjectDetector initialization. Args: model (Union[torch.nn.Module, str], optional): a loaded model or path to model. Defaults to "". model_type (str, optional): one of supported_detectors e.g. yolov5. Defaults to "". Raises: ValueError: if model type is not one of "yolo", "detr", or "fastercnn". ValueError: if model type is custom specified. Exception: if YOLO model was not successfully loaded. ValueError: if python version < 3.10 when the model is yolov5 ValueError: if model type is not one of "yolo", "detr", or "fastercnn". """ self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())} self.model_type = model_type if "models" in sys.modules: sys.modules.pop("models") if isinstance(model, torch.nn.Module): super().__init__(model=model, **kwargs) # Define object detection framework for torch.nn.module based on provided model_type. self.__handle_torch_nn(model_type, model, **kwargs) elif isinstance(model, str): device: Any = None # Initialize device type and handle missing model type. device = self.__handle_device_model_type(device, model_type, kwargs) loaded_model: Any = None # Load model frameworks based on str input. self.__handle_model_str_inputs(model, model_type, device, loaded_model, **kwargs) def __handle_model_str_inputs( self, model: str, model_type: str, device: Any, # noqa ANN401 loaded_model: Any, # noqa ANN401 **kwargs: Any, # noqa ANN401 ) -> None: """Load model frameworks based on str input. Args: model (str): A loaded model or path to model. model_type (str): Pytorch framework. device (Any): Torch device type to be used. loaded_model (Any): Empty initialization to be wrapped. Raises: ValueError: If model type is not one of yolo, detr, fasterrcnn. """ # YOLO if "yolo" in model_type: self.__handle_yolo_model(model, model_type, **kwargs) # DETR elif "detr" in model_type: self.__handle_detr_model(model, model_type, device, loaded_model, **kwargs) # Faster-RCNN elif "fasterrcnn" in model_type: self.__handle_fasterrcnn_model(model, model_type, device, loaded_model, **kwargs) else: raise ValueError(f"Model type {model_type} is not supported. Try one of {SUPPORTED_DETECTORS}.") def __handle_torch_nn(self, model_type: str, model: torch.nn.Module, **kwargs: Any) -> None: # noqa ANN401 """Define object detection framework for torch.nn.module based on provided model_type. Args: model_type (str): Pytorch framework. model (torch.nn.module): A loaded model or path to model. Raises: ValueError: If model type is not one of yolo, detr, or fastercnn. """ if "yolo" in model_type: self._detector = PyTorchYolo(model, **kwargs) elif "detr" in model_type: self._detector = PyTorchDetectionTransformer(model, **kwargs) else: self._detector = PyTorchObjectDetector(model, **kwargs) def __handle_device_model_type( self, device: Any, # noqa ANN401 model_type: str, kwargs: Any, # noqa ANN401 ) -> Union[torch.device, str]: """Initialize device type and handle missing model type. Args: device (Any): Empty initialization. model_type (str): Pytorch framework. Raises: ValueError: If model type is not specified. Returns: Union[torch.device, str]: Torch device type to be used. """ if "device_type" in kwargs: device = kwargs["device_type"] else: if not torch.cuda.is_available(): device = torch.device("cpu") else: # pragma: no cover cuda_idx = torch.cuda.current_device() device = torch.device(f"cuda:{cuda_idx}") if model_type == "": raise ValueError( f"To use a local model, please specify param: type, \ with one of the supported models: {SUPPORTED_DETECTORS}", ) return device def __handle_yolo_model( self, model: str, model_type: str, **kwargs: Any, # noqa: ANN401 ) -> None: """Load YOLO model. Args: model (str): A loaded model or path to model. model_type (str): Pytorch framework. device (str): Torch device type to be used. loaded_model (Any): Empty initialization to be wrapped. Raises: Exception: If YOLO model was not successfully loaded. """ yolo_module = self.__import_yolo() # Handle weights path extension based on model_type if model_type == "yolo6n": weights_path = f"{model_type}.yml" if model == "" else model else: weights_path = f"{model_type}.pt" if model == "" else model raw_model = yolo_module(weights_path, task="detect") self._detector = PyTorchYolo(raw_model.model, is_ultralytics=True, model_name=model_type, **kwargs) def __import_yolo(self) -> type: """Safely import YOLO modules. Returns: YOLO module Raises: ImportError: If YOLO is not installed. """ try: from ultralytics import YOLO return YOLO except ImportError as e: raise ImportError("The 'YOLO' package is required but not installed.") from e def __handle_detr_model( self, model: str, model_type: str, device: Union[torch.device, str], loaded_model: Any, # noqa ANN401 **kwargs: Any, # noqa ANN401 ) -> None: """Load detr model. Args: model (str): A loaded model or path to model. model_type (str): Pytorch framework. device (Union[torch.device, str]): Torch device type to be used. loaded_model (Any): Empty initialization to be wrapped. """ if model == "": loaded_model = torch.hub.load("facebookresearch/detr", model_type, pretrained=True) else: # pragma: no cover checkpoint = torch.load(model, map_location=device) loaded_model = torch.hub.load("facebookresearch/detr", model_type, pretrained=False) loaded_model.load_state_dict(checkpoint["model"]) self._detector = PyTorchDetectionTransformer(loaded_model, **kwargs) def __handle_fasterrcnn_model( self, model: str, model_type: str, device: Union[torch.device, str], loaded_model: Any, # noqa ANN401 **kwargs: Any, # noqa ANN401 ) -> None: """Load Faster-RCNN model. Args: model (str): A loaded model or path to model. model_type (str): Pytorch framework. device (Union[torch.device, str]): Torch device type to be used. loaded_model (Any): Empty initialization to be wrapped. """ try: from torchvision.models import detection as fasterrcnn except ImportError as e: raise ImportError("The 'torchvision' package is required but is not installed.") from e if model == "": frcnn_detector = getattr(fasterrcnn, model_type) loaded_model = frcnn_detector( pretrained=True, progress=True, num_classes=91, pretrained_backbone=True, ) else: # pragma: no cover checkpoint = torch.load(model, map_location=device) n_classes = checkpoint["roi_heads.box_predictor.cls_score.weight"].shape[0] frcnn_detector = getattr(fasterrcnn, model_type) loaded_model = frcnn_detector( pretrained=False, progress=True, num_classes=n_classes, pretrained_backbone=True, ) loaded_model.load_state_dict(checkpoint) self._detector = PyTorchFasterRCNN(loaded_model, **kwargs) def __getattr__(self, attr: str) -> Any: # noqa ANN401 """Return value of detector attribute. Args: attr (str): string that contain's detector's name. Returns: Any: Model framework. """ return getattr(self._detector, attr) def __call__(self, data: Sequence[ArrayLike]) -> Sequence[JaticPyTorchObjectDetectionOutput]: """Convert JATIC supported data to ART supported data and perform prediction. Args: data (Any): JATIC supported data. Returns: Sequence[JaticPyTorchObjectDetectionOutput]: Predictions in JATIC supported type. """ # convert to ART supported type images, _, _ = process_inputs_for_art(data) # make prediction output = self._detector.predict(images) # convert back to JATIC supported type return [JaticPyTorchObjectDetectionOutput(detection) for detection in output] def _translate_labels(self, labels: list[dict[str, "torch.Tensor"]]) -> Any: # noqa ANN401 """Route to method of instantiated detector. Args: labels (List[dict[str, "torch.Tensor"]]): Object detection labels in format x1y1x2y2 (torchvision). Returns: Any: Object detection labels in format x1y1x2y2 (torchvision). """ return self._detector._translate_labels(labels) # noqa SLF001 def _translate_predictions(self, predictions: Any) -> list[dict[str, np.ndarray]]: # noqa ANN401 """Route to method of instantiated detector. Args: predictions (Any): Object detection predictions in format x1y1x2y2 (torchvision). Returns: List[dict[str, np.ndarray]]: Object detection predictions in format x1y1x2y2 (torchvision). """ return self._detector._translate_predictions(predictions) # noqa SLF001