opsml.model.interfaces.pytorch
1import tempfile 2from pathlib import Path 3from typing import Any, Dict, List, Optional, Tuple, Union, cast 4 5import joblib 6from pydantic import model_validator 7 8from opsml.helpers.utils import OpsmlImportExceptions, get_class_name 9from opsml.model.interfaces.base import ( 10 ModelInterface, 11 SamplePrediction, 12 get_model_args, 13 get_processor_name, 14) 15from opsml.types import ( 16 CommonKwargs, 17 ModelReturn, 18 SaveName, 19 Suffix, 20 TorchOnnxArgs, 21 TorchSaveArgs, 22 TrainedModelType, 23) 24 25try: 26 import torch 27 28 ValidData = Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]] 29 30 class TorchModel(ModelInterface): 31 """Model interface for Pytorch models. 32 33 Args: 34 model: 35 Torch model 36 preprocessor: 37 Optional preprocessor 38 sample_data: 39 Sample data to be used for type inference and ONNX conversion/validation. 40 This should match exactly what the model expects as input. 41 save_args: 42 Optional arguments for saving model. See `TorchSaveArgs` for supported arguments. 43 task_type: 44 Task type for model. Defaults to undefined. 45 model_type: 46 Optional model type. This is inferred automatically. 47 preprocessor_name: 48 Optional preprocessor. This is inferred automatically if a 49 preprocessor is provided. 50 onnx_args: 51 Optional arguments for ONNX conversion. See `TorchOnnxArgs` for supported arguments. 52 53 Returns: 54 TorchModel 55 """ 56 57 model: Optional[torch.nn.Module] = None 58 sample_data: Optional[ 59 Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]] 60 ] = None 61 onnx_args: Optional[TorchOnnxArgs] = None 62 save_args: TorchSaveArgs = TorchSaveArgs() 63 preprocessor: Optional[Any] = None 64 preprocessor_name: str = CommonKwargs.UNDEFINED.value 65 66 @property 67 def model_class(self) -> str: 68 return TrainedModelType.PYTORCH.value 69 70 @classmethod 71 def _get_sample_data(cls, sample_data: Any) -> Any: 72 """Check sample data and returns one record to be used 73 during type inference and ONNX conversion/validation. 74 75 Returns: 76 Sample data with only one record 77 """ 78 if isinstance(sample_data, torch.Tensor): 79 return sample_data[0:1] 80 81 if isinstance(sample_data, list): 82 return [data[0:1] for data in sample_data] 83 84 if isinstance(sample_data, tuple): 85 return tuple(data[0:1] for data in sample_data) 86 87 if isinstance(sample_data, dict): 88 sample_dict = {} 89 for key, value in sample_data.items(): 90 sample_dict[key] = value[0:1] 91 return sample_dict 92 93 raise ValueError("Provided sample data is not a valid type") 94 95 @model_validator(mode="before") 96 @classmethod 97 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 98 model = model_args.get("model") 99 100 if model_args.get("modelcard_uid", False): 101 return model_args 102 103 model, _, bases = get_model_args(model) 104 105 for base in bases: 106 if "torch" in base: 107 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 108 109 sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value]) 110 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 111 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 112 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 113 model_args.get(CommonKwargs.PREPROCESSOR.value), 114 ) 115 116 return model_args 117 118 def get_sample_prediction(self) -> SamplePrediction: 119 assert self.model is not None, "Model is not defined" 120 assert self.sample_data is not None, "Sample data must be provided" 121 122 # test dict input 123 if isinstance(self.sample_data, dict): 124 try: 125 prediction = self.model(**self.sample_data) 126 except Exception as _: # pylint: disable=broad-except 127 prediction = self.model(self.sample_data) 128 129 # test list and tuple inputs 130 elif isinstance(self.sample_data, (list, tuple)): 131 try: 132 prediction = self.model(*self.sample_data) 133 except Exception as _: # pylint: disable=broad-except 134 prediction = self.model(self.sample_data) 135 136 # all others 137 else: 138 prediction = self.model(self.sample_data) 139 140 prediction_type = get_class_name(prediction) 141 142 return SamplePrediction(prediction_type, prediction) 143 144 def save_model(self, path: Path) -> None: 145 """Save pytorch model to path 146 147 Args: 148 path: 149 pathlib object 150 """ 151 assert self.model is not None, "No model found" 152 153 if self.save_args.as_state_dict: 154 torch.save(self.model.state_dict(), path) 155 else: 156 torch.save(self.model, path) 157 158 def load_model(self, path: Path, **kwargs: Any) -> None: 159 """Load pytorch model from path 160 161 Args: 162 path: 163 pathlib object 164 kwargs: 165 Additional arguments to be passed to torch.load 166 """ 167 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 168 169 if model_arch is not None: 170 model_arch.load_state_dict(torch.load(path)) 171 model_arch.eval() 172 self.model = model_arch 173 174 else: 175 self.model = torch.load(path) 176 177 def save_onnx(self, path: Path) -> ModelReturn: 178 """Saves an onnx model 179 180 Args: 181 path: 182 Path to save model to 183 184 Returns: 185 ModelReturn 186 """ 187 import onnxruntime as rt 188 189 from opsml.model.onnx import _get_onnx_metadata 190 191 if self.onnx_model is None: 192 self.convert_to_onnx(**{"path": path}) 193 194 else: 195 # save onnx model 196 self.onnx_model.sess_to_path(path) 197 198 # no need to save onnx to bytes since its done during onnx conversion 199 assert self.onnx_model is not None, "No onnx model detected in interface" 200 return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess)) 201 202 def _convert_to_onnx_inplace(self) -> None: 203 """Convert to onnx model using temp dir""" 204 with tempfile.TemporaryDirectory() as tmpdir: 205 lpath = Path(tmpdir) / SaveName.ONNX_MODEL.value 206 onnx_path = lpath.with_suffix(Suffix.ONNX.value) 207 self.convert_to_onnx(**{"path": onnx_path}) 208 209 def convert_to_onnx(self, **kwargs: Path) -> None: 210 # import packages for onnx conversion 211 OpsmlImportExceptions.try_torchonnx_imports() 212 if self.onnx_model is not None: 213 return None 214 215 from opsml.model.onnx.torch_converter import _PyTorchOnnxModel 216 217 path: Optional[Path] = kwargs.get("path") 218 219 if path is None: 220 return self._convert_to_onnx_inplace() 221 222 self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path) 223 return None 224 225 def save_preprocessor(self, path: Path) -> None: 226 """Saves preprocessor to path if present. Base implementation use Joblib 227 228 Args: 229 path: 230 Pathlib object 231 """ 232 assert self.preprocessor is not None, "No preprocessor detected in interface" 233 joblib.dump(self.preprocessor, path) 234 235 def load_preprocessor(self, path: Path) -> None: 236 """Load preprocessor from pathlib object 237 238 Args: 239 path: 240 Pathlib object 241 """ 242 self.preprocessor = joblib.load(path) 243 244 @property 245 def preprocessor_suffix(self) -> str: 246 """Returns suffix for storage""" 247 return Suffix.JOBLIB.value 248 249 @property 250 def model_suffix(self) -> str: 251 """Returns suffix for storage""" 252 return Suffix.PT.value 253 254 @staticmethod 255 def name() -> str: 256 return TorchModel.__name__ 257 258except ModuleNotFoundError: 259 from opsml.model.interfaces.backups import TorchModelNoModule as TorchModel
31 class TorchModel(ModelInterface): 32 """Model interface for Pytorch models. 33 34 Args: 35 model: 36 Torch model 37 preprocessor: 38 Optional preprocessor 39 sample_data: 40 Sample data to be used for type inference and ONNX conversion/validation. 41 This should match exactly what the model expects as input. 42 save_args: 43 Optional arguments for saving model. See `TorchSaveArgs` for supported arguments. 44 task_type: 45 Task type for model. Defaults to undefined. 46 model_type: 47 Optional model type. This is inferred automatically. 48 preprocessor_name: 49 Optional preprocessor. This is inferred automatically if a 50 preprocessor is provided. 51 onnx_args: 52 Optional arguments for ONNX conversion. See `TorchOnnxArgs` for supported arguments. 53 54 Returns: 55 TorchModel 56 """ 57 58 model: Optional[torch.nn.Module] = None 59 sample_data: Optional[ 60 Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]] 61 ] = None 62 onnx_args: Optional[TorchOnnxArgs] = None 63 save_args: TorchSaveArgs = TorchSaveArgs() 64 preprocessor: Optional[Any] = None 65 preprocessor_name: str = CommonKwargs.UNDEFINED.value 66 67 @property 68 def model_class(self) -> str: 69 return TrainedModelType.PYTORCH.value 70 71 @classmethod 72 def _get_sample_data(cls, sample_data: Any) -> Any: 73 """Check sample data and returns one record to be used 74 during type inference and ONNX conversion/validation. 75 76 Returns: 77 Sample data with only one record 78 """ 79 if isinstance(sample_data, torch.Tensor): 80 return sample_data[0:1] 81 82 if isinstance(sample_data, list): 83 return [data[0:1] for data in sample_data] 84 85 if isinstance(sample_data, tuple): 86 return tuple(data[0:1] for data in sample_data) 87 88 if isinstance(sample_data, dict): 89 sample_dict = {} 90 for key, value in sample_data.items(): 91 sample_dict[key] = value[0:1] 92 return sample_dict 93 94 raise ValueError("Provided sample data is not a valid type") 95 96 @model_validator(mode="before") 97 @classmethod 98 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 99 model = model_args.get("model") 100 101 if model_args.get("modelcard_uid", False): 102 return model_args 103 104 model, _, bases = get_model_args(model) 105 106 for base in bases: 107 if "torch" in base: 108 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 109 110 sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value]) 111 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 112 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 113 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 114 model_args.get(CommonKwargs.PREPROCESSOR.value), 115 ) 116 117 return model_args 118 119 def get_sample_prediction(self) -> SamplePrediction: 120 assert self.model is not None, "Model is not defined" 121 assert self.sample_data is not None, "Sample data must be provided" 122 123 # test dict input 124 if isinstance(self.sample_data, dict): 125 try: 126 prediction = self.model(**self.sample_data) 127 except Exception as _: # pylint: disable=broad-except 128 prediction = self.model(self.sample_data) 129 130 # test list and tuple inputs 131 elif isinstance(self.sample_data, (list, tuple)): 132 try: 133 prediction = self.model(*self.sample_data) 134 except Exception as _: # pylint: disable=broad-except 135 prediction = self.model(self.sample_data) 136 137 # all others 138 else: 139 prediction = self.model(self.sample_data) 140 141 prediction_type = get_class_name(prediction) 142 143 return SamplePrediction(prediction_type, prediction) 144 145 def save_model(self, path: Path) -> None: 146 """Save pytorch model to path 147 148 Args: 149 path: 150 pathlib object 151 """ 152 assert self.model is not None, "No model found" 153 154 if self.save_args.as_state_dict: 155 torch.save(self.model.state_dict(), path) 156 else: 157 torch.save(self.model, path) 158 159 def load_model(self, path: Path, **kwargs: Any) -> None: 160 """Load pytorch model from path 161 162 Args: 163 path: 164 pathlib object 165 kwargs: 166 Additional arguments to be passed to torch.load 167 """ 168 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 169 170 if model_arch is not None: 171 model_arch.load_state_dict(torch.load(path)) 172 model_arch.eval() 173 self.model = model_arch 174 175 else: 176 self.model = torch.load(path) 177 178 def save_onnx(self, path: Path) -> ModelReturn: 179 """Saves an onnx model 180 181 Args: 182 path: 183 Path to save model to 184 185 Returns: 186 ModelReturn 187 """ 188 import onnxruntime as rt 189 190 from opsml.model.onnx import _get_onnx_metadata 191 192 if self.onnx_model is None: 193 self.convert_to_onnx(**{"path": path}) 194 195 else: 196 # save onnx model 197 self.onnx_model.sess_to_path(path) 198 199 # no need to save onnx to bytes since its done during onnx conversion 200 assert self.onnx_model is not None, "No onnx model detected in interface" 201 return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess)) 202 203 def _convert_to_onnx_inplace(self) -> None: 204 """Convert to onnx model using temp dir""" 205 with tempfile.TemporaryDirectory() as tmpdir: 206 lpath = Path(tmpdir) / SaveName.ONNX_MODEL.value 207 onnx_path = lpath.with_suffix(Suffix.ONNX.value) 208 self.convert_to_onnx(**{"path": onnx_path}) 209 210 def convert_to_onnx(self, **kwargs: Path) -> None: 211 # import packages for onnx conversion 212 OpsmlImportExceptions.try_torchonnx_imports() 213 if self.onnx_model is not None: 214 return None 215 216 from opsml.model.onnx.torch_converter import _PyTorchOnnxModel 217 218 path: Optional[Path] = kwargs.get("path") 219 220 if path is None: 221 return self._convert_to_onnx_inplace() 222 223 self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path) 224 return None 225 226 def save_preprocessor(self, path: Path) -> None: 227 """Saves preprocessor to path if present. Base implementation use Joblib 228 229 Args: 230 path: 231 Pathlib object 232 """ 233 assert self.preprocessor is not None, "No preprocessor detected in interface" 234 joblib.dump(self.preprocessor, path) 235 236 def load_preprocessor(self, path: Path) -> None: 237 """Load preprocessor from pathlib object 238 239 Args: 240 path: 241 Pathlib object 242 """ 243 self.preprocessor = joblib.load(path) 244 245 @property 246 def preprocessor_suffix(self) -> str: 247 """Returns suffix for storage""" 248 return Suffix.JOBLIB.value 249 250 @property 251 def model_suffix(self) -> str: 252 """Returns suffix for storage""" 253 return Suffix.PT.value 254 255 @staticmethod 256 def name() -> str: 257 return TorchModel.__name__
Model interface for Pytorch models.
Arguments:
- model: Torch model
- preprocessor: Optional preprocessor
- sample_data: Sample data to be used for type inference and ONNX conversion/validation. This should match exactly what the model expects as input.
- save_args: Optional arguments for saving model. See
TorchSaveArgs
for supported arguments. - task_type: Task type for model. Defaults to undefined.
- model_type: Optional model type. This is inferred automatically.
- preprocessor_name: Optional preprocessor. This is inferred automatically if a preprocessor is provided.
- onnx_args: Optional arguments for ONNX conversion. See
TorchOnnxArgs
for supported arguments.
Returns: TorchModel
sample_data: Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor], NoneType]
@model_validator(mode='before')
@classmethod
def
check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
96 @model_validator(mode="before") 97 @classmethod 98 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 99 model = model_args.get("model") 100 101 if model_args.get("modelcard_uid", False): 102 return model_args 103 104 model, _, bases = get_model_args(model) 105 106 for base in bases: 107 if "torch" in base: 108 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 109 110 sample_data = cls._get_sample_data(model_args[CommonKwargs.SAMPLE_DATA.value]) 111 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 112 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 113 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 114 model_args.get(CommonKwargs.PREPROCESSOR.value), 115 ) 116 117 return model_args
119 def get_sample_prediction(self) -> SamplePrediction: 120 assert self.model is not None, "Model is not defined" 121 assert self.sample_data is not None, "Sample data must be provided" 122 123 # test dict input 124 if isinstance(self.sample_data, dict): 125 try: 126 prediction = self.model(**self.sample_data) 127 except Exception as _: # pylint: disable=broad-except 128 prediction = self.model(self.sample_data) 129 130 # test list and tuple inputs 131 elif isinstance(self.sample_data, (list, tuple)): 132 try: 133 prediction = self.model(*self.sample_data) 134 except Exception as _: # pylint: disable=broad-except 135 prediction = self.model(self.sample_data) 136 137 # all others 138 else: 139 prediction = self.model(self.sample_data) 140 141 prediction_type = get_class_name(prediction) 142 143 return SamplePrediction(prediction_type, prediction)
def
save_model(self, path: pathlib.Path) -> None:
145 def save_model(self, path: Path) -> None: 146 """Save pytorch model to path 147 148 Args: 149 path: 150 pathlib object 151 """ 152 assert self.model is not None, "No model found" 153 154 if self.save_args.as_state_dict: 155 torch.save(self.model.state_dict(), path) 156 else: 157 torch.save(self.model, path)
Save pytorch model to path
Arguments:
- path: pathlib object
def
load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
159 def load_model(self, path: Path, **kwargs: Any) -> None: 160 """Load pytorch model from path 161 162 Args: 163 path: 164 pathlib object 165 kwargs: 166 Additional arguments to be passed to torch.load 167 """ 168 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 169 170 if model_arch is not None: 171 model_arch.load_state_dict(torch.load(path)) 172 model_arch.eval() 173 self.model = model_arch 174 175 else: 176 self.model = torch.load(path)
Load pytorch model from path
Arguments:
- path: pathlib object
- kwargs: Additional arguments to be passed to torch.load
def
save_onnx(self, path: pathlib.Path) -> opsml.types.model.ModelReturn:
178 def save_onnx(self, path: Path) -> ModelReturn: 179 """Saves an onnx model 180 181 Args: 182 path: 183 Path to save model to 184 185 Returns: 186 ModelReturn 187 """ 188 import onnxruntime as rt 189 190 from opsml.model.onnx import _get_onnx_metadata 191 192 if self.onnx_model is None: 193 self.convert_to_onnx(**{"path": path}) 194 195 else: 196 # save onnx model 197 self.onnx_model.sess_to_path(path) 198 199 # no need to save onnx to bytes since its done during onnx conversion 200 assert self.onnx_model is not None, "No onnx model detected in interface" 201 return _get_onnx_metadata(self, cast(rt.InferenceSession, self.onnx_model.sess))
Saves an onnx model
Arguments:
- path: Path to save model to
Returns:
ModelReturn
def
convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
210 def convert_to_onnx(self, **kwargs: Path) -> None: 211 # import packages for onnx conversion 212 OpsmlImportExceptions.try_torchonnx_imports() 213 if self.onnx_model is not None: 214 return None 215 216 from opsml.model.onnx.torch_converter import _PyTorchOnnxModel 217 218 path: Optional[Path] = kwargs.get("path") 219 220 if path is None: 221 return self._convert_to_onnx_inplace() 222 223 self.onnx_model = _PyTorchOnnxModel(self).convert_to_onnx(path=path) 224 return None
Converts model to onnx format
def
save_preprocessor(self, path: pathlib.Path) -> None:
226 def save_preprocessor(self, path: Path) -> None: 227 """Saves preprocessor to path if present. Base implementation use Joblib 228 229 Args: 230 path: 231 Pathlib object 232 """ 233 assert self.preprocessor is not None, "No preprocessor detected in interface" 234 joblib.dump(self.preprocessor, path)
Saves preprocessor to path if present. Base implementation use Joblib
Arguments:
- path: Pathlib object
def
load_preprocessor(self, path: pathlib.Path) -> None:
236 def load_preprocessor(self, path: Path) -> None: 237 """Load preprocessor from pathlib object 238 239 Args: 240 path: 241 Pathlib object 242 """ 243 self.preprocessor = joblib.load(path)
Load preprocessor from pathlib object
Arguments:
- path: Pathlib object
preprocessor_suffix: str
245 @property 246 def preprocessor_suffix(self) -> str: 247 """Returns suffix for storage""" 248 return Suffix.JOBLIB.value
Returns suffix for storage
model_suffix: str
250 @property 251 def model_suffix(self) -> str: 252 """Returns suffix for storage""" 253 return Suffix.PT.value
Returns suffix for storage
model_config =
{'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_fields =
{'model': FieldInfo(annotation=Union[Module, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Tensor, Dict[str, Tensor], List[Tensor], Tuple[Tensor], NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default=''), 'onnx_args': FieldInfo(annotation=Union[TorchOnnxArgs, NoneType], required=False), 'save_args': FieldInfo(annotation=TorchSaveArgs, required=False, default=TorchSaveArgs(as_state_dict=False)), 'preprocessor': FieldInfo(annotation=Union[Any, NoneType], required=False), 'preprocessor_name': FieldInfo(annotation=str, required=False, default='undefined')}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs