opsml.model.challenger

  1# Copyright (c) Shipt, Inc.
  2# This source code is licensed under the MIT license found in the
  3# LICENSE file in the root directory of this source tree.
  4from typing import Any, Dict, List, Optional, Union, cast
  5
  6from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
  7
  8from opsml.cards.model import ModelCard
  9from opsml.cards.run import RunCard
 10from opsml.helpers.logging import ArtifactLogger
 11from opsml.registry.registry import CardRegistries
 12from opsml.types import CardInfo, Metric
 13
 14logger = ArtifactLogger.get_logger()
 15
 16# User interfaces should primarily be checked at runtime
 17
 18
 19class BattleReport(BaseModel):
 20    model_config = ConfigDict(arbitrary_types_allowed=True)
 21    champion_name: str
 22    champion_version: str
 23    champion_metric: Optional[Metric] = None
 24    challenger_metric: Optional[Metric] = None
 25    challenger_win: bool
 26
 27
 28MetricName = Union[str, List[str]]
 29MetricValue = Union[int, float, List[Union[int, float]]]
 30
 31
 32class ChallengeInputs(BaseModel):
 33    metric_name: MetricName
 34    metric_value: Optional[MetricValue] = None
 35    lower_is_better: Union[bool, List[bool]] = True
 36
 37    @property
 38    def metric_names(self) -> List[str]:
 39        return cast(List[str], self.metric_name)
 40
 41    @property
 42    def metric_values(self) -> List[Optional[Union[int, float]]]:
 43        return cast(List[Optional[Union[int, float]]], self.metric_value)
 44
 45    @property
 46    def thresholds(self) -> List[bool]:
 47        return cast(List[bool], self.lower_is_better)
 48
 49    @field_validator("metric_name")
 50    @classmethod
 51    def convert_name(cls, name: Union[List[str], str]) -> List[str]:
 52        if not isinstance(name, list):
 53            return [name]
 54        return name
 55
 56    @field_validator("metric_value")
 57    @classmethod
 58    def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]:
 59        data = info.data
 60        metric = cast(MetricName, data["metric_name"])
 61        nbr_metrics = len(metric)
 62
 63        if value is not None:
 64            if not isinstance(value, list):
 65                metric_value = [value]
 66            else:
 67                metric_value = value
 68        else:
 69            metric_value = [None] * nbr_metrics  # type: ignore
 70
 71        if len(metric_value) != nbr_metrics:
 72            raise ValueError("List of metric values must be the same length as metric names")
 73
 74        return metric_value
 75
 76    @field_validator("lower_is_better")
 77    @classmethod
 78    def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]:
 79        data = info.data
 80        metric = cast(MetricName, data["metric_name"])
 81        nbr_metrics = len(metric)
 82
 83        if not isinstance(threshold, list):
 84            _threshold = [threshold] * nbr_metrics
 85        else:
 86            _threshold = threshold
 87
 88        if len(_threshold) != nbr_metrics:
 89            if len(_threshold) == 1:
 90                _threshold = _threshold * nbr_metrics
 91            else:
 92                raise ValueError("Length of lower_is_better must be the same length as number of metrics")
 93
 94        return _threshold
 95
 96
 97class ModelChallenger:
 98    def __init__(self, challenger: ModelCard):
 99        """
100        Instantiates ModelChallenger class
101
102        Args:
103            challenger:
104                ModelCard of challenger
105
106        """
107        self._challenger = challenger
108        self._challenger_metric: Optional[Metric] = None
109        self._registries = CardRegistries()
110
111    @property
112    def challenger_metric(self) -> Metric:
113        if self._challenger_metric is not None:
114            return self._challenger_metric
115        raise ValueError("Challenger metric not set")
116
117    @challenger_metric.setter
118    def challenger_metric(self, metric: Metric) -> None:
119        self._challenger_metric = metric
120
121    def _get_last_champion_record(self) -> Optional[Dict[str, Any]]:
122        """Gets the previous champion record"""
123
124        champion_records = self._registries.model.list_cards(
125            name=self._challenger.name,
126            repository=self._challenger.repository,
127        )
128
129        if not bool(champion_records):
130            return None
131
132        # indicates challenger has been registered
133        if self._challenger.version is not None and len(champion_records) > 1:
134            return champion_records[1]
135
136        # account for cases where challenger is only model in registry
137        champion_record = champion_records[0]
138        if champion_record.get("version") == self._challenger.version:
139            return None
140
141        return champion_record
142
143    def _get_runcard_metric(self, runcard_uid: str, metric_name: str) -> Metric:
144        """
145        Loads a RunCard from uid
146
147        Args:
148            runcard_uid:
149                RunCard uid
150            metric_name:
151                Name of metric
152
153        """
154        runcard = cast(RunCard, self._registries.run.load_card(uid=runcard_uid))
155        metric = runcard.get_metric(name=metric_name)
156
157        if isinstance(metric, list):
158            metric = metric[0]
159
160        return metric
161
162    def _battle(self, champion: CardInfo, champion_metric: Metric, lower_is_better: bool) -> BattleReport:
163        """
164        Runs a battle between champion and current challenger
165
166        Args:
167            champion:
168                Champion record
169            champion_metric:
170                Champion metric from a runcard
171            lower_is_better:
172                Whether lower metric is preferred
173
174        Returns:
175            `BattleReport`
176
177        """
178        if lower_is_better:
179            challenger_win = self.challenger_metric.value < champion_metric.value
180        else:
181            challenger_win = self.challenger_metric.value > champion_metric.value
182        return BattleReport.model_construct(
183            champion_name=str(champion.name),
184            champion_version=str(champion.version),
185            champion_metric=champion_metric,
186            challenger_metric=self.challenger_metric.model_copy(deep=True),
187            challenger_win=challenger_win,
188        )
189
190    def _battle_last_model_version(self, metric_name: str, lower_is_better: bool) -> BattleReport:
191        """Compares the last champion model to the current challenger"""
192
193        champion_record = self._get_last_champion_record()
194
195        if champion_record is None:
196            logger.info("No previous model found. Challenger wins")
197
198            return BattleReport(
199                champion_name="No model",
200                champion_version="No version",
201                challenger_win=True,
202            )
203
204        runcard_id = champion_record.get("runcard_uid")
205        if runcard_id is None:
206            raise ValueError(f"No RunCard is associated with champion: {champion_record}")
207
208        champion_metric = self._get_runcard_metric(runcard_uid=runcard_id, metric_name=metric_name)
209
210        return self._battle(
211            champion=CardInfo(
212                name=champion_record.get("name"),
213                version=champion_record.get("version"),
214            ),
215            champion_metric=champion_metric,
216            lower_is_better=lower_is_better,
217        )
218
219    def _battle_champions(
220        self,
221        champions: List[CardInfo],
222        metric_name: str,
223        lower_is_better: bool,
224    ) -> List[BattleReport]:
225        """Loops through and creates a `BattleReport` for each champion"""
226        battle_reports = []
227
228        for champion in champions:
229            champion_record = self._registries.model.list_cards(
230                info=champion,
231            )
232
233            if not bool(champion_record):
234                raise ValueError(f"Champion model does not exist. {champion}")
235
236            champion_card = champion_record[0]
237            runcard_uid = champion_card.get("runcard_uid")
238            if runcard_uid is None:
239                raise ValueError(f"No RunCard associated with champion: {champion}")
240
241            champion_metric = self._get_runcard_metric(
242                runcard_uid=runcard_uid,
243                metric_name=metric_name,
244            )
245
246            # update name, repository and version in case of None
247            champion.name = champion.name or champion_card.get("name")
248            champion.repository = champion.repository or champion_card.get("repository")
249            champion.version = champion.version or champion_card.get("version")
250
251            battle_reports.append(
252                self._battle(
253                    champion=champion,
254                    champion_metric=champion_metric,
255                    lower_is_better=lower_is_better,
256                )
257            )
258        return battle_reports
259
260    def challenge_champion(
261        self,
262        metric_name: MetricName,
263        metric_value: Optional[MetricValue] = None,
264        champions: Optional[List[CardInfo]] = None,
265        lower_is_better: Union[bool, List[bool]] = True,
266    ) -> Dict[str, List[BattleReport]]:
267        """
268        Challenges n champion models against the challenger model. If no champion is provided,
269        the latest model version is used as a champion.
270
271        Args:
272            champions:
273                Optional list of champion CardInfo
274            metric_name:
275                Name of metric to evaluate
276            metric_value:
277                Challenger metric value
278            lower_is_better:
279                Whether a lower metric value is better or not
280
281        Returns
282            `BattleReport`
283        """
284
285        # validate inputs
286        inputs = ChallengeInputs(
287            metric_name=metric_name,
288            metric_value=metric_value,
289            lower_is_better=lower_is_better,
290        )
291
292        report_dict = {}
293
294        for name, value, _lower_is_better in zip(
295            inputs.metric_names,
296            inputs.metric_values,
297            inputs.thresholds,
298        ):
299            # get challenger metric
300            if value is None:
301                if self._challenger.metadata.runcard_uid is not None:
302                    self.challenger_metric = self._get_runcard_metric(
303                        self._challenger.metadata.runcard_uid, metric_name=name
304                    )
305                else:
306                    raise ValueError("Challenger and champions must be associated with a registered RunCard")
307            else:
308                self.challenger_metric = Metric(name=name, value=value)
309
310            if champions is None:
311                report_dict[name] = [
312                    self._battle_last_model_version(
313                        metric_name=name,
314                        lower_is_better=_lower_is_better,
315                    )
316                ]
317
318            else:
319                report_dict[name] = self._battle_champions(
320                    champions=champions,
321                    metric_name=name,
322                    lower_is_better=_lower_is_better,
323                )
324
325        return report_dict
logger = <builtins.Logger object>
class BattleReport(pydantic.main.BaseModel):
20class BattleReport(BaseModel):
21    model_config = ConfigDict(arbitrary_types_allowed=True)
22    champion_name: str
23    champion_version: str
24    champion_metric: Optional[Metric] = None
25    challenger_metric: Optional[Metric] = None
26    challenger_win: bool

Usage docs: https://docs.pydantic.dev/2.6/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of classvars defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The signature for instantiating the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a RootModel.
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_extra__: An instance attribute with the values of extra fields from validation when model_config['extra'] == 'allow'.
  • __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
  • __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
model_config = {'arbitrary_types_allowed': True}
champion_name: str
champion_version: str
champion_metric: Optional[opsml.types.card.Metric]
challenger_metric: Optional[opsml.types.card.Metric]
challenger_win: bool
model_fields = {'champion_name': FieldInfo(annotation=str, required=True), 'champion_version': FieldInfo(annotation=str, required=True), 'champion_metric': FieldInfo(annotation=Union[Metric, NoneType], required=False), 'challenger_metric': FieldInfo(annotation=Union[Metric, NoneType], required=False), 'challenger_win': FieldInfo(annotation=bool, required=True)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
MetricName = typing.Union[str, typing.List[str]]
MetricValue = typing.Union[int, float, typing.List[typing.Union[int, float]]]
class ChallengeInputs(pydantic.main.BaseModel):
33class ChallengeInputs(BaseModel):
34    metric_name: MetricName
35    metric_value: Optional[MetricValue] = None
36    lower_is_better: Union[bool, List[bool]] = True
37
38    @property
39    def metric_names(self) -> List[str]:
40        return cast(List[str], self.metric_name)
41
42    @property
43    def metric_values(self) -> List[Optional[Union[int, float]]]:
44        return cast(List[Optional[Union[int, float]]], self.metric_value)
45
46    @property
47    def thresholds(self) -> List[bool]:
48        return cast(List[bool], self.lower_is_better)
49
50    @field_validator("metric_name")
51    @classmethod
52    def convert_name(cls, name: Union[List[str], str]) -> List[str]:
53        if not isinstance(name, list):
54            return [name]
55        return name
56
57    @field_validator("metric_value")
58    @classmethod
59    def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]:
60        data = info.data
61        metric = cast(MetricName, data["metric_name"])
62        nbr_metrics = len(metric)
63
64        if value is not None:
65            if not isinstance(value, list):
66                metric_value = [value]
67            else:
68                metric_value = value
69        else:
70            metric_value = [None] * nbr_metrics  # type: ignore
71
72        if len(metric_value) != nbr_metrics:
73            raise ValueError("List of metric values must be the same length as metric names")
74
75        return metric_value
76
77    @field_validator("lower_is_better")
78    @classmethod
79    def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]:
80        data = info.data
81        metric = cast(MetricName, data["metric_name"])
82        nbr_metrics = len(metric)
83
84        if not isinstance(threshold, list):
85            _threshold = [threshold] * nbr_metrics
86        else:
87            _threshold = threshold
88
89        if len(_threshold) != nbr_metrics:
90            if len(_threshold) == 1:
91                _threshold = _threshold * nbr_metrics
92            else:
93                raise ValueError("Length of lower_is_better must be the same length as number of metrics")
94
95        return _threshold

Usage docs: https://docs.pydantic.dev/2.6/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of classvars defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The signature for instantiating the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a RootModel.
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_extra__: An instance attribute with the values of extra fields from validation when model_config['extra'] == 'allow'.
  • __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
  • __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
metric_name: Union[str, List[str]]
metric_value: Union[int, float, List[Union[int, float]], NoneType]
lower_is_better: Union[bool, List[bool]]
metric_names: List[str]
38    @property
39    def metric_names(self) -> List[str]:
40        return cast(List[str], self.metric_name)
metric_values: List[Union[int, float, NoneType]]
42    @property
43    def metric_values(self) -> List[Optional[Union[int, float]]]:
44        return cast(List[Optional[Union[int, float]]], self.metric_value)
thresholds: List[bool]
46    @property
47    def thresholds(self) -> List[bool]:
48        return cast(List[bool], self.lower_is_better)
@field_validator('metric_name')
@classmethod
def convert_name(cls, name: Union[List[str], str]) -> List[str]:
50    @field_validator("metric_name")
51    @classmethod
52    def convert_name(cls, name: Union[List[str], str]) -> List[str]:
53        if not isinstance(name, list):
54            return [name]
55        return name
@field_validator('metric_value')
@classmethod
def convert_value( cls, value: Union[int, float, List[Union[int, float]], NoneType], info: pydantic_core.core_schema.ValidationInfo) -> List[Any]:
57    @field_validator("metric_value")
58    @classmethod
59    def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]:
60        data = info.data
61        metric = cast(MetricName, data["metric_name"])
62        nbr_metrics = len(metric)
63
64        if value is not None:
65            if not isinstance(value, list):
66                metric_value = [value]
67            else:
68                metric_value = value
69        else:
70            metric_value = [None] * nbr_metrics  # type: ignore
71
72        if len(metric_value) != nbr_metrics:
73            raise ValueError("List of metric values must be the same length as metric names")
74
75        return metric_value
@field_validator('lower_is_better')
@classmethod
def convert_threshold( cls, threshold: Union[bool, List[bool]], info: pydantic_core.core_schema.ValidationInfo) -> List[bool]:
77    @field_validator("lower_is_better")
78    @classmethod
79    def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]:
80        data = info.data
81        metric = cast(MetricName, data["metric_name"])
82        nbr_metrics = len(metric)
83
84        if not isinstance(threshold, list):
85            _threshold = [threshold] * nbr_metrics
86        else:
87            _threshold = threshold
88
89        if len(_threshold) != nbr_metrics:
90            if len(_threshold) == 1:
91                _threshold = _threshold * nbr_metrics
92            else:
93                raise ValueError("Length of lower_is_better must be the same length as number of metrics")
94
95        return _threshold
model_config = {}
model_fields = {'metric_name': FieldInfo(annotation=Union[str, List[str]], required=True), 'metric_value': FieldInfo(annotation=Union[int, float, List[Union[int, float]], NoneType], required=False), 'lower_is_better': FieldInfo(annotation=Union[bool, List[bool]], required=False, default=True)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
class ModelChallenger:
 98class ModelChallenger:
 99    def __init__(self, challenger: ModelCard):
100        """
101        Instantiates ModelChallenger class
102
103        Args:
104            challenger:
105                ModelCard of challenger
106
107        """
108        self._challenger = challenger
109        self._challenger_metric: Optional[Metric] = None
110        self._registries = CardRegistries()
111
112    @property
113    def challenger_metric(self) -> Metric:
114        if self._challenger_metric is not None:
115            return self._challenger_metric
116        raise ValueError("Challenger metric not set")
117
118    @challenger_metric.setter
119    def challenger_metric(self, metric: Metric) -> None:
120        self._challenger_metric = metric
121
122    def _get_last_champion_record(self) -> Optional[Dict[str, Any]]:
123        """Gets the previous champion record"""
124
125        champion_records = self._registries.model.list_cards(
126            name=self._challenger.name,
127            repository=self._challenger.repository,
128        )
129
130        if not bool(champion_records):
131            return None
132
133        # indicates challenger has been registered
134        if self._challenger.version is not None and len(champion_records) > 1:
135            return champion_records[1]
136
137        # account for cases where challenger is only model in registry
138        champion_record = champion_records[0]
139        if champion_record.get("version") == self._challenger.version:
140            return None
141
142        return champion_record
143
144    def _get_runcard_metric(self, runcard_uid: str, metric_name: str) -> Metric:
145        """
146        Loads a RunCard from uid
147
148        Args:
149            runcard_uid:
150                RunCard uid
151            metric_name:
152                Name of metric
153
154        """
155        runcard = cast(RunCard, self._registries.run.load_card(uid=runcard_uid))
156        metric = runcard.get_metric(name=metric_name)
157
158        if isinstance(metric, list):
159            metric = metric[0]
160
161        return metric
162
163    def _battle(self, champion: CardInfo, champion_metric: Metric, lower_is_better: bool) -> BattleReport:
164        """
165        Runs a battle between champion and current challenger
166
167        Args:
168            champion:
169                Champion record
170            champion_metric:
171                Champion metric from a runcard
172            lower_is_better:
173                Whether lower metric is preferred
174
175        Returns:
176            `BattleReport`
177
178        """
179        if lower_is_better:
180            challenger_win = self.challenger_metric.value < champion_metric.value
181        else:
182            challenger_win = self.challenger_metric.value > champion_metric.value
183        return BattleReport.model_construct(
184            champion_name=str(champion.name),
185            champion_version=str(champion.version),
186            champion_metric=champion_metric,
187            challenger_metric=self.challenger_metric.model_copy(deep=True),
188            challenger_win=challenger_win,
189        )
190
191    def _battle_last_model_version(self, metric_name: str, lower_is_better: bool) -> BattleReport:
192        """Compares the last champion model to the current challenger"""
193
194        champion_record = self._get_last_champion_record()
195
196        if champion_record is None:
197            logger.info("No previous model found. Challenger wins")
198
199            return BattleReport(
200                champion_name="No model",
201                champion_version="No version",
202                challenger_win=True,
203            )
204
205        runcard_id = champion_record.get("runcard_uid")
206        if runcard_id is None:
207            raise ValueError(f"No RunCard is associated with champion: {champion_record}")
208
209        champion_metric = self._get_runcard_metric(runcard_uid=runcard_id, metric_name=metric_name)
210
211        return self._battle(
212            champion=CardInfo(
213                name=champion_record.get("name"),
214                version=champion_record.get("version"),
215            ),
216            champion_metric=champion_metric,
217            lower_is_better=lower_is_better,
218        )
219
220    def _battle_champions(
221        self,
222        champions: List[CardInfo],
223        metric_name: str,
224        lower_is_better: bool,
225    ) -> List[BattleReport]:
226        """Loops through and creates a `BattleReport` for each champion"""
227        battle_reports = []
228
229        for champion in champions:
230            champion_record = self._registries.model.list_cards(
231                info=champion,
232            )
233
234            if not bool(champion_record):
235                raise ValueError(f"Champion model does not exist. {champion}")
236
237            champion_card = champion_record[0]
238            runcard_uid = champion_card.get("runcard_uid")
239            if runcard_uid is None:
240                raise ValueError(f"No RunCard associated with champion: {champion}")
241
242            champion_metric = self._get_runcard_metric(
243                runcard_uid=runcard_uid,
244                metric_name=metric_name,
245            )
246
247            # update name, repository and version in case of None
248            champion.name = champion.name or champion_card.get("name")
249            champion.repository = champion.repository or champion_card.get("repository")
250            champion.version = champion.version or champion_card.get("version")
251
252            battle_reports.append(
253                self._battle(
254                    champion=champion,
255                    champion_metric=champion_metric,
256                    lower_is_better=lower_is_better,
257                )
258            )
259        return battle_reports
260
261    def challenge_champion(
262        self,
263        metric_name: MetricName,
264        metric_value: Optional[MetricValue] = None,
265        champions: Optional[List[CardInfo]] = None,
266        lower_is_better: Union[bool, List[bool]] = True,
267    ) -> Dict[str, List[BattleReport]]:
268        """
269        Challenges n champion models against the challenger model. If no champion is provided,
270        the latest model version is used as a champion.
271
272        Args:
273            champions:
274                Optional list of champion CardInfo
275            metric_name:
276                Name of metric to evaluate
277            metric_value:
278                Challenger metric value
279            lower_is_better:
280                Whether a lower metric value is better or not
281
282        Returns
283            `BattleReport`
284        """
285
286        # validate inputs
287        inputs = ChallengeInputs(
288            metric_name=metric_name,
289            metric_value=metric_value,
290            lower_is_better=lower_is_better,
291        )
292
293        report_dict = {}
294
295        for name, value, _lower_is_better in zip(
296            inputs.metric_names,
297            inputs.metric_values,
298            inputs.thresholds,
299        ):
300            # get challenger metric
301            if value is None:
302                if self._challenger.metadata.runcard_uid is not None:
303                    self.challenger_metric = self._get_runcard_metric(
304                        self._challenger.metadata.runcard_uid, metric_name=name
305                    )
306                else:
307                    raise ValueError("Challenger and champions must be associated with a registered RunCard")
308            else:
309                self.challenger_metric = Metric(name=name, value=value)
310
311            if champions is None:
312                report_dict[name] = [
313                    self._battle_last_model_version(
314                        metric_name=name,
315                        lower_is_better=_lower_is_better,
316                    )
317                ]
318
319            else:
320                report_dict[name] = self._battle_champions(
321                    champions=champions,
322                    metric_name=name,
323                    lower_is_better=_lower_is_better,
324                )
325
326        return report_dict
ModelChallenger(challenger: opsml.cards.model.ModelCard)
 99    def __init__(self, challenger: ModelCard):
100        """
101        Instantiates ModelChallenger class
102
103        Args:
104            challenger:
105                ModelCard of challenger
106
107        """
108        self._challenger = challenger
109        self._challenger_metric: Optional[Metric] = None
110        self._registries = CardRegistries()

Instantiates ModelChallenger class

Arguments:
  • challenger: ModelCard of challenger
challenger_metric: opsml.types.card.Metric
112    @property
113    def challenger_metric(self) -> Metric:
114        if self._challenger_metric is not None:
115            return self._challenger_metric
116        raise ValueError("Challenger metric not set")
def challenge_champion( self, metric_name: Union[str, List[str]], metric_value: Union[int, float, List[Union[int, float]], NoneType] = None, champions: Optional[List[opsml.types.card.CardInfo]] = None, lower_is_better: Union[bool, List[bool]] = True) -> Dict[str, List[BattleReport]]:
261    def challenge_champion(
262        self,
263        metric_name: MetricName,
264        metric_value: Optional[MetricValue] = None,
265        champions: Optional[List[CardInfo]] = None,
266        lower_is_better: Union[bool, List[bool]] = True,
267    ) -> Dict[str, List[BattleReport]]:
268        """
269        Challenges n champion models against the challenger model. If no champion is provided,
270        the latest model version is used as a champion.
271
272        Args:
273            champions:
274                Optional list of champion CardInfo
275            metric_name:
276                Name of metric to evaluate
277            metric_value:
278                Challenger metric value
279            lower_is_better:
280                Whether a lower metric value is better or not
281
282        Returns
283            `BattleReport`
284        """
285
286        # validate inputs
287        inputs = ChallengeInputs(
288            metric_name=metric_name,
289            metric_value=metric_value,
290            lower_is_better=lower_is_better,
291        )
292
293        report_dict = {}
294
295        for name, value, _lower_is_better in zip(
296            inputs.metric_names,
297            inputs.metric_values,
298            inputs.thresholds,
299        ):
300            # get challenger metric
301            if value is None:
302                if self._challenger.metadata.runcard_uid is not None:
303                    self.challenger_metric = self._get_runcard_metric(
304                        self._challenger.metadata.runcard_uid, metric_name=name
305                    )
306                else:
307                    raise ValueError("Challenger and champions must be associated with a registered RunCard")
308            else:
309                self.challenger_metric = Metric(name=name, value=value)
310
311            if champions is None:
312                report_dict[name] = [
313                    self._battle_last_model_version(
314                        metric_name=name,
315                        lower_is_better=_lower_is_better,
316                    )
317                ]
318
319            else:
320                report_dict[name] = self._battle_champions(
321                    champions=champions,
322                    metric_name=name,
323                    lower_is_better=_lower_is_better,
324                )
325
326        return report_dict

Challenges n champion models against the challenger model. If no champion is provided, the latest model version is used as a champion.

Arguments:
  • champions: Optional list of champion CardInfo
  • metric_name: Name of metric to evaluate
  • metric_value: Challenger metric value
  • lower_is_better: Whether a lower metric value is better or not

Returns BattleReport