import json
from .image import Image
from .annotation import Annotation
from collections import defaultdict
from iaa_od.utils import load_categories_dictionary
from typing import Any
from dataclasses import dataclass, field
[docs]
@dataclass(slots=True, kw_only=True)
class GroundTruth:
"""
Represents a Ground Truth JSON file.
Attributes:
gt_filepath (str): Path to the Ground Truth JSON file.
name (str): Name of the Ground Truth dataset.
images (dict): Dictionary of Image objects, keyed by image filename.
annotations (dict): Dictionary of lists of Annotation objects, keyed by image filename.
categories_dict (dict): Dictionary mapping category IDs to category names.
image_id_to_filename (dict): Dictionary mapping image IDs to filenames.
"""
# Properties
gt_filepath: str
name: str
images: dict[str, Image] = field(init=False)
annotations: dict[str, list[Annotation]] = field(init=False)
categories_dict: dict[int, str] = field(init=False)
image_id_to_filename: dict[int, str] = field(init=False)
image_filename_to_id: dict[str, int] = field(init=False)
def __post_init__(self):
with open(self.gt_filepath, 'r') as file:
gt_data: dict[str, Any] = json.load(file)
self.images: dict[str, Image] = self.find_images(gt_data)
self.annotations: dict[str, list[Annotation]] = self.find_annotations(gt_data, self.images)
self.categories_dict: dict[int, str] = load_categories_dictionary(self.gt_filepath)
[docs]
def find_images(self, gt_data: dict[str, str]) -> dict[str, Image]:
"""
Finds and returns a dictionary of Image objects from the Ground Truth data.
Parameters:
gt_data (dict): The loaded Ground Truth JSON data.
Returns:
dict[str, Image]: A dictionary of Image objects, keyed by image filename.
"""
image_data: Any | None = gt_data.get("images")
images = {}
if not image_data:
raise ValueError(f"No images found in Ground Truth file: {self.gt_filepath}")
for image in image_data:
image_obj = Image(image)
images[image_obj.file_name] = image_obj
# Build a mapping from image IDs to filenames and vice versa
self.image_id_to_filename = {img.id: img.file_name for img in images.values()}
self.image_filename_to_id = {img.file_name: img.id for img in images.values()}
return images
[docs]
def find_annotations(self, gt_data: dict[str, str], images: dict[str, Image]) -> dict[str, list[Annotation]]:
"""
Finds and returns a dictionary of lists of Annotation objects from the Ground Truth data, keyed by image filename.
Parameters:
gt_data (dict): The loaded Ground Truth JSON data.
images (dict): A dictionary of Image objects, keyed by image filename.
Returns:
dict[str, list[Annotation]]: A dictionary of lists of Annotation objects, keyed by image filename.
"""
annotations = defaultdict(list)
annotations_data: Any | None = gt_data.get("annotations", None)
# If this has a value, then this is a normal GT and
# we get annotations from the root of the JSON file.
if annotations_data:
for annotation_data in annotations_data:
annotation_img_id = annotation_data.get("image_id", None)
if annotation_img_id:
image_name: str | None = self.image_id_to_filename.get(annotation_img_id)
if not image_name:
raise ValueError(f"Image ID {annotation_img_id} in annotations not found in images in Ground Truth file: {self.gt_filepath}")
annotations[annotation_img_id].append(Annotation(gt_filepath=self.gt_filepath,
gt_name=self.name,
annotation=annotation_data,
image_name=image_name
)
)
# Now that we gathered all annotations by ID, we simply group them by filename.
grouped_annotations = defaultdict(list)
for img_id, ann_list in annotations.items():
img_filename: str | None = self.image_id_to_filename.get(img_id)
if img_filename:
grouped_annotations[img_filename].extend(ann_list)
annotations = grouped_annotations
# If not, then we must get annotations from within
# each image and we're working with an expert GT.
else:
for image_filename, image in images.items():
annotations_data = image.annotations
if annotations_data:
for annotation_data in annotations_data:
annotations[image_filename].append(Annotation(gt_filepath=self.gt_filepath,
gt_name=self.name,
annotation=annotation_data
)
)
return annotations
def __str__(self):
gt_string = "Ground truth: " + self.name + "\n"
gt_string += "=" * 80 + "\n"
for img_filename, _ in self.images.items():
gt_string += "Image: " + str(img_filename) + "\n\n"
gt_string += self.images[img_filename].__str__()
if self.annotations[img_filename]:
gt_string += "\nAnnotations for this image:\n\n"
for ann in self.annotations[img_filename]:
gt_string += ann.__str__()
else:
gt_string += "\nNo annotations for this image.\n"
gt_string += "-" * 80 + "\n"
return gt_string