Source code for harbor.plotting.plotly_utils

import plotly.graph_objects as go
import pandas as pd
from matplotlib import colors as plt_colors


[docs] def rgb_to_rgba(rgb_str, alpha): # Split the RGB string into its components rgb_values = rgb_str.strip("rgb()").split(",") # Extract individual RGB values and convert them to integers r, g, b = map(int, rgb_values) # Construct the RGBA string rgba_str = f"rgba({r}, {g}, {b}, {alpha})" return rgba_str
[docs] def hex_to_rgb(hex_color: str) -> tuple: hex_color = hex_color.lstrip("#") if len(hex_color) == 3: hex_color = hex_color * 2 return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
[docs] def update_traces(fig, labels): for trace in fig.data: if trace.name is None: continue trace.name = trace.name.replace("_", " ") trace.name = trace.name.replace("Split", "") trace.name = trace.name.replace(", ", " | ") trace.name = trace.name.replace("RMSD", "RMSD (Positive Control)") for k, v in labels.items(): trace.name = trace.name.replace(k, v) return fig
[docs] def create_plot_with_error_bands_and_dual_legend( df, x, y, error_column=None, lower_error_column=None, upper_error_column=None, color_category_column=None, dash_category_column=None, name_column=None, title=None, x_title=None, y_title=None, width=900, height=600, ): """ Creates a line plot with error bands and separate legends for categories and types from a tidy (long-format) DataFrame. Parameters: ----------- df : pandas.DataFrame Tidy DataFrame containing the data to plot x_column : str Column name for x-axis values y_column : str Column name for y-axis values error_column : str or None Column name for symmetric error values (set to None if using lower/upper error columns) lower_error_column : str or None Column name for lower error bounds (only used if error_column is None) upper_error_column : str or None Column name for upper error bounds (only used if error_column is None) color_category_column : str or None Column name for categories (for color grouping) dash_category_column : str or None Column name for types (for line style grouping) name_column : str or None Column name for trace names (if None, uses combination of color_category and type) title, x_title, y_title : str Plot titles (if None, uses column names) width, height : int Plot dimensions Returns: -------- fig : plotly.graph_objects.Figure """ # Create figure fig = go.Figure() # Set default titles if not provided if x_title is None: x_title = x if y_title is None: y_title = y # If color_category or type columns not provided, create dummy ones for consistent processing df_plot = df.copy() if color_category_column is None: df_plot["_category"] = "Default Color Category" color_category_column = "_color_category" if dash_category_column is None: df_plot["_type"] = "Default Dash Category" dash_category_column = "_dash_category" # Determine trace name source if name_column is None: # Create a name based on color_category and type if not provided df_plot["_name"] = ( df_plot[color_category_column] + " - " + df_plot[dash_category_column] ) name_column = "_name" # Get unique values for grouping unique_names = df_plot[name_column].unique() unique_categories = df_plot[color_category_column].unique() unique_types = df_plot[dash_category_column].unique() # Create color and dash style mappings colors = { cat: f"rgb{tuple(int(c * 255) for c in plt_colors.to_rgb(f'C{i}'))}" for i, cat in enumerate(unique_categories) } dash_styles = { typ: style for typ, style in zip( unique_types, ["solid", "dash", "dot", "dashdot", "longdash", "longdashdot"][ : len(unique_types) ], ) } # Process each unique trace (combination of grouping variables) for name in unique_names: # Filter data for this trace trace_data = df_plot[df_plot[name_column] == name] if len(trace_data) == 0: continue # Get first row to determine color_category and type first_row = trace_data.iloc[0] color_category = first_row[color_category_column] type_val = first_row[dash_category_column] # Get color and dash style color = colors[color_category] dash = dash_styles[type_val] # Sort data by x values trace_data = trace_data.sort_values(by=x) # Get x and y values x_values = trace_data[x] y_values = trace_data[y] # Check if we need to draw error bands has_error = False y_upper = None y_lower = None if error_column is not None and error_column in trace_data.columns: # Symmetric error error_values = trace_data[error_column] y_upper = y_values + error_values y_lower = y_values - error_values has_error = True elif ( lower_error_column is not None and upper_error_column is not None and lower_error_column in trace_data.columns and upper_error_column in trace_data.columns ): # Asymmetric error with separate bounds y_lower = trace_data[lower_error_column] y_upper = trace_data[upper_error_column] has_error = True # Add error band (if errors are available) if has_error: # Add error band as a filled area x_error = list(x_values) + list(x_values[::-1]) y_error = list(y_upper) + list(y_lower[::-1]) # Convert RGB color to RGBA with transparency color_parts = ( color.replace("rgb", "").replace("(", "").replace(")", "").split(",") ) rgba_color = f"rgba({color_parts[0]},{color_parts[1]},{color_parts[2]},0.3)" fig.add_trace( go.Scatter( x=x_error, y=y_error, fill="toself", fillcolor=rgba_color, line=dict(color="rgba(255,255,255,0)"), hoverinfo="skip", showlegend=False, name=f"{name} Error Band", ) ) # Add line trace fig.add_trace( go.Scatter( x=x_values, y=y_values, mode="lines", line=dict(color=color, dash=dash, width=2), name=name, legendgroup=name, showlegend=False, # Will hide actual traces from legend ) ) # Add "dummy" traces for color_category legend (colors) for color_category, color in colors.items(): fig.add_trace( go.Scatter( x=[None], y=[None], # No data points mode="lines", line=dict(color=color, width=2), name=color_category, legendgroup=color_category_column, legendgrouptitle_text=color_category_column, showlegend=True, ) ) # Add "dummy" traces for type legend (line styles) for dash_category, dash in dash_styles.items(): fig.add_trace( go.Scatter( x=[None], y=[None], # No data points mode="lines", line=dict(color="black", dash=dash, width=2), name=dash_category, legendgroup=dash_category_column, legendgrouptitle_text=dash_category_column, showlegend=True, ) ) # Update layout fig.update_layout( title=title, xaxis_title=x_title, yaxis_title=y_title, width=width, height=height, legend=dict(groupclick="toggleitem"), template="plotly_white", ) return fig
# Example usage if __name__ == "__main__": # example data data = [ { "Bootstraps": 1000, "StructureChoice": "Dock_to_All", "StructureChoice_Choose_N": "All", "Score": "POSIT_Probability", "Score_Choose_N": 1, "EvaluationMetric": "RMSD", "EvaluationMetric_Cutoff": 2.0, "Split": "SimilaritySplit", "N_Per_Split": -1, "Split_Variable": "Tanimoto", "PoseSelection": "Default", "PoseSelection_Choose_N": 1, "Min": 0.0, "Max": 0.0, "CI_Upper": 0.975, "CI_Lower": 0.025, "Total": 0, "Fraction": 0.0, "Similarity_Threshold": 0.0, "Include_Similar": False, "Higher_Is_More_Similar": True, "Aligned": True, "Type": "TanimotoCombo", "Error_Lower": -0.025, "Error_Upper": 0.975, "Engine": "POSIT", }, { "Bootstraps": 1000, "StructureChoice": "Dock_to_All", "StructureChoice_Choose_N": "All", "Score": "POSIT_Probability", "Score_Choose_N": 1, "EvaluationMetric": "RMSD", "EvaluationMetric_Cutoff": 2.0, "Split": "SimilaritySplit", "N_Per_Split": -1, "Split_Variable": "Tanimoto", "PoseSelection": "Default", "PoseSelection_Choose_N": 1, "Min": 0.0181818181818181, "Max": 0.0181818181818181, "CI_Upper": 0.0519043474859503, "CI_Lower": 0.0066036075243498, "Total": 165, "Fraction": 0.0181818181818181, "Similarity_Threshold": 0.25, "Include_Similar": False, "Higher_Is_More_Similar": True, "Aligned": True, "Type": "TanimotoCombo", "Error_Lower": 0.0115782106574683, "Error_Upper": 0.03372252930413219, "Engine": "FRED", }, { "Bootstraps": 1000, "StructureChoice": "Dock_to_All", "StructureChoice_Choose_N": "All", "Score": "POSIT_Probability", "Score_Choose_N": 1, "EvaluationMetric": "RMSD", "EvaluationMetric_Cutoff": 2.0, "Split": "SimilaritySplit", "N_Per_Split": -1, "Split_Variable": "Tanimoto", "PoseSelection": "Default", "PoseSelection_Choose_N": 1, "Min": 0.3715596330275229, "Max": 0.3715596330275229, "CI_Upper": 0.4375041030245344, "CI_Lower": 0.310142546161883, "Total": 218, "Fraction": 0.3715596330275229, "Similarity_Threshold": 0.5, "Include_Similar": False, "Higher_Is_More_Similar": True, "Aligned": True, "Type": "TanimotoCombo", "Error_Lower": 0.06141708686563985, "Error_Upper": 0.06594446999701153, "Engine": "POSIT", }, { "Bootstraps": 1000, "StructureChoice": "Dock_to_All", "StructureChoice_Choose_N": "All", "Score": "POSIT_Probability", "Score_Choose_N": 1, "EvaluationMetric": "RMSD", "EvaluationMetric_Cutoff": 2.0, "Split": "SimilaritySplit", "N_Per_Split": -1, "Split_Variable": "Tanimoto", "PoseSelection": "Default", "PoseSelection_Choose_N": 1, "Min": 0.7522935779816514, "Max": 0.7522935779816514, "CI_Upper": 0.8048516150447931, "CI_Lower": 0.690843980335917, "Total": 218, "Fraction": 0.7522935779816514, "Similarity_Threshold": 0.75, "Include_Similar": False, "Higher_Is_More_Similar": True, "Aligned": True, "Type": "TanimotoCombo", "Error_Lower": 0.06144959764573443, "Error_Upper": 0.05255803706314166, "Engine": "FRED", }, ] df = pd.DataFrame.from_records(data) # Example with asymmetric errors fig = create_plot_with_error_bands_and_dual_legend( df=df, x="Similarity_Threshold", y="Fraction", error_column=None, # Don't use symmetric error lower_error_column="Error_Lower", # For asymmetric error upper_error_column="Error_Upper", # For asymmetric error color_category_column="Engine", dash_category_column="Score", title="Fraction of Ligands with RMSD < 2.0", x_title="TanimotoCombo Similarity", y_title="Fraction of Ligands with RMSD < 2.0", ) fig.show()