opsml.model.interfaces.pytorch_lightning
1from pathlib import Path 2from typing import Any, Dict, Optional 3 4from pydantic import model_validator 5 6from opsml.helpers.utils import OpsmlImportExceptions, get_class_name 7from opsml.model.interfaces.base import ( 8 SamplePrediction, 9 get_model_args, 10 get_processor_name, 11) 12from opsml.model.interfaces.pytorch import TorchModel 13from opsml.types import CommonKwargs, Suffix, TorchOnnxArgs, TrainedModelType 14 15try: 16 from lightning import LightningModule, Trainer 17 18 class LightningModel(TorchModel): 19 """Model interface for Pytorch Lightning models. 20 21 Args: 22 model: 23 Torch lightning model 24 preprocessor: 25 Optional preprocessor 26 sample_data: 27 Sample data to be used for type inference. 28 This should match exactly what the model expects as input. 29 task_type: 30 Task type for model. Defaults to undefined. 31 model_type: 32 Optional model type. This is inferred automatically. 33 preprocessor_name: 34 Optional preprocessor. This is inferred automatically if a 35 preprocessor is provided. 36 37 Returns: 38 LightningModel 39 """ 40 41 model: Optional[Trainer] = None # type: ignore[assignment] 42 onnx_args: Optional[TorchOnnxArgs] = None 43 44 @property 45 def model_class(self) -> str: 46 return TrainedModelType.PYTORCH_LIGHTNING.value 47 48 @model_validator(mode="before") 49 @classmethod 50 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 51 model = model_args.get("model") 52 53 if model_args.get("modelcard_uid", False): 54 return model_args 55 56 model, module, bases = get_model_args(model) 57 58 if "lightning.pytorch" in module: 59 model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__ 60 61 for base in bases: 62 if "lightning.pytorch" in base: 63 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 64 65 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 66 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 67 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 68 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 69 model_args.get(CommonKwargs.PREPROCESSOR.value), 70 ) 71 72 return model_args 73 74 def get_sample_prediction(self) -> SamplePrediction: 75 assert self.model is not None, "Trainer is not defined" 76 assert self.sample_data is not None, "Sample data must be provided" 77 78 trainer_model = self.model.model 79 assert trainer_model is not None, "No model provided to trainer" 80 81 # test dict input 82 if isinstance(self.sample_data, dict): 83 try: 84 prediction = trainer_model(**self.sample_data) 85 except Exception as _: # pylint: disable=broad-except 86 prediction = trainer_model(self.sample_data) 87 88 # test list and tuple inputs 89 elif isinstance(self.sample_data, (list, tuple)): 90 try: 91 prediction = trainer_model(*self.sample_data) 92 except Exception as _: # pylint: disable=broad-except 93 prediction = trainer_model(self.sample_data) 94 95 # all others 96 else: 97 prediction = trainer_model(self.sample_data) 98 99 prediction_type = get_class_name(prediction) 100 101 return SamplePrediction(prediction_type, prediction) 102 103 def save_model(self, path: Path) -> None: 104 assert self.model is not None, "No model detected in interface" 105 self.model.save_checkpoint(path) 106 107 def load_model(self, path: Path, **kwargs: Any) -> None: 108 """Load lightning model from path""" 109 110 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 111 112 try: 113 if model_arch is not None: 114 # attempt to load checkpoint into model 115 assert issubclass( 116 model_arch, LightningModule 117 ), "Model architecture must be a subclass of LightningModule" 118 self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs) 119 120 else: 121 # load via torch 122 import torch 123 124 self.model = torch.load(path) 125 126 except Exception as exc: 127 raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc 128 129 def convert_to_onnx(self, **kwargs: Path) -> None: 130 """Converts model to onnx""" 131 # import packages for onnx conversion 132 OpsmlImportExceptions.try_torchonnx_imports() 133 134 if self.onnx_model is not None: 135 return None 136 137 from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel 138 139 path: Optional[Path] = kwargs.get("path") 140 if path is None: 141 return self._convert_to_onnx_inplace() 142 143 self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path}) 144 return None 145 146 @property 147 def model_suffix(self) -> str: 148 """Returns suffix for storage""" 149 return Suffix.CKPT.value 150 151 @staticmethod 152 def name() -> str: 153 return LightningModel.__name__ 154 155except ModuleNotFoundError: 156 from opsml.model.interfaces.backups import LightningModelNoModule as LightningModel
19 class LightningModel(TorchModel): 20 """Model interface for Pytorch Lightning models. 21 22 Args: 23 model: 24 Torch lightning model 25 preprocessor: 26 Optional preprocessor 27 sample_data: 28 Sample data to be used for type inference. 29 This should match exactly what the model expects as input. 30 task_type: 31 Task type for model. Defaults to undefined. 32 model_type: 33 Optional model type. This is inferred automatically. 34 preprocessor_name: 35 Optional preprocessor. This is inferred automatically if a 36 preprocessor is provided. 37 38 Returns: 39 LightningModel 40 """ 41 42 model: Optional[Trainer] = None # type: ignore[assignment] 43 onnx_args: Optional[TorchOnnxArgs] = None 44 45 @property 46 def model_class(self) -> str: 47 return TrainedModelType.PYTORCH_LIGHTNING.value 48 49 @model_validator(mode="before") 50 @classmethod 51 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 52 model = model_args.get("model") 53 54 if model_args.get("modelcard_uid", False): 55 return model_args 56 57 model, module, bases = get_model_args(model) 58 59 if "lightning.pytorch" in module: 60 model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__ 61 62 for base in bases: 63 if "lightning.pytorch" in base: 64 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 65 66 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 67 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 68 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 69 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 70 model_args.get(CommonKwargs.PREPROCESSOR.value), 71 ) 72 73 return model_args 74 75 def get_sample_prediction(self) -> SamplePrediction: 76 assert self.model is not None, "Trainer is not defined" 77 assert self.sample_data is not None, "Sample data must be provided" 78 79 trainer_model = self.model.model 80 assert trainer_model is not None, "No model provided to trainer" 81 82 # test dict input 83 if isinstance(self.sample_data, dict): 84 try: 85 prediction = trainer_model(**self.sample_data) 86 except Exception as _: # pylint: disable=broad-except 87 prediction = trainer_model(self.sample_data) 88 89 # test list and tuple inputs 90 elif isinstance(self.sample_data, (list, tuple)): 91 try: 92 prediction = trainer_model(*self.sample_data) 93 except Exception as _: # pylint: disable=broad-except 94 prediction = trainer_model(self.sample_data) 95 96 # all others 97 else: 98 prediction = trainer_model(self.sample_data) 99 100 prediction_type = get_class_name(prediction) 101 102 return SamplePrediction(prediction_type, prediction) 103 104 def save_model(self, path: Path) -> None: 105 assert self.model is not None, "No model detected in interface" 106 self.model.save_checkpoint(path) 107 108 def load_model(self, path: Path, **kwargs: Any) -> None: 109 """Load lightning model from path""" 110 111 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 112 113 try: 114 if model_arch is not None: 115 # attempt to load checkpoint into model 116 assert issubclass( 117 model_arch, LightningModule 118 ), "Model architecture must be a subclass of LightningModule" 119 self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs) 120 121 else: 122 # load via torch 123 import torch 124 125 self.model = torch.load(path) 126 127 except Exception as exc: 128 raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc 129 130 def convert_to_onnx(self, **kwargs: Path) -> None: 131 """Converts model to onnx""" 132 # import packages for onnx conversion 133 OpsmlImportExceptions.try_torchonnx_imports() 134 135 if self.onnx_model is not None: 136 return None 137 138 from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel 139 140 path: Optional[Path] = kwargs.get("path") 141 if path is None: 142 return self._convert_to_onnx_inplace() 143 144 self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path}) 145 return None 146 147 @property 148 def model_suffix(self) -> str: 149 """Returns suffix for storage""" 150 return Suffix.CKPT.value 151 152 @staticmethod 153 def name() -> str: 154 return LightningModel.__name__
Model interface for Pytorch Lightning models.
Arguments:
- model: Torch lightning model
- preprocessor: Optional preprocessor
- sample_data: Sample data to be used for type inference. This should match exactly what the model expects as input.
- task_type: Task type for model. Defaults to undefined.
- model_type: Optional model type. This is inferred automatically.
- preprocessor_name: Optional preprocessor. This is inferred automatically if a preprocessor is provided.
Returns: LightningModel
@model_validator(mode='before')
@classmethod
def
check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
49 @model_validator(mode="before") 50 @classmethod 51 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 52 model = model_args.get("model") 53 54 if model_args.get("modelcard_uid", False): 55 return model_args 56 57 model, module, bases = get_model_args(model) 58 59 if "lightning.pytorch" in module: 60 model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__ 61 62 for base in bases: 63 if "lightning.pytorch" in base: 64 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 65 66 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 67 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 68 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 69 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 70 model_args.get(CommonKwargs.PREPROCESSOR.value), 71 ) 72 73 return model_args
75 def get_sample_prediction(self) -> SamplePrediction: 76 assert self.model is not None, "Trainer is not defined" 77 assert self.sample_data is not None, "Sample data must be provided" 78 79 trainer_model = self.model.model 80 assert trainer_model is not None, "No model provided to trainer" 81 82 # test dict input 83 if isinstance(self.sample_data, dict): 84 try: 85 prediction = trainer_model(**self.sample_data) 86 except Exception as _: # pylint: disable=broad-except 87 prediction = trainer_model(self.sample_data) 88 89 # test list and tuple inputs 90 elif isinstance(self.sample_data, (list, tuple)): 91 try: 92 prediction = trainer_model(*self.sample_data) 93 except Exception as _: # pylint: disable=broad-except 94 prediction = trainer_model(self.sample_data) 95 96 # all others 97 else: 98 prediction = trainer_model(self.sample_data) 99 100 prediction_type = get_class_name(prediction) 101 102 return SamplePrediction(prediction_type, prediction)
def
save_model(self, path: pathlib.Path) -> None:
104 def save_model(self, path: Path) -> None: 105 assert self.model is not None, "No model detected in interface" 106 self.model.save_checkpoint(path)
Save pytorch model to path
Arguments:
- path: pathlib object
def
load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
108 def load_model(self, path: Path, **kwargs: Any) -> None: 109 """Load lightning model from path""" 110 111 model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) 112 113 try: 114 if model_arch is not None: 115 # attempt to load checkpoint into model 116 assert issubclass( 117 model_arch, LightningModule 118 ), "Model architecture must be a subclass of LightningModule" 119 self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs) 120 121 else: 122 # load via torch 123 import torch 124 125 self.model = torch.load(path) 126 127 except Exception as exc: 128 raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc
Load lightning model from path
def
convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
130 def convert_to_onnx(self, **kwargs: Path) -> None: 131 """Converts model to onnx""" 132 # import packages for onnx conversion 133 OpsmlImportExceptions.try_torchonnx_imports() 134 135 if self.onnx_model is not None: 136 return None 137 138 from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel 139 140 path: Optional[Path] = kwargs.get("path") 141 if path is None: 142 return self._convert_to_onnx_inplace() 143 144 self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path}) 145 return None
Converts model to onnx
model_suffix: str
147 @property 148 def model_suffix(self) -> str: 149 """Returns suffix for storage""" 150 return Suffix.CKPT.value
Returns suffix for storage
model_config =
{'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_fields =
{'model': FieldInfo(annotation=Union[Trainer, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Tensor, Dict[str, Tensor], List[Tensor], Tuple[Tensor], NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default=''), 'onnx_args': FieldInfo(annotation=Union[TorchOnnxArgs, NoneType], required=False), 'save_args': FieldInfo(annotation=TorchSaveArgs, required=False, default=TorchSaveArgs(as_state_dict=False)), 'preprocessor': FieldInfo(annotation=Union[Any, NoneType], required=False), 'preprocessor_name': FieldInfo(annotation=str, required=False, default='undefined')}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs