Source code for iaa_od.visualisation.show_disagreements

from iaa_od.models import Result, Image, KAlphaUnit, AnnotationProtocol
from iaa_od.models.constants import GREEN, RED, DEFAULT_COLOR
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.patches import Rectangle, Patch
import mplcursors
from .show_utils import get_image_path
import numpy as np

[docs] def show_image_with_disagreements(result: Result, filename: str, filepath: str, /, *, compare_categories: bool = False, show_image_filename: bool = False) -> None: """ Shows which bounding boxes on a given image contribute to agreement (in green) and which don't (in red). Parameters: result (Result): The Result object containing the ground truths and units. filename (str): The filename of the image to display. filepath (str): The filepath where the image is located. compare_categories (bool): Whether to compare categories for agreement status, or just localisation. show_image_filename (bool): Whether to show the image filename in the title of the plot. Defaults to False. """ # Check that both filepath and filename are provided if filepath is None or filepath == "": raise ValueError("No filepath provided.") if filename is None or filename == "": raise ValueError("No filename provided.") # Check that the result object contains units to show if not result.units or len(result.units) == 0: raise ValueError("No units found in the provided Result object. Please run the k-alpha algorithm to generate them.") # Check that the selected image actually contains annotations. If not, skip it. image_units_exist = any(unit.img_filename == filename for unit in result.units) if not image_units_exist: raise ValueError(f"No annotations found for image '{filename}' in the provided Result object. Maybe it was simply empty?") # Construct the full image path full_path: str = get_image_path(filename, filepath) image: Image | None = result.gts[0].images.get(filename) if not image: raise ValueError(f"Image '{filename}' not found in the provided Ground Truths.") # Get all units from the result object units: list[KAlphaUnit] = result.units # Filter the units related to the selected image image_units = [unit for unit in units if unit.img_filename == filename] # Get all annotations for this image from all GTs gts = result.gts annotations: dict[str, list[AnnotationProtocol]] = {} for gt in gts: anns = gt.annotations[filename] annotations[gt.name] = anns # Initialise the image plot _, axes = plt.subplots(1, gts.__len__(), figsize=(10 * gts.__len__(), 10), sharex=True, sharey=True) img: np.ndarray = mpimg.imread(full_path) for ax in axes: ax.imshow(img, origin='upper') ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) # Initialise handle list handles = [] if compare_categories: handles.append(Patch(color=GREEN, label='Agreement')) handles.append(Patch(color=RED, label='Disagreement')) else: handles.append(Patch(color=GREEN, label='Localisation agreement')) handles.append(Patch(color=RED, label='Localisation disagreement')) # Initialise annotation list for interactive cursor bboxes = [] # Sort annotations by GT to ensure they are plotted in the same order as the subplots annotations = {key: value for key, value in sorted(annotations.items())} ordered_gt_list = list(annotations.keys()) gt_index_lut = {gt_name: idx for idx, gt_name in enumerate(ordered_gt_list)} # Create a drawn annotation set to avoid duplicates drawn_annotations = set() for idx, (gt_name, anns) in enumerate(annotations.items()): # Write the GT name above the image title_str: str if idx == 0 and show_image_filename: title_str = f"{filename} - {gt_name}" else: title_str = gt_name axes[idx].set_title(title_str, fontsize=16, fontweight='bold') for ann in anns: if ann.unique_id not in drawn_annotations: # Get unit for the current annotation curr_unit_id = ann.unit_id curr_unit: KAlphaUnit | None = None for unit in image_units: if unit.id == curr_unit_id: curr_unit = unit break if curr_unit is None: raise ValueError(f"Annotation with unit_id {curr_unit_id} does not belong to any unit in the current image.") # Get all annotations in the unit unit_annotations = curr_unit.annotations # Determine the colour for this annotation based on the agreement status of its unit colour: str = DEFAULT_COLOR partial_agreement_selected_category: int | None = None if compare_categories: if len(unit_annotations) == 1: colour = RED elif len(unit_annotations) >= 2: # Check the categories of the annotations in this unit unit_categories = set(ann.category_id for ann in unit_annotations) if len(unit_categories) == 1: # This is a full agreement colour = GREEN elif len(unit_categories) > 1: # This is a partial agreement (disagreement on categories) # Get the count of each category in this unit categories_in_unit: dict[int, int] = dict.fromkeys(gts[0].categories_dict.keys(), 0) for ann in unit_annotations: categories_in_unit[ann.category_id] += 1 # Sort the dictionary by count in descending order sorted_categories_in_unit = dict(sorted(categories_in_unit.items(), key=lambda item: item[1], reverse=True)) # Colour the most common category with green, and the rest with red # NOTE: This is subject to change in future versions, as this is very effective for three annotators but may not be for more. candidate_selected_category = next(iter(sorted_categories_in_unit)) if sorted_categories_in_unit[candidate_selected_category] >= 2: partial_agreement_selected_category = candidate_selected_category else: colour = RED else: # Something went wrong here raise ValueError(f"Unit {curr_unit.id} has annotations but no categories?") else: continue else: if len(unit_annotations) >= 2: colour = GREEN else: colour = RED # Draw each annotation for ann in unit_annotations: bbox = ann.bbox_coords.coords if partial_agreement_selected_category is not None: if ann.category_id == partial_agreement_selected_category: colour = GREEN else: colour = RED rect = Rectangle((bbox.x, bbox.y), bbox.w, bbox.h, linewidth=2, edgecolor=colour, facecolor='none') # Get category name for the interactive cursor category_name = result.gts[0].categories_dict[ann.category_id] # Create metadata to be shown in the interactive cursor meta = { "Annotator": ann.gt_name, "Category": category_name, } # Add metadata to the rectangle for the interactive cursor setattr(rect, "_meta", meta) # Understand which plot the annotation should be drawn on plot_idx = gt_index_lut[ann.gt_name] axes[plot_idx].add_patch(rect) # Add the rectangle to the list for the interactive cursor bboxes.append(rect) # Reset the partial agreement marker for the next unit partial_agreement_selected_category = None # Set the above annotations as drawn to avoid duplicates for ann in unit_annotations: drawn_annotations.add(ann.unique_id) # Add the legend to the right of the subplot if handles: axes[len(gts) // 2].legend(handles=handles, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncols=2, fontsize='xx-large') # Initialise interactive cursor cursor = mplcursors.cursor(bboxes, hover=True) @cursor.connect("add") def on_add(sel): meta = getattr(sel.artist, "_meta", {}) annotator = meta.get("Annotator", "Unknown") category = meta.get("Category", "Unknown") sel.annotation.set(text=f"Annotator: {annotator}\nCategory: {category}") # Show the plot plt.tight_layout() plt.show()
[docs] def show_gts_with_disagreements(result: Result, filepath: str, /, *, compare_categories: bool = False, show_image_filename: bool = False) -> None: """ A wrapper for the function "show_image_with_disagreements_new" that runs it for all images in a given ground truth set. Parameters: result (Result): The Result object containing the ground truths and units. filepath (str): The filepath where the images are located. compare_categories (bool): Whether to compare categories for agreement status, or just localisation. show_image_filename (bool): Whether to show the image filename in the title of the plot. Defaults to False. """ # Check that filepath is provided if filepath is None or filepath == "": raise ValueError("No filepath provided.") # Check that the result object contains units to show if not result.units or len(result.units) == 0: raise ValueError("No units found in the provided Result object. Please run the k-alpha algorithm to generate them.") # Get image list from result object images: dict[str, Image] = result.gts[0].images filenames: list[str] = list(images.keys()) # Run show_image_with_disagreements for all images for filename in filenames: try: show_image_with_disagreements(result, filename, filepath, compare_categories=compare_categories, show_image_filename=show_image_filename) except ValueError as e: print(f"Skipping image '{filename}': {e}")