opsml.model.interfaces.pytorch

  1import tempfile
  2from pathlib import Path
  3from typing import Any, Dict, List, Optional, Tuple, Union, cast
  4
  5import joblib
  6from pydantic import model_validator
  7
  8from opsml.helpers.utils import OpsmlImportExceptions, get_class_name
  9from opsml.model.interfaces.base import (
 10    ModelInterface,
 11    SamplePrediction,
 12    get_model_args,
 13    get_processor_name,
 14)
 15from opsml.types import (
 16    CommonKwargs,
 17    ModelReturn,
 18    SaveName,
 19    Suffix,
 20    TorchOnnxArgs,
 21    TorchSaveArgs,
 22    TrainedModelType,
 23)
 24
 25try:
 26    import torch
 27
 28    ValidData = Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]
 29
 30    class TorchModel(ModelInterface):
 31        """Model interface for Pytorch models.
 32
 33        Args:
 34            model:
 35                Torch model
 36            preprocessor:
 37                Optional preprocessor
 38            sample_data:
 39                Sample data to be used for type inference and ONNX conversion/validation.
 40                This should match exactly what the model expects as input.
 41            save_args:
 42                Optional arguments for saving model. See `TorchSaveArgs` for supported arguments.
 43            task_type:
 44                Task type for model. Defaults to undefined.
 45            model_type:
 46                Optional model type. This is inferred automatically.
 47            preprocessor_name:
 48                Optional preprocessor. This is inferred automatically if a
 49                preprocessor is provided.
 50            onnx_args:
 51                Optional arguments for ONNX conversion. See `TorchOnnxArgs` for supported arguments.
 52
 53        Returns:
 54        TorchModel
 55        """
 56
 57        model: Optional[torch.nn.Module] = None
 58        sample_data: Optional[
 59            Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]
 60        ] = None
 61        onnx_args: Optional[TorchOnnxArgs] = None
 62        save_args: TorchSaveArgs = TorchSaveArgs()
 63        preprocessor: Optional[Any] = None
 64        preprocessor_name: str = CommonKwargs.UNDEFINED.value
 65
 66        @property
 67        def model_class(self) -> str:
 68            return TrainedModelType.PYTORCH.value
 69
 70        @classmethod
 71        def _get_sample_data(cls, sample_data: Any) -> Any:
 72            """Check sample data and returns one record to be used
 73            during type inference and ONNX conversion/validation.
 74
 75            Returns:
 76                Sample data with only one record
 77            """
 78            if isinstance(sample_data, torch.Tensor):
 79                return sample_data[0:1]
 80
 81            if isinstance(sample_data, list):
 82                return [data[0:1] for data in sample_data]
 83
 84            if isinstance(sample_data, tuple):
 85                return tuple(data[0:1] for data in sample_data)
 86
 87            if isinstance(sample_data, dict):
 88                sample_dict = {}
 89                for key, value in sample_data.items():
 90                    sample_dict[key] = value[0:1]
 91                return sample_dict
 92
 93            raise ValueError("Provided sample data is not a valid type")
 94
 95        @model_validator(mode="before")
 96        @classmethod
 97        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 98            model = model_args.get("model")
 99
100            if model_args.get("modelcard_uid", False):
101                return model_args
102
103            model, _, bases = get_model_args(model)
104
105            for base in bases:
106                if "torch" in base:
107                    model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__
108
109            sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value])
110            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
111            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
112            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
113                model_args.get(CommonKwargs.PREPROCESSOR.value),
114            )
115
116            return model_args
117
118        def get_sample_prediction(self) -> SamplePrediction:
119            assert self.model is not None, "Model is not defined"
120            assert self.sample_data is not None, "Sample data must be provided"
121
122            # test dict input
123            if isinstance(self.sample_data, dict):
124                try:
125                    prediction = self.model(**self.sample_data)
126                except Exception as _:  # pylint: disable=broad-except
127                    prediction = self.model(self.sample_data)
128
129            # test list and tuple inputs
130            elif isinstance(self.sample_data, (list, tuple)):
131                try:
132                    prediction = self.model(*self.sample_data)
133                except Exception as _:  # pylint: disable=broad-except
134                    prediction = self.model(self.sample_data)
135
136            # all others
137            else:
138                prediction = self.model(self.sample_data)
139
140            prediction_type = get_class_name(prediction)
141
142            return SamplePrediction(prediction_type, prediction)
143
144        def save_model(self, path: Path) -> None:
145            """Save pytorch model to path
146
147            Args:
148                path:
149                    pathlib object
150            """
151            assert self.model is not None, "No model found"
152
153            if self.save_args.as_state_dict:
154                torch.save(self.model.state_dict(), path)
155            else:
156                torch.save(self.model, path)
157
158        def load_model(self, path: Path, **kwargs: Any) -> None:
159            """Load pytorch model from path
160
161            Args:
162                path:
163                    pathlib object
164                kwargs:
165                    Additional arguments to be passed to torch.load
166            """
167            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
168
169            if model_arch is not None:
170                model_arch.load_state_dict(torch.load(path))
171                model_arch.eval()
172                self.model = model_arch
173
174            else:
175                self.model = torch.load(path)
176
177        def save_onnx(self, path: Path) -> ModelReturn:
178            """Saves an onnx model
179
180            Args:
181                path:
182                    Path to save model to
183
184            Returns:
185                ModelReturn
186            """
187            import onnxruntime as rt
188
189            from opsml.model.onnx import _get_onnx_metadata
190
191            if self.onnx_model is None:
192                self.convert_to_onnx(**{"path": path})
193
194            else:
195                # save onnx model
196                self.onnx_model.sess_to_path(path)
197
198            # no need to save onnx to bytes since its done during onnx conversion
199            assert self.onnx_model is not None, "No onnx model detected in interface"
200            return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
201
202        def _convert_to_onnx_inplace(self) -> None:
203            """Convert to onnx model using temp dir"""
204            with tempfile.TemporaryDirectory() as tmpdir:
205                lpath = Path(tmpdir) / SaveName.ONNX_MODEL.value
206                onnx_path = lpath.with_suffix(Suffix.ONNX.value)
207                self.convert_to_onnx(**{"path": onnx_path})
208
209        def convert_to_onnx(self, **kwargs: Path) -> None:
210            # import packages for onnx conversion
211            OpsmlImportExceptions.try_torchonnx_imports()
212            if self.onnx_model is not None:
213                return None
214
215            from opsml.model.onnx.torch_converter import _PyTorchOnnxModel
216
217            path: Optional[Path] = kwargs.get("path")
218
219            if path is None:
220                return self._convert_to_onnx_inplace()
221
222            self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path)
223            return None
224
225        def save_preprocessor(self, path: Path) -> None:
226            """Saves preprocessor to path if present. Base implementation use Joblib
227
228            Args:
229                path:
230                    Pathlib object
231            """
232            assert self.preprocessor is not None, "No preprocessor detected in interface"
233            joblib.dump(self.preprocessor, path)
234
235        def load_preprocessor(self, path: Path) -> None:
236            """Load preprocessor from pathlib object
237
238            Args:
239                path:
240                    Pathlib object
241            """
242            self.preprocessor = joblib.load(path)
243
244        @property
245        def preprocessor_suffix(self) -> str:
246            """Returns suffix for storage"""
247            return Suffix.JOBLIB.value
248
249        @property
250        def model_suffix(self) -> str:
251            """Returns suffix for storage"""
252            return Suffix.PT.value
253
254        @staticmethod
255        def name() -> str:
256            return TorchModel.__name__
257
258except ModuleNotFoundError:
259    from opsml.model.interfaces.backups import TorchModelNoModule as TorchModel
class TorchModel(opsml.model.interfaces.base.ModelInterface):
 31    class TorchModel(ModelInterface):
 32        """Model interface for Pytorch models.
 33
 34        Args:
 35            model:
 36                Torch model
 37            preprocessor:
 38                Optional preprocessor
 39            sample_data:
 40                Sample data to be used for type inference and ONNX conversion/validation.
 41                This should match exactly what the model expects as input.
 42            save_args:
 43                Optional arguments for saving model. See `TorchSaveArgs` for supported arguments.
 44            task_type:
 45                Task type for model. Defaults to undefined.
 46            model_type:
 47                Optional model type. This is inferred automatically.
 48            preprocessor_name:
 49                Optional preprocessor. This is inferred automatically if a
 50                preprocessor is provided.
 51            onnx_args:
 52                Optional arguments for ONNX conversion. See `TorchOnnxArgs` for supported arguments.
 53
 54        Returns:
 55        TorchModel
 56        """
 57
 58        model: Optional[torch.nn.Module] = None
 59        sample_data: Optional[
 60            Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]
 61        ] = None
 62        onnx_args: Optional[TorchOnnxArgs] = None
 63        save_args: TorchSaveArgs = TorchSaveArgs()
 64        preprocessor: Optional[Any] = None
 65        preprocessor_name: str = CommonKwargs.UNDEFINED.value
 66
 67        @property
 68        def model_class(self) -> str:
 69            return TrainedModelType.PYTORCH.value
 70
 71        @classmethod
 72        def _get_sample_data(cls, sample_data: Any) -> Any:
 73            """Check sample data and returns one record to be used
 74            during type inference and ONNX conversion/validation.
 75
 76            Returns:
 77                Sample data with only one record
 78            """
 79            if isinstance(sample_data, torch.Tensor):
 80                return sample_data[0:1]
 81
 82            if isinstance(sample_data, list):
 83                return [data[0:1] for data in sample_data]
 84
 85            if isinstance(sample_data, tuple):
 86                return tuple(data[0:1] for data in sample_data)
 87
 88            if isinstance(sample_data, dict):
 89                sample_dict = {}
 90                for key, value in sample_data.items():
 91                    sample_dict[key] = value[0:1]
 92                return sample_dict
 93
 94            raise ValueError("Provided sample data is not a valid type")
 95
 96        @model_validator(mode="before")
 97        @classmethod
 98        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 99            model = model_args.get("model")
100
101            if model_args.get("modelcard_uid", False):
102                return model_args
103
104            model, _, bases = get_model_args(model)
105
106            for base in bases:
107                if "torch" in base:
108                    model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__
109
110            sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value])
111            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
112            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
113            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
114                model_args.get(CommonKwargs.PREPROCESSOR.value),
115            )
116
117            return model_args
118
119        def get_sample_prediction(self) -> SamplePrediction:
120            assert self.model is not None, "Model is not defined"
121            assert self.sample_data is not None, "Sample data must be provided"
122
123            # test dict input
124            if isinstance(self.sample_data, dict):
125                try:
126                    prediction = self.model(**self.sample_data)
127                except Exception as _:  # pylint: disable=broad-except
128                    prediction = self.model(self.sample_data)
129
130            # test list and tuple inputs
131            elif isinstance(self.sample_data, (list, tuple)):
132                try:
133                    prediction = self.model(*self.sample_data)
134                except Exception as _:  # pylint: disable=broad-except
135                    prediction = self.model(self.sample_data)
136
137            # all others
138            else:
139                prediction = self.model(self.sample_data)
140
141            prediction_type = get_class_name(prediction)
142
143            return SamplePrediction(prediction_type, prediction)
144
145        def save_model(self, path: Path) -> None:
146            """Save pytorch model to path
147
148            Args:
149                path:
150                    pathlib object
151            """
152            assert self.model is not None, "No model found"
153
154            if self.save_args.as_state_dict:
155                torch.save(self.model.state_dict(), path)
156            else:
157                torch.save(self.model, path)
158
159        def load_model(self, path: Path, **kwargs: Any) -> None:
160            """Load pytorch model from path
161
162            Args:
163                path:
164                    pathlib object
165                kwargs:
166                    Additional arguments to be passed to torch.load
167            """
168            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
169
170            if model_arch is not None:
171                model_arch.load_state_dict(torch.load(path))
172                model_arch.eval()
173                self.model = model_arch
174
175            else:
176                self.model = torch.load(path)
177
178        def save_onnx(self, path: Path) -> ModelReturn:
179            """Saves an onnx model
180
181            Args:
182                path:
183                    Path to save model to
184
185            Returns:
186                ModelReturn
187            """
188            import onnxruntime as rt
189
190            from opsml.model.onnx import _get_onnx_metadata
191
192            if self.onnx_model is None:
193                self.convert_to_onnx(**{"path": path})
194
195            else:
196                # save onnx model
197                self.onnx_model.sess_to_path(path)
198
199            # no need to save onnx to bytes since its done during onnx conversion
200            assert self.onnx_model is not None, "No onnx model detected in interface"
201            return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
202
203        def _convert_to_onnx_inplace(self) -> None:
204            """Convert to onnx model using temp dir"""
205            with tempfile.TemporaryDirectory() as tmpdir:
206                lpath = Path(tmpdir) / SaveName.ONNX_MODEL.value
207                onnx_path = lpath.with_suffix(Suffix.ONNX.value)
208                self.convert_to_onnx(**{"path": onnx_path})
209
210        def convert_to_onnx(self, **kwargs: Path) -> None:
211            # import packages for onnx conversion
212            OpsmlImportExceptions.try_torchonnx_imports()
213            if self.onnx_model is not None:
214                return None
215
216            from opsml.model.onnx.torch_converter import _PyTorchOnnxModel
217
218            path: Optional[Path] = kwargs.get("path")
219
220            if path is None:
221                return self._convert_to_onnx_inplace()
222
223            self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path)
224            return None
225
226        def save_preprocessor(self, path: Path) -> None:
227            """Saves preprocessor to path if present. Base implementation use Joblib
228
229            Args:
230                path:
231                    Pathlib object
232            """
233            assert self.preprocessor is not None, "No preprocessor detected in interface"
234            joblib.dump(self.preprocessor, path)
235
236        def load_preprocessor(self, path: Path) -> None:
237            """Load preprocessor from pathlib object
238
239            Args:
240                path:
241                    Pathlib object
242            """
243            self.preprocessor = joblib.load(path)
244
245        @property
246        def preprocessor_suffix(self) -> str:
247            """Returns suffix for storage"""
248            return Suffix.JOBLIB.value
249
250        @property
251        def model_suffix(self) -> str:
252            """Returns suffix for storage"""
253            return Suffix.PT.value
254
255        @staticmethod
256        def name() -> str:
257            return TorchModel.__name__

Model interface for Pytorch models.

Arguments:
  • model: Torch model
  • preprocessor: Optional preprocessor
  • sample_data: Sample data to be used for type inference and ONNX conversion/validation. This should match exactly what the model expects as input.
  • save_args: Optional arguments for saving model. See TorchSaveArgs for supported arguments.
  • task_type: Task type for model. Defaults to undefined.
  • model_type: Optional model type. This is inferred automatically.
  • preprocessor_name: Optional preprocessor. This is inferred automatically if a preprocessor is provided.
  • onnx_args: Optional arguments for ONNX conversion. See TorchOnnxArgs for supported arguments.

Returns: TorchModel

model: Optional[torch.nn.modules.module.Module]
sample_data: Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor], NoneType]
onnx_args: Optional[opsml.types.model.TorchOnnxArgs]
save_args: opsml.types.model.TorchSaveArgs
preprocessor: Optional[Any]
preprocessor_name: str
model_class: str
67        @property
68        def model_class(self) -> str:
69            return TrainedModelType.PYTORCH.value
@model_validator(mode='before')
@classmethod
def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 96        @model_validator(mode="before")
 97        @classmethod
 98        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 99            model = model_args.get("model")
100
101            if model_args.get("modelcard_uid", False):
102                return model_args
103
104            model, _, bases = get_model_args(model)
105
106            for base in bases:
107                if "torch" in base:
108                    model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__
109
110            sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value])
111            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
112            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
113            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
114                model_args.get(CommonKwargs.PREPROCESSOR.value),
115            )
116
117            return model_args
def get_sample_prediction(self) -> opsml.model.interfaces.base.SamplePrediction:
119        def get_sample_prediction(self) -> SamplePrediction:
120            assert self.model is not None, "Model is not defined"
121            assert self.sample_data is not None, "Sample data must be provided"
122
123            # test dict input
124            if isinstance(self.sample_data, dict):
125                try:
126                    prediction = self.model(**self.sample_data)
127                except Exception as _:  # pylint: disable=broad-except
128                    prediction = self.model(self.sample_data)
129
130            # test list and tuple inputs
131            elif isinstance(self.sample_data, (list, tuple)):
132                try:
133                    prediction = self.model(*self.sample_data)
134                except Exception as _:  # pylint: disable=broad-except
135                    prediction = self.model(self.sample_data)
136
137            # all others
138            else:
139                prediction = self.model(self.sample_data)
140
141            prediction_type = get_class_name(prediction)
142
143            return SamplePrediction(prediction_type, prediction)
def save_model(self, path: pathlib.Path) -> None:
145        def save_model(self, path: Path) -> None:
146            """Save pytorch model to path
147
148            Args:
149                path:
150                    pathlib object
151            """
152            assert self.model is not None, "No model found"
153
154            if self.save_args.as_state_dict:
155                torch.save(self.model.state_dict(), path)
156            else:
157                torch.save(self.model, path)

Save pytorch model to path

Arguments:
  • path: pathlib object
def load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
159        def load_model(self, path: Path, **kwargs: Any) -> None:
160            """Load pytorch model from path
161
162            Args:
163                path:
164                    pathlib object
165                kwargs:
166                    Additional arguments to be passed to torch.load
167            """
168            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
169
170            if model_arch is not None:
171                model_arch.load_state_dict(torch.load(path))
172                model_arch.eval()
173                self.model = model_arch
174
175            else:
176                self.model = torch.load(path)

Load pytorch model from path

Arguments:
  • path: pathlib object
  • kwargs: Additional arguments to be passed to torch.load
def save_onnx(self, path: pathlib.Path) -> opsml.types.model.ModelReturn:
178        def save_onnx(self, path: Path) -> ModelReturn:
179            """Saves an onnx model
180
181            Args:
182                path:
183                    Path to save model to
184
185            Returns:
186                ModelReturn
187            """
188            import onnxruntime as rt
189
190            from opsml.model.onnx import _get_onnx_metadata
191
192            if self.onnx_model is None:
193                self.convert_to_onnx(**{"path": path})
194
195            else:
196                # save onnx model
197                self.onnx_model.sess_to_path(path)
198
199            # no need to save onnx to bytes since its done during onnx conversion
200            assert self.onnx_model is not None, "No onnx model detected in interface"
201            return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))

Saves an onnx model

Arguments:
  • path: Path to save model to
Returns:

ModelReturn

def convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
210        def convert_to_onnx(self, **kwargs: Path) -> None:
211            # import packages for onnx conversion
212            OpsmlImportExceptions.try_torchonnx_imports()
213            if self.onnx_model is not None:
214                return None
215
216            from opsml.model.onnx.torch_converter import _PyTorchOnnxModel
217
218            path: Optional[Path] = kwargs.get("path")
219
220            if path is None:
221                return self._convert_to_onnx_inplace()
222
223            self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path)
224            return None

Converts model to onnx format

def save_preprocessor(self, path: pathlib.Path) -> None:
226        def save_preprocessor(self, path: Path) -> None:
227            """Saves preprocessor to path if present. Base implementation use Joblib
228
229            Args:
230                path:
231                    Pathlib object
232            """
233            assert self.preprocessor is not None, "No preprocessor detected in interface"
234            joblib.dump(self.preprocessor, path)

Saves preprocessor to path if present. Base implementation use Joblib

Arguments:
  • path: Pathlib object
def load_preprocessor(self, path: pathlib.Path) -> None:
236        def load_preprocessor(self, path: Path) -> None:
237            """Load preprocessor from pathlib object
238
239            Args:
240                path:
241                    Pathlib object
242            """
243            self.preprocessor = joblib.load(path)

Load preprocessor from pathlib object

Arguments:
  • path: Pathlib object
preprocessor_suffix: str
245        @property
246        def preprocessor_suffix(self) -> str:
247            """Returns suffix for storage"""
248            return Suffix.JOBLIB.value

Returns suffix for storage

model_suffix: str
250        @property
251        def model_suffix(self) -> str:
252            """Returns suffix for storage"""
253            return Suffix.PT.value

Returns suffix for storage

@staticmethod
def name() -> str:
255        @staticmethod
256        def name() -> str:
257            return TorchModel.__name__
model_config = {'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_fields = {'model': FieldInfo(annotation=Union[Module, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Tensor, Dict[str, Tensor], List[Tensor], Tuple[Tensor], NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default=''), 'onnx_args': FieldInfo(annotation=Union[TorchOnnxArgs, NoneType], required=False), 'save_args': FieldInfo(annotation=TorchSaveArgs, required=False, default=TorchSaveArgs(as_state_dict=False)), 'preprocessor': FieldInfo(annotation=Union[Any, NoneType], required=False), 'preprocessor_name': FieldInfo(annotation=str, required=False, default='undefined')}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
opsml.model.interfaces.base.ModelInterface
onnx_model
task_type
model_type
data_type
modelcard_uid
check_modelcard_uid
load_onnx_model
save_sample_data
load_sample_data
data_suffix