Source code for medmodels.treatment_effect.builder

"""This module contains the TreatmentEffectBuilder class."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Literal, Optional

import medmodels.treatment_effect.treatment_effect as tee

if TYPE_CHECKING:
    from medmodels.medrecord.querying import NodeQuery
    from medmodels.medrecord.types import (
        Group,
        MedRecordAttribute,
        MedRecordAttributeInputList,
    )
    from medmodels.treatment_effect.matching.algorithms.propensity_score import Model
    from medmodels.treatment_effect.matching.matching import MatchingMethod


[docs] class TreatmentEffectBuilder: """Builder class for the TreatmentEffect object. The TreatmentEffectBuilder class is used to build a TreatmentEffect object with the desired configurations for the treatment effect estimation using a builder pattern. By default, it configures a static treatment effect estimation. To configure a time-dependent treatment effect estimation, the time_attribute must be set. """ treatment: Group outcome: Group patients_group: Optional[Group] time_attribute: Optional[MedRecordAttribute] washout_period_days: Optional[Dict[str, int]] washout_period_reference: Optional[Literal["first", "last"]] grace_period_days: Optional[int] grace_period_reference: Optional[Literal["first", "last"]] follow_up_period_days: Optional[int] follow_up_period_reference: Optional[Literal["first", "last"]] outcome_before_treatment_days: Optional[int] filter_controls_query: Optional[NodeQuery] matching_method: Optional[MatchingMethod] matching_essential_covariates: Optional[MedRecordAttributeInputList] matching_one_hot_covariates: Optional[MedRecordAttributeInputList] matching_model: Optional[Model] matching_number_of_neighbors: Optional[int] matching_hyperparameters: Optional[Dict[str, Any]]
[docs] def with_treatment(self, treatment: Group) -> TreatmentEffectBuilder: """Sets the treatment group for the treatment effect estimation. Args: treatment (Group): The treatment group. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder. """ self.treatment = treatment return self
[docs] def with_outcome(self, outcome: Group) -> TreatmentEffectBuilder: """Sets the outcome group for the treatment effect estimation. Args: outcome (Group): The group to be used as the outcome. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated outcome group. """ self.outcome = outcome return self
[docs] def with_patients_group(self, group: Group) -> TreatmentEffectBuilder: """Sets the group of patients to be used in the treatment effect estimation. Args: group (Group): The group of patients. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated patients group. """ self.patients_group = group return self
[docs] def with_time_attribute( self, attribute: MedRecordAttribute ) -> TreatmentEffectBuilder: """Sets the time attribute to be used in the treatment effect estimation. It turs the treatment effect estimation from a static to a time-dependent analysis. Args: attribute (MedRecordAttribute): The time attribute. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ self.time_attribute = attribute return self
[docs] def with_washout_period( self, days: Optional[Dict[str, int]] = None, reference: Optional[Literal["first", "last"]] = None, ) -> TreatmentEffectBuilder: """Sets the washout period for the treatment effect estimation. The washout period is the period of time before the treatment that is not considered in the estimation. Args: days (Optional[Dict[str, int]], optional): The duration of the washout period in days. If None, the duration is left as it was. Defaults to None. reference (Optional[Literal['first', 'last']], optional): The reference point for the washout period. Must be either 'first' or 'last'. Defaults to None. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ if days is not None: self.washout_period_days = days if reference is not None: self.washout_period_reference = reference return self
[docs] def with_grace_period( self, days: Optional[int] = None, reference: Optional[Literal["first", "last"]] = None, ) -> TreatmentEffectBuilder: """Sets the grace period for the treatment effect estimation. The grace period is the period of time after the treatment that is not considered in the estimation. Args: days (Optional[int], optional): The duration of the grace period in days. If None, the duration is left as it was. Defaults to 0. reference (Optional[Literal['first', 'last']], optional): The reference point for the grace period. Must be either 'first' or 'last'. Defaults to None. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ if days is not None: self.grace_period_days = days if reference is not None: self.grace_period_reference = reference return self
[docs] def with_follow_up_period( self, days: Optional[int] = None, reference: Optional[Literal["first", "last"]] = None, ) -> TreatmentEffectBuilder: """Sets the follow-up period for the treatment effect estimation. Args: days (Optional[int], optional): The duration of the follow-up period in days. If None, the duration is left as it was. Defaults to 365. reference (Optional[Literal['first', 'last']], optional): The reference point for the follow-up period. Must be either 'first' or 'last'. Defaults to None. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ if days is not None: self.follow_up_period_days = days if reference is not None: self.follow_up_period_reference = reference return self
[docs] def with_outcome_before_treatment_exclusion( self, days: int ) -> TreatmentEffectBuilder: """Define whether we allow the outcome to exist before the treatment or not. The outcome_before_treatment_days parameter is used to set the number of days before the treatment that the outcome should not exist. If not set, the outcome is allowed to exist before the treatment. Args: days (int): The number of days before the treatment that the outcome should not exist. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ self.outcome_before_treatment_days = days return self
[docs] def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder: """Filter the control group based on the provided query. Args: query (NodeQuery): The query to be applied to the control group. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ self.filter_controls_query = query return self
[docs] def with_propensity_matching( self, essential_covariates: Optional[MedRecordAttributeInputList] = None, one_hot_covariates: Optional[MedRecordAttributeInputList] = None, model: Model = "logit", number_of_neighbors: int = 1, hyperparameters: Optional[Dict[str, Any]] = None, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using propensity score matching. Args: essential_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. one_hot_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are one-hot encoded for matching. Defaults to ["gender"]. model (Model, optional): Model to choose for the matching. Defaults to "logit". number_of_neighbors (int, optional): Number of neighbors to consider for the matching. Defaults to 1. hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for the matching model. Defaults to None. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ if essential_covariates is None: essential_covariates = ["gender", "age"] if one_hot_covariates is None: one_hot_covariates = ["gender"] self.matching_method = "propensity" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates self.matching_model = model self.matching_number_of_neighbors = number_of_neighbors self.matching_hyperparameters = hyperparameters return self
[docs] def with_nearest_neighbors_matching( self, essential_covariates: Optional[MedRecordAttributeInputList] = None, one_hot_covariates: Optional[MedRecordAttributeInputList] = None, number_of_neighbors: int = 1, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using nearest neighbors matching. Args: essential_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are essential for matching. Defaults to ["gender", "age"]. one_hot_covariates (Optional[MedRecordAttributeInputList], optional): Covariates that are one-hot encoded for matching. Defaults to ["gender"]. number_of_neighbors (int, optional): Number of neighbors to consider for the matching. Defaults to 1. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ if essential_covariates is None: essential_covariates = ["gender", "age"] if one_hot_covariates is None: one_hot_covariates = ["gender"] self.matching_method = "nearest_neighbors" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates self.matching_number_of_neighbors = number_of_neighbors return self
[docs] def build(self) -> tee.TreatmentEffect: """Builds the treatment effect with all the provided configurations. Returns: tee.TreatmentEffect: treatment effect object """ treatment_effect = tee.TreatmentEffect.__new__(tee.TreatmentEffect) tee.TreatmentEffect._set_configuration(treatment_effect, **vars(self)) return treatment_effect