Source code for medmodels.treatment_effect.treatment_effect

"""This module provides a class for analyzing treatment effects in medical records.

The TreatmentEffect class facilitates the analysis of treatment effects over time or
across different patient groups. It allows users to identify patients who underwent
treatment and experienced outcomes, and find a control group with similar criteria but
without undergoing the treatment. The class supports customizable criteria filtering,
time constraints between treatment and outcome, and optional matching of control groups
to treatment groups using a specified matching class.

The default TreatmentEffect class performs an static analysis without considering time.
To perform a time-based analysis, users can specify a time attribute in the
configuration and set the washout period, grace period, and follow-up period.
"""

from __future__ import annotations

import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set, Tuple

from medmodels import MedRecord
from medmodels.medrecord.querying import EdgeDirection, NodeOperand, NodeQuery
from medmodels.medrecord.types import (
    Group,
    MedRecordAttribute,
    MedRecordAttributeInputList,
    NodeIndex,
)
from medmodels.treatment_effect.builder import TreatmentEffectBuilder
from medmodels.treatment_effect.estimate import Estimate
from medmodels.treatment_effect.report import Report

if TYPE_CHECKING:
    from medmodels import MedRecord
    from medmodels.medrecord.types import (
        Group,
        MedRecordAttribute,
        MedRecordAttributeInputList,
        NodeIndex,
    )
    from medmodels.treatment_effect.matching.algorithms.propensity_score import Model
    from medmodels.treatment_effect.matching.matching import MatchingMethod


logger = logging.getLogger(__name__)


[docs] class TreatmentEffect: """The TreatmentEffect class for analyzing treatment effects in medical records.""" _treatments_group: Group _outcomes_group: Group _patients_group: Group _time_attribute: Optional[MedRecordAttribute] _washout_period_days: Dict[Group, int] _washout_period_reference: Literal["first", "last"] _grace_period_days: int _grace_period_reference: Literal["first", "last"] _follow_up_period_days: int _follow_up_period_reference: Literal["first", "last"] _outcome_before_treatment_days: Optional[int] _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] _matching_essential_covariates: MedRecordAttributeInputList _matching_one_hot_covariates: MedRecordAttributeInputList _matching_model: Model _matching_number_of_neighbors: int _matching_hyperparameters: Optional[Dict[str, Any]] def __init__( self, treatment: Group, outcome: Group, ) -> None: """Instantiates a Treatment Effect class. It requires the group of the Medrecord that contains the treatment node IDs and the group of the Medrecord that contains the outcome node IDs. Args: treatment (Group): The group of treatments to analyze. outcome (Group): The group of outcomes to analyze. """ TreatmentEffect._set_configuration(self, treatment=treatment, outcome=outcome)
[docs] @classmethod def builder(cls) -> TreatmentEffectBuilder: """Creates a TreatmentEffectBuilder instance for the TreatmentEffect class. Returns: TreatmentEffectBuilder: A TreatmentEffectBuilder instance for the TreatmentEffect class. """ return TreatmentEffectBuilder()
@staticmethod def _set_configuration( treatment_effect: TreatmentEffect, *, treatment: Group, outcome: Group, patients_group: Group = "patients", time_attribute: Optional[MedRecordAttribute] = None, washout_period_days: Optional[Dict[Group, int]] = None, washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, grace_period_reference: Literal["first", "last"] = "last", follow_up_period_days: int = 1000 * 365, follow_up_period_reference: Literal["first", "last"] = "last", outcome_before_treatment_days: Optional[int] = None, filter_controls_query: Optional[NodeQuery] = None, matching_method: Optional[MatchingMethod] = None, matching_essential_covariates: Optional[MedRecordAttributeInputList] = None, matching_one_hot_covariates: Optional[MedRecordAttributeInputList] = None, matching_model: Model = "logit", matching_number_of_neighbors: int = 1, matching_hyperparameters: Optional[Dict[str, Any]] = None, ) -> None: """Sets the configuration for the TreatmentEffect instance. Validates the presence of specified dimensions and attributes within the provided MedRecord object, ensuring the specified treatments and outcomes are valid and available for analysis. Args: treatment_effect (TreatmentEffect): The TreatmentEffect instance to configure. treatment (Group): The group of treatments to analyze. outcome (Group): The group of outcomes to analyze. patients_group (Group, optional): The group of patients to analyze. Defaults to "patients". time_attribute (Optional[MedRecordAttribute], optional): The time attribute. If None, the treatment effect analysis is performed in an static way (without considering time). Defaults to None. washout_period_days (Dict[str, int], optional): The washout period in days for each treatment group. In the case of no time attribute, it is not applied. Defaults to dict(). washout_period_reference (Literal["first", "last"], optional): The reference point for the washout period. Defaults to "first". grace_period_days (int, optional): The grace period in days after the treatment. Defaults to 0. grace_period_reference (Literal["first", "last"], optional): The reference point for the grace period. Defaults to "last". follow_up_period_days (int, optional): The follow-up period in days after the treatment. Defaults to 365000. follow_up_period_reference (Literal["first", "last"], optional): The reference point for the follow-up period. Defaults to "last". outcome_before_treatment_days (Optional[int], optional): The number of days before the treatment to consider for outcomes. Defaults to None. filter_controls_query (Optional[NodeQuery], optional): An optional query to filter the control group based on specified criteria. Defaults to None. matching_method (Optional[MatchingMethod]): The method to match treatment and control groups. Defaults to None. matching_essential_covariates (Optional[MedRecordAttributeInputList], optional): The essential covariates to use for matching. Defaults to ["gender", "age"]. matching_one_hot_covariates (Optional[MedRecordAttributeInputList], optional): The one-hot covariates to use for matching. Defaults to ["gender"]. matching_model (Model, optional): The model to use for matching. Defaults to "logit". matching_number_of_neighbors (int, optional): The number of neighbors to match for each treated subject. Defaults to 1. matching_hyperparameters (Optional[Dict[str, Any]], optional): The hyperparameters for the matching model. Defaults to None. Raises: ValueError: If the follow-up period is less than the grace period. """ # noqa: W505 if washout_period_days is None: washout_period_days = {} if matching_essential_covariates is None: matching_essential_covariates = ["gender", "age"] if matching_one_hot_covariates is None: matching_one_hot_covariates = ["gender"] treatment_effect._patients_group = patients_group treatment_effect._time_attribute = time_attribute treatment_effect._treatments_group = treatment treatment_effect._outcomes_group = outcome if follow_up_period_days < grace_period_days: msg = ( "The follow-up period must be greater than or equal to the grace period" ) raise ValueError(msg) treatment_effect._washout_period_days = washout_period_days treatment_effect._washout_period_reference = washout_period_reference treatment_effect._grace_period_days = grace_period_days treatment_effect._grace_period_reference = grace_period_reference treatment_effect._follow_up_period_days = follow_up_period_days treatment_effect._follow_up_period_reference = follow_up_period_reference treatment_effect._outcome_before_treatment_days = outcome_before_treatment_days treatment_effect._filter_controls_query = filter_controls_query treatment_effect._matching_method = matching_method treatment_effect._matching_essential_covariates = matching_essential_covariates treatment_effect._matching_one_hot_covariates = matching_one_hot_covariates treatment_effect._matching_model = matching_model treatment_effect._matching_number_of_neighbors = matching_number_of_neighbors treatment_effect._matching_hyperparameters = matching_hyperparameters if washout_period_days and not time_attribute: logger.warning( "Washout period is not applied because the time attribute is not set." ) if ( grace_period_days or (follow_up_period_days != 1000 * 365) or outcome_before_treatment_days ) and not time_attribute: msg = ( "Time attribute is not set, thus the grace period, follow-up " + "period, and outcome before treatment cannot be applied. The " + "treatment effect analysis is performed in a static way." ) logger.warning(msg) def _find_groups( self, medrecord: MedRecord ) -> Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: """Finds the treated and control groups in the MedRecord. This method finds the patients in the treated group and the control groups and whether they had the outcome or not. It supports customizable criteria filtering, time constraints between treatment and outcome, and optional matching of control groups to treatment groups using a specified matching class. Args: medrecord (MedRecord): An instance of the MedRecord class containing patient medical data. Returns: Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: A tuple containing the IDs of patients in the treated group who had the outcome (treated_outcome_true), the IDs of patients in the treated group who did not have the outcome (treatment_outcome_false), the IDs of patients in the control group who had the outcome (control_outcome_true), and the IDs of patients in the control group who did not have the outcome (control_outcome_false). """ # Find patients that underwent the treatment treated_set = self._find_treated_patients(medrecord) if self._time_attribute: treated_set, washout_nodes = self._apply_washout_period( medrecord, treated_set ) else: washout_nodes = set() treated_set, treated_outcome_true, outcome_before_treatment_nodes = ( self._find_outcomes(medrecord, treated_set) ) treated_outcome_false = treated_set - treated_outcome_true # Find the controls (patients that did not undergo the treatment) control_set = set(medrecord.nodes_in_group(self._patients_group)) control_outcome_true, control_outcome_false = self._find_controls( medrecord=medrecord, control_set=control_set, treated_set=treated_set, rejected_nodes=washout_nodes | outcome_before_treatment_nodes, filter_controls_query=self._filter_controls_query, ) return ( treated_outcome_true, treated_outcome_false, control_outcome_true, control_outcome_false, ) def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: """Find the patients that underwent the treatment. Args: medrecord (MedRecord): An instance of the MedRecord class containing patient medical data. Returns: Set[NodeIndex]: A set of patient nodes that underwent the treatment. Raises: ValueError: If no patients are found for the treatment groups in the MedRecord. """ def query(node: NodeOperand) -> None: node.in_group(self._patients_group) node.neighbors(edge_direction=EdgeDirection.BOTH).in_group( self._treatments_group ) # Create the group with all the patients that underwent the treatment treated_set = set(medrecord.select_nodes(query)) if not treated_set: msg = "No patients found for the treatment group in this MedRecord" raise ValueError(msg) return treated_set def _find_outcomes( self, medrecord: MedRecord, treated_set: Set[NodeIndex] ) -> Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: """Find the patients that had the outcome after the treatment. If set in the configuration, remove the ones that already had the outcome before the treatment. Args: medrecord (MedRecord): An instance of the MedRecord class containing patient medical data. treated_set (Set[NodeIndex]): A set of patient nodes that underwent the treatment. Returns: Tuple[Set[NodeIndex], Set[NodeIndex], Set[NodeIndex]]: A tuple containing: - The updated set of patient nodes that underwent the treatment. - The nodes that had the outcome after the treatment. - The nodes that had the outcome before the treatment (to be rejected). Only if the outcome_before_treatment_days is set. Raises: ValueError: If no outcomes are found in the MedRecord for the specified outcome group. """ outcome_before_treatment_nodes = set() outcome_before_treatment_days = self._outcome_before_treatment_days # Find nodes with the outcomes outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError(msg) if outcome_before_treatment_days and self._time_attribute: outcome_before_treatment_nodes = set( medrecord.select_nodes( lambda node: self._query_node_within_time_window( node, treated_set, self._outcomes_group, -outcome_before_treatment_days, 0, "first", ) ) ) treated_set -= outcome_before_treatment_nodes dropped_num = len(outcome_before_treatment_nodes) msg = ( f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " f"dropped due to having an outcome before the treatment." ) logger.warning(msg) if self._time_attribute: treated_outcome_true = set( medrecord.select_nodes( lambda node: self._query_node_within_time_window( node, treated_set, self._outcomes_group, self._grace_period_days, self._follow_up_period_days, self._follow_up_period_reference, ) ) ) else: treated_outcome_true = set( medrecord.select_nodes( lambda node: self._query_set_outcome_true(node, treated_set) ) ) return treated_set, treated_outcome_true, outcome_before_treatment_nodes def _apply_washout_period( self, medrecord: MedRecord, treated_set: Set[NodeIndex] ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Apply the washout period to the treatment group. Args: medrecord (MedRecord): An instance of the MedRecord class containing patient medical data. treated_set (Set[NodeIndex]): A set of patient nodes that underwent the treatment. Returns: Tuple[Set[NodeIndex], Set[NodeIndex]]: A tuple containing the updated set of patient nodes that underwent the treatment and the nodes that were dropped due to the washout period. """ washout_nodes = set() if not self._washout_period_days: return treated_set, washout_nodes # Apply the washout period to the treatment group # TODO: washout in both directions? We need a List then # noqa: TD003, TD002 for washout_group_id, washout_days in self._washout_period_days.items(): washout_nodes.update( medrecord.select_nodes( lambda node, group_id=washout_group_id, days=washout_days, treated=treated_set: self._query_node_within_time_window( node, treated, group_id, -days, 0, self._washout_period_reference, ) ) ) treated_set -= washout_nodes if washout_nodes: dropped_num = len(washout_nodes) msg = ( f"{dropped_num} subject{' was' if dropped_num == 1 else 's were'} " f"dropped due to having a treatment in the washout period." ) logger.warning(msg) return treated_set, washout_nodes def _find_controls( self, medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], rejected_nodes: Optional[Set[NodeIndex]] = None, filter_controls_query: Optional[NodeQuery] = None, ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Identifies control patients based on specified criteria. It takes the control group and removes the rejected nodes, the treated nodes, and applies the filter_controls_query if specified. Control groups are divided into those who had the outcome (control_outcome_true) and those who did not (control_outcome_false), based on the presence of the specified outcome codes. Args: medrecord (MedRecord): An instance of the MedRecord class containing patient medical data. control_set (Set[NodeIndex]): A set of patient nodes that did not undergo the treatment. treated_set (Set[NodeIndex]): A set of patient nodes that underwent the treatment. rejected_nodes (Optional[Set[NodeIndex]], optional): A set of patient nodes that were rejected due to the washout period or outcome before treatment. filter_controls_query (Optional[NodeQuery], optional): An optional query to filter the control group based on specified criteria. Defaults to None. Returns: Tuple[Set[NodeIndex], Set[NodeIndex]]: Two sets representing the IDs of control patients. The first set includes patients who experienced the specified outcomes (control_outcome_true), and the second set includes patients who did not experience the outcomes (control_outcome_false). Raises: ValueError: If no patients are found for the control groups in the MedRecord. ValueError: If no outcomes are found in the MedRecord for the specified outcome group. """ # Apply the filter to the control group if specified if rejected_nodes is None: rejected_nodes = set() if filter_controls_query: control_set = ( set(medrecord.select_nodes(filter_controls_query)) & control_set ) control_set = control_set - treated_set - rejected_nodes if len(control_set) == 0: msg = "No patients found for control groups in this MedRecord" raise ValueError(msg) control_outcome_true = set() outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError(msg) # Finding the patients that had the outcome in the control group control_outcome_true = set( medrecord.select_nodes( lambda node: self._query_set_outcome_true(node, control_set) ) ) control_outcome_false = control_set - control_outcome_true return control_outcome_true, control_outcome_false def _query_set_outcome_true(self, node: NodeOperand, set: Set[NodeIndex]) -> None: """Query for nodes that are in the given set and have the outcome. Args: node (NodeOperand): The node to query. set (Set[NodeIndex]): The set of nodes to query. """ node.index().is_in(list(set)) node.neighbors(edge_direction=EdgeDirection.BOTH).in_group(self._outcomes_group) def _query_node_within_time_window( self, node: NodeOperand, treated_set: Set[NodeIndex], outcome_group: Group, start_days: int, end_days: int, reference: Literal["first", "last"], ) -> None: """Queries for nodes with edges containing time info within a time window. It queries for nodes that: - Are in the treated group. - Have edges with time information. - Have edges that connect to the treatment group. - Have edges that connect to the outcome group. - The time of the outcome is within the specified time window: it being greater or equal than the first or last time of treatment (depending on the `reference`) and less or equal than the time of treatment plus the `end_days` specified. Args: node (NodeOperand): The node to query. treated_set (Set[NodeIndex]): A set of patient nodes that underwent the treatment. outcome_group (Group): The group of outcomes to analyze. start_days (int): The start of the time window in days relative to the reference event. end_days (int): The end of the time window in days relative to the reference event. reference (Literal["first", "last"]): The reference point for the time window. Raises: ValueError: If the time attribute is not set. """ node.index().is_in(list(treated_set)) if self._time_attribute is None: msg = "Time attribute is not set." raise ValueError(msg) edges_to_treatment = node.edges() edges_to_treatment.attribute(self._time_attribute).is_datetime() edges_to_treatment.either_or( lambda edge: edge.source_node().in_group(self._treatments_group), lambda edge: edge.target_node().in_group(self._treatments_group), ) edges_to_outcome = node.edges() edges_to_outcome.attribute(self._time_attribute).is_datetime() edges_to_outcome.either_or( lambda edge: edge.source_node().in_group(outcome_group), lambda edge: edge.target_node().in_group(outcome_group), ) if reference == "first": time_of_treatment = edges_to_treatment.attribute(self._time_attribute).min() else: time_of_treatment = edges_to_treatment.attribute(self._time_attribute).max() time_of_outcome = edges_to_outcome.attribute(self._time_attribute) min_time_window = time_of_treatment.clone() if start_days < 0: min_time_window.subtract(timedelta(-start_days)) else: min_time_window.add(timedelta(start_days)) max_time_window = time_of_treatment.clone() if end_days < 0: max_time_window.subtract(timedelta(-end_days)) else: max_time_window.add(timedelta(end_days)) time_of_outcome.greater_than_or_equal_to(min_time_window) time_of_outcome.less_than_or_equal_to(max_time_window) @property def estimate(self) -> Estimate: """Creates an Estimate object for the TreatmentEffect instance. Returns: Estimate: An Estimate object for the current TreatmentEffect instance. """ return Estimate(self) @property def report(self) -> Report: """Creates a Report object for the TreatmentEffect instance. Returns: Report: A Report object for the current TreatmentEffect instance. """ return Report(self)