Source code for medmodels.treatment_effect.estimate

"""Estimators class for calculating treatment effect metrics."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Set, Tuple, TypedDict

from medmodels.treatment_effect.continuous_estimators import (
    average_treatment_effect,
    cohens_d,
    hedges_g,
)
from medmodels.treatment_effect.matching.neighbors import NeighborsMatching
from medmodels.treatment_effect.matching.propensity import PropensityMatching

if TYPE_CHECKING:
    from medmodels.medrecord.medrecord import MedRecord
    from medmodels.medrecord.types import MedRecordAttribute, NodeIndex
    from medmodels.treatment_effect.matching.matching import Matching
    from medmodels.treatment_effect.treatment_effect import TreatmentEffect


[docs] class SubjectIndices(TypedDict): """Dictionary with the patient ids of all contingency table groups.""" treated_outcome_true: Set[NodeIndex] treated_outcome_false: Set[NodeIndex] control_outcome_true: Set[NodeIndex] control_outcome_false: Set[NodeIndex]
[docs] class ContingencyTable: """Contingency table for the treatment and control groups with/without outcome.""" number_treated_outcome_true: int number_treated_outcome_false: int number_control_outcome_true: int number_control_outcome_false: int def __init__( self, number_treated_outcome_true: int, number_treated_outcome_false: int, number_control_outcome_true: int, number_control_outcome_false: int, ) -> None: """Initializes the ContingencyTable object. It stores the number of patients in the treatment and control groups with and without the outcome. Args: number_treated_outcome_true (int): Number of patients in the treatment group with the outcome. number_treated_outcome_false (int): Number of patients in the treatment group without the outcome. number_control_outcome_true (int): Number of patients in the control group with the outcome. number_control_outcome_false (int): Number of patients in the control group without the outcome. """ self.number_treated_outcome_true = number_treated_outcome_true self.number_treated_outcome_false = number_treated_outcome_false self.number_control_outcome_true = number_control_outcome_true self.number_control_outcome_false = number_control_outcome_false
[docs] def __str__(self) -> str: """Returns a string representation of the ContingencyTable object. The contingency table provides an overview of the number of subjects in the treatment and control groups with the outcome (true) and without the outcome (false). Example: .. code-block:: text ----------------------------------- Outcome Group True False ----------------------------------- Treated 2 1 Control 3 3 ----------------------------------- """ line = "-" * 35 upper_header_line = "{:<18} {:<10}".format("", "Outcome") lower_header_line = "{:<15} {:<8} {:<8}".format("Group", "True", "False") treated = "{:<15} {:<8} {:<8}".format( "Treated", self.number_treated_outcome_true, self.number_treated_outcome_false, ) control = "{:<15} {:<8} {:<8}".format( "Control", self.number_control_outcome_true, self.number_control_outcome_false, ) return f"{line}\n{upper_header_line}\n{lower_header_line}\n{line}\n{treated}\n{control}\n{line}"
[docs] def __getitem__( self, key: Literal[ "treated_outcome_true", "treated_outcome_false", "control_outcome_true", "control_outcome_false", ], ) -> int: """Returns the number of subjects in the given group. Args: key (Literal["treated_outcome_true", "treated_outcome_false", "control_outcome_true", "control_outcome_false"]): The key to access the number of subjects in the treatment and control groups with and without the outcome. Returns: int: Number of subject in the selected group. """ # noqa: W505 completed_key = "number_" + key return getattr(self, completed_key)
[docs] class Estimate: """Estimators class for calculating treatment effect metrics.""" _treatment_effect: TreatmentEffect def __init__(self, treatment_effect: TreatmentEffect) -> None: """Initializes the Estimate object.""" self._treatment_effect = treatment_effect def _check_medrecord(self, medrecord: MedRecord) -> None: """Checks if the required groups are present in the MedRecord. Args: medrecord (MedRecord): The MedRecord object containing the data. Raises: ValueError: Raises Error if the required groups are not present in the MedRecord (patients, treatments, outcomes). """ if self._treatment_effect._patients_group not in medrecord.groups: msg = ( f"Patient group {self._treatment_effect._patients_group} not found in " f"the MedRecord. Available groups: {medrecord.groups}" ) raise ValueError(msg) if self._treatment_effect._treatments_group not in medrecord.groups: msg = ( "Treatment group not found in the MedRecord. " f"Available groups: {medrecord.groups}" ) raise ValueError(msg) if self._treatment_effect._outcomes_group not in medrecord.groups: msg = ( "Outcome group not found in the MedRecord." f"Available groups: {medrecord.groups}" ) raise ValueError(msg) def _sort_subjects_in_groups( self, medrecord: MedRecord ) -> Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: """Sorts subjects into the different groups of the contingency table. This means, sorting the subjects into treatment-outcome, treatment-no outcome, control-outcome and control-no outcome. The treatment group and control matching is determined based on the treatment effect configuration. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]: The patient ids of true and false subjects in the treatment and control groups, respectively. """ self._check_medrecord(medrecord=medrecord) ( treated_outcome_true, treated_outcome_false, control_outcome_true, control_outcome_false, ) = self._treatment_effect._find_groups(medrecord) treated_set = treated_outcome_true | treated_outcome_false if self._treatment_effect._matching_method: matching: Matching = ( NeighborsMatching( number_of_neighbors=self._treatment_effect._matching_number_of_neighbors, ) if self._treatment_effect._matching_method == "nearest_neighbors" else PropensityMatching( number_of_neighbors=self._treatment_effect._matching_number_of_neighbors, model=self._treatment_effect._matching_model, hyperparameters=self._treatment_effect._matching_hyperparameters, ) ) control_set = control_outcome_true | control_outcome_false matched_controls = matching.match_controls( medrecord=medrecord, treated_set=treated_set, control_set=control_set, patients_group=self._treatment_effect._patients_group, essential_covariates=self._treatment_effect._matching_essential_covariates, one_hot_covariates=self._treatment_effect._matching_one_hot_covariates, ) control_outcome_true, control_outcome_false = ( self._treatment_effect._find_controls( medrecord=medrecord, control_set=matched_controls, treated_set=treated_set, ) ) return ( treated_outcome_true, treated_outcome_false, control_outcome_true, control_outcome_false, ) def _compute_subject_counts( self, medrecord: MedRecord ) -> Tuple[int, int, int, int]: """Computes the subject counts for the treatment and control groups. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: Tuple[int, int, int, int]: The number of true and false subjects in the treatment and control groups, respectively. Raises: ValueError: Raises error if the required groups are not present in the MedRecord (patients, treatments, outcomes). ValueError: If there are no subjects in the group of treated with no outcome, in the one of controls with outcome or in the one of controls with no outcome, an error is raised. This would result in division by zero errors. """ ( treated_outcome_true, treated_outcome_false, control_outcome_true, control_outcome_false, ) = self._sort_subjects_in_groups(medrecord=medrecord) if len(treated_outcome_false) == 0: msg = "No subjects found in the group of treated with no outcome" raise ValueError(msg) if len(control_outcome_true) == 0: msg = "No subjects found in the group of controls with outcome" raise ValueError(msg) if len(control_outcome_false) == 0: msg = "No subjects found in the group of controls with no outcome" raise ValueError(msg) return ( len(treated_outcome_true), len(treated_outcome_false), len(control_outcome_true), len(control_outcome_false), )
[docs] def subject_indices(self, medrecord: MedRecord) -> SubjectIndices: """Overview of which subjects are in which group from the contingency table. Returns a dictionary with the patient ids of all contingency table groups, i.e., the treated group with and without the outcome, and the control group with and without the outcome. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: SubjectIndices: Dictionary with the patient ids of true and false subjects in the treatment and control groups, respectively. """ ( treated_outcome_true, treated_outcome_false, control_outcome_true, control_outcome_false, ) = self._sort_subjects_in_groups(medrecord=medrecord) return SubjectIndices( treated_outcome_true=treated_outcome_true, treated_outcome_false=treated_outcome_false, control_outcome_true=control_outcome_true, control_outcome_false=control_outcome_false, )
[docs] def subject_counts(self, medrecord: MedRecord) -> ContingencyTable: """Overview of how many subjects are in which group from the contingency table. Returns a contingency table object with the number of subjects in the treatment and control groups with and without the outcome. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: ContingencyTable: The contingency table object containing the number of subjects in the treatment and control groups with and without the outcome. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) return ContingencyTable( number_treated_outcome_true=number_treated_outcome_true, number_treated_outcome_false=number_treated_outcome_false, number_control_outcome_true=number_control_outcome_true, number_control_outcome_false=number_control_outcome_false, )
[docs] def relative_risk(self, medrecord: MedRecord) -> float: """Calculates the relative risk (RR) of an event. RR is a key measure in epidemiological studies for estimating the likelihood of an event in one group relative to another, in this case, the treatment group compared to the control group. The interpretation of RR is as follows: - RR = 1 indicates no difference in risk between the two groups. - RR > 1 indicates a higher risk in the treatment group. - RR < 1 indicates a lower risk in the treatment group. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated relative risk between the treatment and control groups. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) return ( number_treated_outcome_true / (number_treated_outcome_true + number_treated_outcome_false) ) / ( number_control_outcome_true / (number_control_outcome_true + number_control_outcome_false) )
[docs] def odds_ratio(self, medrecord: MedRecord) -> float: """Calculates the odds ratio (OR). The OR quantifies the association between exposure to a treatment and the occurrence of an outcome. OR compares the odds of an event occurring in the treatment group to the odds in the control group, providing insight into the strength of the association between the treatment and the outcome. Interpretation of the odds ratio: - OR = 1 indicates no difference in odds between the two groups. - OR > 1 suggests the event is more likely in the treatment group. - OR < 1 suggests the event is less likely in the treatment group. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated odds ratio between the treatment and control groups. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) return (number_treated_outcome_true / number_control_outcome_true) / ( number_treated_outcome_false / number_control_outcome_false )
[docs] def confounding_bias(self, medrecord: MedRecord) -> float: """Calculates the confounding bias (CB). The CB is used to assess the impact of potential confounders on the observed association between treatment and outcome. A confounder is a variable that influences both the dependent (outcome) and independent (treatment) variables, potentially biasing the study results. Interpretation of CB: - CB = 1 indicates no confounding bias. - CB != 1 suggests the presence of confounding bias. The method relies on the relative risk (RR) as an intermediary measure and adjusts the observed association for potential confounding effects. This adjustment helps in identifying whether the observed association might be influenced by factors other than the treatment. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated confounding bias. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) relative_risk = self.relative_risk(medrecord) if relative_risk == 1: return 1.0 multiplier = relative_risk - 1 numerator = ( number_treated_outcome_true / (number_treated_outcome_true + number_treated_outcome_false) ) * multiplier + 1 denominator = ( number_control_outcome_true / (number_control_outcome_true + number_control_outcome_false) ) * multiplier + 1 return numerator / denominator
[docs] def absolute_risk_reduction(self, medrecord: MedRecord) -> float: """Calculates the absolute risk reduction (ARR). AR (absolute risk) is a measure of the incidence of an event in each group. ARR, in turn, quantifies the difference in risk between the treatment and control groups. It is positive if the treatment reduces the risk, and negative if it increases the risk. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated absolute risk reduction between the treatment and control groups. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) ar_treated_group = number_treated_outcome_true / ( number_treated_outcome_true + number_treated_outcome_false ) ar_control_group = number_control_outcome_true / ( number_control_outcome_true + number_control_outcome_false ) return ar_control_group - ar_treated_group
[docs] def number_needed_to_treat(self, medrecord: MedRecord) -> float: """Calculates the number needed to treat (NNT) to prevent one extra bad outcome. NNT is derived from the absolute risk reduction (ARR) and provides an estimate of the number of patients that need to be treated to prevent one additional bad outcome. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated number needed to treat between the treatment and control groups. Raises: ValueError: Raises Error if the required groups are not present in the MedRecord (patients, treatments, outcomes). ValueError: If there are no subjects in the group of treated with no outcome, in the one of controls with outcome or in the one of controls with no outcome, an error is raised. This would result in division by zero errors. ValueError: If the ARR is zero, cannot calculate NNT. """ absolute_risk_reduction = self.absolute_risk_reduction(medrecord) if absolute_risk_reduction == 0: msg = "Absolute Risk Reduction is zero, cannot calculate NNT." raise ValueError(msg) return 1 / absolute_risk_reduction
[docs] def hazard_ratio(self, medrecord: MedRecord) -> float: """Calculates the hazard ratio (HR). HR is used to compare the hazard rates of two groups in survival analysis. Args: medrecord (MedRecord): The MedRecord object containing the data. Returns: float: The calculated hazard ratio between the treatment and control groups. Raises: ValueError: Raises Error if the required groups are not present in the MedRecord (patients, treatments, outcomes). ValueError: If there are no subjects in the group of treated with no outcome, in the one of controls with outcome or in the one of controls with no outcome, an error is raised. This would result in division by zero errors. ValueError: If the control hazard rate is zero, cannot calculate HR. """ ( number_treated_outcome_true, number_treated_outcome_false, number_control_outcome_true, number_control_outcome_false, ) = self._compute_subject_counts(medrecord=medrecord) hazard_treat = number_treated_outcome_true / ( number_treated_outcome_true + number_treated_outcome_false ) hazard_control = number_control_outcome_true / ( number_control_outcome_true + number_control_outcome_false ) if hazard_control == 0: msg = "Control hazard rate is zero, cannot calculate hazard ratio." raise ValueError(msg) return hazard_treat / hazard_control
[docs] def average_treatment_effect( self, medrecord: MedRecord, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", ) -> float: """Calculates the Average Treatment Effect (ATE). It is calculated as the difference between the outcome means of the treated and control sets. A positive ATE indicates that the treatment increased the outcome, while a negative ATE suggests a decrease. The ATE is computed as follows when the numbers of observations in treated and control sets are N and M, respectively: Args: medrecord (MedRecord): An instance of the MedRecord class containing medical data. outcome_variable (MedRecordAttribute): The attribute in the edge that contains the outcome variable. It must be numeric and continuous. reference (Literal["first", "last"], optional): The reference point for the exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". Returns: float: The average treatment effect. """ subjects = self.subject_indices(medrecord=medrecord) return average_treatment_effect( medrecord=medrecord, treatment_outcome_true_set=subjects.get("treated_outcome_true"), control_outcome_true_set=subjects.get("control_outcome_true"), outcome_group=self._treatment_effect._outcomes_group, outcome_variable=outcome_variable, reference=reference, time_attribute=self._treatment_effect._time_attribute, )
[docs] def cohens_d( self, medrecord: MedRecord, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", ) -> float: """Calculates Cohen's D, the standardized mean difference between two sets. This measures the effect size of the difference between two outcome means. It's applicable for any two sets but is recommended for sets of the same size. Cohen's D indicates how many standard deviations the two groups differ by, with 1 standard deviation equal to 1 z-score. A rule of thumb for interpreting Cohen's D: - Small effect = ±0.2 - Medium effect = ±0.5 - Large effect = ±0.8 If the difference is negative, it indicates the mean in the treated group is lower than the control group. This metric provides a dimensionless measure of effect size, facilitating the comparison across different studies and contexts. Args: medrecord (MedRecord): An instance of the MedRecord class containing medical data. outcome_variable (MedRecordAttribute): The attribute in the edge that contains the outcome variable. It must be numeric and continuous. reference (Literal["first", "last"], optional): The reference point for the exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". add_correction (bool, optional): Whether to apply a correction factor for small sample sizes. Defaults to False. Returns: float: The Cohen's D coefficient, representing the effect size. """ subjects = self.subject_indices(medrecord=medrecord) return cohens_d( medrecord=medrecord, treatment_outcome_true_set=subjects.get("treated_outcome_true"), control_outcome_true_set=subjects.get("control_outcome_true"), outcome_group=self._treatment_effect._outcomes_group, outcome_variable=outcome_variable, reference=reference, time_attribute=self._treatment_effect._time_attribute, )
[docs] def hedges_g( self, medrecord: MedRecord, outcome_variable: MedRecordAttribute, reference: Literal["first", "last"] = "last", ) -> float: """Calculates Hedges' g, the unbiased effect size estimate. Hedges' g is a corrected version of Cohen's d that provides an unbiased estimate of the effect size, especially important when sample sizes are small (under 50). The correction factor is applied regardless of the sample size. Args: medrecord (MedRecord): An instance of the MedRecord class containing medical data. outcome_variable (MedRecordAttribute): The attribute in the edge that contains the outcome variable. It must be numeric and continuous. reference (Literal["first", "last"], optional): The reference point for the exposure time. Options include "first" and "last". If "first", the function returns the earliest exposure edge. If "last", the function returns the latest exposure edge. Defaults to "last". add_correction (bool, optional): Whether to apply a correction factor for small sample sizes. Defaults to False. Returns: float: The Hedges' g coefficient, representing the effect size. """ subjects = self.subject_indices(medrecord=medrecord) return hedges_g( medrecord=medrecord, treatment_outcome_true_set=subjects.get("treated_outcome_true"), control_outcome_true_set=subjects.get("control_outcome_true"), outcome_group=self._treatment_effect._outcomes_group, outcome_variable=outcome_variable, reference=reference, time_attribute=self._treatment_effect._time_attribute, )