"""Utility functions."""
import functools
import matplotlib as mpl
import numpy as np
from astropy import units as u
from yieldplotlib.key_map import KEY_MAP
[docs]
def get_nice_number(value, round=False):
"""Calculates a "nice" number for labeling axes in a plot.
Args:
value (float):
The value to be transformed into a "nice" number.
round (bool, optional):
If True, rounds the number to the nearest "nice" number. If False,
the number is only scaled to be a "nice" number.
Returns:
float:
A "nice" number that is a rounded or scaled version of the input value.
"""
if value == 0:
return 0
# Calculate the exponent of the base 10 representation of the value
exponent = np.floor(np.log10(value))
# Calculate the fractional part of the value
fraction = value / 10**exponent
# Determine the "nice" fraction based on whether rounding is required
if round:
if fraction < 1.5:
nice_fraction = 1
elif fraction < 3:
nice_fraction = 2
elif fraction < 7:
nice_fraction = 5
else:
nice_fraction = 10
else:
if fraction <= 1:
nice_fraction = 1
elif fraction <= 2:
nice_fraction = 2
elif fraction <= 5:
nice_fraction = 5
else:
nice_fraction = 10
# Return the "nice" number by scaling the nice fraction by the exponent
return nice_fraction * 10**exponent
[docs]
def calculate_axis_limits_and_ticks(data_min, data_max, num_ticks=5, exact=False):
"""Calculates the axis limits and tick spacing for a plot.
Args:
data_min (float):
The minimum value of the data.
data_max (float):
The maximum value of the data.
num_ticks (int, optional):
The desired number of tick marks on the axis. Default is 5.
exact (bool, optional):
If True, use exact min and max values for the limits. If False, the
limits are adjusted to "nice" values.
Returns:
tuple:
nice_min (float):
The adjusted minimum axis limit.
nice_max (float):
The adjusted maximum axis limit.
tick_spacing (float):
The spacing between ticks.
offset (float):
A small offset to apply to the axis limits for better visualization.
"""
# Calculate the "nice" range span of the data
range_span = get_nice_number(data_max - data_min, round=True)
# Calculate the "nice" tick spacing based on the range span and desired
# number of ticks
tick_spacing = get_nice_number(range_span / (num_ticks - 1), round=True)
if exact:
# Use exact min and max values if specified
nice_min = data_min
nice_max = data_max
else:
# Adjust the min and max values to "nice" numbers
nice_min = np.floor(data_min / tick_spacing) * tick_spacing
nice_max = np.ceil(data_max / tick_spacing) * tick_spacing
# Calculate a small offset for better visualization of the axis limits
offset = 0.025 * tick_spacing
return nice_min, nice_max, tick_spacing, offset
[docs]
def is_monotonic(x):
"""Checks if an array is monotonic."""
dx = np.diff(x)
return np.all(dx <= 0) or np.all(dx >= 0)
[docs]
def rgetattr(obj, attr, *args):
"""Recursively get attributes of an object."""
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj, *attr.split(".")])
[docs]
def discretize_colormap(num_colors, colormap_name, start_frac=0.1, end_frac=0.9):
"""Returns evenly spaced discrete colors from a matplotlib colormap."""
cmap = mpl.colormaps[colormap_name]
colors = cmap(np.linspace(start_frac, end_frac, num_colors))
return colors
[docs]
def find_unit_for_module_key(module_key, module_name, key_map):
"""Find the unit for a given module-specific key.
Searches through the KEY_MAP to find a mapping where the provided module key
matches the 'name' field for the specified module. If found, returns the
corresponding unit.
Args:
module_key (str):
The module-specific key (e.g., 'Angdiam (mas)' for AYO,
'pixelScale' for EXOSIMS).
module_name (str):
The name of the module (e.g., 'AYOCSVFile', 'EXOSIMSInputFile').
key_map (dict):
The key mapping dictionary to use for lookups.
Returns:
str or None:
The unit string if found, None otherwise.
"""
for _yieldplotlib_key, module_data in key_map.items():
# Check if this entry has data for the specified module
if module_name in module_data:
module_info = module_data[module_name]
# Check if the 'name' field matches the module key
if module_info.get("name") == module_key:
return module_info.get("unit", "")
return None
[docs]
def get_unit(key, module_name, find_unit_func=None):
"""Get the associated unit for a given key.
This generic method handles both yieldplotlib keys and module-specific keys:
1. First tries a direct lookup in KEY_MAP (assuming key is a yieldplotlib key)
2. If not found, tries to find the corresponding yieldplotlib key by looking up
the module-specific key in KEY_MAP
Args:
key (str):
The key to look up the unit for (can be either a yieldplotlib key
or a module-specific key).
module_name (str):
The name of the module making the request (used for KEY_MAP lookup).
find_unit_func (callable, optional):
Optional custom fallback function that takes a key and returns a
unit string. If None, the built-in find_unit_for_module_key
function will be used.
Returns:
astropy.units.Unit or None:
The astropy Unit object if found, None otherwise.
"""
# First try direct lookup (for yieldplotlib keys)
if key in KEY_MAP and module_name in KEY_MAP[key]:
unit = KEY_MAP[key][module_name].get("unit", "")
elif find_unit_func is not None:
# If not found, try to find the corresponding yieldplotlib key
# by using the provided custom fallback function
unit = find_unit_func(key)
else:
# Use the built-in generic function as a fallback
unit = find_unit_for_module_key(key, module_name, KEY_MAP)
if unit:
astropy_unit = u.Unit(unit)
return astropy_unit
return None