Source code for yieldplotlib.plots.yield_hist

"""Script to test AYO and EXOSIMS loading."""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Patch


[docs] def plot_hist( temps, planet_bins, runs, run_labels, ax=None, ax_kwargs=None, use_cyberpunk=False ): """Plot a histogram of planet populations for different temperature bins. Args: temps (list): List of temperature bins to plot, e.g., ["hot", "warm", "cold"]. planet_bins (list): List of planet types to plot, e.g., ["Earth", "Rocky", "Super Earth"]. runs (list): List of EXOSIMSDirectories and AYODirectories to plot. run_labels (list): List of labels for each run. ax (matplotlib.axes.Axes, optional): Axes to plot on. If None, a new figure is created. ax_kwargs (dict, optional): Keyword arguments to pass to ax.set(). use_cyberpunk (bool, optional): Whether to use the mplcyberpunk style. Default is False. Returns: matplotlib.figure.Figure, matplotlib.axes.Axes: Figure and axes objects for the plot. """ if ax_kwargs is None: ax_kwargs = {} if use_cyberpunk: import mplcyberpunk # noqa: F401 plt.style.use("cyberpunk") # Create a list of keys for the planet populations planet_populations = [] for temp in temps: for planet_bin in planet_bins: if planet_bin == "Earth": if "yield_earth" not in planet_populations: planet_populations.append("yield_earth") else: # for run, label in zip(runs, run_labels): planet_populations.append( f"yield_{temp}_{planet_bin.lower().replace(' ', '_')}" ) data = [] for key in planet_populations: parts = key.split("_") # Extract temperature and planet type from the key if key == "yield_earth": # Assign a unique temperature category for Earth temperature = "earth" planet_type = "Earth" elif parts[1].lower() in ["hot", "warm", "cold"]: # 'hot', 'warm', 'cold' temperature = parts[1].lower() # e.g., 'Rocky', 'Super Earth' planet_type = "_".join(parts[2:]).replace("_", " ").title() else: # Handle unexpected formatting temperature = "unknown" planet_type = "Unknown" for run, label in zip(runs, run_labels, strict=False): # Retrieve data run_data = run.get(key) try: value = float(run_data) except (ValueError, TypeError): value = 0 # Assign 0 if data is missing or not a number # Append data data.append( { "planet_type": planet_type, "temperature": temperature, "model": label, "value": value, } ) # Create DataFrame from the data df = pd.DataFrame(data) # Define unique planet types and temperatures # Ensure 'Earth' is the first planet type plotting_earths = "yield_earth" in planet_populations planet_bins = ["Rocky", "Super Earth", "Sub Neptune", "Neptune", "Jupiter"] planet_types = [x for x in planet_bins if x in df["planet_type"].unique()] group_labels = ["Earth", *planet_types] if plotting_earths else planet_types temps = df.temperature.unique() # Sort to make sure it's always "hot", "warm", "cold" temp_order = ["hot", "warm", "cold"] temperatures = [x for x in temp_order if x in temps] df["planet_type"] = pd.Categorical( df["planet_type"], categories=group_labels, ordered=True ) df["temperature"] = pd.Categorical( df["temperature"], categories=temperatures, ordered=True ) df = df.sort_values(["planet_type", "temperature", "model"]).reset_index(drop=True) # Define color mapping for temperatures if not use_cyberpunk: color_map = { "earth": "#228B22", "hot": "#FF6347", "warm": "#FFD700", "cold": "#1E90FF", } else: color_map = { "earth": "C3", "hot": "C1", "warm": "C2", "cold": "C0", } # Define hatching patterns for models (runs) hatch_patterns = ["", "///", "\\\\", "xx", "++", "..", "**"] hatch_map = {} for i, label in enumerate(run_labels): hatch_map[label] = hatch_patterns[i % len(hatch_patterns)] # Set up the plot if ax is None: fig, ax = plt.subplots(figsize=(15, 7.5)) else: fig = ax.get_figure() n_planets = len(planet_types) n_temps = len(temperatures) n_runs = len(runs) # Width of each bar bar_width = 0.3 # Get the width of the earth group and the total width of the other groups earth_group_width = bar_width * n_runs temp_group_width = bar_width * n_runs * n_temps gap_width = bar_width / 2 # Positions of planet groups on the x-axis planet_x = np.arange( 0, (temp_group_width + gap_width) * n_planets, temp_group_width + gap_width ) if plotting_earths: earth_group_offset = earth_group_width / 2 + gap_width + temp_group_width / 2 planet_x += earth_group_offset # Iterate over each temperature and model to plot bars # Plot the Earth bars first if plotting_earths: for j, (_run, label) in enumerate(zip(runs, run_labels, strict=False)): offset = (j - (n_runs - 1) / 2) * bar_width # Filter the df to get the earth values for this run subset = df[(df["planet_type"] == "Earth") & (df["model"] == label)] _bar = ax.bar( offset, subset.value.values, bar_width, color=color_map.get("earth"), hatch=hatch_map.get(label), edgecolor="black", linewidth=1, alpha=0.8, ) autolabel(ax, _bar, use_cyberpunk) for i, temp in enumerate(temperatures): for j, (_run, label) in enumerate(zip(runs, run_labels, strict=False)): # Calculate the offset for each bar offset = (i * n_runs + j + 0.5) * bar_width - (temp_group_width / 2) # Filter the DataFrame for the current temperature and model subset = df[(df["temperature"] == temp) & (df["model"] == label)] # Plot the bars _bar = ax.bar( planet_x + offset, subset.value.values, bar_width, color=color_map.get(temp.lower()), hatch=hatch_map.get(label), edgecolor="black", linewidth=1, alpha=0.8, ) autolabel(ax, _bar, use_cyberpunk) # Create custom legends if plotting_earths: temp_patches = [Patch(facecolor=color_map["earth"], label="Earth")] + [ Patch(facecolor=color_map[temp.lower()], label=temp.title()) for temp in temperatures ] else: temp_patches = [ Patch(facecolor=color_map[temp.lower()], label=temp.title()) for temp in temperatures ] # Legend for models model_patches = [ Patch(facecolor="grey", hatch=hatch_map[label], label=label) for label in run_labels ] # Define titles for each group temp_title = "Planet Type" model_title = "Run" # Combine handles and labels with titles handles = [Patch(alpha=0), *temp_patches, Patch(alpha=0), *model_patches] labels = ( [temp_title] + [p.get_label() for p in temp_patches] + [model_title] + [p.get_label() for p in model_patches] ) # Create the combined legend combined_legend = ax.legend(handles=handles, labels=labels, ncol=1) ax.add_artist(combined_legend) # Set labels and title ax.set_ylabel("Yield") ax.set_xticks([0, *planet_x.tolist()] if plotting_earths else planet_x.tolist()) xtick_labels = ["Earth", *planet_types] if plotting_earths else planet_types ax.set_xticklabels(xtick_labels, ha="center") ax.set(**ax_kwargs) return fig, ax
# Function to attach a text label above each bar displaying its height
[docs] def autolabel(ax, rects, use_cyberpunk=False): """Attach a text label above each bar displaying its height.""" for rect in rects: height = rect.get_height() if height > 0: ax.annotate( f"{height:.1f}", xy=(rect.get_x() + rect.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha="center", va="bottom", fontsize=8, color="white" if use_cyberpunk else "black", )