opsml.model.challenger
1# Copyright (c) Shipt, Inc. 2# This source code is licensed under the MIT license found in the 3# LICENSE file in the root directory of this source tree. 4from typing import Any, Dict, List, Optional, Union, cast 5 6from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator 7 8from opsml.cards.model import ModelCard 9from opsml.cards.run import RunCard 10from opsml.helpers.logging import ArtifactLogger 11from opsml.registry.registry import CardRegistries 12from opsml.types import CardInfo, Metric 13 14logger = ArtifactLogger.get_logger() 15 16# User interfaces should primarily be checked at runtime 17 18 19class BattleReport(BaseModel): 20 model_config = ConfigDict(arbitrary_types_allowed=True) 21 champion_name: str 22 champion_version: str 23 champion_metric: Optional[Metric] = None 24 challenger_metric: Optional[Metric] = None 25 challenger_win: bool 26 27 28MetricName = Union[str, List[str]] 29MetricValue = Union[int, float, List[Union[int, float]]] 30 31 32class ChallengeInputs(BaseModel): 33 metric_name: MetricName 34 metric_value: Optional[MetricValue] = None 35 lower_is_better: Union[bool, List[bool]] = True 36 37 @property 38 def metric_names(self) -> List[str]: 39 return cast(List[str], self.metric_name) 40 41 @property 42 def metric_values(self) -> List[Optional[Union[int, float]]]: 43 return cast(List[Optional[Union[int, float]]], self.metric_value) 44 45 @property 46 def thresholds(self) -> List[bool]: 47 return cast(List[bool], self.lower_is_better) 48 49 @field_validator("metric_name") 50 @classmethod 51 def convert_name(cls, name: Union[List[str], str]) -> List[str]: 52 if not isinstance(name, list): 53 return [name] 54 return name 55 56 @field_validator("metric_value") 57 @classmethod 58 def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]: 59 data = info.data 60 metric = cast(MetricName, data["metric_name"]) 61 nbr_metrics = len(metric) 62 63 if value is not None: 64 if not isinstance(value, list): 65 metric_value = [value] 66 else: 67 metric_value = value 68 else: 69 metric_value = [None] * nbr_metrics # type: ignore 70 71 if len(metric_value) != nbr_metrics: 72 raise ValueError("List of metric values must be the same length as metric names") 73 74 return metric_value 75 76 @field_validator("lower_is_better") 77 @classmethod 78 def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]: 79 data = info.data 80 metric = cast(MetricName, data["metric_name"]) 81 nbr_metrics = len(metric) 82 83 if not isinstance(threshold, list): 84 _threshold = [threshold] * nbr_metrics 85 else: 86 _threshold = threshold 87 88 if len(_threshold) != nbr_metrics: 89 if len(_threshold) == 1: 90 _threshold = _threshold * nbr_metrics 91 else: 92 raise ValueError("Length of lower_is_better must be the same length as number of metrics") 93 94 return _threshold 95 96 97class ModelChallenger: 98 def __init__(self, challenger: ModelCard): 99 """ 100 Instantiates ModelChallenger class 101 102 Args: 103 challenger: 104 ModelCard of challenger 105 106 """ 107 self._challenger = challenger 108 self._challenger_metric: Optional[Metric] = None 109 self._registries = CardRegistries() 110 111 @property 112 def challenger_metric(self) -> Metric: 113 if self._challenger_metric is not None: 114 return self._challenger_metric 115 raise ValueError("Challenger metric not set") 116 117 @challenger_metric.setter 118 def challenger_metric(self, metric: Metric) -> None: 119 self._challenger_metric = metric 120 121 def _get_last_champion_record(self) -> Optional[Dict[str, Any]]: 122 """Gets the previous champion record""" 123 124 champion_records = self._registries.model.list_cards( 125 name=self._challenger.name, 126 repository=self._challenger.repository, 127 ) 128 129 if not bool(champion_records): 130 return None 131 132 # indicates challenger has been registered 133 if self._challenger.version is not None and len(champion_records) > 1: 134 return champion_records[1] 135 136 # account for cases where challenger is only model in registry 137 champion_record = champion_records[0] 138 if champion_record.get("version") == self._challenger.version: 139 return None 140 141 return champion_record 142 143 def _get_runcard_metric(self, runcard_uid: str, metric_name: str) -> Metric: 144 """ 145 Loads a RunCard from uid 146 147 Args: 148 runcard_uid: 149 RunCard uid 150 metric_name: 151 Name of metric 152 153 """ 154 runcard = cast(RunCard, self._registries.run.load_card(uid=runcard_uid)) 155 metric = runcard.get_metric(name=metric_name) 156 157 if isinstance(metric, list): 158 metric = metric[0] 159 160 return metric 161 162 def _battle(self, champion: CardInfo, champion_metric: Metric, lower_is_better: bool) -> BattleReport: 163 """ 164 Runs a battle between champion and current challenger 165 166 Args: 167 champion: 168 Champion record 169 champion_metric: 170 Champion metric from a runcard 171 lower_is_better: 172 Whether lower metric is preferred 173 174 Returns: 175 `BattleReport` 176 177 """ 178 if lower_is_better: 179 challenger_win = self.challenger_metric.value < champion_metric.value 180 else: 181 challenger_win = self.challenger_metric.value > champion_metric.value 182 return BattleReport.model_construct( 183 champion_name=str(champion.name), 184 champion_version=str(champion.version), 185 champion_metric=champion_metric, 186 challenger_metric=self.challenger_metric.model_copy(deep=True), 187 challenger_win=challenger_win, 188 ) 189 190 def _battle_last_model_version(self, metric_name: str, lower_is_better: bool) -> BattleReport: 191 """Compares the last champion model to the current challenger""" 192 193 champion_record = self._get_last_champion_record() 194 195 if champion_record is None: 196 logger.info("No previous model found. Challenger wins") 197 198 return BattleReport( 199 champion_name="No model", 200 champion_version="No version", 201 challenger_win=True, 202 ) 203 204 runcard_id = champion_record.get("runcard_uid") 205 if runcard_id is None: 206 raise ValueError(f"No RunCard is associated with champion: {champion_record}") 207 208 champion_metric = self._get_runcard_metric(runcard_uid=runcard_id, metric_name=metric_name) 209 210 return self._battle( 211 champion=CardInfo( 212 name=champion_record.get("name"), 213 version=champion_record.get("version"), 214 ), 215 champion_metric=champion_metric, 216 lower_is_better=lower_is_better, 217 ) 218 219 def _battle_champions( 220 self, 221 champions: List[CardInfo], 222 metric_name: str, 223 lower_is_better: bool, 224 ) -> List[BattleReport]: 225 """Loops through and creates a `BattleReport` for each champion""" 226 battle_reports = [] 227 228 for champion in champions: 229 champion_record = self._registries.model.list_cards( 230 info=champion, 231 ) 232 233 if not bool(champion_record): 234 raise ValueError(f"Champion model does not exist. {champion}") 235 236 champion_card = champion_record[0] 237 runcard_uid = champion_card.get("runcard_uid") 238 if runcard_uid is None: 239 raise ValueError(f"No RunCard associated with champion: {champion}") 240 241 champion_metric = self._get_runcard_metric( 242 runcard_uid=runcard_uid, 243 metric_name=metric_name, 244 ) 245 246 # update name, repository and version in case of None 247 champion.name = champion.name or champion_card.get("name") 248 champion.repository = champion.repository or champion_card.get("repository") 249 champion.version = champion.version or champion_card.get("version") 250 251 battle_reports.append( 252 self._battle( 253 champion=champion, 254 champion_metric=champion_metric, 255 lower_is_better=lower_is_better, 256 ) 257 ) 258 return battle_reports 259 260 def challenge_champion( 261 self, 262 metric_name: MetricName, 263 metric_value: Optional[MetricValue] = None, 264 champions: Optional[List[CardInfo]] = None, 265 lower_is_better: Union[bool, List[bool]] = True, 266 ) -> Dict[str, List[BattleReport]]: 267 """ 268 Challenges n champion models against the challenger model. If no champion is provided, 269 the latest model version is used as a champion. 270 271 Args: 272 champions: 273 Optional list of champion CardInfo 274 metric_name: 275 Name of metric to evaluate 276 metric_value: 277 Challenger metric value 278 lower_is_better: 279 Whether a lower metric value is better or not 280 281 Returns 282 `BattleReport` 283 """ 284 285 # validate inputs 286 inputs = ChallengeInputs( 287 metric_name=metric_name, 288 metric_value=metric_value, 289 lower_is_better=lower_is_better, 290 ) 291 292 report_dict = {} 293 294 for name, value, _lower_is_better in zip( 295 inputs.metric_names, 296 inputs.metric_values, 297 inputs.thresholds, 298 ): 299 # get challenger metric 300 if value is None: 301 if self._challenger.metadata.runcard_uid is not None: 302 self.challenger_metric = self._get_runcard_metric( 303 self._challenger.metadata.runcard_uid, metric_name=name 304 ) 305 else: 306 raise ValueError("Challenger and champions must be associated with a registered RunCard") 307 else: 308 self.challenger_metric = Metric(name=name, value=value) 309 310 if champions is None: 311 report_dict[name] = [ 312 self._battle_last_model_version( 313 metric_name=name, 314 lower_is_better=_lower_is_better, 315 ) 316 ] 317 318 else: 319 report_dict[name] = self._battle_champions( 320 champions=champions, 321 metric_name=name, 322 lower_is_better=_lower_is_better, 323 ) 324 325 return report_dict
logger =
<builtins.Logger object>
class
BattleReport(pydantic.main.BaseModel):
20class BattleReport(BaseModel): 21 model_config = ConfigDict(arbitrary_types_allowed=True) 22 champion_name: str 23 champion_version: str 24 champion_metric: Optional[Metric] = None 25 challenger_metric: Optional[Metric] = None 26 challenger_win: bool
Usage docs: https://docs.pydantic.dev/2.6/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of classvars defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The signature for instantiating the model.
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a
RootModel
. - __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_extra__: An instance attribute with the values of extra fields from validation when
model_config['extra'] == 'allow'
. - __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
- __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
model_fields =
{'champion_name': FieldInfo(annotation=str, required=True), 'champion_version': FieldInfo(annotation=str, required=True), 'champion_metric': FieldInfo(annotation=Union[Metric, NoneType], required=False), 'challenger_metric': FieldInfo(annotation=Union[Metric, NoneType], required=False), 'challenger_win': FieldInfo(annotation=bool, required=True)}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
MetricName =
typing.Union[str, typing.List[str]]
MetricValue =
typing.Union[int, float, typing.List[typing.Union[int, float]]]
class
ChallengeInputs(pydantic.main.BaseModel):
33class ChallengeInputs(BaseModel): 34 metric_name: MetricName 35 metric_value: Optional[MetricValue] = None 36 lower_is_better: Union[bool, List[bool]] = True 37 38 @property 39 def metric_names(self) -> List[str]: 40 return cast(List[str], self.metric_name) 41 42 @property 43 def metric_values(self) -> List[Optional[Union[int, float]]]: 44 return cast(List[Optional[Union[int, float]]], self.metric_value) 45 46 @property 47 def thresholds(self) -> List[bool]: 48 return cast(List[bool], self.lower_is_better) 49 50 @field_validator("metric_name") 51 @classmethod 52 def convert_name(cls, name: Union[List[str], str]) -> List[str]: 53 if not isinstance(name, list): 54 return [name] 55 return name 56 57 @field_validator("metric_value") 58 @classmethod 59 def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]: 60 data = info.data 61 metric = cast(MetricName, data["metric_name"]) 62 nbr_metrics = len(metric) 63 64 if value is not None: 65 if not isinstance(value, list): 66 metric_value = [value] 67 else: 68 metric_value = value 69 else: 70 metric_value = [None] * nbr_metrics # type: ignore 71 72 if len(metric_value) != nbr_metrics: 73 raise ValueError("List of metric values must be the same length as metric names") 74 75 return metric_value 76 77 @field_validator("lower_is_better") 78 @classmethod 79 def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]: 80 data = info.data 81 metric = cast(MetricName, data["metric_name"]) 82 nbr_metrics = len(metric) 83 84 if not isinstance(threshold, list): 85 _threshold = [threshold] * nbr_metrics 86 else: 87 _threshold = threshold 88 89 if len(_threshold) != nbr_metrics: 90 if len(_threshold) == 1: 91 _threshold = _threshold * nbr_metrics 92 else: 93 raise ValueError("Length of lower_is_better must be the same length as number of metrics") 94 95 return _threshold
Usage docs: https://docs.pydantic.dev/2.6/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of classvars defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The signature for instantiating the model.
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a
RootModel
. - __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_extra__: An instance attribute with the values of extra fields from validation when
model_config['extra'] == 'allow'
. - __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
- __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
@field_validator('metric_name')
@classmethod
def
convert_name(cls, name: Union[List[str], str]) -> List[str]:
@field_validator('metric_value')
@classmethod
def
convert_value( cls, value: Union[int, float, List[Union[int, float]], NoneType], info: pydantic_core.core_schema.ValidationInfo) -> List[Any]:
57 @field_validator("metric_value") 58 @classmethod 59 def convert_value(cls, value: Optional[MetricValue], info: ValidationInfo) -> List[Any]: 60 data = info.data 61 metric = cast(MetricName, data["metric_name"]) 62 nbr_metrics = len(metric) 63 64 if value is not None: 65 if not isinstance(value, list): 66 metric_value = [value] 67 else: 68 metric_value = value 69 else: 70 metric_value = [None] * nbr_metrics # type: ignore 71 72 if len(metric_value) != nbr_metrics: 73 raise ValueError("List of metric values must be the same length as metric names") 74 75 return metric_value
@field_validator('lower_is_better')
@classmethod
def
convert_threshold( cls, threshold: Union[bool, List[bool]], info: pydantic_core.core_schema.ValidationInfo) -> List[bool]:
77 @field_validator("lower_is_better") 78 @classmethod 79 def convert_threshold(cls, threshold: Union[bool, List[bool]], info: ValidationInfo) -> List[bool]: 80 data = info.data 81 metric = cast(MetricName, data["metric_name"]) 82 nbr_metrics = len(metric) 83 84 if not isinstance(threshold, list): 85 _threshold = [threshold] * nbr_metrics 86 else: 87 _threshold = threshold 88 89 if len(_threshold) != nbr_metrics: 90 if len(_threshold) == 1: 91 _threshold = _threshold * nbr_metrics 92 else: 93 raise ValueError("Length of lower_is_better must be the same length as number of metrics") 94 95 return _threshold
model_fields =
{'metric_name': FieldInfo(annotation=Union[str, List[str]], required=True), 'metric_value': FieldInfo(annotation=Union[int, float, List[Union[int, float]], NoneType], required=False), 'lower_is_better': FieldInfo(annotation=Union[bool, List[bool]], required=False, default=True)}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
class
ModelChallenger:
98class ModelChallenger: 99 def __init__(self, challenger: ModelCard): 100 """ 101 Instantiates ModelChallenger class 102 103 Args: 104 challenger: 105 ModelCard of challenger 106 107 """ 108 self._challenger = challenger 109 self._challenger_metric: Optional[Metric] = None 110 self._registries = CardRegistries() 111 112 @property 113 def challenger_metric(self) -> Metric: 114 if self._challenger_metric is not None: 115 return self._challenger_metric 116 raise ValueError("Challenger metric not set") 117 118 @challenger_metric.setter 119 def challenger_metric(self, metric: Metric) -> None: 120 self._challenger_metric = metric 121 122 def _get_last_champion_record(self) -> Optional[Dict[str, Any]]: 123 """Gets the previous champion record""" 124 125 champion_records = self._registries.model.list_cards( 126 name=self._challenger.name, 127 repository=self._challenger.repository, 128 ) 129 130 if not bool(champion_records): 131 return None 132 133 # indicates challenger has been registered 134 if self._challenger.version is not None and len(champion_records) > 1: 135 return champion_records[1] 136 137 # account for cases where challenger is only model in registry 138 champion_record = champion_records[0] 139 if champion_record.get("version") == self._challenger.version: 140 return None 141 142 return champion_record 143 144 def _get_runcard_metric(self, runcard_uid: str, metric_name: str) -> Metric: 145 """ 146 Loads a RunCard from uid 147 148 Args: 149 runcard_uid: 150 RunCard uid 151 metric_name: 152 Name of metric 153 154 """ 155 runcard = cast(RunCard, self._registries.run.load_card(uid=runcard_uid)) 156 metric = runcard.get_metric(name=metric_name) 157 158 if isinstance(metric, list): 159 metric = metric[0] 160 161 return metric 162 163 def _battle(self, champion: CardInfo, champion_metric: Metric, lower_is_better: bool) -> BattleReport: 164 """ 165 Runs a battle between champion and current challenger 166 167 Args: 168 champion: 169 Champion record 170 champion_metric: 171 Champion metric from a runcard 172 lower_is_better: 173 Whether lower metric is preferred 174 175 Returns: 176 `BattleReport` 177 178 """ 179 if lower_is_better: 180 challenger_win = self.challenger_metric.value < champion_metric.value 181 else: 182 challenger_win = self.challenger_metric.value > champion_metric.value 183 return BattleReport.model_construct( 184 champion_name=str(champion.name), 185 champion_version=str(champion.version), 186 champion_metric=champion_metric, 187 challenger_metric=self.challenger_metric.model_copy(deep=True), 188 challenger_win=challenger_win, 189 ) 190 191 def _battle_last_model_version(self, metric_name: str, lower_is_better: bool) -> BattleReport: 192 """Compares the last champion model to the current challenger""" 193 194 champion_record = self._get_last_champion_record() 195 196 if champion_record is None: 197 logger.info("No previous model found. Challenger wins") 198 199 return BattleReport( 200 champion_name="No model", 201 champion_version="No version", 202 challenger_win=True, 203 ) 204 205 runcard_id = champion_record.get("runcard_uid") 206 if runcard_id is None: 207 raise ValueError(f"No RunCard is associated with champion: {champion_record}") 208 209 champion_metric = self._get_runcard_metric(runcard_uid=runcard_id, metric_name=metric_name) 210 211 return self._battle( 212 champion=CardInfo( 213 name=champion_record.get("name"), 214 version=champion_record.get("version"), 215 ), 216 champion_metric=champion_metric, 217 lower_is_better=lower_is_better, 218 ) 219 220 def _battle_champions( 221 self, 222 champions: List[CardInfo], 223 metric_name: str, 224 lower_is_better: bool, 225 ) -> List[BattleReport]: 226 """Loops through and creates a `BattleReport` for each champion""" 227 battle_reports = [] 228 229 for champion in champions: 230 champion_record = self._registries.model.list_cards( 231 info=champion, 232 ) 233 234 if not bool(champion_record): 235 raise ValueError(f"Champion model does not exist. {champion}") 236 237 champion_card = champion_record[0] 238 runcard_uid = champion_card.get("runcard_uid") 239 if runcard_uid is None: 240 raise ValueError(f"No RunCard associated with champion: {champion}") 241 242 champion_metric = self._get_runcard_metric( 243 runcard_uid=runcard_uid, 244 metric_name=metric_name, 245 ) 246 247 # update name, repository and version in case of None 248 champion.name = champion.name or champion_card.get("name") 249 champion.repository = champion.repository or champion_card.get("repository") 250 champion.version = champion.version or champion_card.get("version") 251 252 battle_reports.append( 253 self._battle( 254 champion=champion, 255 champion_metric=champion_metric, 256 lower_is_better=lower_is_better, 257 ) 258 ) 259 return battle_reports 260 261 def challenge_champion( 262 self, 263 metric_name: MetricName, 264 metric_value: Optional[MetricValue] = None, 265 champions: Optional[List[CardInfo]] = None, 266 lower_is_better: Union[bool, List[bool]] = True, 267 ) -> Dict[str, List[BattleReport]]: 268 """ 269 Challenges n champion models against the challenger model. If no champion is provided, 270 the latest model version is used as a champion. 271 272 Args: 273 champions: 274 Optional list of champion CardInfo 275 metric_name: 276 Name of metric to evaluate 277 metric_value: 278 Challenger metric value 279 lower_is_better: 280 Whether a lower metric value is better or not 281 282 Returns 283 `BattleReport` 284 """ 285 286 # validate inputs 287 inputs = ChallengeInputs( 288 metric_name=metric_name, 289 metric_value=metric_value, 290 lower_is_better=lower_is_better, 291 ) 292 293 report_dict = {} 294 295 for name, value, _lower_is_better in zip( 296 inputs.metric_names, 297 inputs.metric_values, 298 inputs.thresholds, 299 ): 300 # get challenger metric 301 if value is None: 302 if self._challenger.metadata.runcard_uid is not None: 303 self.challenger_metric = self._get_runcard_metric( 304 self._challenger.metadata.runcard_uid, metric_name=name 305 ) 306 else: 307 raise ValueError("Challenger and champions must be associated with a registered RunCard") 308 else: 309 self.challenger_metric = Metric(name=name, value=value) 310 311 if champions is None: 312 report_dict[name] = [ 313 self._battle_last_model_version( 314 metric_name=name, 315 lower_is_better=_lower_is_better, 316 ) 317 ] 318 319 else: 320 report_dict[name] = self._battle_champions( 321 champions=champions, 322 metric_name=name, 323 lower_is_better=_lower_is_better, 324 ) 325 326 return report_dict
ModelChallenger(challenger: opsml.cards.model.ModelCard)
99 def __init__(self, challenger: ModelCard): 100 """ 101 Instantiates ModelChallenger class 102 103 Args: 104 challenger: 105 ModelCard of challenger 106 107 """ 108 self._challenger = challenger 109 self._challenger_metric: Optional[Metric] = None 110 self._registries = CardRegistries()
Instantiates ModelChallenger class
Arguments:
- challenger: ModelCard of challenger
def
challenge_champion( self, metric_name: Union[str, List[str]], metric_value: Union[int, float, List[Union[int, float]], NoneType] = None, champions: Optional[List[opsml.types.card.CardInfo]] = None, lower_is_better: Union[bool, List[bool]] = True) -> Dict[str, List[BattleReport]]:
261 def challenge_champion( 262 self, 263 metric_name: MetricName, 264 metric_value: Optional[MetricValue] = None, 265 champions: Optional[List[CardInfo]] = None, 266 lower_is_better: Union[bool, List[bool]] = True, 267 ) -> Dict[str, List[BattleReport]]: 268 """ 269 Challenges n champion models against the challenger model. If no champion is provided, 270 the latest model version is used as a champion. 271 272 Args: 273 champions: 274 Optional list of champion CardInfo 275 metric_name: 276 Name of metric to evaluate 277 metric_value: 278 Challenger metric value 279 lower_is_better: 280 Whether a lower metric value is better or not 281 282 Returns 283 `BattleReport` 284 """ 285 286 # validate inputs 287 inputs = ChallengeInputs( 288 metric_name=metric_name, 289 metric_value=metric_value, 290 lower_is_better=lower_is_better, 291 ) 292 293 report_dict = {} 294 295 for name, value, _lower_is_better in zip( 296 inputs.metric_names, 297 inputs.metric_values, 298 inputs.thresholds, 299 ): 300 # get challenger metric 301 if value is None: 302 if self._challenger.metadata.runcard_uid is not None: 303 self.challenger_metric = self._get_runcard_metric( 304 self._challenger.metadata.runcard_uid, metric_name=name 305 ) 306 else: 307 raise ValueError("Challenger and champions must be associated with a registered RunCard") 308 else: 309 self.challenger_metric = Metric(name=name, value=value) 310 311 if champions is None: 312 report_dict[name] = [ 313 self._battle_last_model_version( 314 metric_name=name, 315 lower_is_better=_lower_is_better, 316 ) 317 ] 318 319 else: 320 report_dict[name] = self._battle_champions( 321 champions=champions, 322 metric_name=name, 323 lower_is_better=_lower_is_better, 324 ) 325 326 return report_dict
Challenges n champion models against the challenger model. If no champion is provided, the latest model version is used as a champion.
Arguments:
- champions: Optional list of champion CardInfo
- metric_name: Name of metric to evaluate
- metric_value: Challenger metric value
- lower_is_better: Whether a lower metric value is better or not
Returns
BattleReport