Skip to content

Commit

Permalink
Merge pull request #5 from EMalagoli92/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
EMalagoli92 authored Apr 10, 2024
2 parents d094ea5 + 5c971a6 commit 5459d30
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 129 deletions.
6 changes: 3 additions & 3 deletions src/od_metrics/od_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from .constants import DEFAULT_COCO, _STANDARD_OUTPUT
from .utils import to_array, get_indexes, get_suffix, _Missing
from .utils import get_indexes, get_suffix, _Missing
from .validators import ConstructorModel, ComputeModel, MeanModel


Expand Down Expand Up @@ -899,8 +899,8 @@ def _get_mean(
# Default
default_value = {
"iou_threshold": self.iou_thresholds,
"label_id": to_array(label_ids),
"area_range_key": to_array(list(self.area_ranges.keys())),
"label_id": np.array(label_ids),
"area_range_key": np.array(list(self.area_ranges.keys())),
"max_detection_threshold": self.max_detection_thresholds,
}

Expand Down
35 changes: 5 additions & 30 deletions src/od_metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,18 @@

__all__ = [
"_Missing",
"to_array",
"get_indexes",
"get_suffix",
]

from typing import Literal, Any
from typing import Literal
import numpy as np


class _Missing:
"""Sentinel class for missing values."""


def to_array(
input_: Any,
) -> np.ndarray:
"""
Trasform input to `np.ndarray`.
Parameters
----------
input_ : Any | None, optional
Input to be converted.
Returns
-------
np.ndarray
Input converted to `np.ndarray`.
"""
if not isinstance(input_, np.ndarray):
output = np.array(input_)
else:
output = input_

if output.ndim == 0:
output = output.reshape(-1)
return output


def get_indexes(
array1: np.ndarray,
array2: np.ndarray
Expand Down Expand Up @@ -145,8 +118,10 @@ def to_xywh(
return xyxy_xywh(bbox)
if box_format == "cxcywh":
return cxcywh_xywh(bbox)
raise ValueError("`box_format` can be `'xyxy'`, `'xywh'`, `'cxcywh'`. "
f"Found {box_format}")
raise ValueError( # pragma: no cover
"`box_format` can be `'xyxy'`, `'xywh'`, `'cxcywh'`. "
f"Found {box_format}"
)


def get_suffix(
Expand Down
28 changes: 14 additions & 14 deletions src/od_metrics/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def iou_recall_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _common_validator(
name=info.field_name,
Expand Down Expand Up @@ -271,8 +271,8 @@ def max_detection_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _common_validator(
name=info.field_name,
Expand Down Expand Up @@ -315,8 +315,8 @@ def area_ranges_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _area_ranges_validator(
name=info.field_name,
Expand Down Expand Up @@ -680,8 +680,8 @@ def annotation_parser(
Ground truth or predictions annotations.
"""
if info.context is None or "box_format" not in info.context:
raise ValueError("Missing required context or `box_format` "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or `box_format` information.")
box_format = info.context["box_format"]

# y_true
Expand Down Expand Up @@ -750,8 +750,8 @@ def iou_threshold_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _common_validator(
name=info.field_name,
Expand Down Expand Up @@ -789,8 +789,8 @@ def area_range_key_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _common_validator(
name=info.field_name,
Expand Down Expand Up @@ -828,8 +828,8 @@ def max_detection_label_id_validator(
or "default_value" not in info.context
or info.field_name is None
):
raise ValueError("Missing required context or field name "
"information.")
raise ValueError( # pragma: no cover
"Missing required context or field name information.")

return _common_validator(
name=info.field_name,
Expand Down
97 changes: 93 additions & 4 deletions tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,18 +678,41 @@
},
]

misc_tests = [

annotations_tests = [
{
"compute_settings": {"extended_summary": True},
"ids": "default_COCO",
"y_true": [
{"labels": [0, 2],
"boxes": np.array([[17, 83, 97, 47], [57, 86, 96, 73]])}
],
"y_pred": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]], "scores": [.2, .3]}
],
"ids": "annotations_boxes_numpy_array"
},
{
"compute_settings": {"extended_summary": True},
"y_true": [
{"labels": [0, 2],
"boxes": np.array([[17, 83, 97, 47], [57, 86, 96, 73]]),
"area": np.array([4559, 7008]),
}
],
"y_pred": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]], "scores": [.2, .3]}
],
"ids": "annotations_area_numpy_array"
},
{
"compute_settings": {"extended_summary": True},
"annotations_settings": {
"y_true": {"n_classes": 3},
"y_pred": {"n_classes": 7},
},
"ids": "misc_default_COCO_different_classes_y_true_y_pred"
"ids": "annotations_exception_different_classes_y_true_y_pred"
},
{
"compute_settings": {"extended_summary": True},
Expand All @@ -698,7 +721,68 @@
"y_pred": {"n_images": 5},
},
"exceptions": {"compute": ValidationError},
"ids": "misc_exception_compute_different_images"
"ids": "annotations_exception_different_images"
},
{
"compute_settings": {"extended_summary": True},
"y_true": [
{"labels": [0],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]]}
],
"y_pred": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]], "scores": [.2, .3]}
],
"exceptions": {"compute": ValidationError},
"ids": "annotations_exception_different_attributes_length"
},
{
"compute_settings": {"extended_summary": True},
"y_true": [
{"labels": [0, 2],
"boxes": [[17, 83, 97], [57, 86, 96, 73]]}
],
"y_pred": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]], "scores": [.2, .3]}
],
"exceptions": {"compute": ValidationError},
"to_cover": {"pycoco_converter": False},
"ids": "annotations_exception_boxes_length"
},
{
"compute_settings": {"extended_summary": True},
"y_true": [
{"labels": [0, 2]}
],
"y_pred": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]], "scores": [.2, .3]}
],
"exceptions": {"compute": ValidationError},
"to_cover": {"pycoco_converter": False, "box_format_converter": False},
"ids": "annotations_exception_ytrue_no_boxes"
},
{
"compute_settings": {"extended_summary": True},
"y_true": [
{"labels": [0, 2],
"boxes": [[17, 83, 97, 47], [57, 86, 96, 73]]}
],
"y_pred": [
{"labels": [0, 2], "scores": [.2, .3]}
],
"exceptions": {"compute": ValidationError},
"to_cover": {"pycoco_converter": False, "box_format_converter": False},
"ids": "annotations_exception_ypred_no_boxes"
},
]


misc_tests = [
{
"compute_settings": {"extended_summary": True},
"ids": "default_COCO",
},
{
"compute_settings": {"extended_summary": "yes"},
Expand All @@ -717,6 +801,7 @@
+ objects_number_tests
+ objects_size_tests
+ mean_evaluator_tests
+ annotations_tests
+ misc_tests
)

Expand Down Expand Up @@ -750,4 +835,8 @@
"exceptions",
{}
)
test_tmp["to_cover"] = test_tmp.get(
"to_cover",
{}
)
TESTS.append(test_tmp)
Loading

0 comments on commit 5459d30

Please sign in to comment.