opsml.model.interfaces.pytorch_lightning

  1from pathlib import Path
  2from typing import Any, Dict, Optional
  3
  4from pydantic import model_validator
  5
  6from opsml.helpers.utils import OpsmlImportExceptions, get_class_name
  7from opsml.model.interfaces.base import (
  8    SamplePrediction,
  9    get_model_args,
 10    get_processor_name,
 11)
 12from opsml.model.interfaces.pytorch import TorchModel
 13from opsml.types import CommonKwargs, Suffix, TorchOnnxArgs, TrainedModelType
 14
 15try:
 16    from lightning import LightningModule, Trainer
 17
 18    class LightningModel(TorchModel):
 19        """Model interface for Pytorch Lightning models.
 20
 21        Args:
 22            model:
 23                Torch lightning model
 24            preprocessor:
 25                Optional preprocessor
 26            sample_data:
 27                Sample data to be used for type inference.
 28                This should match exactly what the model expects as input.
 29            task_type:
 30                Task type for model. Defaults to undefined.
 31            model_type:
 32                Optional model type. This is inferred automatically.
 33            preprocessor_name:
 34                Optional preprocessor. This is inferred automatically if a
 35                preprocessor is provided.
 36
 37        Returns:
 38        LightningModel
 39        """
 40
 41        model: Optional[Trainer] = None  # type: ignore[assignment]
 42        onnx_args: Optional[TorchOnnxArgs] = None
 43
 44        @property
 45        def model_class(self) -> str:
 46            return TrainedModelType.PYTORCH_LIGHTNING.value
 47
 48        @model_validator(mode="before")
 49        @classmethod
 50        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 51            model = model_args.get("model")
 52
 53            if model_args.get("modelcard_uid", False):
 54                return model_args
 55
 56            model, module, bases = get_model_args(model)
 57
 58            if "lightning.pytorch" in module:
 59                model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__
 60
 61            for base in bases:
 62                if "lightning.pytorch" in base:
 63                    model_args[CommonKwargs.MODEL_TYPE.value] = "subclass"
 64
 65            sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
 66            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
 67            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
 68            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
 69                model_args.get(CommonKwargs.PREPROCESSOR.value),
 70            )
 71
 72            return model_args
 73
 74        def get_sample_prediction(self) -> SamplePrediction:
 75            assert self.model is not None, "Trainer is not defined"
 76            assert self.sample_data is not None, "Sample data must be provided"
 77
 78            trainer_model = self.model.model
 79            assert trainer_model is not None, "No model provided to trainer"
 80
 81            # test dict input
 82            if isinstance(self.sample_data, dict):
 83                try:
 84                    prediction = trainer_model(**self.sample_data)
 85                except Exception as _:  # pylint: disable=broad-except
 86                    prediction = trainer_model(self.sample_data)
 87
 88            # test list and tuple inputs
 89            elif isinstance(self.sample_data, (list, tuple)):
 90                try:
 91                    prediction = trainer_model(*self.sample_data)
 92                except Exception as _:  # pylint: disable=broad-except
 93                    prediction = trainer_model(self.sample_data)
 94
 95            # all others
 96            else:
 97                prediction = trainer_model(self.sample_data)
 98
 99            prediction_type = get_class_name(prediction)
100
101            return SamplePrediction(prediction_type, prediction)
102
103        def save_model(self, path: Path) -> None:
104            assert self.model is not None, "No model detected in interface"
105            self.model.save_checkpoint(path)
106
107        def load_model(self, path: Path, **kwargs: Any) -> None:
108            """Load lightning model from path"""
109
110            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
111
112            try:
113                if model_arch is not None:
114                    # attempt to load checkpoint into model
115                    assert issubclass(
116                        model_arch, LightningModule
117                    ), "Model architecture must be a subclass of LightningModule"
118                    self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs)
119
120                else:
121                    # load via torch
122                    import torch
123
124                    self.model = torch.load(path)
125
126            except Exception as exc:
127                raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc
128
129        def convert_to_onnx(self, **kwargs: Path) -> None:
130            """Converts model to onnx"""
131            # import packages for onnx conversion
132            OpsmlImportExceptions.try_torchonnx_imports()
133
134            if self.onnx_model is not None:
135                return None
136
137            from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel
138
139            path: Optional[Path] = kwargs.get("path")
140            if path is None:
141                return self._convert_to_onnx_inplace()
142
143            self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path})
144            return None
145
146        @property
147        def model_suffix(self) -> str:
148            """Returns suffix for storage"""
149            return Suffix.CKPT.value
150
151        @staticmethod
152        def name() -> str:
153            return LightningModel.__name__
154
155except ModuleNotFoundError:
156    from opsml.model.interfaces.backups import LightningModelNoModule as LightningModel
class LightningModel(opsml.model.interfaces.pytorch.TorchModel):
 19    class LightningModel(TorchModel):
 20        """Model interface for Pytorch Lightning models.
 21
 22        Args:
 23            model:
 24                Torch lightning model
 25            preprocessor:
 26                Optional preprocessor
 27            sample_data:
 28                Sample data to be used for type inference.
 29                This should match exactly what the model expects as input.
 30            task_type:
 31                Task type for model. Defaults to undefined.
 32            model_type:
 33                Optional model type. This is inferred automatically.
 34            preprocessor_name:
 35                Optional preprocessor. This is inferred automatically if a
 36                preprocessor is provided.
 37
 38        Returns:
 39        LightningModel
 40        """
 41
 42        model: Optional[Trainer] = None  # type: ignore[assignment]
 43        onnx_args: Optional[TorchOnnxArgs] = None
 44
 45        @property
 46        def model_class(self) -> str:
 47            return TrainedModelType.PYTORCH_LIGHTNING.value
 48
 49        @model_validator(mode="before")
 50        @classmethod
 51        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
 52            model = model_args.get("model")
 53
 54            if model_args.get("modelcard_uid", False):
 55                return model_args
 56
 57            model, module, bases = get_model_args(model)
 58
 59            if "lightning.pytorch" in module:
 60                model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__
 61
 62            for base in bases:
 63                if "lightning.pytorch" in base:
 64                    model_args[CommonKwargs.MODEL_TYPE.value] = "subclass"
 65
 66            sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
 67            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
 68            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
 69            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
 70                model_args.get(CommonKwargs.PREPROCESSOR.value),
 71            )
 72
 73            return model_args
 74
 75        def get_sample_prediction(self) -> SamplePrediction:
 76            assert self.model is not None, "Trainer is not defined"
 77            assert self.sample_data is not None, "Sample data must be provided"
 78
 79            trainer_model = self.model.model
 80            assert trainer_model is not None, "No model provided to trainer"
 81
 82            # test dict input
 83            if isinstance(self.sample_data, dict):
 84                try:
 85                    prediction = trainer_model(**self.sample_data)
 86                except Exception as _:  # pylint: disable=broad-except
 87                    prediction = trainer_model(self.sample_data)
 88
 89            # test list and tuple inputs
 90            elif isinstance(self.sample_data, (list, tuple)):
 91                try:
 92                    prediction = trainer_model(*self.sample_data)
 93                except Exception as _:  # pylint: disable=broad-except
 94                    prediction = trainer_model(self.sample_data)
 95
 96            # all others
 97            else:
 98                prediction = trainer_model(self.sample_data)
 99
100            prediction_type = get_class_name(prediction)
101
102            return SamplePrediction(prediction_type, prediction)
103
104        def save_model(self, path: Path) -> None:
105            assert self.model is not None, "No model detected in interface"
106            self.model.save_checkpoint(path)
107
108        def load_model(self, path: Path, **kwargs: Any) -> None:
109            """Load lightning model from path"""
110
111            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
112
113            try:
114                if model_arch is not None:
115                    # attempt to load checkpoint into model
116                    assert issubclass(
117                        model_arch, LightningModule
118                    ), "Model architecture must be a subclass of LightningModule"
119                    self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs)
120
121                else:
122                    # load via torch
123                    import torch
124
125                    self.model = torch.load(path)
126
127            except Exception as exc:
128                raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc
129
130        def convert_to_onnx(self, **kwargs: Path) -> None:
131            """Converts model to onnx"""
132            # import packages for onnx conversion
133            OpsmlImportExceptions.try_torchonnx_imports()
134
135            if self.onnx_model is not None:
136                return None
137
138            from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel
139
140            path: Optional[Path] = kwargs.get("path")
141            if path is None:
142                return self._convert_to_onnx_inplace()
143
144            self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path})
145            return None
146
147        @property
148        def model_suffix(self) -> str:
149            """Returns suffix for storage"""
150            return Suffix.CKPT.value
151
152        @staticmethod
153        def name() -> str:
154            return LightningModel.__name__

Model interface for Pytorch Lightning models.

Arguments:
  • model: Torch lightning model
  • preprocessor: Optional preprocessor
  • sample_data: Sample data to be used for type inference. This should match exactly what the model expects as input.
  • task_type: Task type for model. Defaults to undefined.
  • model_type: Optional model type. This is inferred automatically.
  • preprocessor_name: Optional preprocessor. This is inferred automatically if a preprocessor is provided.

Returns: LightningModel

model: Optional[lightning.pytorch.trainer.trainer.Trainer]
onnx_args: Optional[opsml.types.model.TorchOnnxArgs]
model_class: str
45        @property
46        def model_class(self) -> str:
47            return TrainedModelType.PYTORCH_LIGHTNING.value
@model_validator(mode='before')
@classmethod
def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
49        @model_validator(mode="before")
50        @classmethod
51        def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
52            model = model_args.get("model")
53
54            if model_args.get("modelcard_uid", False):
55                return model_args
56
57            model, module, bases = get_model_args(model)
58
59            if "lightning.pytorch" in module:
60                model_args[CommonKwargs.MODEL_TYPE.value] = model.model.__class__.__name__
61
62            for base in bases:
63                if "lightning.pytorch" in base:
64                    model_args[CommonKwargs.MODEL_TYPE.value] = "subclass"
65
66            sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
67            model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
68            model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
69            model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
70                model_args.get(CommonKwargs.PREPROCESSOR.value),
71            )
72
73            return model_args
def get_sample_prediction(self) -> opsml.model.interfaces.base.SamplePrediction:
 75        def get_sample_prediction(self) -> SamplePrediction:
 76            assert self.model is not None, "Trainer is not defined"
 77            assert self.sample_data is not None, "Sample data must be provided"
 78
 79            trainer_model = self.model.model
 80            assert trainer_model is not None, "No model provided to trainer"
 81
 82            # test dict input
 83            if isinstance(self.sample_data, dict):
 84                try:
 85                    prediction = trainer_model(**self.sample_data)
 86                except Exception as _:  # pylint: disable=broad-except
 87                    prediction = trainer_model(self.sample_data)
 88
 89            # test list and tuple inputs
 90            elif isinstance(self.sample_data, (list, tuple)):
 91                try:
 92                    prediction = trainer_model(*self.sample_data)
 93                except Exception as _:  # pylint: disable=broad-except
 94                    prediction = trainer_model(self.sample_data)
 95
 96            # all others
 97            else:
 98                prediction = trainer_model(self.sample_data)
 99
100            prediction_type = get_class_name(prediction)
101
102            return SamplePrediction(prediction_type, prediction)
def save_model(self, path: pathlib.Path) -> None:
104        def save_model(self, path: Path) -> None:
105            assert self.model is not None, "No model detected in interface"
106            self.model.save_checkpoint(path)

Save pytorch model to path

Arguments:
  • path: pathlib object
def load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
108        def load_model(self, path: Path, **kwargs: Any) -> None:
109            """Load lightning model from path"""
110
111            model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)
112
113            try:
114                if model_arch is not None:
115                    # attempt to load checkpoint into model
116                    assert issubclass(
117                        model_arch, LightningModule
118                    ), "Model architecture must be a subclass of LightningModule"
119                    self.model = model_arch.load_from_checkpoint(checkpoint_path=path, **kwargs)
120
121                else:
122                    # load via torch
123                    import torch
124
125                    self.model = torch.load(path)
126
127            except Exception as exc:
128                raise ValueError(f"Unable to load pytorch lightning model: {exc}") from exc

Load lightning model from path

def convert_to_onnx(self, **kwargs: pathlib.Path) -> None:
130        def convert_to_onnx(self, **kwargs: Path) -> None:
131            """Converts model to onnx"""
132            # import packages for onnx conversion
133            OpsmlImportExceptions.try_torchonnx_imports()
134
135            if self.onnx_model is not None:
136                return None
137
138            from opsml.model.onnx.torch_converter import _PyTorchLightningOnnxModel
139
140            path: Optional[Path] = kwargs.get("path")
141            if path is None:
142                return self._convert_to_onnx_inplace()
143
144            self.onnx_model = _PyTorchLightningOnnxModel(self).convert_to_onnx(**{"path": path})
145            return None

Converts model to onnx

model_suffix: str
147        @property
148        def model_suffix(self) -> str:
149            """Returns suffix for storage"""
150            return Suffix.CKPT.value

Returns suffix for storage

@staticmethod
def name() -> str:
152        @staticmethod
153        def name() -> str:
154            return LightningModel.__name__
model_config = {'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_fields = {'model': FieldInfo(annotation=Union[Trainer, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[Tensor, Dict[str, Tensor], List[Tensor], Tuple[Tensor], NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default=''), 'onnx_args': FieldInfo(annotation=Union[TorchOnnxArgs, NoneType], required=False), 'save_args': FieldInfo(annotation=TorchSaveArgs, required=False, default=TorchSaveArgs(as_state_dict=False)), 'preprocessor': FieldInfo(annotation=Union[Any, NoneType], required=False), 'preprocessor_name': FieldInfo(annotation=str, required=False, default='undefined')}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
opsml.model.interfaces.pytorch.TorchModel
sample_data
save_args
preprocessor
preprocessor_name
save_onnx
save_preprocessor
load_preprocessor
preprocessor_suffix
opsml.model.interfaces.base.ModelInterface
onnx_model
task_type
model_type
data_type
modelcard_uid
check_modelcard_uid
load_onnx_model
save_sample_data
load_sample_data
data_suffix