Source code for medmodels.medrecord.schema

"""This module contains the schema classes for the medrecord module."""

from __future__ import annotations

from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union, overload

from medmodels._medmodels import (
    PyAttributeDataType,
    PyAttributeType,
    PyGroupSchema,
    PySchema,
)
from medmodels.medrecord.datatype import DataType

if TYPE_CHECKING:
    from medmodels.medrecord.types import Group, MedRecordAttribute


[docs] class AttributeType(Enum): """Enumeration of attribute types.""" Categorical = auto() Continuous = auto() Temporal = auto() @staticmethod def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType: """Converts a PyAttributeType to an AttributeType. Args: py_attribute_type (PyAttributeType): The PyAttributeType to convert. Returns: AttributeType: The converted AttributeType. """ if py_attribute_type == PyAttributeType.Categorical: return AttributeType.Categorical if py_attribute_type == PyAttributeType.Continuous: return AttributeType.Continuous if py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal return None def _into_py_attribute_type(self) -> PyAttributeType: """Converts an AttributeType to a PyAttributeType. Returns: PyAttributeType: The converted PyAttributeType. """ if self == AttributeType.Categorical: return PyAttributeType.Categorical if self == AttributeType.Continuous: return PyAttributeType.Continuous if self == AttributeType.Temporal: return PyAttributeType.Temporal msg = "Should never be reached" raise NotImplementedError(msg)
[docs] def __repr__(self) -> str: """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. """ return f"AttributeType.{self.name}"
[docs] def __str__(self) -> str: """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. """ return self.name
[docs] def __eq__(self, value: object) -> bool: """Compares the AttributeType instance to another object for equality. Args: value (object): The object to compare against. Returns: bool: True if the objects are equal, False otherwise. """ if isinstance(value, PyAttributeType): return self._into_py_attribute_type() == value if isinstance(value, AttributeType): return str(self) == str(value) return False
[docs] class AttributesSchema: """A schema for a collection of attributes.""" _attributes_schema: Dict[ MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] ] def __init__( self, attributes_schema: Dict[ MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] ], ) -> None: """Initializes a new instance of AttributesSchema. Args: attributes_schema (Dict[MedRecordAttribute, Tuple[DataType, Optional[AttributeType]]]): A dictionary mapping MedRecordAttributes to their data types and optional attribute types. """ # noqa: W505 self._attributes_schema = attributes_schema
[docs] def __repr__(self) -> str: """Returns a string representation of the AttributesSchema instance. Returns: str: String representation of the attribute schema. """ return self._attributes_schema.__repr__()
[docs] def __getitem__( self, key: MedRecordAttribute ) -> Tuple[DataType, Optional[AttributeType]]: """Gets the type and optional attribute type for a given MedRecordAttribute. Args: key (MedRecordAttribute): The attribute for which the data type is requested. Returns: Tuple[DataType, Optional[AttributeType]]: The data type and optional attribute type of the given attribute. """ return self._attributes_schema[key]
[docs] def __contains__(self, key: MedRecordAttribute) -> bool: """Checks if a given MedRecordAttribute is in the attributes schema. Args: key (MedRecordAttribute): The attribute to check. Returns: bool: True if the attribute exists in the schema, False otherwise. """ return key in self._attributes_schema
[docs] def __iter__(self) -> Iterator[MedRecordAttribute]: """Returns an iterator over the attributes schema. Returns: Iterator: An iterator over the attribute keys. """ return self._attributes_schema.__iter__()
[docs] def __len__(self) -> int: """Returns the number of attributes in the schema. Returns: int: The number of attributes. """ return len(self._attributes_schema)
[docs] def __eq__(self, value: object) -> bool: """Compares the AttributesSchema instance to another object for equality. Args: value (object): The object to compare against. Returns: bool: True if the objects are equal, False otherwise. """ if not (isinstance(value, (AttributesSchema, dict))): return False attribute_schema = ( value._attributes_schema if isinstance(value, AttributesSchema) else value ) if not attribute_schema.keys() == self._attributes_schema.keys(): return False for key in self._attributes_schema: if ( not isinstance(attribute_schema[key], tuple) or not isinstance( attribute_schema[key][0], type(self._attributes_schema[key][0]) ) or attribute_schema[key][1] != self._attributes_schema[key][1] ): return False return True
[docs] def keys(self): # noqa: ANN201 """Returns the attribute keys in the schema. Returns: KeysView: A view object displaying a list of dictionary's keys. """ return self._attributes_schema.keys()
[docs] def values(self): # noqa: ANN201 """Returns the attribute values in the schema. Returns: ValuesView: A view object displaying a list of dictionary's values. """ return self._attributes_schema.values()
[docs] def items(self): # noqa: ANN201 """Returns the attribute key-value pairs in the schema. Returns: ItemsView: A set-like object providing a view on D's items. """ return self._attributes_schema.items()
@overload def get( self, key: MedRecordAttribute ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: ... @overload def get( self, key: MedRecordAttribute, default: Tuple[DataType, Optional[AttributeType]] ) -> Tuple[DataType, Optional[AttributeType]]: ...
[docs] def get( self, key: MedRecordAttribute, default: Optional[Tuple[DataType, Optional[AttributeType]]] = None, ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: """Gets the data type and optional attribute type for a given attribute. It returns a default value if the attribute is not present. Args: key (MedRecordAttribute): The attribute for which the data type is requested. default (Optional[Tuple[DataType, Optional[AttributeType]]], optional): The default data type and attribute type to return if the attribute is not found. Defaults to None. Returns: Optional[Tuple[DataType, Optional[AttributeType]]]: The data type and optional attribute type of the given attribute or the default value. """ return self._attributes_schema.get(key, default)
[docs] class GroupSchema: """A schema for a group of nodes and edges.""" _group_schema: PyGroupSchema def __init__( self, *, nodes: Optional[ Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] ] = None, edges: Optional[ Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] ] = None, strict: bool = False, ) -> None: """Initializes a new instance of GroupSchema. Args: nodes (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): A dictionary mapping node attributes to their data types and optional attribute types. Defaults to an empty dictionary. edges (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): A dictionary mapping edge attributes to their data types and optional attribute types. Defaults to an empty dictionary. strict (bool, optional): Indicates whether the schema should be strict. Defaults to False. """ # noqa: W505 if edges is None: edges = {} if nodes is None: nodes = {} def _convert_input( input: Union[DataType, Tuple[DataType, AttributeType]], ) -> PyAttributeDataType: if isinstance(input, tuple): return PyAttributeDataType( input[0]._inner(), input[1]._into_py_attribute_type() ) return PyAttributeDataType(input._inner(), None) self._group_schema = PyGroupSchema( nodes={x: _convert_input(nodes[x]) for x in nodes}, edges={x: _convert_input(edges[x]) for x in edges}, strict=strict, ) @classmethod def _from_pygroupschema(cls, group_schema: PyGroupSchema) -> GroupSchema: """Creates a GroupSchema instance from an existing PyGroupSchema. Args: group_schema (PyGroupSchema): The PyGroupSchema instance to convert. Returns: GroupSchema: A new GroupSchema instance. """ new_group_schema = cls() new_group_schema._group_schema = group_schema return new_group_schema @property def nodes(self) -> AttributesSchema: """Returns the node attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the node attributes and their data types. """ def _convert_node( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( DataType._from_py_data_type(input.data_type), AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) return AttributesSchema( { x: _convert_node(self._group_schema.nodes[x]) for x in self._group_schema.nodes } ) @property def edges(self) -> AttributesSchema: """Returns the edge attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the edge attributes and their data types. """ def _convert_edge( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( DataType._from_py_data_type(input.data_type), AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) return AttributesSchema( { x: _convert_edge(self._group_schema.edges[x]) for x in self._group_schema.edges } ) @property def strict(self) -> Optional[bool]: """Indicates whether the GroupSchema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. """ return self._group_schema.strict
[docs] class Schema: """A schema for a collection of groups.""" _schema: PySchema def __init__( self, *, groups: Optional[Dict[Group, GroupSchema]] = None, default: Optional[GroupSchema] = None, strict: bool = False, ) -> None: """Initializes a new instance of Schema. Args: groups (Dict[Group, GroupSchema], optional): A dictionary of group names to their schemas. Defaults to an empty dictionary. default (Optional[GroupSchema], optional): The default group schema. Defaults to None. strict (bool, optional): Indicates whether the schema should be strict. Defaults to False. """ if groups is None: groups = {} if default is not None: self._schema = PySchema( groups={x: groups[x]._group_schema for x in groups}, default=default._group_schema, strict=strict, ) else: self._schema = PySchema( groups={x: groups[x]._group_schema for x in groups}, strict=strict, ) @classmethod def _from_py_schema(cls, schema: PySchema) -> Schema: """Creates a Schema instance from an existing PySchema. Args: schema (PySchema): The PySchema instance to convert. Returns: Schema: A new Schema instance. """ new_schema = cls() new_schema._schema = schema return new_schema @property def groups(self) -> List[Group]: """Lists all the groups in the Schema instance. Returns: List[Group]: A list of groups. """ return self._schema.groups
[docs] def group(self, group: Group) -> GroupSchema: """Retrieves the schema for a specific group. Args: group (Group): The name of the group. Returns: GroupSchema: The schema for the specified group. Raises: ValueError: If the group does not exist in the schema. """ # noqa: DOC502 return GroupSchema._from_pygroupschema(self._schema.group(group))
@property def default(self) -> Optional[GroupSchema]: """Retrieves the default group schema. Returns: Optional[GroupSchema]: The default group schema if it exists, otherwise None. """ if self._schema.default is None: return None return GroupSchema._from_pygroupschema(self._schema.default) @property def strict(self) -> Optional[bool]: """Indicates whether the Schema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. """ return self._schema.strict