Source code for yieldplotlib.accessibility
"""Accessibility features for yieldplotlib."""
import matplotlib
import numpy as np
from colorspacious import cspace_converter
from matplotlib.colors import to_rgb
from yieldplotlib.logger import logger
from yieldplotlib.util import is_monotonic, rgetattr
[docs]
class AccessibilityManager:
"""Manages accessibility features for a given matplotlib.axes.Axes."""
def __init__(self, ax):
"""Manages accessibility features for a given matplotlib.axes.Axes.
Main feature is the run_checks function which will run a series of
accessibility checks. If a check is failed it will send a warning to
the console, but will not raise an error.
Args:
ax (matplotlib.axes.Axes):
Axes object on which to run the accessibility checks.
"""
self.ax = ax
self.warnings = []
[docs]
def run_checks(self):
"""Run accessibility checks."""
self.check_colors()
self.check_fonts()
if not self.warnings:
logger.info("All accessibility checks have passed!")
else:
logger.warning(f"{len(self.warnings)} accessibility checks failed.")
return self.warnings
[docs]
def check_colors(self):
"""Check if colors used are monotonic and span an acceptable lightness range."""
# Convert RGB colors into greyscale to check lightness.
rgb = []
# If cmap is defined, get a sample of those colors.
try:
rgb_values = [
image.cmap(np.arange(0, 1, 0.1))[:, :3] for image in self.ax.images
]
for val in rgb_values[0]:
rgb.append([val[0], val[1], val[2]])
except IndexError:
pass
# Get colors for all lines on the axes.
line_colors = [list(to_rgb(line.get_color())) for line in self.ax.get_lines()]
for lc in line_colors:
rgb.append(lc)
# Get colormap used for scatter points.
try:
scatter_cmap = self.ax.collections[0].get_cmap()
if isinstance(scatter_cmap, matplotlib.colors.LinearSegmentedColormap):
rgb_values = self.ax.collections[0].get_cmap()(np.arange(0, 1, 0.1))[
:, :3
]
cmap_colors = []
for val in rgb_values:
cmap_colors.append([val[0], val[1], val[2]])
else:
cmap_colors = scatter_cmap.colors
cmap_colors = [to_rgb(color) for color in cmap_colors]
cmap_lab = cspace_converter("sRGB1", "CAM02-UCS")(cmap_colors)
cmap_lightness = cmap_lab[:, 0]
if not is_monotonic(cmap_lightness):
warning = "Colormap for scatter points is not monotonic"
self.warnings.append(warning)
logger.warning(warning)
except IndexError:
pass
# Get colors for all faces.
fcs = [s.get_facecolor() for s in self.ax.collections]
face_colors = [list(to_rgb(fc)) for fc in np.squeeze(fcs)]
for fc in face_colors:
rgb.append(fc)
# Get colors for all containers (i.e. bars)
container_colors = [
list(to_rgb(b.get_facecolor())) for c in self.ax.containers for b in c
]
for cc in container_colors:
rgb.append(cc)
# Clean up the RBG values.
rgb = [el for el in rgb if el != []]
# Convert to greyscale to determine lightness.
lab = cspace_converter("sRGB1", "CAM02-UCS")(rgb)
lightness = lab[:, 0]
# If there is only one color, return.
if len(lightness) == 1:
return
if np.abs(np.max(lightness) - np.min(lightness)) < 50:
warning = (
f"Colors do not have enough dynamic range, lightness "
f"difference is"
f" {np.abs(np.max(lightness) - np.min(lightness)):.2f} "
f"and should be at least 50."
)
self.warnings.append(warning)
logger.warning(warning)
if not is_monotonic(lightness):
warning = "Colors may not be monotonic"
self.warnings.append(warning)
logger.warning(warning)
[docs]
def check_fonts(self, size_threshold=10):
"""Checks that all font sizes in the plot are larger than a given threshold."""
font_sizes = {}
# Fonts to check with a get_size() property.
attrs = [
"xaxis.label",
"yaxis.label",
"title",
"get_xticklabels",
"get_yticklabels",
]
for at in attrs:
objs = rgetattr(self.ax, at)
try:
font_sizes[at] = objs.get_size()
except AttributeError:
eval = objs()
for item in eval:
font_sizes[at] = item.get_size()
# Get font sizes for annotations and legends.
for child in self.ax.get_children():
if isinstance(child, matplotlib.text.Annotation):
font_sizes[child] = child.get_fontsize()
if isinstance(child, matplotlib.legend.Legend):
for leg in child.get_texts():
font_sizes[leg] = leg.get_fontsize()
# Get all fonts and sizes that are less than the specified size point threshold.
noncompliant_dict = {k: v for k, v in font_sizes.items() if v < size_threshold}
if bool(noncompliant_dict):
warning = (
f"The following font sizes are likely too small and could present "
f"readability issues {noncompliant_dict}"
)
self.warnings.append(warning)
logger.warning(warning)