"""Generic plotting utilities that work with any DirectoryNode.
This module extends matplotlib's Axes class with methods for plotting data directly
from DirectoryNode objects using key-based access.
"""
import inspect
import sys
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
[docs]
def ypl_plot(self, directory_node, x, y, c=None, autolabel=True, **kwargs):
"""Plot data from a DirectoryNode.
Args:
self (matplotlib.axes.Axes):
The axes to plot the data on.
directory_node (DirectoryNode):
The data source to extract plotting variables from.
x (str):
Key for x-axis data.
y (str):
Key for y-axis data.
c (str, optional):
Key for color data.
autolabel (bool, optional):
Whether to automatically label the axes.
**kwargs:
Additional keyword arguments passed to the plot method.
Returns:
matplotlib.lines.Line2D: The line(s) created.
"""
# Extract data
x_data = directory_node.get(x) if x else None
y_data = directory_node.get(y) if y else None
# Basic validation
if x_data is None:
raise ValueError(f"Could not find data for key: x='{x}'")
if y_data is None:
raise ValueError(f"Could not find data for key: y='{y}'")
plot_kwargs = kwargs.copy()
# Handle units for x and y data
x_unit, y_unit = None, None
if isinstance(x_data, u.Quantity):
x_unit, x_data = x_data.unit, x_data.value
if isinstance(y_data, u.Quantity):
y_unit, y_data = y_data.unit, y_data.value
# Set axis labels with units if available
if autolabel:
_xlabel = x.replace("_", " ").title()
_ylabel = y.replace("_", " ").title()
_xlabel += f" ({x_unit})" if x_unit else ""
_ylabel += f" ({y_unit})" if y_unit else ""
self.set_xlabel(_xlabel)
self.set_ylabel(_ylabel)
# Handle additional data parameters
if c:
if isinstance(c, str):
c_data = directory_node.get(c)
else:
c_data = c
if c_data is not None:
# Handle units for color data
_clabel = c.replace("_", " ").title()
if isinstance(c_data, u.Quantity):
c_unit = c_data.unit
c_data = c_data.value
_clabel += f" [{c_unit}]"
plot_kwargs["c"] = c_data
# Add colorbar if requested
if kwargs.get("colorbar", True):
cmap = plot_kwargs.get("cmap", plt.cm.viridis)
norm = Normalize(vmin=np.min(c_data), vmax=np.max(c_data))
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=self)
cbar.set_label(_clabel)
# Use the standard plot method with our extracted data
return self.plot(x_data, y_data, **plot_kwargs)
[docs]
def ypl_scatter(self, directory_node, x, y, c=None, autolabel=True, **kwargs):
"""Create a scatter plot from DirectoryNode data.
Args:
self (matplotlib.axes.Axes):
The axes to plot the scatter on.
directory_node (DirectoryNode):
The data source to extract plotting variables from.
x (str):
Key for x-axis data.
y (str):
Key for y-axis data.
c (str or array-like, optional):
Key for color data or an array of color values.
autolabel (bool, optional):
Whether to automatically label the axes.
**kwargs:
Additional keyword arguments passed to the scatter method.
Returns:
matplotlib.collections.PathCollection: The scatter plot created.
"""
# Extract data
x_data = directory_node.get(x) if x else None
y_data = directory_node.get(y) if y else None
# Basic validation
if x_data is None:
raise ValueError(f"Could not get key '{x}' from {directory_node}")
if y_data is None:
raise ValueError(f"Could not get key '{y}' from {directory_node}")
scatter_kwargs = kwargs.copy()
# Handle units for x and y data
x_unit = None
y_unit = None
if isinstance(x_data, u.Quantity):
x_unit, x_data = x_data.unit, x_data.value
if isinstance(y_data, u.Quantity):
y_unit, y_data = y_data.unit, y_data.value
# Set axis labels with units if available
if autolabel:
_xlabel = x.replace("_", " ").title()
_ylabel = y.replace("_", " ").title()
_xlabel += f" ({x_unit})" if x_unit else ""
_ylabel += f" ({y_unit})" if y_unit else ""
self.set_xlabel(_xlabel)
self.set_ylabel(_ylabel)
# Handle additional data parameters
if c is not None:
if isinstance(c, str):
c_data = directory_node.get(c)
else:
c_data = c
if c_data is not None:
# Handle units for color data
if isinstance(c, str):
_clabel = c.replace("_", " ").title()
if isinstance(c_data, u.Quantity):
c_unit, c_data = c_data.unit, c_data.value
_clabel += f" [{c_unit}]"
else:
# If c was passed as an array, we don't know the unit
# or label, so we don't add anything
pass
scatter_kwargs["c"] = np.array(c_data)
# Add colorbar if requested
if kwargs.get("colorbar", False):
cmap = scatter_kwargs.get("cmap", plt.cm.viridis)
norm = Normalize(vmin=np.min(c_data), vmax=np.max(c_data))
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=self)
cbar.set_label(_clabel)
# Remove the colorbar from the kwargs
scatter_kwargs.pop("colorbar", None)
# Use the standard scatter method with our extracted data
return self.scatter(x_data, y_data, **scatter_kwargs)
[docs]
def ypl_hist(self, directory_node, x, autolabel=True, reference_unit=None, **kwargs):
"""Create a histogram from DirectoryNode data.
Args:
self (matplotlib.axes.Axes):
The axes to plot the histogram on.
directory_node (DirectoryNode):
The data source to extract plotting variables from.
x (str):
Key for the data to be histogrammed.
autolabel (bool, optional):
Whether to automatically label the axes.
reference_unit (astropy.units.Unit, optional):
Reference unit to convert data to. If provided, all data will be converted
to this unit before plotting. This ensures unit consistency across
multiple datasets.
**kwargs:
Additional keyword arguments passed to the hist method.
Returns:
tuple: (n, bins, patches) as returned by hist
"""
# Extract data
x_data = directory_node.get(x) if x else None
# Basic validation
if x_data is None:
raise ValueError(f"Could not get key '{x}' from {directory_node}")
hist_kwargs = kwargs.copy()
# Handle units for the data
x_unit = None
if isinstance(x_data, u.Quantity):
# Save the original unit for labeling
x_unit = x_data.unit
# If reference unit provided, convert the data
if reference_unit is not None:
if x_data.unit != reference_unit:
x_data = x_data.to(reference_unit)
# Update the display unit
x_unit = reference_unit
# Extract numerical values for histogram
x_data = x_data.value
if "bins" in hist_kwargs and isinstance(hist_kwargs["bins"], u.Quantity):
hist_kwargs["bins"] = hist_kwargs["bins"].to(x_unit).value.astype(int)
# Set axis label with unit if available
if autolabel:
_xlabel = x.replace("_", " ").title()
_xlabel += f" ({x_unit})" if x_unit else ""
self.set_xlabel(_xlabel)
# Use the standard hist method with our extracted data
return self.hist(x_data, **hist_kwargs)
[docs]
def extend_matplotlib():
"""Extend matplotlib's Axes class with our DirectoryNode methods.
This automatically adds all functions with the 'ypl_' prefix in this file
to the matplotlib Axes class, so they can be used directly on any axes.
"""
# Get all objects in the current module
current_module = sys.modules[__name__]
# Find all functions with the ypl_ prefix
for name, obj in inspect.getmembers(current_module):
# Check if it's a function and starts with ypl_
if inspect.isfunction(obj) and name.startswith("ypl_"):
# Only add if it doesn't already exist on Axes
if not hasattr(Axes, name):
setattr(Axes, name, obj)
[docs]
def subplots(*args, **kwargs):
"""A simple wrapper around plt.subplots().
Ensures our extensions are applied before returning.
Args:
*args: Arguments to pass to plt.subplots().
**kwargs: Keyword arguments to pass to plt.subplots().
Returns:
tuple: (fig, ax) as returned by plt.subplots()
"""
extend_matplotlib()
return plt.subplots(*args, **kwargs)