from iaa_od.models import Result, Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.patches import Rectangle, Patch
from .show_utils import upsert_handle, get_image_path
import numpy as np
[docs]
def show_units(result: Result, filename: str, filepath: str, /) -> None:
"""
Function which shows all units found for a given image in the provided Result object.
Parameters:
result (Result): The Result object containing the ground truths and units.
filename (str): The filename of the image to display.
filepath (str): The filepath where the image is located.
"""
if filepath is None or filepath == "":
raise ValueError("No filepath provided.")
if filename is None or filename == "":
raise ValueError("No filename provided.")
# Check that the result object contains units to show
if not result.units or len(result.units) == 0:
raise ValueError("No units found in the provided Result object. Please run the k-alpha algorithm to generate them.")
# Construct the full image path
full_path: str = get_image_path(filename, filepath)
image: Image | None = result.gts[0].images.get(filename)
if not image:
raise ValueError(f"Image '{filename}' not found in the provided Ground Truths.")
# Initialise the image plot
_, ax = plt.subplots(1, 1, figsize=(10, 10))
img: np.ndarray = mpimg.imread(full_path)
ax.imshow(img, origin='upper')
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
# Initialise handle list
handles = []
# Filter units related to the selected image
image_units = [unit for unit in result.units if unit.img_filename == filename]
# Initialise a colour map for the units
colour_map = plt.get_cmap('hsv', len(image_units) + 1)
# Draw all unit annotations with different colours for each unit
for idx, unit in enumerate(image_units):
colour = colour_map(idx)
for ann in unit.annotations:
bbox = ann.bbox_coords.coords
rect = Rectangle((bbox.x, bbox.y), bbox.w, bbox.h, linewidth=2, edgecolor=colour, facecolor='none')
ax.add_patch(rect)
# Add a legend handle for this unit
new_handle = Patch(color=colour, label=f"Unit {unit.id}")
upsert_handle(handles, new_handle)
# Add the legend to the right of the subplot
if handles:
ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5), fontsize='large')
# Show the plot
plt.tight_layout()
plt.show()