# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (HEART) Authors 2024
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Module implementing varying metrics for assessing model robustness. These fall mainly under two categories:
attack-dependent and attack-independent.
"""
import logging
import uuid
from collections.abc import Sequence
from typing import Any, Optional
import numpy as np
from numpy.typing import NDArray
from heart_library.attacks.attack import JaticAttack
from heart_library.estimators.object_detection import JaticPyTorchObjectDetectionOutput
logger: logging.Logger = logging.getLogger(__name__)
[docs]
class HeartMAPMetric:
"""
Facilitating support for Torchmetric's MAP metric.
Examples
--------
We can define a MAP metric and evaluate it on a JaticPyTorchObjectDetector's performance:
>>> from maite.workflows import evaluate
>>> from heart_library.estimators.object_detection import JaticPyTorchObjectDetector
>>> import torch
>>> import numpy
>>> from datasets import load_dataset
>>> from torchvision.transforms import transforms
>>> from copy import deepcopy
Define the JaticPyTorchObjectDetector, in this case passing in a resnet model:
>>> MEAN = [0.485, 0.456, 0.406]
>>> STD = [0.229, 0.224, 0.225]
>>> preprocessing = (MEAN, STD)
>>> detector = JaticPyTorchObjectDetector(
... model_type="detr_resnet50_dc5",
... input_shape=(3, 800, 800),
... clip_values=(0, 1),
... attack_losses=(
... "loss_ce",
... "loss_bbox",
... "loss_giou",
... ),
... device_type="cpu",
... optimizer=torch.nn.CrossEntropyLoss(),
... preprocessing=preprocessing,
... )
Prepare images for detection:
>>> data = load_dataset("guydada/quickstart-coco", split="train[20:25]")
>>> preprocess = transforms.Compose([transforms.Resize(800), transforms.CenterCrop(800), transforms.ToTensor()])
>>> data = data.map(lambda x: {"image": preprocess(x["image"]), "label": None})
Execute object detection and return JaticPyTorchObjectDetectionOutput:
>>> detections = detector(data)
Define data with detections:
>>> class ImageDataset:
... def __init__(self, images, groundtruth, threshold=0.8):
... self.images = images
... self.groundtruth = groundtruth
... self.threshold = threshold
...
... def __len__(self) -> int:
... return len(self.images)
...
... def __getitem__(self, ind: int) -> tuple[np.ndarray, np.ndarray, dict[str, Any]]:
... image = np.asarray(self.images[ind]["image"]).astype(np.float32)
... filtered_detection = self.groundtruth[ind]
... filtered_detection.boxes = filtered_detection.boxes[filtered_detection.scores > self.threshold]
... filtered_detection.labels = filtered_detection.labels[filtered_detection.scores > self.threshold]
... filtered_detection.scores = filtered_detection.scores[filtered_detection.scores > self.threshold]
... return (image, filtered_detection, None)
>>> data_with_detections = ImageDataset(data, deepcopy(detections), threshold=0.9)
Set the MAP parameters and evaluate:
>>> map_args = {
... "box_format": "xyxy",
... "iou_type": "bbox",
... "iou_thresholds": [0.5],
... "rec_thresholds": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
... "max_detection_thresholds": [1, 10, 100],
... "class_metrics": False,
... "extended_summary": False,
... "average": "macro",
... }
>>> metric = HeartMAPMetric(**map_args)
>>> results, _, metadata = evaluate(
... model=detector,
... dataset=data_with_detections,
... metric=metric,
... )
>>> results["map"]
tensor(1.)
"""
metadata: dict[str, Any]
def __init__(self, metadata_id: Optional[str] = None, **kwargs: Any) -> None: # noqa ANN401
"""HeartMAPMetric initialization.
Args:
**kwargs: arguments passed to Torchmetric's MAP metric."""
from torchmetrics.detection.mean_ap import MeanAveragePrecision
self._metric = MeanAveragePrecision(**kwargs)
self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())}
[docs]
def reset(self) -> None:
"""Clear contents of current metric's cache of predictions and targets.
Returns:
_type_: None.
"""
return self._metric.reset()
[docs]
def update(
self,
preds_batch: Sequence[JaticPyTorchObjectDetectionOutput],
targets_batch: Sequence[JaticPyTorchObjectDetectionOutput],
) -> None:
"""Add predictions and targets to metric's cache for later calculation.
Args:
preds_batch (`Sequence[JaticPyTorchObjectDetectionOutput]`): predictions in ObjectDetectionTarget format.
targets_batch (`Sequence[JaticPyTorchObjectDetectionOutput]`):
groundtruth targets in ObjectDetectionTarget format.
"""
import torch
# Torchmetrics mAP expects list of dicts with one dict per image; each dict with:
# - boxes: Tensor w/shape (num_boxes, 4)
# - scores: Tensor w/shape (num_boxes)
# - labels: Tensor w/shape (num_boxes)
# iterate over images in batch
for preds, targets in zip(preds_batch, targets_batch):
# put predictions and labels in dictionaries
# tensor bridge to PyTorch tensors (required by Torchmetrics)
preds_dict = {
"boxes": torch.as_tensor(preds.boxes),
"scores": torch.as_tensor(preds.scores),
"labels": torch.as_tensor(preds.labels),
}
targets_dict = {
"boxes": torch.as_tensor(targets.boxes),
"scores": torch.as_tensor(targets.scores),
"labels": torch.as_tensor(targets.labels),
}
self._metric.update([preds_dict], [targets_dict])
[docs]
def compute(self) -> dict[str, Any]:
"""Compute MAP scores.
Returns:
dict[str, Any]: Final value from the state of the metric.
"""
return self._metric.compute()
[docs]
class HeartAccuracyMetric:
"""Facilitating support for Torchmetric's Accuracy metric."""
metadata: dict[str, Any]
def __init__(self, is_logits: bool = True, metadata_id: Optional[str] = None, **kwargs: Any) -> None: # noqa ANN401
"""HeartAccuracyMetric initialization.
Args:
is_logits (bool, optional): bool indicating if predictions are logits. Defaults to True.
**kwargs: arguments passed to Torchmetric's Accuracy metric.
"""
from torchmetrics.classification import (
BinaryAccuracy,
MulticlassAccuracy,
MultilabelAccuracy,
)
self.is_logits = is_logits
self._metric: BinaryAccuracy | MulticlassAccuracy | MultilabelAccuracy
self._task = kwargs.pop("task")
if self._task == "binary":
self._metric = BinaryAccuracy(**kwargs)
elif self._task == "multiclass":
self._metric = MulticlassAccuracy(**kwargs)
elif self._task == "multilabel":
self._metric = MultilabelAccuracy(**kwargs)
self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())}
[docs]
def reset(self) -> None:
"""Clear contents of current metric's cache of predictions and targets.
Returns:
_type_: None.
"""
return self._metric.reset()
[docs]
def update(self, preds_batch: Sequence[NDArray[np.float32]], targets_batch: Sequence[NDArray[np.float32]]) -> None:
"""Add predictions and targets to metric's cache for later calculation.
Args:
preds_batch (Sequence[NDArray[np.float32]]): edictions in numpy array format.
targets_batch (Sequence[NDArray[np.float32]]): groundtruth targets in numpy array format.
"""
import torch
if self.is_logits:
preds = torch.as_tensor(np.argmax(np.asarray(preds_batch), axis=1).ravel())
else:
preds = torch.as_tensor(np.asarray(preds_batch)).ravel()
targets = torch.as_tensor(np.asarray(targets_batch)).ravel()
self._metric.update(preds, targets)
[docs]
def compute(self) -> dict[str, float]:
"""Compute accuracy score.
Returns:
dict[str, float]: Final value from the state of the metric.
"""
return {"accuracy": self._metric.compute().item()}
[docs]
class RobustnessBiasMetric:
"""
A metric which describes Robustness Bias for features
of datasets. Currently supports only classification tasks.
"""
metadata: dict[str, Any]
def __init__(
self,
metadata: Sequence[dict[str, Any]],
labels: NDArray[np.float32],
interval: int = 100,
metadata_id: Optional[str] = None,
) -> None:
"""RobustnessBiasMetric initialization.
Args:
metadata (Sequence[dict[str, Any]]): the metadata computed during attack
which contains delta between benign and adversarial images.
labels (NDArray[np.float32]): classification labels.
interval (int, optional): tau. Defaults to 100.
"""
self._state: dict[str, Any] = {}
self._labels: np.ndarray = labels
self._metadata: Sequence[dict[str, Any]] = metadata
self._interval: int = interval
self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())}
[docs]
def reset(self) -> None:
"""Reset the metric to default values."""
self._state = {}
[docs]
def update(self, preds_batch: Sequence[NDArray[np.float32]], targets_batch: Sequence[NDArray[np.float32]]) -> None:
"""Add predictions and targets to metric's cache for later calculation.
Args:
preds_batch (Sequence[NDArray[np.float32]]): predictions in numpy array format.
targets_batch (Sequence[NDArray[np.float32]]): groundtruth targets in numpy array format.
Raises:
KeyError: if Delta not computed for metadata.
"""
try:
errors = np.stack([item["delta"] for item in self._metadata])
except KeyError as key_error:
raise KeyError(
"Delta not computed for metadata. Set norm > 0 of JaticAttack to compute delta.",
) from key_error
# assuming the targets batch is the groundtruth or original predictions
# and the preds batch are predictions for the augmented / attacked data
success = (
np.argmax(np.asarray(targets_batch), axis=1).ravel() != np.argmax(np.asarray(preds_batch), axis=1).ravel()
).astype(int)
taus = np.linspace(0, max(errors) + 1, self._interval)
series: dict = {}
for tau in taus:
for label in range(len(self._labels)):
idxs_of_label = np.argwhere(np.argmax(np.asarray(targets_batch), axis=1).ravel() == label).ravel()
idxs_of_label_success = np.argwhere(success[idxs_of_label] == 1).ravel()
idxs_of_label = idxs_of_label[idxs_of_label_success]
errors_of_label = errors[idxs_of_label]
idx_error_greater_tau = np.argwhere(errors_of_label > tau).ravel()
if len(errors_of_label) != 0:
proportion = len(idx_error_greater_tau) / len(errors_of_label)
self.__populate_series(label, series, tau, proportion)
self._state = series
[docs]
def compute(self) -> dict[str, Any]:
"""Returns the computed metric.
Returns:
dict[str, Any]: Final value from the state of the metric.
"""
return self._state
def __populate_series(self, label: int, series: dict[int, Any], tau: float, proportion: float) -> None:
"""Add tau, proportion pair to series based on if label already exists.
Args:
label (int): Index of series.
series (dict[str, Any]): dict to be added to.
tau (float): tau.
proportion (float): proportion.
"""
if label in series:
series[label].append([tau, proportion])
else:
series[label] = [[tau, proportion]]
self._state = series
[docs]
class AccuracyPerturbationMetric:
"""
A metric for easily calculating the clean and robust accuracy
as well as the perturbation between clean and adversarial input.
"""
metadata: dict[str, Any]
def __init__(
self,
benign_predictions: Sequence[NDArray[np.float32]],
metadata: Sequence[dict[str, Any]],
accuracy_type: str = "robust",
metadata_id: Optional[str] = None,
) -> None:
"""AccuracyPerturbationMetric initialization.
Args:
benign_predictions (Sequence[NDArray[np.float32]]): _description_
metadata (Sequence[Dict[str, Any]]): _description_
accuracy_type (str, optional): the type of accuracy to calculate. Choice of "adversarial" or "robust".
- Robust accuracy is the accuracy of the model on all samples
- Adversarial accuracy is the accuracy of the model only samples which were
correctly predicted in the non-adversarial scenario. Defaults to "robust".
"""
self._state: dict = {}
self._benign_predictions = benign_predictions
self._metadata = metadata
self._accuracy_type = accuracy_type
self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())}
[docs]
def reset(self) -> None:
"""Reset the metric to default values."""
self._state = {}
[docs]
def update(self, preds_batch: Sequence[NDArray[np.float32]], targets_batch: Sequence[NDArray[np.float32]]) -> None:
"""Updates the metric value.
Args:
preds_batch (Sequence[NDArray[np.float32]]): Predicted values.
targets_batch (Sequence[NDArray[np.float32]]): Target values.
Raises:
KeyError: if Delta not computed for metadata.
"""
y_orig = np.argmax(np.stack(self._benign_predictions), axis=1).ravel()
y_pred = np.argmax(np.stack(preds_batch), axis=1).ravel()
try:
mean_delta = np.stack([item["delta"] for item in self._metadata]).mean()
except KeyError as key_error:
raise KeyError(
"Delta not computed for metadata. Set norm > 0 of JaticAttack to compute delta.",
) from key_error
y_corr = y_orig == np.stack(targets_batch)
clean_acc = np.sum(y_corr) / len(y_orig)
attack_acc: float = 0.0
if self._accuracy_type == "adversarial":
attack_acc = np.sum((y_pred == y_orig) & y_corr) / np.sum(y_corr)
elif self._accuracy_type == "robust":
attack_acc = np.mean(y_pred == np.stack(targets_batch))
self._state = {
"clean_accuracy": clean_acc,
f"{self._accuracy_type}_accuracy": attack_acc,
"mean_delta": mean_delta,
}
[docs]
def compute(self) -> dict[str, float]:
"""Returns the computed metric
in Tuple (clean_accuracy, robust_accuracy, average_perturbation)
Returns:
dict[str, float]: Final value from the state of the metric.
"""
return self._state
[docs]
class BlackBoxAttackQualityMetric:
"""A metric for extracting the black box quality metrics."""
metadata: dict[str, Any]
def __init__(self, attack: JaticAttack, metadata_id: Optional[str] = None) -> None:
"""BlackBoxAttackQualityMetric initialization.
Args:
attack (JaticAttack): the black-box attack (currently only HopSkipJump supported)."""
self._state: dict = {}
self._attack = attack.get_attack()
self.metadata = {"id": metadata_id if metadata_id is not None else str(uuid.uuid4())}
[docs]
def reset(self) -> None:
"""Reset the metric to default values."""
self._state = {}
[docs]
def update(self) -> None:
"""Updates the metric value."""
total_queries = getattr(self._attack, "total_queries", np.array([]))
adv_query_idx = getattr(self._attack, "adv_query_idx", [])
adv_queries = [len(item) for item in adv_query_idx]
benign_queries = [total_queries[i] - n_adv for i, n_adv in enumerate(adv_queries)]
adv_perturb_total = getattr(self._attack, "perturbs", [])
adv_perturb_iter = getattr(self._attack, "perturbs_iter", [])
adv_confs_total = getattr(self._attack, "confs", [])
self._state = {
"total_queries": total_queries,
"adv_queries": adv_queries,
"benign_queries": benign_queries,
"adv_query_idx": adv_query_idx,
"adv_perturb_total": adv_perturb_total,
"adv_perturb_iter": adv_perturb_iter,
"adv_confs_total": adv_confs_total,
}
[docs]
def compute(self) -> dict[str, Any]:
"""Returns the computed metric
in dict
Returns:
dict[str, Any]: Final value from the state of the metric.
"""
return self._state