opsml.model.interfaces.base

  1from dataclasses import dataclass
  2from pathlib import Path
  3from typing import Any, Dict, List, Optional, Tuple, cast
  4from uuid import UUID
  5
  6import joblib
  7import numpy as np
  8import pandas as pd
  9from pydantic import BaseModel, ConfigDict, field_validator, model_validator
 10
 11from opsml.helpers.utils import get_class_name
 12from opsml.types import CommonKwargs, ModelReturn, OnnxModel
 13from opsml.types.extra import Suffix
 14
 15
 16def get_processor_name(_class: Optional[Any] = None) -> str:
 17    if _class is not None:
 18        return str(_class.__class__.__name__)
 19
 20    return CommonKwargs.UNDEFINED.value
 21
 22
 23def get_model_args(model: Any) -> Tuple[Any, str, List[str]]:
 24    assert model is not None, "Model must not be None"
 25
 26    model_module = model.__module__
 27    model_bases = [str(base) for base in model.__class__.__bases__]
 28
 29    return model, model_module, model_bases
 30
 31
 32@dataclass
 33class SamplePrediction:
 34    """Dataclass that holds sample prediction information
 35
 36    Args:
 37        prediction_type:
 38            Type of prediction
 39        prediction:
 40            Sample prediction
 41    """
 42
 43    prediction_type: str
 44    prediction: Any
 45
 46
 47class ModelInterface(BaseModel):
 48    model: Optional[Any] = None
 49    sample_data: Optional[Any] = None
 50    onnx_model: Optional[OnnxModel] = None
 51    task_type: str = CommonKwargs.UNDEFINED.value
 52    model_type: str = CommonKwargs.UNDEFINED.value
 53    data_type: str = CommonKwargs.UNDEFINED.value
 54    modelcard_uid: str = ""
 55
 56    model_config = ConfigDict(
 57        protected_namespaces=("protect_",),
 58        arbitrary_types_allowed=True,
 59        validate_assignment=False,
 60        validate_default=True,
 61        extra="allow",
 62    )
 63
 64    @property
 65    def model_class(self) -> str:
 66        return CommonKwargs.UNDEFINED.value
 67
 68    @model_validator(mode="before")
 69    @classmethod
 70    def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 71        if model_args.get("modelcard_uid", False):
 72            return model_args
 73
 74        sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
 75        model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
 76        model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
 77
 78        return model_args
 79
 80    @field_validator("modelcard_uid", mode="before")
 81    @classmethod
 82    def check_modelcard_uid(cls, modelcard_uid: str) -> str:
 83        # empty strings are falsey
 84        if not modelcard_uid:
 85            return modelcard_uid
 86
 87        try:
 88            UUID(modelcard_uid, version=4)  # we use uuid4
 89            return modelcard_uid
 90
 91        except ValueError as exc:
 92            raise ValueError("ModelCard uid is not a valid uuid") from exc
 93
 94    def save_model(self, path: Path) -> None:
 95        """Saves model to path. Base implementation use Joblib
 96
 97        Args:
 98            path:
 99                Pathlib object
100        """
101        assert self.model is not None, "No model detected in interface"
102        joblib.dump(self.model, path)
103
104    def load_model(self, path: Path, **kwargs: Any) -> None:
105        """Load model from pathlib object
106
107        Args:
108            path:
109                Pathlib object
110            kwargs:
111                Additional kwargs
112        """
113        self.model = joblib.load(path)
114
115    def save_onnx(self, path: Path) -> ModelReturn:
116        """Saves the onnx model
117
118        Args:
119            path:
120                Path to save
121
122        Returns:
123            ModelReturn
124        """
125        import onnxruntime as rt
126
127        from opsml.model.onnx import _get_onnx_metadata
128
129        if self.onnx_model is None:
130            self.convert_to_onnx()
131            sess: rt.InferenceSession = self.onnx_model.sess
132            path.write_bytes(sess._model_bytes)  # pylint: disable=protected-access
133
134        else:
135            self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value))
136
137        assert self.onnx_model is not None, "No onnx model detected in interface"
138        metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
139
140        return metadata
141
142    def convert_to_onnx(self, **kwargs: Path) -> None:
143        """Converts model to onnx format"""
144        from opsml.model.onnx import _OnnxModelConverter
145
146        if self.onnx_model is not None:
147            return None
148
149        metadata = _OnnxModelConverter(self).convert_model()
150        self.onnx_model = metadata.onnx_model
151
152        return None
153
154    def load_onnx_model(self, path: Path) -> None:
155        """Load onnx model from pathlib object
156
157        Args:
158            path:
159                Pathlib object
160        """
161        from onnxruntime import InferenceSession
162
163        assert self.onnx_model is not None, "No onnx model detected in interface"
164        self.onnx_model.sess = InferenceSession(path)
165
166    def save_sample_data(self, path: Path) -> None:
167        """Serialized and save sample data to path.
168
169        Args:
170            path:
171                Pathlib object
172        """
173        joblib.dump(self.sample_data, path)
174
175    def load_sample_data(self, path: Path) -> None:
176        """Serialized and save sample data to path.
177
178        Args:
179            path:
180                Pathlib object
181        """
182
183        self.sample_data = joblib.load(path)
184
185    @classmethod
186    def _get_sample_data(cls, sample_data: Any) -> Any:
187        """Check sample data and returns one record to be used
188        during type inference and ONNX conversion/validation.
189
190        Returns:
191            Sample data with only one record
192        """
193
194        if isinstance(sample_data, list):
195            return [data[0:1] for data in sample_data]
196
197        if isinstance(sample_data, tuple):
198            return (data[0:1] for data in sample_data)
199
200        if isinstance(sample_data, dict):
201            return {key: data[0:1] for key, data in sample_data.items()}
202
203        return sample_data[0:1]
204
205    def get_sample_prediction(self) -> SamplePrediction:
206        assert self.model is not None, "Model is not defined"
207        assert self.sample_data is not None, "Sample data must be provided"
208
209        if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)):
210            prediction = self.model.predict(self.sample_data)
211
212        elif isinstance(self.sample_data, dict):
213            try:
214                prediction = self.model.predict(**self.sample_data)
215            except Exception as _:  # pylint: disable=broad-except
216                prediction = self.model.predict(self.sample_data)
217
218        elif isinstance(self.sample_data, (list, tuple)):
219            try:
220                prediction = self.model.predict(*self.sample_data)
221            except Exception as _:  # pylint: disable=broad-except
222                prediction = self.model.predict(self.sample_data)
223
224        else:
225            prediction = self.model.predict(self.sample_data)
226
227        prediction_type = get_class_name(prediction)
228
229        return SamplePrediction(
230            prediction_type,
231            prediction,
232        )
233
234    @property
235    def model_suffix(self) -> str:
236        """Returns suffix for storage"""
237        return Suffix.JOBLIB.value
238
239    @property
240    def data_suffix(self) -> str:
241        """Returns suffix for storage"""
242        return Suffix.JOBLIB.value
243
244    @staticmethod
245    def name() -> str:
246        return ModelInterface.__name__
def get_processor_name(_class: Optional[Any] = None) -> str:
17def get_processor_name(_class: Optional[Any] = None) -> str:
18    if _class is not None:
19        return str(_class.__class__.__name__)
20
21    return CommonKwargs.UNDEFINED.value
def get_model_args(model: Any) -> Tuple[Any, str, List[str]]:
24def get_model_args(model: Any) -> Tuple[Any, str, List[str]]:
25    assert model is not None, "Model must not be None"
26
27    model_module = model.__module__
28    model_bases = [str(base) for base in model.__class__.__bases__]
29
30    return model, model_module, model_bases
@dataclass
class SamplePrediction:
33@dataclass
34class SamplePrediction:
35    """Dataclass that holds sample prediction information
36
37    Args:
38        prediction_type:
39            Type of prediction
40        prediction:
41            Sample prediction
42    """
43
44    prediction_type: str
45    prediction: Any

Dataclass that holds sample prediction information

Arguments:
  • prediction_type: Type of prediction
  • prediction: Sample prediction
SamplePrediction(prediction_type: str, prediction: Any)
prediction_type: str
prediction: Any
class ModelInterface(pydantic.main.BaseModel):
 48class ModelInterface(BaseModel):
 49    model: Optional[Any] = None
 50    sample_data: Optional[Any] = None
 51    onnx_model: Optional[OnnxModel] = None
 52    task_type: str = CommonKwargs.UNDEFINED.value
 53    model_type: str = CommonKwargs.UNDEFINED.value
 54    data_type: str = CommonKwargs.UNDEFINED.value
 55    modelcard_uid: str = ""
 56
 57    model_config = ConfigDict(
 58        protected_namespaces=("protect_",),
 59        arbitrary_types_allowed=True,
 60        validate_assignment=False,
 61        validate_default=True,
 62        extra="allow",
 63    )
 64
 65    @property
 66    def model_class(self) -> str:
 67        return CommonKwargs.UNDEFINED.value
 68
 69    @model_validator(mode="before")
 70    @classmethod
 71    def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 72        if model_args.get("modelcard_uid", False):
 73            return model_args
 74
 75        sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
 76        model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
 77        model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
 78
 79        return model_args
 80
 81    @field_validator("modelcard_uid", mode="before")
 82    @classmethod
 83    def check_modelcard_uid(cls, modelcard_uid: str) -> str:
 84        # empty strings are falsey
 85        if not modelcard_uid:
 86            return modelcard_uid
 87
 88        try:
 89            UUID(modelcard_uid, version=4)  # we use uuid4
 90            return modelcard_uid
 91
 92        except ValueError as exc:
 93            raise ValueError("ModelCard uid is not a valid uuid") from exc
 94
 95    def save_model(self, path: Path) -> None:
 96        """Saves model to path. Base implementation use Joblib
 97
 98        Args:
 99            path:
100                Pathlib object
101        """
102        assert self.model is not None, "No model detected in interface"
103        joblib.dump(self.model, path)
104
105    def load_model(self, path: Path, **kwargs: Any) -> None:
106        """Load model from pathlib object
107
108        Args:
109            path:
110                Pathlib object
111            kwargs:
112                Additional kwargs
113        """
114        self.model = joblib.load(path)
115
116    def save_onnx(self, path: Path) -> ModelReturn:
117        """Saves the onnx model
118
119        Args:
120            path:
121                Path to save
122
123        Returns:
124            ModelReturn
125        """
126        import onnxruntime as rt
127
128        from opsml.model.onnx import _get_onnx_metadata
129
130        if self.onnx_model is None:
131            self.convert_to_onnx()
132            sess: rt.InferenceSession = self.onnx_model.sess
133            path.write_bytes(sess._model_bytes)  # pylint: disable=protected-access
134
135        else:
136            self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value))
137
138        assert self.onnx_model is not None, "No onnx model detected in interface"
139        metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
140
141        return metadata
142
143    def convert_to_onnx(self, **kwargs: Path) -> None:
144        """Converts model to onnx format"""
145        from opsml.model.onnx import _OnnxModelConverter
146
147        if self.onnx_model is not None:
148            return None
149
150        metadata = _OnnxModelConverter(self).convert_model()
151        self.onnx_model = metadata.onnx_model
152
153        return None
154
155    def load_onnx_model(self, path: Path) -> None:
156        """Load onnx model from pathlib object
157
158        Args:
159            path:
160                Pathlib object
161        """
162        from onnxruntime import InferenceSession
163
164        assert self.onnx_model is not None, "No onnx model detected in interface"
165        self.onnx_model.sess = InferenceSession(path)
166
167    def save_sample_data(self, path: Path) -> None:
168        """Serialized and save sample data to path.
169
170        Args:
171            path:
172                Pathlib object
173        """
174        joblib.dump(self.sample_data, path)
175
176    def load_sample_data(self, path: Path) -> None:
177        """Serialized and save sample data to path.
178
179        Args:
180            path:
181                Pathlib object
182        """
183
184        self.sample_data = joblib.load(path)
185
186    @classmethod
187    def _get_sample_data(cls, sample_data: Any) -> Any:
188        """Check sample data and returns one record to be used
189        during type inference and ONNX conversion/validation.
190
191        Returns:
192            Sample data with only one record
193        """
194
195        if isinstance(sample_data, list):
196            return [data[0:1] for data in sample_data]
197
198        if isinstance(sample_data, tuple):
199            return (data[0:1] for data in sample_data)
200
201        if isinstance(sample_data, dict):
202            return {key: data[0:1] for key, data in sample_data.items()}
203
204        return sample_data[0:1]
205
206    def get_sample_prediction(self) -> SamplePrediction:
207        assert self.model is not None, "Model is not defined"
208        assert self.sample_data is not None, "Sample data must be provided"
209
210        if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)):
211            prediction = self.model.predict(self.sample_data)
212
213        elif isinstance(self.sample_data, dict):
214            try:
215                prediction = self.model.predict(**self.sample_data)
216            except Exception as _:  # pylint: disable=broad-except
217                prediction = self.model.predict(self.sample_data)
218
219        elif isinstance(self.sample_data, (list, tuple)):
220            try:
221                prediction = self.model.predict(*self.sample_data)
222            except Exception as _:  # pylint: disable=broad-except
223                prediction = self.model.predict(self.sample_data)
224
225        else:
226            prediction = self.model.predict(self.sample_data)
227
228        prediction_type = get_class_name(prediction)
229
230        return SamplePrediction(
231            prediction_type,
232            prediction,
233        )
234
235    @property
236    def model_suffix(self) -> str:
237        """Returns suffix for storage"""
238        return Suffix.JOBLIB.value
239
240    @property
241    def data_suffix(self) -> str:
242        """Returns suffix for storage"""
243        return Suffix.JOBLIB.value
244
245    @staticmethod
246    def name() -> str:
247        return ModelInterface.__name__

Usage docs: https://docs.pydantic.dev/2.6/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of classvars defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The signature for instantiating the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a RootModel.
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_extra__: An instance attribute with the values of extra fields from validation when model_config['extra'] == 'allow'.
  • __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
  • __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
model: Optional[Any]
sample_data: Optional[Any]
onnx_model: Optional[opsml.types.model.OnnxModel]
task_type: str
model_type: str
data_type: str
modelcard_uid: str
model_config = {'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_class: str
65    @property
66    def model_class(self) -> str:
67        return CommonKwargs.UNDEFINED.value
@model_validator(mode='before')
@classmethod
def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
69    @model_validator(mode="before")
70    @classmethod
71    def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
72        if model_args.get("modelcard_uid", False):
73            return model_args
74
75        sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
76        model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
77        model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
78
79        return model_args
@field_validator('modelcard_uid', mode='before')
@classmethod
def check_modelcard_uid(cls, modelcard_uid: str) -> str:
81    @field_validator("modelcard_uid", mode="before")
82    @classmethod
83    def check_modelcard_uid(cls, modelcard_uid: str) -> str:
84        # empty strings are falsey
85        if not modelcard_uid:
86            return modelcard_uid
87
88        try:
89            UUID(modelcard_uid, version=4)  # we use uuid4
90            return modelcard_uid
91
92        except ValueError as exc:
93            raise ValueError("ModelCard uid is not a valid uuid") from exc
def save_model(self, path: pathlib.Path) -> None:
 95    def save_model(self, path: Path) -> None:
 96        """Saves model to path. Base implementation use Joblib
 97
 98        Args:
 99            path:
100                Pathlib object
101        """
102        assert self.model is not None, "No model detected in interface"
103        joblib.dump(self.model, path)

Saves model to path. Base implementation use Joblib

Arguments:
  • path: Pathlib object
def load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
105    def load_model(self, path: Path, **kwargs: Any) -> None:
106        """Load model from pathlib object
107
108        Args:
109            path:
110                Pathlib object
111            kwargs:
112                Additional kwargs
113        """
114        self.model = joblib.load(path)

Load model from pathlib object

Arguments:
  • path: Pathlib object
  • kwargs: Additional kwargs
def save_onnx(self, path: pathlib.Path) -> opsml.types.model.ModelReturn:
116    def save_onnx(self, path: Path) -> ModelReturn:
117        """Saves the onnx model
118
119        Args:
120            path:
121                Path to save
122
123        Returns:
124            ModelReturn
125        """
126        import onnxruntime as rt
127
128        from opsml.model.onnx import _get_onnx_metadata
129
130        if self.onnx_model is None:
131            self.convert_to_onnx()
132            sess: rt.InferenceSession = self.onnx_model.sess
133            path.write_bytes(sess._model_bytes)  # pylint: disable=protected-access
134
135        else:
136            self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value))
137
138        assert self.onnx_model is not None, "No onnx model detected in interface"
139        metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
140
141        return metadata

Saves the onnx model

Arguments:
  • path: Path to save
Returns:

ModelReturn

def convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
143    def convert_to_onnx(self, **kwargs: Path) -> None:
144        """Converts model to onnx format"""
145        from opsml.model.onnx import _OnnxModelConverter
146
147        if self.onnx_model is not None:
148            return None
149
150        metadata = _OnnxModelConverter(self).convert_model()
151        self.onnx_model = metadata.onnx_model
152
153        return None

Converts model to onnx format

def load_onnx_model(self, path: pathlib.Path) -> None:
155    def load_onnx_model(self, path: Path) -> None:
156        """Load onnx model from pathlib object
157
158        Args:
159            path:
160                Pathlib object
161        """
162        from onnxruntime import InferenceSession
163
164        assert self.onnx_model is not None, "No onnx model detected in interface"
165        self.onnx_model.sess = InferenceSession(path)

Load onnx model from pathlib object

Arguments:
  • path: Pathlib object
def save_sample_data(self, path: pathlib.Path) -> None:
167    def save_sample_data(self, path: Path) -> None:
168        """Serialized and save sample data to path.
169
170        Args:
171            path:
172                Pathlib object
173        """
174        joblib.dump(self.sample_data, path)

Serialized and save sample data to path.

Arguments:
  • path: Pathlib object
def load_sample_data(self, path: pathlib.Path) -> None:
176    def load_sample_data(self, path: Path) -> None:
177        """Serialized and save sample data to path.
178
179        Args:
180            path:
181                Pathlib object
182        """
183
184        self.sample_data = joblib.load(path)

Serialized and save sample data to path.

Arguments:
  • path: Pathlib object
def get_sample_prediction(self) -> SamplePrediction:
206    def get_sample_prediction(self) -> SamplePrediction:
207        assert self.model is not None, "Model is not defined"
208        assert self.sample_data is not None, "Sample data must be provided"
209
210        if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)):
211            prediction = self.model.predict(self.sample_data)
212
213        elif isinstance(self.sample_data, dict):
214            try:
215                prediction = self.model.predict(**self.sample_data)
216            except Exception as _:  # pylint: disable=broad-except
217                prediction = self.model.predict(self.sample_data)
218
219        elif isinstance(self.sample_data, (list, tuple)):
220            try:
221                prediction = self.model.predict(*self.sample_data)
222            except Exception as _:  # pylint: disable=broad-except
223                prediction = self.model.predict(self.sample_data)
224
225        else:
226            prediction = self.model.predict(self.sample_data)
227
228        prediction_type = get_class_name(prediction)
229
230        return SamplePrediction(
231            prediction_type,
232            prediction,
233        )
model_suffix: str
235    @property
236    def model_suffix(self) -> str:
237        """Returns suffix for storage"""
238        return Suffix.JOBLIB.value

Returns suffix for storage

data_suffix: str
240    @property
241    def data_suffix(self) -> str:
242        """Returns suffix for storage"""
243        return Suffix.JOBLIB.value

Returns suffix for storage

@staticmethod
def name() -> str:
245    @staticmethod
246    def name() -> str:
247        return ModelInterface.__name__
model_fields = {'model': FieldInfo(annotation=Union[Any, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Any, NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default='')}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs