Source code for tfitpy.indices.grn
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
[docs]
def grn_set_metrics(
source: List[str],
target: str,
grn_data: pd.DataFrame
) -> Dict:
"""Calculate set-based metrics treating predictions as a single set.
Args:
source: Predicted regulator genes
target: Target gene name (for validation/reference).
grn_data: Ground truth regulators for this target.
Must contain columns: ['regulator', 'target', 'score'].
Returns:
A dictionary containing:
- precision: Precision score (TP / (TP + FP)).
- recall: Recall score (TP / (TP + FN)).
- jaccard: Jaccard index (intersection over union).
"""
if grn_data.empty:
# return 0 for all scores
return {
'grn_collectri_precision': 0.0,
'grn_collectri_recall': 0.0,
'grn_collectri_jaccard': 0.0
}
# Remove duplicates from source
source_set = set(source)
# Get true regulators
true_regulators = set(grn_data['regulator'].values)
# Calculate set intersections and differences
tp_set = source_set & true_regulators # Intersection (correctly predicted)
fp_set = source_set - true_regulators # Predicted but not true
fn_set = true_regulators - source_set # True but not predicted
tp = len(tp_set)
fp = len(fp_set)
fn = len(fn_set)
num_predicted = len(source_set)
num_true = len(true_regulators)
# Calculate precision and recall
precision = tp / num_predicted if num_predicted > 0 else 0.0
recall = tp / num_true if num_true > 0 else 0.0
# Calculate F1 score (harmonic mean)
# if precision + recall > 0:
# f1 = 2 * (precision * recall) / (precision + recall)
# else:
# f1 = 0.0
# Calculate Jaccard index (intersection over union)
union = source_set | true_regulators
jaccard = len(tp_set) / len(union) if len(union) > 0 else 0.0
return {
'grn_collectri_precision': precision,
'grn_collectri_recall': recall,
'grn_collectri_jaccard': jaccard
}
def _get_GRN_for_target(
dataset: Dict,
grn_key: str,
target: str
) -> pd.DataFrame:
"""Retrieve and filter GRN data for a specific target gene.
Args:
dataset: Dictionary of loaded GRN datasets.
grn_key: Key identifying which GRN to use (e.g. 'collectri').
target: Target gene name to filter for.
Returns:
DataFrame with columns ['regulator', 'target', 'score'],
filtered to rows where target matches.
"""
if dataset is None:
raise ValueError("dataset is None")
if grn_key is None:
raise ValueError("no grn_key provided")
if target is None:
raise ValueError("no target provided")
if grn_key == "collectri":
grn = dataset["collectri"].copy()
grn = grn.rename(columns={"source": "regulator"})
grn = grn[grn["target"] == target][["regulator", "target", "weight"]].rename(
columns={"weight": "score"}
)
else:
raise ValueError(f"invalid grn_key: '{grn_key}'")
# if grn.empty:
# raise ValueError(f"no GRN entries found for target '{target}' in '{grn_key}'")
return grn
GRN_METHODS = {
'grn_collectri': {
'func': lambda sources, target, datasets=None, **kwargs:
grn_set_metrics(
sources,
target,
_get_GRN_for_target(datasets, 'collectri', target)
),
"type":"df_columns",
"cols":["grn_collectri_precision","grn_collectri_recall","grn_collectri_jaccard"],
'datasets': ['collectri']
},
}