from iaa_od.models import Result, Image, KAlphaUnit, AnnotationProtocol
from iaa_od.models.constants import GREEN, RED, DEFAULT_COLOR
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.patches import Rectangle, Patch
import mplcursors
from .show_utils import get_image_path
import numpy as np
[docs]
def show_image_with_disagreements(result: Result, filename: str, filepath: str, /, *, compare_categories: bool = False, show_image_filename: bool = False) -> None:
"""
Shows which bounding boxes on a given image contribute to agreement (in green) and which don't (in red).
Parameters:
result (Result): The Result object containing the ground truths and units.
filename (str): The filename of the image to display.
filepath (str): The filepath where the image is located.
compare_categories (bool): Whether to compare categories for agreement status, or just localisation.
show_image_filename (bool): Whether to show the image filename in the title of the plot. Defaults to False.
"""
# Check that both filepath and filename are provided
if filepath is None or filepath == "":
raise ValueError("No filepath provided.")
if filename is None or filename == "":
raise ValueError("No filename provided.")
# Check that the result object contains units to show
if not result.units or len(result.units) == 0:
raise ValueError("No units found in the provided Result object. Please run the k-alpha algorithm to generate them.")
# Check that the selected image actually contains annotations. If not, skip it.
image_units_exist = any(unit.img_filename == filename for unit in result.units)
if not image_units_exist:
raise ValueError(f"No annotations found for image '{filename}' in the provided Result object. Maybe it was simply empty?")
# Construct the full image path
full_path: str = get_image_path(filename, filepath)
image: Image | None = result.gts[0].images.get(filename)
if not image:
raise ValueError(f"Image '{filename}' not found in the provided Ground Truths.")
# Get all units from the result object
units: list[KAlphaUnit] = result.units
# Filter the units related to the selected image
image_units = [unit for unit in units if unit.img_filename == filename]
# Get all annotations for this image from all GTs
gts = result.gts
annotations: dict[str, list[AnnotationProtocol]] = {}
for gt in gts:
anns = gt.annotations[filename]
annotations[gt.name] = anns
# Initialise the image plot
_, axes = plt.subplots(1, gts.__len__(), figsize=(10 * gts.__len__(), 10), sharex=True, sharey=True)
img: np.ndarray = mpimg.imread(full_path)
for ax in axes:
ax.imshow(img, origin='upper')
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
# Initialise handle list
handles = []
if compare_categories:
handles.append(Patch(color=GREEN, label='Agreement'))
handles.append(Patch(color=RED, label='Disagreement'))
else:
handles.append(Patch(color=GREEN, label='Localisation agreement'))
handles.append(Patch(color=RED, label='Localisation disagreement'))
# Initialise annotation list for interactive cursor
bboxes = []
# Sort annotations by GT to ensure they are plotted in the same order as the subplots
annotations = {key: value for key, value in sorted(annotations.items())}
ordered_gt_list = list(annotations.keys())
gt_index_lut = {gt_name: idx for idx, gt_name in enumerate(ordered_gt_list)}
# Create a drawn annotation set to avoid duplicates
drawn_annotations = set()
for idx, (gt_name, anns) in enumerate(annotations.items()):
# Write the GT name above the image
title_str: str
if idx == 0 and show_image_filename:
title_str = f"{filename} - {gt_name}"
else:
title_str = gt_name
axes[idx].set_title(title_str, fontsize=16, fontweight='bold')
for ann in anns:
if ann.unique_id not in drawn_annotations:
# Get unit for the current annotation
curr_unit_id = ann.unit_id
curr_unit: KAlphaUnit | None = None
for unit in image_units:
if unit.id == curr_unit_id:
curr_unit = unit
break
if curr_unit is None:
raise ValueError(f"Annotation with unit_id {curr_unit_id} does not belong to any unit in the current image.")
# Get all annotations in the unit
unit_annotations = curr_unit.annotations
# Determine the colour for this annotation based on the agreement status of its unit
colour: str = DEFAULT_COLOR
partial_agreement_selected_category: int | None = None
if compare_categories:
if len(unit_annotations) == 1:
colour = RED
elif len(unit_annotations) >= 2:
# Check the categories of the annotations in this unit
unit_categories = set(ann.category_id for ann in unit_annotations)
if len(unit_categories) == 1:
# This is a full agreement
colour = GREEN
elif len(unit_categories) > 1:
# This is a partial agreement (disagreement on categories)
# Get the count of each category in this unit
categories_in_unit: dict[int, int] = dict.fromkeys(gts[0].categories_dict.keys(), 0)
for ann in unit_annotations:
categories_in_unit[ann.category_id] += 1
# Sort the dictionary by count in descending order
sorted_categories_in_unit = dict(sorted(categories_in_unit.items(), key=lambda item: item[1], reverse=True))
# Colour the most common category with green, and the rest with red
# NOTE: This is subject to change in future versions, as this is very effective for three annotators but may not be for more.
candidate_selected_category = next(iter(sorted_categories_in_unit))
if sorted_categories_in_unit[candidate_selected_category] >= 2:
partial_agreement_selected_category = candidate_selected_category
else:
colour = RED
else:
# Something went wrong here
raise ValueError(f"Unit {curr_unit.id} has annotations but no categories?")
else:
continue
else:
if len(unit_annotations) >= 2:
colour = GREEN
else:
colour = RED
# Draw each annotation
for ann in unit_annotations:
bbox = ann.bbox_coords.coords
if partial_agreement_selected_category is not None:
if ann.category_id == partial_agreement_selected_category:
colour = GREEN
else:
colour = RED
rect = Rectangle((bbox.x, bbox.y), bbox.w, bbox.h, linewidth=2, edgecolor=colour, facecolor='none')
# Get category name for the interactive cursor
category_name = result.gts[0].categories_dict[ann.category_id]
# Create metadata to be shown in the interactive cursor
meta = {
"Annotator": ann.gt_name,
"Category": category_name,
}
# Add metadata to the rectangle for the interactive cursor
setattr(rect, "_meta", meta)
# Understand which plot the annotation should be drawn on
plot_idx = gt_index_lut[ann.gt_name]
axes[plot_idx].add_patch(rect)
# Add the rectangle to the list for the interactive cursor
bboxes.append(rect)
# Reset the partial agreement marker for the next unit
partial_agreement_selected_category = None
# Set the above annotations as drawn to avoid duplicates
for ann in unit_annotations:
drawn_annotations.add(ann.unique_id)
# Add the legend to the right of the subplot
if handles:
axes[len(gts) // 2].legend(handles=handles, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncols=2, fontsize='xx-large')
# Initialise interactive cursor
cursor = mplcursors.cursor(bboxes, hover=True)
@cursor.connect("add")
def on_add(sel):
meta = getattr(sel.artist, "_meta", {})
annotator = meta.get("Annotator", "Unknown")
category = meta.get("Category", "Unknown")
sel.annotation.set(text=f"Annotator: {annotator}\nCategory: {category}")
# Show the plot
plt.tight_layout()
plt.show()
[docs]
def show_gts_with_disagreements(result: Result, filepath: str, /, *, compare_categories: bool = False, show_image_filename: bool = False) -> None:
"""
A wrapper for the function "show_image_with_disagreements_new" that runs it for all images in a given ground truth set.
Parameters:
result (Result): The Result object containing the ground truths and units.
filepath (str): The filepath where the images are located.
compare_categories (bool): Whether to compare categories for agreement status, or just localisation.
show_image_filename (bool): Whether to show the image filename in the title of the plot. Defaults to False.
"""
# Check that filepath is provided
if filepath is None or filepath == "":
raise ValueError("No filepath provided.")
# Check that the result object contains units to show
if not result.units or len(result.units) == 0:
raise ValueError("No units found in the provided Result object. Please run the k-alpha algorithm to generate them.")
# Get image list from result object
images: dict[str, Image] = result.gts[0].images
filenames: list[str] = list(images.keys())
# Run show_image_with_disagreements for all images
for filename in filenames:
try:
show_image_with_disagreements(result, filename, filepath, compare_categories=compare_categories, show_image_filename=show_image_filename)
except ValueError as e:
print(f"Skipping image '{filename}': {e}")