opsml.model.interfaces.base
1from dataclasses import dataclass 2from pathlib import Path 3from typing import Any, Dict, List, Optional, Tuple, cast 4from uuid import UUID 5 6import joblib 7import numpy as np 8import pandas as pd 9from pydantic import BaseModel, ConfigDict, field_validator, model_validator 10 11from opsml.helpers.utils import get_class_name 12from opsml.types import CommonKwargs, ModelReturn, OnnxModel 13from opsml.types.extra import Suffix 14 15 16def get_processor_name(_class: Optional[Any] = None) -> str: 17 if _class is not None: 18 return str(_class.__class__.__name__) 19 20 return CommonKwargs.UNDEFINED.value 21 22 23def get_model_args(model: Any) -> Tuple[Any, str, List[str]]: 24 assert model is not None, "Model must not be None" 25 26 model_module = model.__module__ 27 model_bases = [str(base) for base in model.__class__.__bases__] 28 29 return model, model_module, model_bases 30 31 32@dataclass 33class SamplePrediction: 34 """Dataclass that holds sample prediction information 35 36 Args: 37 prediction_type: 38 Type of prediction 39 prediction: 40 Sample prediction 41 """ 42 43 prediction_type: str 44 prediction: Any 45 46 47class ModelInterface(BaseModel): 48 model: Optional[Any] = None 49 sample_data: Optional[Any] = None 50 onnx_model: Optional[OnnxModel] = None 51 task_type: str = CommonKwargs.UNDEFINED.value 52 model_type: str = CommonKwargs.UNDEFINED.value 53 data_type: str = CommonKwargs.UNDEFINED.value 54 modelcard_uid: str = "" 55 56 model_config = ConfigDict( 57 protected_namespaces=("protect_",), 58 arbitrary_types_allowed=True, 59 validate_assignment=False, 60 validate_default=True, 61 extra="allow", 62 ) 63 64 @property 65 def model_class(self) -> str: 66 return CommonKwargs.UNDEFINED.value 67 68 @model_validator(mode="before") 69 @classmethod 70 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 71 if model_args.get("modelcard_uid", False): 72 return model_args 73 74 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 75 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 76 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 77 78 return model_args 79 80 @field_validator("modelcard_uid", mode="before") 81 @classmethod 82 def check_modelcard_uid(cls, modelcard_uid: str) -> str: 83 # empty strings are falsey 84 if not modelcard_uid: 85 return modelcard_uid 86 87 try: 88 UUID(modelcard_uid, version=4) # we use uuid4 89 return modelcard_uid 90 91 except ValueError as exc: 92 raise ValueError("ModelCard uid is not a valid uuid") from exc 93 94 def save_model(self, path: Path) -> None: 95 """Saves model to path. Base implementation use Joblib 96 97 Args: 98 path: 99 Pathlib object 100 """ 101 assert self.model is not None, "No model detected in interface" 102 joblib.dump(self.model, path) 103 104 def load_model(self, path: Path, **kwargs: Any) -> None: 105 """Load model from pathlib object 106 107 Args: 108 path: 109 Pathlib object 110 kwargs: 111 Additional kwargs 112 """ 113 self.model = joblib.load(path) 114 115 def save_onnx(self, path: Path) -> ModelReturn: 116 """Saves the onnx model 117 118 Args: 119 path: 120 Path to save 121 122 Returns: 123 ModelReturn 124 """ 125 import onnxruntime as rt 126 127 from opsml.model.onnx import _get_onnx_metadata 128 129 if self.onnx_model is None: 130 self.convert_to_onnx() 131 sess: rt.InferenceSession = self.onnx_model.sess 132 path.write_bytes(sess._model_bytes) # pylint: disable=protected-access 133 134 else: 135 self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value)) 136 137 assert self.onnx_model is not None, "No onnx model detected in interface" 138 metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess)) 139 140 return metadata 141 142 def convert_to_onnx(self, **kwargs: Path) -> None: 143 """Converts model to onnx format""" 144 from opsml.model.onnx import _OnnxModelConverter 145 146 if self.onnx_model is not None: 147 return None 148 149 metadata = _OnnxModelConverter(self).convert_model() 150 self.onnx_model = metadata.onnx_model 151 152 return None 153 154 def load_onnx_model(self, path: Path) -> None: 155 """Load onnx model from pathlib object 156 157 Args: 158 path: 159 Pathlib object 160 """ 161 from onnxruntime import InferenceSession 162 163 assert self.onnx_model is not None, "No onnx model detected in interface" 164 self.onnx_model.sess = InferenceSession(path) 165 166 def save_sample_data(self, path: Path) -> None: 167 """Serialized and save sample data to path. 168 169 Args: 170 path: 171 Pathlib object 172 """ 173 joblib.dump(self.sample_data, path) 174 175 def load_sample_data(self, path: Path) -> None: 176 """Serialized and save sample data to path. 177 178 Args: 179 path: 180 Pathlib object 181 """ 182 183 self.sample_data = joblib.load(path) 184 185 @classmethod 186 def _get_sample_data(cls, sample_data: Any) -> Any: 187 """Check sample data and returns one record to be used 188 during type inference and ONNX conversion/validation. 189 190 Returns: 191 Sample data with only one record 192 """ 193 194 if isinstance(sample_data, list): 195 return [data[0:1] for data in sample_data] 196 197 if isinstance(sample_data, tuple): 198 return (data[0:1] for data in sample_data) 199 200 if isinstance(sample_data, dict): 201 return {key: data[0:1] for key, data in sample_data.items()} 202 203 return sample_data[0:1] 204 205 def get_sample_prediction(self) -> SamplePrediction: 206 assert self.model is not None, "Model is not defined" 207 assert self.sample_data is not None, "Sample data must be provided" 208 209 if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)): 210 prediction = self.model.predict(self.sample_data) 211 212 elif isinstance(self.sample_data, dict): 213 try: 214 prediction = self.model.predict(**self.sample_data) 215 except Exception as _: # pylint: disable=broad-except 216 prediction = self.model.predict(self.sample_data) 217 218 elif isinstance(self.sample_data, (list, tuple)): 219 try: 220 prediction = self.model.predict(*self.sample_data) 221 except Exception as _: # pylint: disable=broad-except 222 prediction = self.model.predict(self.sample_data) 223 224 else: 225 prediction = self.model.predict(self.sample_data) 226 227 prediction_type = get_class_name(prediction) 228 229 return SamplePrediction( 230 prediction_type, 231 prediction, 232 ) 233 234 @property 235 def model_suffix(self) -> str: 236 """Returns suffix for storage""" 237 return Suffix.JOBLIB.value 238 239 @property 240 def data_suffix(self) -> str: 241 """Returns suffix for storage""" 242 return Suffix.JOBLIB.value 243 244 @staticmethod 245 def name() -> str: 246 return ModelInterface.__name__
def
get_processor_name(_class: Optional[Any] = None) -> str:
def
get_model_args(model: Any) -> Tuple[Any, str, List[str]]:
@dataclass
class
SamplePrediction:
33@dataclass 34class SamplePrediction: 35 """Dataclass that holds sample prediction information 36 37 Args: 38 prediction_type: 39 Type of prediction 40 prediction: 41 Sample prediction 42 """ 43 44 prediction_type: str 45 prediction: Any
Dataclass that holds sample prediction information
Arguments:
- prediction_type: Type of prediction
- prediction: Sample prediction
class
ModelInterface(pydantic.main.BaseModel):
48class ModelInterface(BaseModel): 49 model: Optional[Any] = None 50 sample_data: Optional[Any] = None 51 onnx_model: Optional[OnnxModel] = None 52 task_type: str = CommonKwargs.UNDEFINED.value 53 model_type: str = CommonKwargs.UNDEFINED.value 54 data_type: str = CommonKwargs.UNDEFINED.value 55 modelcard_uid: str = "" 56 57 model_config = ConfigDict( 58 protected_namespaces=("protect_",), 59 arbitrary_types_allowed=True, 60 validate_assignment=False, 61 validate_default=True, 62 extra="allow", 63 ) 64 65 @property 66 def model_class(self) -> str: 67 return CommonKwargs.UNDEFINED.value 68 69 @model_validator(mode="before") 70 @classmethod 71 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 72 if model_args.get("modelcard_uid", False): 73 return model_args 74 75 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 76 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 77 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 78 79 return model_args 80 81 @field_validator("modelcard_uid", mode="before") 82 @classmethod 83 def check_modelcard_uid(cls, modelcard_uid: str) -> str: 84 # empty strings are falsey 85 if not modelcard_uid: 86 return modelcard_uid 87 88 try: 89 UUID(modelcard_uid, version=4) # we use uuid4 90 return modelcard_uid 91 92 except ValueError as exc: 93 raise ValueError("ModelCard uid is not a valid uuid") from exc 94 95 def save_model(self, path: Path) -> None: 96 """Saves model to path. Base implementation use Joblib 97 98 Args: 99 path: 100 Pathlib object 101 """ 102 assert self.model is not None, "No model detected in interface" 103 joblib.dump(self.model, path) 104 105 def load_model(self, path: Path, **kwargs: Any) -> None: 106 """Load model from pathlib object 107 108 Args: 109 path: 110 Pathlib object 111 kwargs: 112 Additional kwargs 113 """ 114 self.model = joblib.load(path) 115 116 def save_onnx(self, path: Path) -> ModelReturn: 117 """Saves the onnx model 118 119 Args: 120 path: 121 Path to save 122 123 Returns: 124 ModelReturn 125 """ 126 import onnxruntime as rt 127 128 from opsml.model.onnx import _get_onnx_metadata 129 130 if self.onnx_model is None: 131 self.convert_to_onnx() 132 sess: rt.InferenceSession = self.onnx_model.sess 133 path.write_bytes(sess._model_bytes) # pylint: disable=protected-access 134 135 else: 136 self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value)) 137 138 assert self.onnx_model is not None, "No onnx model detected in interface" 139 metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess)) 140 141 return metadata 142 143 def convert_to_onnx(self, **kwargs: Path) -> None: 144 """Converts model to onnx format""" 145 from opsml.model.onnx import _OnnxModelConverter 146 147 if self.onnx_model is not None: 148 return None 149 150 metadata = _OnnxModelConverter(self).convert_model() 151 self.onnx_model = metadata.onnx_model 152 153 return None 154 155 def load_onnx_model(self, path: Path) -> None: 156 """Load onnx model from pathlib object 157 158 Args: 159 path: 160 Pathlib object 161 """ 162 from onnxruntime import InferenceSession 163 164 assert self.onnx_model is not None, "No onnx model detected in interface" 165 self.onnx_model.sess = InferenceSession(path) 166 167 def save_sample_data(self, path: Path) -> None: 168 """Serialized and save sample data to path. 169 170 Args: 171 path: 172 Pathlib object 173 """ 174 joblib.dump(self.sample_data, path) 175 176 def load_sample_data(self, path: Path) -> None: 177 """Serialized and save sample data to path. 178 179 Args: 180 path: 181 Pathlib object 182 """ 183 184 self.sample_data = joblib.load(path) 185 186 @classmethod 187 def _get_sample_data(cls, sample_data: Any) -> Any: 188 """Check sample data and returns one record to be used 189 during type inference and ONNX conversion/validation. 190 191 Returns: 192 Sample data with only one record 193 """ 194 195 if isinstance(sample_data, list): 196 return [data[0:1] for data in sample_data] 197 198 if isinstance(sample_data, tuple): 199 return (data[0:1] for data in sample_data) 200 201 if isinstance(sample_data, dict): 202 return {key: data[0:1] for key, data in sample_data.items()} 203 204 return sample_data[0:1] 205 206 def get_sample_prediction(self) -> SamplePrediction: 207 assert self.model is not None, "Model is not defined" 208 assert self.sample_data is not None, "Sample data must be provided" 209 210 if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)): 211 prediction = self.model.predict(self.sample_data) 212 213 elif isinstance(self.sample_data, dict): 214 try: 215 prediction = self.model.predict(**self.sample_data) 216 except Exception as _: # pylint: disable=broad-except 217 prediction = self.model.predict(self.sample_data) 218 219 elif isinstance(self.sample_data, (list, tuple)): 220 try: 221 prediction = self.model.predict(*self.sample_data) 222 except Exception as _: # pylint: disable=broad-except 223 prediction = self.model.predict(self.sample_data) 224 225 else: 226 prediction = self.model.predict(self.sample_data) 227 228 prediction_type = get_class_name(prediction) 229 230 return SamplePrediction( 231 prediction_type, 232 prediction, 233 ) 234 235 @property 236 def model_suffix(self) -> str: 237 """Returns suffix for storage""" 238 return Suffix.JOBLIB.value 239 240 @property 241 def data_suffix(self) -> str: 242 """Returns suffix for storage""" 243 return Suffix.JOBLIB.value 244 245 @staticmethod 246 def name() -> str: 247 return ModelInterface.__name__
Usage docs: https://docs.pydantic.dev/2.6/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of classvars defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The signature for instantiating the model.
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a
RootModel
. - __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_extra__: An instance attribute with the values of extra fields from validation when
model_config['extra'] == 'allow'
. - __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
- __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
model_config =
{'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
@model_validator(mode='before')
@classmethod
def
check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
69 @model_validator(mode="before") 70 @classmethod 71 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 72 if model_args.get("modelcard_uid", False): 73 return model_args 74 75 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 76 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 77 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 78 79 return model_args
@field_validator('modelcard_uid', mode='before')
@classmethod
def
check_modelcard_uid(cls, modelcard_uid: str) -> str:
81 @field_validator("modelcard_uid", mode="before") 82 @classmethod 83 def check_modelcard_uid(cls, modelcard_uid: str) -> str: 84 # empty strings are falsey 85 if not modelcard_uid: 86 return modelcard_uid 87 88 try: 89 UUID(modelcard_uid, version=4) # we use uuid4 90 return modelcard_uid 91 92 except ValueError as exc: 93 raise ValueError("ModelCard uid is not a valid uuid") from exc
def
save_model(self, path: pathlib.Path) -> None:
95 def save_model(self, path: Path) -> None: 96 """Saves model to path. Base implementation use Joblib 97 98 Args: 99 path: 100 Pathlib object 101 """ 102 assert self.model is not None, "No model detected in interface" 103 joblib.dump(self.model, path)
Saves model to path. Base implementation use Joblib
Arguments:
- path: Pathlib object
def
load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
105 def load_model(self, path: Path, **kwargs: Any) -> None: 106 """Load model from pathlib object 107 108 Args: 109 path: 110 Pathlib object 111 kwargs: 112 Additional kwargs 113 """ 114 self.model = joblib.load(path)
Load model from pathlib object
Arguments:
- path: Pathlib object
- kwargs: Additional kwargs
def
save_onnx(self, path: pathlib.Path) -> opsml.types.model.ModelReturn:
116 def save_onnx(self, path: Path) -> ModelReturn: 117 """Saves the onnx model 118 119 Args: 120 path: 121 Path to save 122 123 Returns: 124 ModelReturn 125 """ 126 import onnxruntime as rt 127 128 from opsml.model.onnx import _get_onnx_metadata 129 130 if self.onnx_model is None: 131 self.convert_to_onnx() 132 sess: rt.InferenceSession = self.onnx_model.sess 133 path.write_bytes(sess._model_bytes) # pylint: disable=protected-access 134 135 else: 136 self.onnx_model.sess_to_path(path.with_suffix(Suffix.ONNX.value)) 137 138 assert self.onnx_model is not None, "No onnx model detected in interface" 139 metadata = _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess)) 140 141 return metadata
Saves the onnx model
Arguments:
- path: Path to save
Returns:
ModelReturn
def
convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
143 def convert_to_onnx(self, **kwargs: Path) -> None: 144 """Converts model to onnx format""" 145 from opsml.model.onnx import _OnnxModelConverter 146 147 if self.onnx_model is not None: 148 return None 149 150 metadata = _OnnxModelConverter(self).convert_model() 151 self.onnx_model = metadata.onnx_model 152 153 return None
Converts model to onnx format
def
load_onnx_model(self, path: pathlib.Path) -> None:
155 def load_onnx_model(self, path: Path) -> None: 156 """Load onnx model from pathlib object 157 158 Args: 159 path: 160 Pathlib object 161 """ 162 from onnxruntime import InferenceSession 163 164 assert self.onnx_model is not None, "No onnx model detected in interface" 165 self.onnx_model.sess = InferenceSession(path)
Load onnx model from pathlib object
Arguments:
- path: Pathlib object
def
save_sample_data(self, path: pathlib.Path) -> None:
167 def save_sample_data(self, path: Path) -> None: 168 """Serialized and save sample data to path. 169 170 Args: 171 path: 172 Pathlib object 173 """ 174 joblib.dump(self.sample_data, path)
Serialized and save sample data to path.
Arguments:
- path: Pathlib object
def
load_sample_data(self, path: pathlib.Path) -> None:
176 def load_sample_data(self, path: Path) -> None: 177 """Serialized and save sample data to path. 178 179 Args: 180 path: 181 Pathlib object 182 """ 183 184 self.sample_data = joblib.load(path)
Serialized and save sample data to path.
Arguments:
- path: Pathlib object
206 def get_sample_prediction(self) -> SamplePrediction: 207 assert self.model is not None, "Model is not defined" 208 assert self.sample_data is not None, "Sample data must be provided" 209 210 if isinstance(self.sample_data, (pd.DataFrame, np.ndarray)): 211 prediction = self.model.predict(self.sample_data) 212 213 elif isinstance(self.sample_data, dict): 214 try: 215 prediction = self.model.predict(**self.sample_data) 216 except Exception as _: # pylint: disable=broad-except 217 prediction = self.model.predict(self.sample_data) 218 219 elif isinstance(self.sample_data, (list, tuple)): 220 try: 221 prediction = self.model.predict(*self.sample_data) 222 except Exception as _: # pylint: disable=broad-except 223 prediction = self.model.predict(self.sample_data) 224 225 else: 226 prediction = self.model.predict(self.sample_data) 227 228 prediction_type = get_class_name(prediction) 229 230 return SamplePrediction( 231 prediction_type, 232 prediction, 233 )
model_suffix: str
235 @property 236 def model_suffix(self) -> str: 237 """Returns suffix for storage""" 238 return Suffix.JOBLIB.value
Returns suffix for storage
data_suffix: str
240 @property 241 def data_suffix(self) -> str: 242 """Returns suffix for storage""" 243 return Suffix.JOBLIB.value
Returns suffix for storage
model_fields =
{'model': FieldInfo(annotation=Union[Any, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Any, NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default='')}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs