Source code for harbor.analysis.cross_docking

import itertools
from pydantic import BaseModel, Field, model_validator, field_validator, ConfigDict
from typing_extensions import Self
import abc
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path
from typing import Optional
from enum import Flag, auto
import json
import yaml
from enum import Enum, StrEnum
from operator import eq, gt, lt, ge, le, ne
from pydantic import confloat


[docs] class Operator(StrEnum): EQ = "eq" GT = "gt" LT = "lt" GE = "ge" LE = "le" NE = "ne" IN = "in" def to_callable(self) -> callable: def isin(x, value): if isinstance(value, (list, tuple, set)): return x in value return False return { self.EQ: eq, self.GT: gt, self.LT: lt, self.GE: ge, self.LE: le, self.NE: ne, self.IN: isin, }[self]
[docs] class ColumnFilter(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) column: str = Field(..., description="Column to filter on") value: str | int | float | list operator: Operator = Operator.EQ @model_validator(mode="after") def match_operator_with_value(self): if self.operator == Operator.IN and not isinstance(self.value, (list, tuple)): raise ValueError("Operator 'in' requires value to be a list or tuple.") if self.operator != Operator.IN and isinstance(self.value, (list, tuple)): raise ValueError("Operator 'in' is only valid for list or tuple values.") return self def filter(self, dataframe: pd.DataFrame) -> pd.DataFrame: if self.column not in dataframe.columns: raise ValueError(f"Column '{self.column}' not found in DataFrame.") df = dataframe[ dataframe[self.column].apply( lambda x: self.operator.to_callable()(x, self.value) ) ] return df
[docs] class ColumnSortFilter(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) key_columns: list[str] = Field(..., description="Columns to get unique data from") sort_column: str = Field(..., description="Columns to sort by") ascending: bool = Field( True, description="Sort in ascending order if True, descending if False" ) number_to_return: int = Field(1, description="Number of rows to return") def filter(self, dataframe: pd.DataFrame) -> pd.DataFrame: if self.sort_column in self.key_columns: # If the sort column is also a key column, we need to remove it from the key columns self.key_columns.remove(self.sort_column) if self.sort_column not in dataframe.columns: raise ValueError(f"Column '{self.sort_column}' not found in DataFrame.") df = ( dataframe.sort_values(self.sort_column, ascending=self.ascending) .groupby([key for key in self.key_columns]) .head(self.number_to_return) ) return df
[docs] def merge_on_common_columns(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: """Merge two DataFrames on all their common columns.""" common_cols = list(set(df1.columns) & set(df2.columns)) if not common_cols: raise ValueError("No common columns found between DataFrames") return pd.merge(df1, df2, on=common_cols, how="left")
[docs] class DataFrameType(StrEnum): """Enum for DataFrame types.""" REFERENCE = "ReferenceData" QUERY = "QueryData" PAIRED = "PairedData" POSE = "PoseData" CHEMICAL_SIMILARITY = "ChemicalSimilarityData" COMBINED = "CombinedData" def __or__(self, other): if not isinstance(other, DataFrameType): return NotImplemented return (self, other)
[docs] class DataFrameModelBase(BaseModel): name: str = Field(..., description="Unique name refering to this dataframe's data") type: str = Field(..., description="Data frame type. Used for grouping.") dataframe: pd.DataFrame = Field( ..., description="DataFrame containing model data", exclude=True ) model_config = ConfigDict(arbitrary_types_allowed=True) def __eq__(self, other): if not isinstance(other, DataFrameModelBase): return False return self.dataframe.equals(other.dataframe) @field_validator("type") def check_matches_dataframe_types(cls, v): if not DataFrameType(v): raise ValueError return v def serialize(self, file_path: str | Path) -> Path: # first write schema to json file_path = Path(file_path) with open(file_path.with_suffix(".json"), "w") as f: json.dump(self.model_dump(), f) # then write dataframe to parquet self.dataframe.to_parquet(file_path.with_suffix(".parquet")) return file_path @classmethod def deserialize(cls, file_path: str | Path) -> "DataFrameModelBase": # load json schema file_path = Path(file_path) with open(file_path.with_suffix(".json"), "r") as f: model_schema = json.load(f) df = pd.read_parquet(file_path.with_suffix(".parquet")) return cls(dataframe=df, **model_schema)
[docs] class DataFrameModel(DataFrameModelBase): key_columns: list[str] = Field( ..., description="Columns that connect this data with other data" ) param_columns: list[str] = Field( [], description="Columns that, combined with the key_columns, specify all the columns necessary to select unique rows", ) value_columns: list[str] = Field( [], description="The other columns expected in the data" ) def __eq__(self, other): if not isinstance(other, DataFrameModel): return False return ( self.dataframe.equals(other.dataframe) and self.key_columns == other.key_columns ) @field_validator("key_columns", "value_columns") def check_columns_are_unique(cls, v): if len(set(v)) < len(v): raise ValueError return v
[docs] @model_validator(mode="after") def check_key_and_params_specify_unique_rows(self): """Make sure that if you sort by the combined key and params columns you get a unique row for each combination of key and params""" expected_cols = self.key_columns + self.param_columns for col in expected_cols: if col not in self.dataframe.columns: raise KeyError( f"Expected Column '{col}' specified is not present in the DataFrame columns {self.dataframe.columns}." ) grouped_by_key_and_param = self.dataframe.groupby(expected_cols).count() if len(grouped_by_key_and_param) == 0: raise KeyError( f"Grouping the dataframe by key_columns: '{self.key_columns}'" f"and param_columns: '{self.param_columns}" f"resulted in an empty dataframe" ) n_problem_rows = (grouped_by_key_and_param > 1).any(axis=1).sum() if n_problem_rows > 0: raise KeyError( f"Grouping the dataframe by key_columns: '{self.key_columns}'" f"and param_columns: '{self.param_columns}" f"resulted in {n_problem_rows} rows with duplicate values." f"Perhaps another column is needed!" ) additional_cols = [ col for col in self.dataframe.columns if col not in expected_cols ] self.value_columns.extend(additional_cols) return self
[docs] class DockingDataModel(DataFrameModelBase): dataframe: pd.DataFrame = Field( ..., description="DataFrame containing model data", exclude=True ) data_types_dict: dict[str, str] = Field( ..., description="Dictionary mapping unique names to their type." ) key_columns_dict: dict[str, list[str]] = Field( ..., description="Dictionary mapping a unique name to a list of keys that link all the data", ) param_columns_dict: dict[str, list[str]] = Field( ..., description="Dictionary mapping a unique name to a list of columns, that when combined with the key_columns, result in a unique rows", ) value_columns_dict: dict[str, list[str]] = Field( ..., description="Dictionary mapping a unique name to a list of columns with the values we want to query from each dataframe", ) def __eq__(self, other): if not isinstance(other, DockingDataModel): return False # make sure dataframes are equal even if not in the same order df1_sorted = self.dataframe.reset_index(drop=True).sort_index(axis=1) df2_sorted = other.dataframe.reset_index(drop=True).sort_index(axis=1) return ( df1_sorted.equals(df2_sorted) and set(self.get_key_columns()) == set(other.get_key_columns()) and set(self.get_param_columns()) == set(other.get_param_columns()) and set(self.get_value_columns()) == set(other.get_value_columns()) and set(self.get_dataframe_names()) == set(other.get_dataframe_names()) ) def get_ref_data_name(self) -> str: return [ k for k, v in self.data_types_dict.items() if v == DataFrameType.REFERENCE ][0] def get_ref_column(self) -> str: return self.key_columns_dict[self.get_ref_data_name()][0] def get_lig_data_name(self) -> str: return [k for k, v in self.data_types_dict.items() if v == DataFrameType.QUERY][ 0 ] def get_lig_column(self) -> str: return self.key_columns_dict[self.get_lig_data_name()][0] def get_unique_refs(self) -> list: return list(self.dataframe[self.get_ref_column()].unique()) def get_unique_ligs(self) -> list: return list(self.dataframe[self.get_lig_column()].unique()) def get_dataframe_names(self) -> list: return [ky for ky in self.key_columns_dict.keys()] def get_key_columns(self) -> list: return list( set([col for cols in self.key_columns_dict.values() for col in cols]) ) def get_param_columns(self) -> list: return list( set([col for cols in self.param_columns_dict.values() for col in cols]) ) def get_value_columns(self) -> list: return list( set([col for cols in self.value_columns_dict.values() for col in cols]) )
[docs] def get_groupby_columns(self, except_cols: tuple | list | set = ()) -> list: """ If provided without any `except_cols`, this should return the original dataframe. :param except_cols: :return: """ if not isinstance(except_cols, tuple | list | set): raise ValueError( f"except_cols must be a list or set, not {type(except_cols)}" ) kc = set(self.get_key_columns()) pc = set(self.get_param_columns()) cc = kc.union(pc) - set(except_cols) return list(cc)
def get_pose_data_columns(self) -> list: pose_data_key = [ k for k, v in self.data_types_dict.items() if v == DataFrameType.POSE ][0] return list(self.key_columns_dict[pose_data_key]) def get_total_poses(self) -> int: return len(self.dataframe.groupby(self.get_pose_data_columns())) def get_lig_dataframe(self) -> pd.DataFrame: return self.dataframe.groupby(self.get_lig_column()).head(1) def get_ref_dataframe(self) -> pd.DataFrame: return self.dataframe.groupby(self.get_ref_column()).head(1) @classmethod def from_models(cls, data_models) -> "DockingDataModel": from functools import reduce names = [model.name for model in data_models] if len(set(names)) < len(names): raise ValueError(f"DataFrameModels should have unique names: {names}") # Rename columns in each DataFrame to include the dataset name as a prefix renamed_dataframes = [] for model in data_models: renamed_df = model.dataframe.rename( columns={ col: f"{model.name}_{col}" for col in model.dataframe.columns if col not in model.key_columns } ) renamed_dataframes.append(renamed_df) # Merge the renamed DataFrames on common columns df = reduce(merge_on_common_columns, renamed_dataframes) # Update key_columns_dict and other_columns_dict with renamed columns key_columns_dict = { model.name: [col for col in model.key_columns] for model in data_models } param_columns_dict = { model.name: [f"{model.name}_{col}" for col in model.param_columns] for model in data_models } value_columns_dict = { model.name: [f"{model.name}_{col}" for col in model.value_columns] for model in data_models } return DockingDataModel( name="DockingDataModel", type=DataFrameType.COMBINED, dataframe=df, data_types_dict={model.name: model.type for model in data_models}, key_columns_dict=key_columns_dict, param_columns_dict=param_columns_dict, value_columns_dict=value_columns_dict, ) def to_models(self) -> list["DataFrameModel"]: models = [] for model_name, model_type in self.data_types_dict.items(): # Extract relevant columns from the combined DataFrame relevant_columns = ( self.key_columns_dict[model_name] + self.param_columns_dict[model_name] + self.value_columns_dict[model_name] ) model_dataframe = self.dataframe.groupby(relevant_columns)[ relevant_columns ].head(1) print(model_name, model_type, relevant_columns) # rename columns param_columns = [ col.replace(f"{model_name}_", "") for col in self.param_columns_dict[model_name] ] value_columns = [ col.replace(f"{model_name}_", "") for col in self.value_columns_dict[model_name] ] # Rename columns to remove the dataset name prefix model_dataframe.rename( columns={ f"{model_name}_{col}": col for col in param_columns + value_columns }, inplace=True, ) # Create a DataFrameModel instance model = DataFrameModel( name=model_name, type=model_type, dataframe=model_dataframe, key_columns=self.key_columns_dict[model_name], param_columns=param_columns, value_columns=value_columns, ) models.append(model) return models
[docs] @model_validator(mode="after") def check_columns_in_dataframe(self): """Check if all expected columns are in DataFrame and ensure key_columns uniqueness.""" expected_cols = ( self.get_key_columns() + self.get_param_columns() + self.get_value_columns() ) for col in expected_cols: if col not in self.dataframe.columns: raise ValueError( f"Expected Column '{col}' specified is not present in the DataFrame columns {self.dataframe.columns}." ) return self
[docs] def apply_filters(self, filters: list[ColumnFilter | ColumnSortFilter]): """ Apply filters in place to self.dataframe :param filters: :return: """ for filter_ in filters: self.dataframe = filter_.filter(self.dataframe)
[docs] class ModelBase(BaseModel): type_: str = Field(..., description="Type of model") @abc.abstractmethod def plot_name(self) -> str: pass def get_records(self) -> dict: return {}
[docs] class EmptyModel(ModelBase): type_: str = Field("EmptyModel", description="Empty model") def plot_name(self) -> str: return ""
[docs] class SplitBase(ModelBase): """ Base class for splitting the data (i.e. Random, Dataset, Scaffold, etc) """ name: str = "SplitBase" type_: str = "SplitBase" deterministic: bool = Field( False, description="Whether the split is deterministic, i.e. if True it should not be run in the bootstrapping loop.", ) split_level: int = Field( 0, description="Level of the split, 0 indexed. The first level is applied first, and so on.", ) @abc.abstractmethod def run(self, data: DockingDataModel) -> [DockingDataModel]: pass @property def plot_name(self) -> str: return f"{self.name}"
[docs] class ReferenceStructureSplitBase(SplitBase): """ Base class for splitting the data based on some attributes of the reference structure """ reference_structure_column: str = Field( ..., description="Name of the column to distinguish reference structures by" ) n_reference_structures: Optional[int] = Field( None, description="Number of values per split to generate" ) def get_records(self) -> dict: return { "Reference_Split": self.name, "N_Reference_Structures": self.n_reference_structures, "Reference_Structure_Column": self.reference_structure_column, }
[docs] def get_unique_structures_randomized_by_date( df: pd.DataFrame, structure_column: str, date_column: str, n_structures_to_return: int, n_days_to_randomize: int, date_format="%Y-%m-%d %H:%M:%S", bootstraps: int = 1, ) -> list[set]: """ Get a set of structures randomized by date from a dataframe. Args: df: DataFrame containing structure and date information structure_column: Name of the column containing structure identifiers date_column: Name of the column containing dates n_structures_to_return: Number of structures to return n_days_to_randomize: Number of days to randomize the selection date_format: Format of the dates in date_column Returns: Set of selected structure identifiers """ # Get unique structures unique_structures = df[structure_column].unique() if len(unique_structures) < n_structures_to_return: raise ValueError( f"Number of Unique Structures ({len(unique_structures)}) < N Structures to Return ({n_structures_to_return})." f"Returning all unique structures." ) # Create working dataframe with unique structures and their dates working_df = df[[structure_column, date_column]].drop_duplicates() working_df["date"] = pd.to_datetime(working_df[date_column], format=date_format) working_df.sort_values(by="date", inplace=True) # Get the date of the nth structure last_date = working_df.iloc[n_structures_to_return - 1]["date"] last_date_with_buffer = last_date + pd.Timedelta(days=n_days_to_randomize) # Get all structures within the date range candidates = working_df[working_df["date"] <= last_date_with_buffer][ structure_column ].tolist() candidate_list = [] for i in range(bootstraps): # Get a random sample of the candidates if len(candidates) >= n_structures_to_return: candidate_sample = np.random.choice( candidates, size=n_structures_to_return, replace=False ) candidate_list.append(set(candidate_sample)) elif len(candidates) < n_structures_to_return: raise RuntimeError( f"{len(candidates)} candidates < {n_structures_to_return} structures to return." ) return candidate_list
[docs] def generate_random_samples( values: list, n_values: int, n_samples: int ) -> list[np.ndarray]: """ Generate multiple random samples from a list of values. Args: values: List of values to sample from n_values: Number of values to sample each time n_samples: Number of samples to generate Returns: List of arrays containing the sampled values """ if n_values > len(values): raise ValueError( f"Cannot sample {n_values} values from a list of {len(values)} values." ) return [ np.random.choice(list(values), size=n_values, replace=False) for _ in range(n_samples) ]
[docs] class RandomSplit(ReferenceStructureSplitBase): """ Randomly split the structures into n_splits """ name: str = "RandomSplit" type_: str = "RandomSplit" def run(self, data: DockingDataModel, bootstraps=1) -> [DockingDataModel]: unique_refs = data.dataframe[self.reference_structure_column].unique() if self.n_reference_structures is None or self.n_reference_structures == len( unique_refs ): # then we're returning everything, so no differences return [data] else: random_ref_samples = generate_random_samples( unique_refs, n_values=self.n_reference_structures, n_samples=bootstraps, ) return [ DockingDataModel( dataframe=data.dataframe[ data.dataframe[self.reference_structure_column].isin(sample) ], **data.model_dump(), ) for sample in random_ref_samples ]
[docs] class DateSplit(ReferenceStructureSplitBase): """ Splits the data by date. """ name: str = "DateSplit" type_: str = "DateSplit" date_column: str = Field( ..., description="Column corresponding to date deposition", ) randomize_by_n_days: int = Field( 0, description="Randomize the structures by n days. If 0 no randomization is done. If 1 or greater, for each structure, it can be randomly replaced by any other structure collected on that day or n-1 days from it's collection date.", ) def get_records(self) -> dict: records = super().get_records() records.update( { "Randomize_by_N_Days": self.randomize_by_n_days, "Date_Column": self.date_column, } ) return records def run(self, data: DockingDataModel, bootstraps=1) -> [DockingDataModel]: unique_refs = data.dataframe[self.reference_structure_column].unique() if self.n_reference_structures is None or self.n_reference_structures == len( unique_refs ): # then we're returning everything, so no differences return [data] ref_lists = get_unique_structures_randomized_by_date( data.dataframe, self.reference_structure_column, self.date_column, self.n_reference_structures, self.randomize_by_n_days, bootstraps=bootstraps, ) return [ DockingDataModel( dataframe=data.dataframe[ data.dataframe[self.reference_structure_column].isin(ref_list) ], **data.model_dump(), ) for ref_list in ref_lists ]
[docs] class PairwiseSplitBase(SplitBase): name: str = "PairwiseSplitBase" type_: str = "PairwiseSplitBase" def get_records(self) -> dict: records = super().get_records() records.update({"PairwiseSplit": self.name}) return records
[docs] class SimilaritySplit(PairwiseSplitBase): """ Splits the structures available to dock to by similarity to the query ligand """ name: str = "SimilaritySplit" type_: str = "SimilaritySplit" n_reference_structures: Optional[int] = Field( None, description="Number of values per split to generate" ) similarity_column: str = Field( ..., description="Column name for the similarity between the query and reference ligands", ) groupby: dict = Field( {}, description="Column name : value pairs to group the Tanimoto similarity data by.", ) query_ligand_column: str = Field( ..., description="Column name for the query ligand ID in order to pick the top N structures to dock to", ) threshold: float = Field( 0.5, description="Threshold to use to determine if two structures are similar enough to be in the same split", ) higher_is_more_similar: bool = Field( True, description="Higher values are more similar" ) include_similar: bool = Field( True, description="If True, include structures that are more similar than the threshold. Otherwise, include structures that are less similar.", ) sort_instead_of_threshold: bool = Field( False, description="If True, sort the structures by similarity and take the top N reference structures. Otherwise, use the threshold to determine similarity.", ) n_similar: Optional[int] = Field( None, description="Number of similar structures to return" ) deterministic: bool = True @model_validator(mode="after") def validate_model(self) -> Self: if not self.sort_instead_of_threshold and self.n_reference_structures != -1: self.deterministic = False if ( self.sort_instead_of_threshold and self.n_similar is None or self.n_similar == -1 ): raise NotImplementedError( "n_similar must be set if sort_instead_of_threshold is True" ) else: self.deterministic = True return self def run(self, data: DockingDataModel, bootstraps=1) -> [pd.DataFrame]: df = data.dataframe # first just get the necessary data for key, value in self.groupby.items(): df = df[df[key] == value] if self.sort_instead_of_threshold: # this logic takes a moment but this makes sure we are sorting in the correct direction ascending = not self.include_similar == self.higher_is_more_similar df = ( df.sort_values(self.similarity_column, ascending=ascending) .groupby(self.query_ligand_column) .apply(lambda x: x.head(self.n_similar)) .reset_index(drop=True) ) return [DockingDataModel(dataframe=df, **data.model_dump())] else: # if include similar True and higher is MORE similar, or if similar False and higher is LESS similar if self.include_similar == self.higher_is_more_similar: df = df[df[self.similarity_column] >= self.threshold] # if include similar True and higher is LESS similar, or if similar False and higher is MORE similar elif self.include_similar != self.higher_is_more_similar: df = df[df[self.similarity_column] <= self.threshold] if self.n_reference_structures is None: return [DockingDataModel(dataframe=df, **data.model_dump())] else: return [ DockingDataModel( dataframe=( df.groupby(self.query_ligand_column) .apply( lambda x: ( x if len(x) <= self.n_reference_structures else x.sample(n=self.n_reference_structures) ) ) .reset_index(drop=True) ), **data.model_dump(), ) for _ in range(bootstraps) ] def get_records(self) -> dict: records = super().get_records() records.update( { "N_Reference_Structures": self.n_reference_structures, "Similarity_Column": self.similarity_column, "Similarity_Threshold": self.threshold, "Include_Similar": self.include_similar, "Higher_Is_More_Similar": self.higher_is_more_similar, "Sort_Instead_Of_Threshold": self.sort_instead_of_threshold, "N_Similar": self.n_similar, } ) records.update({key: value for key, value in self.groupby.items()}) return records
[docs] class ScaffoldSplitFlags(Flag): NONE = 0 REQUIRES_QUERY_SUBSET = auto() REQUIRES_REFERENCE_SUBSET = auto() # one of the two must be passed REQUIRES_EITHER_SUBSET = auto() # both must be passed REQUIRES_BOTH = REQUIRES_QUERY_SUBSET | REQUIRES_REFERENCE_SUBSET # might not necessarily require them, but if they are passed, there should only be one of them REQUIRES_SINGLE_QUERY_SUBSET_IF_PASSED = auto() REQUIRES_SINGLE_QUERY_SUBSET = ( REQUIRES_SINGLE_QUERY_SUBSET_IF_PASSED | REQUIRES_QUERY_SUBSET ) REQUIRES_SINGLE_REFERENCE_SUBSET_IF_PASSED = auto() REQUIRES_SINGLE_REFERENCE_SUBSET = ( REQUIRES_SINGLE_REFERENCE_SUBSET_IF_PASSED | REQUIRES_REFERENCE_SUBSET ) REQUIRES_SINGLE_SUBSETS_IF_PASSED = ( REQUIRES_SINGLE_QUERY_SUBSET_IF_PASSED | REQUIRES_SINGLE_REFERENCE_SUBSET_IF_PASSED ) REQUIRES_SINGLE_SUBSETS = REQUIRES_BOTH | REQUIRES_SINGLE_SUBSETS_IF_PASSED ALLOW_OVERLAPPING_QUERY_AND_REFERENCE = auto()
[docs] class ScaffoldSplitOptions(StrEnum): """ Options for how to split the structures by scaffold. If a datasets has scaffolds A-F, there are basically four comparisons that are interesting. The first three are easy to do in parallel in my setup, the last one requires that you set up separate evaluators for each combination of scaffolds. """ X_TO_X = "x_to_x" # , "Dock X to X for X in [all your scaffolds]") X_TO_NOT_X = "x_to_not_x" # , "Dock X to NOT X for X in [all your scaffolds]") NOT_X_TO_X = "not_x_to_x" # , "Dock NOT X to X for X in [all your scaffolds]") X_TO_Y = "x_to_y" # ,"Dock X to Y for X, Y in zip([all your scaffolds], [all your scaffolds]",) X_TO_ALL = "x_to_all" # Dock X to all data for X in [all your scaffolds] ALL_TO_X = "all_to_x" # Dock all to X for X in [all your scaffolds] @property def flags(self) -> ScaffoldSplitFlags: return { self.X_TO_X: ( ScaffoldSplitFlags.REQUIRES_EITHER_SUBSET | ScaffoldSplitFlags.REQUIRES_SINGLE_SUBSETS_IF_PASSED | ScaffoldSplitFlags.ALLOW_OVERLAPPING_QUERY_AND_REFERENCE ), self.X_TO_NOT_X: ScaffoldSplitFlags.REQUIRES_SINGLE_QUERY_SUBSET, self.NOT_X_TO_X: ScaffoldSplitFlags.REQUIRES_SINGLE_REFERENCE_SUBSET, self.X_TO_Y: ScaffoldSplitFlags.REQUIRES_SINGLE_SUBSETS, self.X_TO_ALL: ScaffoldSplitFlags.REQUIRES_SINGLE_QUERY_SUBSET | ScaffoldSplitFlags.ALLOW_OVERLAPPING_QUERY_AND_REFERENCE, self.ALL_TO_X: ScaffoldSplitFlags.REQUIRES_SINGLE_REFERENCE_SUBSET | ScaffoldSplitFlags.ALLOW_OVERLAPPING_QUERY_AND_REFERENCE, }[self]
[docs] class ScaffoldSplit(PairwiseSplitBase): """ Splits the structures available to dock to by whether they share a scaffold with the query ligand. """ name: str = "ScaffoldSplit" type_: str = "ScaffoldSplit" query_scaffold_id_column: str = Field( ..., description="Column name for the query scaffold ID" ) reference_scaffold_id_column: str = Field( ..., description="Column name for the reference scaffold ID" ) query_scaffold_id_subset: Optional[list[int | str]] = Field( None, description="List of query scaffold IDs to consider. If None, consider all scaffolds.", ) reference_scaffold_id_subset: Optional[list[int | str]] = Field( None, description="List of reference scaffold IDs to consider. If None, consider all scaffolds.", ) split_option: ScaffoldSplitOptions = Field( ..., description="How to split the data by scaffold", ) deterministic: bool = Field(True, description="Deterministic split") @field_validator("split_option", mode="before") def convert_to_string(cls, v): if isinstance(v, Enum): return v.value return v def get_records(self) -> dict: records = super().get_records() records.update( { "Query_Scaffold_ID_Column": self.query_scaffold_id_column, "Reference_Scaffold_ID_Column": self.reference_scaffold_id_column, "Scaffold_Split_Option": self.split_option, "Query_Scaffold_ID_Subset": self.query_scaffold_id_subset, "Reference_Scaffold_ID_Subset": self.reference_scaffold_id_subset, } ) return records @model_validator(mode="after") def validate_model(self) -> Self: option = self.split_option flags = self.split_option.flags # Check reference subset requirements if ( ScaffoldSplitFlags.REQUIRES_REFERENCE_SUBSET in flags and not self.reference_scaffold_id_subset ): raise ValueError( f"{option} requires at least one item in reference_scaffold_id_subset" ) if ( ScaffoldSplitFlags.REQUIRES_SINGLE_REFERENCE_SUBSET_IF_PASSED in flags and self.reference_scaffold_id_subset and len(self.reference_scaffold_id_subset) != 1 ): raise ValueError( f"{option} requires exactly one item in reference_scaffold_id_subset" ) # Check query subset requirements if ( ScaffoldSplitFlags.REQUIRES_QUERY_SUBSET in flags and not self.query_scaffold_id_subset ): raise ValueError( f"{option} requires at least one item in query_scaffold_id_subset" ) if ( flags & ScaffoldSplitFlags.REQUIRES_SINGLE_QUERY_SUBSET_IF_PASSED and self.query_scaffold_id_subset and len(self.query_scaffold_id_subset) != 1 ): raise ValueError( f"{option} requires exactly one item in query_scaffold_id_subset" ) # Check either subset requirement if ScaffoldSplitFlags.REQUIRES_EITHER_SUBSET in flags and not ( self.query_scaffold_id_subset or self.reference_scaffold_id_subset ): raise ValueError( f"{option} requires at least one of query_scaffold_id_subset or reference_scaffold_id_subset" ) # Check both subsets requirement if ScaffoldSplitFlags.REQUIRES_BOTH in flags and not ( self.query_scaffold_id_subset and self.reference_scaffold_id_subset ): raise ValueError( f"{option} requires both query_scaffold_id_subset and reference_scaffold_id_subset" ) # Check for overlapping scaffolds when not allowed if ( self.query_scaffold_id_subset and self.reference_scaffold_id_subset and len( set(self.query_scaffold_id_subset).intersection( self.reference_scaffold_id_subset ) ) > 0 and not (ScaffoldSplitFlags.ALLOW_OVERLAPPING_QUERY_AND_REFERENCE in flags) ): raise ValueError( f"Query and reference scaffold IDs are the same ({self.query_scaffold_id_subset[0]}), " f"but {option} does not allow overlapping scaffolds." ) return self
[docs] def run(self, data: DockingDataModel) -> [DockingDataModel]: """Split data based on scaffold relationships.""" df = data.dataframe # set scaffold subsets if not provided if self.reference_scaffold_id_subset is None: self.reference_scaffold_id_subset = ( df[self.reference_scaffold_id_column].unique().tolist() ) if self.query_scaffold_id_subset is None: self.query_scaffold_id_subset = ( df[self.query_scaffold_id_column].unique().tolist() ) # Filter by scaffold subsets first mask = df[self.query_scaffold_id_column].isin(self.query_scaffold_id_subset) mask &= df[self.reference_scaffold_id_column].isin( self.reference_scaffold_id_subset ) df = df[mask] # Handle different split options if self.split_option == ScaffoldSplitOptions.X_TO_X: df = df[ df[self.query_scaffold_id_column] == df[self.reference_scaffold_id_column] ] elif self.split_option in ( ScaffoldSplitOptions.NOT_X_TO_X, ScaffoldSplitOptions.X_TO_NOT_X, ): df = df[ df[self.query_scaffold_id_column] != df[self.reference_scaffold_id_column] ] return [DockingDataModel(dataframe=df, **data.model_dump())]
[docs] class ScaffoldDateSplit(ReferenceStructureSplitBase): """ Returns results stuch that query structures are only docked to the first structure for each scaffold. """ name: str = "ScaffoldDateSplit" type_: str = "ScaffoldDateSplit" date_column: str = Field( ..., description="Column corresponding to date deposition", ) scaffold_id_column: str = Field( ..., description="Column corresponding to the scaffold ID of the ligand" ) randomize_by_n_days: int = Field( 0, description="Randomize the structures by n days. If 0 no randomization is done. If 1 or greater, for each structure, it can be randomly replaced by any other structure collected on that day or n-1 days from it's collection date.", ) n_refs_per_scaffold: Optional[int] = Field( 1, description="Number of reference structures per scaffold" ) def get_records(self) -> dict: records = super().get_records() records.update( { "Randomize_by_N_Days": self.randomize_by_n_days, "Date_Column": self.date_column, "Scaffold_ID_Column": self.scaffold_id_column, "N_Refs_Per_Scaffold": self.n_refs_per_scaffold, } ) return records def run(self, data: DockingDataModel, bootstraps=1) -> [DockingDataModel]: unique_refs = ( data.dataframe.sort_values(self.date_column) .groupby(self.scaffold_id_column) .head(self.n_refs_per_scaffold)[self.reference_structure_column] .unique() ) filtered_df = data.dataframe[ data.dataframe[self.reference_structure_column].isin(unique_refs) ] if self.n_reference_structures is None: self.n_reference_structures = len(unique_refs) ref_lists = get_unique_structures_randomized_by_date( filtered_df, self.reference_structure_column, self.date_column, self.n_reference_structures, self.randomize_by_n_days, bootstraps=bootstraps, ) return [ DockingDataModel( dataframe=data.dataframe[ data.dataframe[self.reference_structure_column].isin(ref_list) ], **data.model_dump(), ) for ref_list in ref_lists ]
# TODO: There might be a better way to do this. ReferenceSplitType = RandomSplit | DateSplit | ScaffoldDateSplit SimilaritySplitType = SimilaritySplit | ScaffoldSplit
[docs] class SorterBase(ModelBase): type_: str = "SorterBase" name: str = Field(..., description="Name of sorting method") category: str = Field( ..., description="Category of sort (i.e. why is sorting necessary here" ) variable: str = Field(..., description="Variable used to sort the data") ascending: bool = Field(True, description="Higher values are better. Defaults True") number_to_return: Optional[int] = Field( None, description="Number of values to return. Returns all values if None." ) @field_validator("number_to_return", mode="before") def allow_number_to_return_to_be_none(cls, v): if v is None: return None return v @abc.abstractmethod def run(self, data: DockingDataModel) -> DockingDataModel: pass @property def plot_name(self) -> str: return f"{self.name}_Choose_{'All' if not self.number_to_return else self.number_to_return}" def get_records(self) -> dict: return { self.category: self.name, f"{self.category}_Choose_N": ( "All" if not self.number_to_return else self.number_to_return ), }
[docs] class PoseSelector(SorterBase): type_: str = "PoseSelector" category: str = "PoseSelection" def run(self, data: DockingDataModel) -> DockingDataModel: newdata = data.model_copy() key_columns = newdata.get_groupby_columns(except_cols=["Pose_ID"]) sf = ColumnSortFilter( sort_column=self.variable, key_columns=key_columns, ascending=self.ascending, number_to_return=self.number_to_return, ) newdata.apply_filters([sf]) return newdata
[docs] class Scorer(SorterBase): category: str = "Score" type_: str = "Scorer" def run(self, data: DockingDataModel) -> DockingDataModel: key_columns = [data.get_lig_column()] sf = ColumnSortFilter( sort_column=self.variable, key_columns=key_columns, ascending=self.ascending, number_to_return=self.number_to_return, ) data.apply_filters([sf]) return data
[docs] class POSITScorer(Scorer): type_: str = "POSITScorer" name: str = "POSIT_Probability" variable: str = "docking-confidence-POSIT" ascending: bool = False number_to_return: int = 1
[docs] class RMSDScorer(Scorer): type_: str = "RMSDScorer" name: str = "RMSD" variable: str = "RMSD" ascending: bool = True number_to_return: int = 1
[docs] class SuccessRate(ModelBase): name: str = "SuccessRate" type_: str = "SuccessRate" total: int = Field(..., description="Total number of items being evaluated") fraction: confloat(ge=0, le=1) = Field(..., description="Fraction of successes") replicates: list[float] = Field( [], description="Replicates used for error bar analysis" ) @property def min(self) -> float: return np.array(self.replicates).min() @property def max(self) -> float: return np.array(self.replicates).max() @property def ci_upper(self): n_reps = len(self.replicates) if n_reps == 1: # use beta function to get CIs from scipy.stats import beta n_successes = self.fraction * self.total n_failures = (1 - self.fraction) * self.total # this is the posterior probability of observing n_successes and n_failures ci_upper = beta(n_successes + 1, n_failures + 1).interval(0.95)[1] else: # otherwise used bootstrapped results self.replicates.sort() ci_upper = self.replicates[int(0.975 * n_reps)] return ci_upper @property def ci_lower(self): n_reps = len(self.replicates) if n_reps == 1: # use beta function to get CIs from scipy.stats import beta ci_lower = beta( self.fraction * self.total + 1, (1 - self.fraction) * self.total + 1 ).interval(0.95)[0] else: # otherwise used bootstrapped results self.replicates.sort() ci_lower = self.replicates[int(0.025 * n_reps)] return ci_lower @classmethod def from_replicates(cls, reps: list["SuccessRate"]) -> "SuccessRate": all_fracs = np.array([rep.fraction for rep in reps]) totals = np.array([rep.total for rep in reps]) return SuccessRate( total=totals.max(), fraction=all_fracs.mean(), replicates=list(all_fracs) ) def get_records(self) -> dict: mydict = { "Min": self.min, "Max": self.max, "CI_Upper": self.ci_upper, "CI_Lower": self.ci_lower, "Total": self.total, "Fraction": self.fraction, } return mydict def plot_name(self) -> str: return "Fraction"
[docs] class BinaryEvaluation(ModelBase): name: str = "BinaryEvaluation" type_: str = "BinaryEvaluation" variable: str = Field(..., description="Variable used to evaluate the results") cutoff: float = Field( ..., description="Cutoff used to determine if a result is good" ) below_cutoff_is_good: bool = Field( True, description="Whether values below or above the cutoff are good. Defaults to below.", ) def run(self, data: DockingDataModel) -> SuccessRate: df = data.dataframe total_by_ligand = len(df.groupby(data.get_lig_column())) if total_by_ligand < len(df): # check which columns it can be raise ValueError( f"There are more rows in your dataframe ({len(df)}) than can be selected by {data.get_lig_column()}, ({total_by_ligand})" ) if total_by_ligand == 0: return SuccessRate(total=0, fraction=0) if self.below_cutoff_is_good: fraction = ( df[self.variable].apply(lambda x: x <= self.cutoff).sum() / total_by_ligand ) else: fraction = ( df[self.variable].apply(lambda x: x >= self.cutoff).sum() / total_by_ligand ) return SuccessRate(total=total_by_ligand, fraction=fraction) def get_records(self) -> dict: return { "EvaluationMetric": self.variable, "EvaluationMetric_Cutoff": self.cutoff, } def plot_name(self) -> str: return "_".join([self.name, self.variable, self.cutoff])
[docs] def get_class_from_name(name: str): """ Is this good? Is it safe? Is it smart? I don't know! :param name: :return: """ match name: case "RandomSplit": return RandomSplit case "DateSplit": return DateSplit case "SimilaritySplit": return SimilaritySplit case "ScaffoldSplit": return ScaffoldSplit case "ScaffoldDateSplit": return ScaffoldDateSplit case "Scorer": return Scorer case "RMSDScorer": return RMSDScorer case "POSITScorer": return POSITScorer case "PoseSelector": return PoseSelector case "FractionGood": return SuccessRate case "BinaryEvaluation": return BinaryEvaluation case "Evaluator": return Evaluator
def _bootstrap_worker(args): """Standalone worker function for parallel bootstrap processing""" bootstrap_idx, evaluator_json, pose_selected_data = args try: # Recreate evaluator from JSON evaluator_data = json.loads(evaluator_json) evaluator_copy = get_class_from_name(evaluator_data["type_"])(**evaluator_data) # Process the bootstrap result = evaluator_copy.process_single_bootstrap(pose_selected_data) return bootstrap_idx, result except Exception as e: print(f"Error processing bootstrap {bootstrap_idx}: {e}") return bootstrap_idx, None
[docs] class Evaluator(ModelBase): name: str = "Evaluator" type_: str = "Evaluator" pose_selector: PoseSelector = Field( PoseSelector(name="Default", variable="Pose_ID", number_to_return=1), description="How to choose which poses to keep", ) dataset_split: Optional[ReferenceSplitType] = Field( None, description="Dataset split" ) similarity_split: Optional[SimilaritySplitType] = Field( None, description="Additional dataset splits to be run after the first one" ) scorer: Scorer = Field(..., description="How to score and rank resulting poses") evaluator: BinaryEvaluation = Field( ..., description="How to determine how good the results are" ) n_bootstraps: int = Field(1, description="Number of bootstrap replicates to run") dataset_before_similarity: bool = Field( True, description="Whether to run the dataset split before the similarity split, or vice versa. Defaults True.", ) def run_pose_selector(self, data: [DockingDataModel]) -> [DockingDataModel]: return [self.pose_selector.run(data_) for data_ in data] def run_dataset_split(self, data_splits: [DockingDataModel]) -> [DockingDataModel]: if self.dataset_split is not None: return [ split for data_ in data_splits for split in self.dataset_split.run(data_, bootstraps=self.n_bootstraps) ] else: return data_splits def run_similarity_split( self, data_splits: [DockingDataModel] ) -> [DockingDataModel]: if self.similarity_split is not None: return [ split for data_ in data_splits for split in self.similarity_split.run(data_) ] else: return data_splits def run_scorer(self, data_splits: [DockingDataModel]) -> [DockingDataModel]: return [self.scorer.run(data_) for data_ in data_splits] def calculate_results(self, data_splits: [DockingDataModel]) -> SuccessRate: results = [self.evaluator.run(data_) for data_ in data_splits] return SuccessRate.from_replicates(results)
[docs] def process_single_bootstrap(self, pose_selected_data: DockingDataModel) -> any: """Process a single bootstrap replicate through the entire pipeline""" # Start with pose-selected data current_data = [pose_selected_data] # Apply splits based on order preference if self.dataset_before_similarity: # Dataset split first (creates bootstrap if needed) if self.dataset_split is not None: bootstrap_splits = self.dataset_split.run(current_data[0], bootstraps=1) current_data = bootstrap_splits # Then similarity split current_data = self.run_similarity_split(current_data) else: # Similarity split first current_data = self.run_similarity_split(current_data) # Then dataset split (creates bootstrap if needed) if self.dataset_split is not None: bootstrap_splits = self.dataset_split.run(current_data[0], bootstraps=1) current_data = bootstrap_splits # Score the data current_data = self.run_scorer(current_data) # Evaluate - take first result if multiple if current_data: result = self.evaluator.run(current_data[0]) return result else: return None
[docs] def run(self, data: DockingDataModel, n_cpus: int = 1) -> SuccessRate: """ Memory-efficient version that processes bootstraps sequentially or in parallel Args: data: Input docking data n_cpus: Number of CPUs to use for parallel processing. If 1, runs sequentially. """ # Apply pose selector once - this doesn't change between bootstraps pose_selected = self.run_pose_selector([data]) pose_selected_data = pose_selected[0] # Should be a single DockingDataModel all_results = [] if n_cpus == 1: # Sequential processing for bootstrap_idx in range(self.n_bootstraps): try: result = self.process_single_bootstrap(pose_selected_data) if result is not None: all_results.append(result) except Exception as e: print(f"Error processing bootstrap {bootstrap_idx}: {e}") continue else: # Parallel processing from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp n_cpus = min(n_cpus, mp.cpu_count()) print( f"Running {self.n_bootstraps} bootstraps in parallel using {n_cpus} CPUs." ) # Create worker arguments worker_args = [ (bootstrap_idx, self.to_json_str(), pose_selected_data) for bootstrap_idx in range(self.n_bootstraps) ] with ProcessPoolExecutor(max_workers=n_cpus) as executor: future_to_idx = { executor.submit(_bootstrap_worker, args): args[0] for args in worker_args } for future in as_completed(future_to_idx): bootstrap_idx, result = future.result() if result is not None: all_results.append(result) return SuccessRate.from_replicates(all_results)
@field_validator( "pose_selector", "dataset_split", "scorer", "evaluator", mode="before", ) def class_from_dict(cls, v): if isinstance(v, dict): return get_class_from_name(v["type_"])(**v) else: return v def to_json_str(self) -> str: return self.model_dump_json() def to_json_file(self, file_path: str | Path) -> Path: with open(file_path, "w") as f: f.write(self.to_json_str()) return file_path @classmethod def from_json_str(cls, model_dump_json_str) -> "Evaluator": data = json.loads(model_dump_json_str) return cls(**data) @classmethod def from_json_file(cls, file_path: str | Path) -> "Evaluator": with open(file_path, "r") as f: data = json.load(f) return cls(**data) @property def plot_name(self) -> str: variables = [ model.plot_name for model in [self.dataset_split, self.scorer] if model is not None ] variables += [f"{self.n_bootstraps}reps"] return "_".join(variables) def get_records(self) -> dict: mydict = {"Bootstraps": self.n_bootstraps} for container in [ self.scorer, self.evaluator, self.pose_selector, ]: if container is not None: mydict.update(container.get_records()) if self.dataset_split: mydict.update(self.dataset_split.get_records()) if self.similarity_split: mydict.update(self.similarity_split.get_records()) return mydict
[docs] class Results(BaseModel): evaluator: Evaluator success_rate: SuccessRate = Field( ..., description="Resulting success rate, with some information about the data" ) def get_records(self) -> dict: mydict = self.evaluator.get_records() mydict.update(self.success_rate.get_records()) return mydict @classmethod def calculate_result( cls, evaluator: Evaluator, data: DockingDataModel, n_cpus: int = 1 ) -> "Results": result = evaluator.run(data, n_cpus=n_cpus) return cls(evaluator=evaluator, success_rate=result) @classmethod def calculate_results( cls, data: DockingDataModel, evaluators: list[Evaluator], n_cpus: int = 1 ) -> list["Results"]: data_copies = [data.__deepcopy__() for ev in evaluators] results = [] for data, ev in tqdm(zip(data_copies, evaluators), total=len(evaluators)): result = ev.run(data, n_cpus=n_cpus) results.append(cls(evaluator=ev, success_rate=result)) return results @classmethod def df_from_results( cls, results: list["Results"], include_model: bool = True ) -> pd.DataFrame: if not results: # Handle empty results explicitly return pd.DataFrame() # Build records based on include_model records = ( [ { **result.get_records(), "Evaluator_Model": result.evaluator.to_json_str(), } for result in results ] if include_model else [result.get_records() for result in results] ) return pd.DataFrame.from_records(records)
[docs] class SettingsBase(BaseModel): def get_descriptions(self) -> dict: schema = self.model_json_schema() return { field: field_info.get("description", "") for field, field_info in schema["properties"].items() } def to_yaml(self): # Get the model's JSON schema return json.loads(self.model_dump_json()) def to_yaml_file(self, file_path) -> Path: # Convert to YAML output = self.to_yaml() descriptions = self.get_descriptions() # Write to file with descriptions as a block comment at the top with open(file_path, "w") as file: for key, value in output.items(): if key in descriptions: file.write(f"# {key}: {descriptions[key]}\n") # then write out full object yaml.dump(output, file, sort_keys=False) return Path(file_path) @classmethod def from_yaml(cls, yaml_str): data = yaml.safe_load(yaml_str) return cls(**data) @classmethod def from_yaml_file(cls, file_path): with open(file_path, "r") as file: return cls.from_yaml(file.read())
[docs] class EvaluatorSettingsBase(SettingsBase): """Base class for all evaluator settings with automatic exclude handling""" use: bool = Field( False, description="Whether this class of settings should be used" )
[docs] class PoseSelectionSettings(EvaluatorSettingsBase): """ Settings flags used to generate """ use: bool = True pose_id_column: str = Field( "Pose_ID", description="Name of the column containing the pose id" ) n_poses: list[int] = Field([1], description="Number of poses to select")
[docs] class RandomSplitSettings(EvaluatorSettingsBase): pass
[docs] class DateSplitSettings(EvaluatorSettingsBase): """Settings for date-based splitting""" reference_structure_date_column: str = Field( "Reference_Structure_Date", description="Column containing reference structure deposition date", ) randomize_by_n_days: int = Field(1, description="Days to randomize by")
[docs] class UpdateReferenceSettings(EvaluatorSettingsBase): use_logarithmic_scaling: bool = False log_base: int = 10
[docs] class CompositSettingsBase(EvaluatorSettingsBase): @abc.abstractmethod def get_component_settings(self) -> list[EvaluatorSettingsBase]: pass @model_validator(mode="after") def validate_component_settings(self): if self.use and not any( [component.use for component in self.get_component_settings()] ): raise ValueError( f"At least one of {self.get_component_settings()} must be set to use=True" ) return self
[docs] class ReferenceSplitSettings(CompositSettingsBase): random_split_settings: RandomSplitSettings = RandomSplitSettings() date_split_settings: DateSplitSettings = DateSplitSettings() n_reference_structures: Optional[list[None | int]] = Field( [None], description="List of number of structures to try" ) update_reference_settings: UpdateReferenceSettings = UpdateReferenceSettings() def get_component_settings(self) -> list[EvaluatorSettingsBase]: return [self.random_split_settings, self.date_split_settings] @field_validator("n_reference_structures", mode="before") def convert_to_list(cls, v): if isinstance(v, int): return [v] elif isinstance(v, np.ndarray): return v.tolist() return v
[docs] class ScaffoldSplitSettings(EvaluatorSettingsBase): """Settings for scaffold-based splitting""" scaffold_split_option: ScaffoldSplitOptions = Field( ScaffoldSplitOptions.X_TO_X, description="How to split data by scaffold" ) query_scaffold_id_column: str = Field( "cluster_id", description="Column containing query scaffold ID" ) reference_scaffold_id_column: str = Field( "cluster_id_Reference", description="Column containing reference scaffold ID" ) query_scaffold_id_subset: Optional[list[int]] = Field( None, description="List of query scaffold IDs to consider" ) reference_scaffold_id_subset: Optional[list[int]] = Field( None, description="List of reference scaffold IDs to consider" ) query_scaffold_min_count: Optional[int] = Field( 5, description="Minimum ligands in query scaffold" ) reference_scaffold_min_count: Optional[int] = Field( 5, description="Minimum ligands in reference scaffold" )
[docs] class SimilaritySplitSettings(EvaluatorSettingsBase): similarity_column_name: Optional[str] = Field("Tanimoto") similarity_range: list[float] = Field([0, 1]) similarity_n_thresholds: int = Field(21) similarity_groupby_dict: dict = {} higher_is_more_similar: bool = True include_similar: bool = True n_reference_structures: Optional[list[None | int]] = Field( [None], description="List of number of structures to try" ) update_reference_settings: UpdateReferenceSettings = UpdateReferenceSettings()
[docs] def get_similarity_thresholds(self) -> np.ndarray: """ Generate similarity thresholds from the range and number of thresholds :return: """ return np.linspace( self.similarity_range[0], self.similarity_range[1], self.similarity_n_thresholds, )
[docs] class PairwiseSplitSettings(CompositSettingsBase): similarity_split_settings: SimilaritySplitSettings = Field( SimilaritySplitSettings() ) scaffold_split_settings: ScaffoldSplitSettings = Field(ScaffoldSplitSettings()) def get_component_settings(self) -> list[EvaluatorSettingsBase]: return [self.similarity_split_settings, self.scaffold_split_settings]
[docs] class POSITScorerSettings(EvaluatorSettingsBase): """Settings for scoring methods""" use: bool = True posit_score_column_name: str = Field( "docking-confidence-POSIT", description="Name of the column containing the POSIT score", ) posit_name: str = Field("POSIT_Probability", description="Name of the POSIT score")
[docs] class RMSDScorerSettings(EvaluatorSettingsBase): use: bool = True rmsd_column_name: str = "RMSD" rmsd_name: str = Field("RMSD", description="Name of the RMSD score")
[docs] class ScorerSettings(CompositSettingsBase): use: bool = True rmsd_scorer_settings: RMSDScorerSettings = RMSDScorerSettings() posit_scorer_settings: POSITScorerSettings = POSITScorerSettings() def get_component_settings(self) -> list[EvaluatorSettingsBase]: return [self.rmsd_scorer_settings, self.posit_scorer_settings]
[docs] class SuccessRateSettings(EvaluatorSettingsBase): use: bool = True success_rate_column: str = "RMSD" rmsd_cutoff: float = Field( 2.0, description="RMSD cutoff to label the resulting poses as successful" )
[docs] def generate_logarithmic_scale(n_max: int, base: int = 10) -> list[int]: """ Generate a logarithmic scale with nice number spacing up to n_max. Args: n_max: Maximum value in the sequence base: Logarithm base (default=10) Returns: List of integers representing the scale Example: >>> generate_logarithmic_scale(300) [1, 2, 5, 10, 15, 25, 50, 75, 100, 150, 200, 250, 300] """ scale = [] for exp in range(int(np.log(n_max) / np.log(base)) + 1): power = base**exp if power > n_max: break # Add standard power values if power <= n_max: scale.append(int(power)) # Add intermediate values if power * 2 <= n_max: scale.append(int(power * 2)) if power * 5 <= n_max: scale.append(int(power * 5)) # Add quarter values for larger numbers if power >= 100 and n_max <= 200: if power * 1.25 <= n_max: scale.append(int(power * 1.25)) if power * 1.75 <= n_max: scale.append(int(power * 1.75)) # Add the maximum value if it's not already included if n_max not in scale: scale.append(n_max) return sorted(list(set(scale)))
[docs] class EvaluatorFactory(SettingsBase): name: str = Field(..., help="Name of this collection of settings") pose_selection_settings: PoseSelectionSettings = Field(PoseSelectionSettings()) reference_split_settings: ReferenceSplitSettings = Field(ReferenceSplitSettings()) pairwise_split_settings: PairwiseSplitSettings = Field(PairwiseSplitSettings()) scorer_settings: ScorerSettings = Field(ScorerSettings()) success_rate_evaluator_settings: SuccessRateSettings = Field(SuccessRateSettings()) class Config: validate_assignment = True combine_reference_and_similarity_splits: bool = Field( True, description="If both reference and pairwise splits are set to use=True, evaluate them at the same time. ", ) dataset_before_similarity: bool = Field( True, description="Whether to run the dataset split before the similarity split, or vice versa. Defaults True.", ) n_bootstraps: int = Field(1000, description="Number of bootstrapped samples to run") query_ligand_column: str = Field( "Query_Ligand", description="Name of the column containing the query ligand id" ) reference_ligand_column: str = Field( "Reference_Ligand", description="Name of the column containing the reference ligand id", ) reference_structure_column: str = Field( "Reference_Structure", description="Name of the column to distinguish reference structures by", ) def to_yaml_file(self, directory: Path = Path("./")) -> Path: # Convert to YAML output = self.to_yaml() descriptions = self.get_descriptions() # Write to file with descriptions as a block comment at the top file_path = directory / f"{self.name}.yaml" with open(file_path, "w") as file: for key, value in output.items(): if key in descriptions: file.write(f"# {key}: {descriptions[key]}\n") # then write out full object yaml.dump(output, file, sort_keys=False) return file_path def create_pose_selectors(self) -> list[PoseSelector]: return [ PoseSelector( name="Default", variable=self.pose_selection_settings.pose_id_column, number_to_return=n, ) for n in self.pose_selection_settings.n_poses ] def create_reference_splits( self, data: DockingDataModel = None ) -> list[ReferenceSplitType]: reference_splits = [] if self.reference_split_settings.update_reference_settings.use: if ( self.reference_split_settings.update_reference_settings.use_logarithmic_scaling ): number_of_refs = len(data.get_unique_refs()) self.reference_split_settings.n_reference_structures = generate_logarithmic_scale( n_max=number_of_refs, base=self.reference_split_settings.update_reference_settings.log_base, ) else: raise NotImplementedError if self.reference_split_settings.random_split_settings.use: reference_splits.extend( [ RandomSplit( reference_structure_column=self.reference_structure_column, n_reference_structures=i, ) for i in self.reference_split_settings.n_reference_structures ] ) if self.reference_split_settings.date_split_settings.use: date_settings = self.reference_split_settings.date_split_settings if data is None: raise ValueError("Must provide input dataframe to use date split") reference_splits.extend( [ DateSplit( reference_structure_column=self.reference_structure_column, date_column=date_settings.reference_structure_date_column, n_reference_structures=i, randomize_by_n_days=date_settings.randomize_by_n_days, ) for i in self.reference_split_settings.n_reference_structures ] ) return reference_splits
[docs] def create_pairwise_split( self, data: DockingDataModel = None ) -> list[SimilaritySplitType]: """Create pairwise splits (scaffold or similarity) based on settings""" settings = self.pairwise_split_settings splits = [] # Handle scaffold splits if settings.scaffold_split_settings.use: scaffold_settings = settings.scaffold_split_settings # Filter scaffolds by minimum count if data is provided if data and ( scaffold_settings.query_scaffold_min_count or scaffold_settings.reference_scaffold_min_count ): lig_data = data.get_lig_dataframe() ref_data = data.get_ref_dataframe() # Get valid scaffolds based on minimum count def get_valid_scaffolds(dataframe, column, min_count): if min_count: return ( dataframe[column] .value_counts()[lambda counts: counts >= min_count] .index.tolist() ) return [] query_scaffolds = get_valid_scaffolds( lig_data, scaffold_settings.query_scaffold_id_column, scaffold_settings.query_scaffold_min_count, ) ref_scaffolds = get_valid_scaffolds( ref_data, scaffold_settings.reference_scaffold_id_column, scaffold_settings.reference_scaffold_min_count, ) # Initialize lists of query and reference scaffolds # Use the same set of scaffolds for both query and reference scaffolds = set(query_scaffolds).union(ref_scaffolds) # Define split logic for each option split_option = scaffold_settings.scaffold_split_option if split_option in { ScaffoldSplitOptions.X_TO_X, ScaffoldSplitOptions.X_TO_Y, ScaffoldSplitOptions.X_TO_NOT_X, ScaffoldSplitOptions.NOT_X_TO_X, ScaffoldSplitOptions.X_TO_ALL, ScaffoldSplitOptions.ALL_TO_X, }: # Map split options to scaffold logic scaffold_list = [[s] for s in scaffolds] not_x_scaffold_list = [ [r for r in scaffolds if r != q] for q in scaffolds ] all_scaffold_list = [scaffolds for _ in scaffolds] query_scaffolds = scaffold_list ref_scaffolds = scaffold_list if split_option == ScaffoldSplitOptions.X_TO_X: pass if split_option == ScaffoldSplitOptions.X_TO_Y: query_scaffolds, ref_scaffolds = zip( *[ (q, r) for q, r in itertools.product(scaffold_list, scaffold_list) if q != r ] ) if split_option == ScaffoldSplitOptions.X_TO_NOT_X: ref_scaffolds = not_x_scaffold_list elif split_option == ScaffoldSplitOptions.NOT_X_TO_X: query_scaffolds = not_x_scaffold_list elif split_option == ScaffoldSplitOptions.X_TO_ALL: ref_scaffolds = all_scaffold_list elif split_option == ScaffoldSplitOptions.ALL_TO_X: query_scaffolds = all_scaffold_list splits.extend( ScaffoldSplit( query_scaffold_id_column=scaffold_settings.query_scaffold_id_column, reference_scaffold_id_column=scaffold_settings.reference_scaffold_id_column, query_scaffold_id_subset=query, reference_scaffold_id_subset=ref, split_option=split_option, ) for query, ref in zip(query_scaffolds, ref_scaffolds) ) else: raise NotImplementedError( f"ScaffoldSplit Option {split_option} not implemented" ) # Handle similarity splits (if implemented) if settings.similarity_split_settings.use: sim_settings = settings.similarity_split_settings # updated n_refs if passed if sim_settings.update_reference_settings.use: if sim_settings.update_reference_settings.use_logarithmic_scaling: number_of_refs = len(data.get_unique_refs()) sim_settings.n_reference_structures = generate_logarithmic_scale( n_max=number_of_refs, base=sim_settings.update_reference_settings.log_base, ) else: raise NotImplementedError # Add similarity split implementation here splits.extend( [ SimilaritySplit( n_reference_structures=refs, threshold=threshold, similarity_column=sim_settings.similarity_column_name, groupby=sim_settings.similarity_groupby_dict, higher_is_more_similar=sim_settings.higher_is_more_similar, include_similar=sim_settings.include_similar, query_ligand_column=self.query_ligand_column, ) for refs in settings.similarity_split_settings.n_reference_structures for threshold in settings.similarity_split_settings.get_similarity_thresholds() ] ) return splits
[docs] def create_scorers(self) -> list[Scorer]: """Create scorers based on settings""" scorers = [] settings = self.scorer_settings if settings.rmsd_scorer_settings.use: rmsd_settings = settings.rmsd_scorer_settings scorers.append( RMSDScorer( name=rmsd_settings.rmsd_name, variable=rmsd_settings.rmsd_column_name, ) ) if settings.posit_scorer_settings.use: posit_settings = settings.posit_scorer_settings scorers.append( POSITScorer( name=posit_settings.posit_name, variable=posit_settings.posit_score_column_name, ) ) return scorers
def create_success_rate_evaluator(self) -> [BinaryEvaluation]: return [ BinaryEvaluation( variable=self.success_rate_evaluator_settings.success_rate_column, cutoff=self.success_rate_evaluator_settings.rmsd_cutoff, ) ]
[docs] def create_evaluators(self, data: DockingDataModel = None) -> list[Evaluator]: """Create all evaluator combinations based on settings""" pose_selectors = self.create_pose_selectors() reference_splits = ( self.create_reference_splits(data) if self.reference_split_settings.use else None ) similarity_splits = ( self.create_pairwise_split(data) if self.pairwise_split_settings.use else None ) scorers = self.create_scorers() success_rate_evaluators = self.create_success_rate_evaluator() evaluators = [] for pose_selector in pose_selectors: for scorer in scorers: for success_rate_evaluator in success_rate_evaluators: # Create basic evaluator evaluator = Evaluator( pose_selector=pose_selector, scorer=scorer, evaluator=success_rate_evaluator, n_bootstraps=self.n_bootstraps, dataset_before_similarity=self.dataset_before_similarity, ) if reference_splits is not None or similarity_splits is not None: if ( reference_splits is not None and similarity_splits is not None ): if self.combine_reference_and_similarity_splits: # Combine all combinations for sim_split, ref_split in itertools.product( similarity_splits, reference_splits ): ev = evaluator.copy() ev.dataset_split = ref_split ev.similarity_split = sim_split evaluators.append(ev) else: # Handle splits separately for ref_split in reference_splits: ev = evaluator.copy() ev.dataset_split = ref_split evaluators.append(ev) for sim_split in similarity_splits: ev = evaluator.copy() ev.similarity_split = sim_split evaluators.append(ev) elif reference_splits is not None: # Only reference splits for ref_split in reference_splits: ev = evaluator.copy() ev.dataset_split = ref_split evaluators.append(ev) else: # Only similarity splits for sim_split in similarity_splits: ev = evaluator.copy() ev.similarity_split = sim_split evaluators.append(ev) else: evaluators.append(evaluator) return evaluators