opsml.data.splitter
1# pylint: disable=invalid-name 2# Copyright (c) Shipt, Inc. 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5from dataclasses import dataclass 6from typing import Any, List, Optional, Tuple, Union 7 8import pandas as pd 9import polars as pl 10import pyarrow as pa 11from numpy.typing import NDArray 12from pydantic import BaseModel, ConfigDict, field_validator 13 14from opsml.types import AllowedDataType 15 16 17@dataclass 18class Data: 19 X: Any 20 y: Optional[Any] = None 21 22 23class DataSplit(BaseModel): 24 model_config = ConfigDict(arbitrary_types_allowed=True) 25 26 label: str 27 column_name: Optional[str] = None 28 column_value: Optional[Union[str, float, int, pd.Timestamp]] = None 29 inequality: Optional[str] = None 30 start: Optional[int] = None 31 stop: Optional[int] = None 32 indices: Optional[List[int]] = None 33 34 @field_validator("indices", mode="before") 35 @classmethod 36 def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]: 37 """Pre to convert indices to list if not None""" 38 39 if value is not None and not isinstance(value, list): 40 value = list(value) 41 42 return value 43 44 @field_validator("inequality", mode="before") 45 @classmethod 46 def trim_whitespace(cls, value: str) -> str: 47 """Trims whitespace from inequality signs""" 48 49 if value is not None: 50 value = value.strip() 51 52 return value 53 54 55class DataSplitterBase: 56 def __init__( 57 self, 58 split: DataSplit, 59 dependent_vars: List[Union[int, str]], 60 ): 61 self.split = split 62 self.dependent_vars = dependent_vars 63 64 @property 65 def column_name(self) -> str: 66 if self.split.column_name is not None: 67 return self.split.column_name 68 69 raise ValueError("Column name was not provided") 70 71 @property 72 def column_value(self) -> Any: 73 if self.split.column_value is not None: 74 return self.split.column_value 75 76 raise ValueError("Column value was not provided") 77 78 @property 79 def indices(self) -> List[int]: 80 if self.split.indices is not None: 81 return self.split.indices 82 raise ValueError("List of indices was not provided") 83 84 @property 85 def start(self) -> int: 86 if self.split.start is not None: 87 return self.split.start 88 raise ValueError("Start index was not provided") 89 90 @property 91 def stop(self) -> int: 92 if self.split.stop is not None: 93 return self.split.stop 94 raise ValueError("Stop index was not provided") 95 96 def get_x_cols(self, columns: List[str], dependent_vars: List[Union[str, int]]) -> List[str]: 97 for var in dependent_vars: 98 if isinstance(var, str): 99 columns.remove(var) 100 101 return columns 102 103 def create_split(self, data: Any) -> Tuple[str, Data]: 104 raise NotImplementedError 105 106 @staticmethod 107 def validate(data_type: str, split: DataSplit) -> bool: 108 raise NotImplementedError 109 110 111class PolarsColumnSplitter(DataSplitterBase): 112 """Column splitter for Polars dataframe""" 113 114 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 115 if self.split.inequality is None: 116 data = data.filter(pl.col(self.column_name) == self.column_value) 117 118 elif self.split.inequality == ">": 119 data = data.filter(pl.col(self.column_name) > self.column_value) 120 121 elif self.split.inequality == ">=": 122 data = data.filter(pl.col(self.column_name) >= self.column_value) 123 124 elif self.split.inequality == "<": 125 data = data.filter(pl.col(self.column_name) < self.column_value) 126 127 else: 128 data = data.filter(pl.col(self.column_name) <= self.column_value) 129 130 if bool(self.dependent_vars): 131 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 132 133 return self.split.label, Data( 134 X=data.select(x_cols), 135 y=data.select(self.dependent_vars), 136 ) 137 138 return self.split.label, Data(X=data) 139 140 @staticmethod 141 def validate(data_type: str, split: DataSplit) -> bool: 142 return data_type == AllowedDataType.POLARS and split.column_name is not None 143 144 145class PolarsIndexSplitter(DataSplitterBase): 146 """Split Polars DataFrame by rows index""" 147 148 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 149 # slice 150 data = data[self.indices] 151 152 if bool(self.dependent_vars): 153 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 154 155 return self.split.label, Data( 156 X=data.select(x_cols), 157 y=data.select(self.dependent_vars), 158 ) 159 160 return self.split.label, Data(X=data) 161 162 @staticmethod 163 def validate(data_type: str, split: DataSplit) -> bool: 164 return data_type == AllowedDataType.POLARS and split.indices is not None 165 166 167class PolarsRowsSplitter(DataSplitterBase): 168 """Split Polars DataFrame by rows slice""" 169 170 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 171 # slice 172 data = data[self.start : self.stop] 173 174 if bool(self.dependent_vars): 175 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 176 177 return self.split.label, Data( 178 X=data.select(x_cols), 179 y=data.select(self.dependent_vars), 180 ) 181 182 return self.split.label, Data(X=data) 183 184 @staticmethod 185 def validate(data_type: str, split: DataSplit) -> bool: 186 return data_type == AllowedDataType.POLARS and split.start is not None 187 188 189class PandasIndexSplitter(DataSplitterBase): 190 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 191 data = data.iloc[self.indices] 192 193 if bool(self.dependent_vars): 194 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 195 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 196 197 return self.split.label, Data(X=x, y=y) 198 199 return self.split.label, Data(X=data) 200 201 @staticmethod 202 def validate(data_type: str, split: DataSplit) -> bool: 203 return data_type == AllowedDataType.PANDAS and split.indices is not None 204 205 206class PandasRowSplitter(DataSplitterBase): 207 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 208 # slice 209 data = data[self.start : self.stop] 210 211 if bool(self.dependent_vars): 212 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 213 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 214 215 return self.split.label, Data(X=x, y=y) 216 217 return self.split.label, Data(X=data) 218 219 @staticmethod 220 def validate(data_type: str, split: DataSplit) -> bool: 221 return data_type == AllowedDataType.PANDAS and split.start is not None 222 223 224class PandasColumnSplitter(DataSplitterBase): 225 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 226 if self.split.inequality is None: 227 data = data[data[self.column_name] == self.column_value] 228 229 elif self.split.inequality == ">": 230 data = data[data[self.column_name] > self.column_value] 231 232 elif self.split.inequality == ">=": 233 data = data[data[self.column_name] >= self.column_value] 234 235 elif self.split.inequality == "<": 236 data = data[data[self.column_name] < self.column_value] 237 238 else: 239 data = data[data[self.column_name] <= self.column_value] 240 241 if bool(self.dependent_vars): 242 return self.split.label, Data( 243 X=data[data.columns[~data.columns.isin(self.dependent_vars)]], 244 y=data[data.columns[data.columns.isin(self.dependent_vars)]], 245 ) 246 247 data_split = Data(X=data) 248 return self.split.label, data_split 249 250 @staticmethod 251 def validate(data_type: str, split: DataSplit) -> bool: 252 return data_type == AllowedDataType.PANDAS and split.column_name is not None 253 254 255class PyArrowIndexSplitter(DataSplitterBase): 256 def create_split(self, data: pa.Table) -> Tuple[str, Data]: 257 return self.split.label, Data(X=data.take(self.indices)) 258 259 @staticmethod 260 def validate(data_type: str, split: DataSplit) -> bool: 261 return data_type == AllowedDataType.PYARROW and split.indices is not None 262 263 264class NumpyIndexSplitter(DataSplitterBase): 265 def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]: 266 return self.split.label, Data(X=data[self.indices]) 267 268 @staticmethod 269 def validate(data_type: str, split: DataSplit) -> bool: 270 return data_type == AllowedDataType.NUMPY and split.indices is not None 271 272 273class NumpyRowSplitter(DataSplitterBase): 274 def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]: 275 data_split = data[self.start : self.stop] 276 return self.split.label, Data(X=data_split) 277 278 @staticmethod 279 def validate(data_type: str, split: DataSplit) -> bool: 280 return data_type == AllowedDataType.NUMPY and split.start is not None 281 282 283class DataSplitter: 284 @staticmethod 285 def split( 286 split: DataSplit, 287 data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame], 288 data_type: str, 289 dependent_vars: List[Union[int, str]], 290 ) -> Tuple[str, Data]: 291 data_splitter = next( 292 ( 293 data_splitter 294 for data_splitter in DataSplitterBase.__subclasses__() 295 if data_splitter.validate( 296 data_type=data_type, 297 split=split, 298 ) 299 ), 300 None, 301 ) 302 303 if data_splitter is not None: 304 return data_splitter( 305 split=split, 306 dependent_vars=dependent_vars, 307 ).create_split(data=data) 308 309 raise ValueError("Failed to find data supporter that supports provided logic")
@dataclass
class
Data:
class
DataSplit(pydantic.main.BaseModel):
24class DataSplit(BaseModel): 25 model_config = ConfigDict(arbitrary_types_allowed=True) 26 27 label: str 28 column_name: Optional[str] = None 29 column_value: Optional[Union[str, float, int, pd.Timestamp]] = None 30 inequality: Optional[str] = None 31 start: Optional[int] = None 32 stop: Optional[int] = None 33 indices: Optional[List[int]] = None 34 35 @field_validator("indices", mode="before") 36 @classmethod 37 def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]: 38 """Pre to convert indices to list if not None""" 39 40 if value is not None and not isinstance(value, list): 41 value = list(value) 42 43 return value 44 45 @field_validator("inequality", mode="before") 46 @classmethod 47 def trim_whitespace(cls, value: str) -> str: 48 """Trims whitespace from inequality signs""" 49 50 if value is not None: 51 value = value.strip() 52 53 return value
Usage docs: https://docs.pydantic.dev/2.6/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of classvars defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The signature for instantiating the model.
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a
RootModel
. - __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_extra__: An instance attribute with the values of extra fields from validation when
model_config['extra'] == 'allow'
. - __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
- __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
@field_validator('indices', mode='before')
@classmethod
def
convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
35 @field_validator("indices", mode="before") 36 @classmethod 37 def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]: 38 """Pre to convert indices to list if not None""" 39 40 if value is not None and not isinstance(value, list): 41 value = list(value) 42 43 return value
Pre to convert indices to list if not None
@field_validator('inequality', mode='before')
@classmethod
def
trim_whitespace(cls, value: str) -> str:
45 @field_validator("inequality", mode="before") 46 @classmethod 47 def trim_whitespace(cls, value: str) -> str: 48 """Trims whitespace from inequality signs""" 49 50 if value is not None: 51 value = value.strip() 52 53 return value
Trims whitespace from inequality signs
model_fields =
{'label': FieldInfo(annotation=str, required=True), 'column_name': FieldInfo(annotation=Union[str, NoneType], required=False), 'column_value': FieldInfo(annotation=Union[str, float, int, Timestamp, NoneType], required=False), 'inequality': FieldInfo(annotation=Union[str, NoneType], required=False), 'start': FieldInfo(annotation=Union[int, NoneType], required=False), 'stop': FieldInfo(annotation=Union[int, NoneType], required=False), 'indices': FieldInfo(annotation=Union[List[int], NoneType], required=False)}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
class
DataSplitterBase:
56class DataSplitterBase: 57 def __init__( 58 self, 59 split: DataSplit, 60 dependent_vars: List[Union[int, str]], 61 ): 62 self.split = split 63 self.dependent_vars = dependent_vars 64 65 @property 66 def column_name(self) -> str: 67 if self.split.column_name is not None: 68 return self.split.column_name 69 70 raise ValueError("Column name was not provided") 71 72 @property 73 def column_value(self) -> Any: 74 if self.split.column_value is not None: 75 return self.split.column_value 76 77 raise ValueError("Column value was not provided") 78 79 @property 80 def indices(self) -> List[int]: 81 if self.split.indices is not None: 82 return self.split.indices 83 raise ValueError("List of indices was not provided") 84 85 @property 86 def start(self) -> int: 87 if self.split.start is not None: 88 return self.split.start 89 raise ValueError("Start index was not provided") 90 91 @property 92 def stop(self) -> int: 93 if self.split.stop is not None: 94 return self.split.stop 95 raise ValueError("Stop index was not provided") 96 97 def get_x_cols(self, columns: List[str], dependent_vars: List[Union[str, int]]) -> List[str]: 98 for var in dependent_vars: 99 if isinstance(var, str): 100 columns.remove(var) 101 102 return columns 103 104 def create_split(self, data: Any) -> Tuple[str, Data]: 105 raise NotImplementedError 106 107 @staticmethod 108 def validate(data_type: str, split: DataSplit) -> bool: 109 raise NotImplementedError
DataSplitterBase( split: DataSplit, dependent_vars: List[Union[int, str]])
112class PolarsColumnSplitter(DataSplitterBase): 113 """Column splitter for Polars dataframe""" 114 115 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 116 if self.split.inequality is None: 117 data = data.filter(pl.col(self.column_name) == self.column_value) 118 119 elif self.split.inequality == ">": 120 data = data.filter(pl.col(self.column_name) > self.column_value) 121 122 elif self.split.inequality == ">=": 123 data = data.filter(pl.col(self.column_name) >= self.column_value) 124 125 elif self.split.inequality == "<": 126 data = data.filter(pl.col(self.column_name) < self.column_value) 127 128 else: 129 data = data.filter(pl.col(self.column_name) <= self.column_value) 130 131 if bool(self.dependent_vars): 132 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 133 134 return self.split.label, Data( 135 X=data.select(x_cols), 136 y=data.select(self.dependent_vars), 137 ) 138 139 return self.split.label, Data(X=data) 140 141 @staticmethod 142 def validate(data_type: str, split: DataSplit) -> bool: 143 return data_type == AllowedDataType.POLARS and split.column_name is not None
Column splitter for Polars dataframe
115 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 116 if self.split.inequality is None: 117 data = data.filter(pl.col(self.column_name) == self.column_value) 118 119 elif self.split.inequality == ">": 120 data = data.filter(pl.col(self.column_name) > self.column_value) 121 122 elif self.split.inequality == ">=": 123 data = data.filter(pl.col(self.column_name) >= self.column_value) 124 125 elif self.split.inequality == "<": 126 data = data.filter(pl.col(self.column_name) < self.column_value) 127 128 else: 129 data = data.filter(pl.col(self.column_name) <= self.column_value) 130 131 if bool(self.dependent_vars): 132 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 133 134 return self.split.label, Data( 135 X=data.select(x_cols), 136 y=data.select(self.dependent_vars), 137 ) 138 139 return self.split.label, Data(X=data)
Inherited Members
146class PolarsIndexSplitter(DataSplitterBase): 147 """Split Polars DataFrame by rows index""" 148 149 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 150 # slice 151 data = data[self.indices] 152 153 if bool(self.dependent_vars): 154 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 155 156 return self.split.label, Data( 157 X=data.select(x_cols), 158 y=data.select(self.dependent_vars), 159 ) 160 161 return self.split.label, Data(X=data) 162 163 @staticmethod 164 def validate(data_type: str, split: DataSplit) -> bool: 165 return data_type == AllowedDataType.POLARS and split.indices is not None
Split Polars DataFrame by rows index
149 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 150 # slice 151 data = data[self.indices] 152 153 if bool(self.dependent_vars): 154 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 155 156 return self.split.label, Data( 157 X=data.select(x_cols), 158 y=data.select(self.dependent_vars), 159 ) 160 161 return self.split.label, Data(X=data)
Inherited Members
168class PolarsRowsSplitter(DataSplitterBase): 169 """Split Polars DataFrame by rows slice""" 170 171 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 172 # slice 173 data = data[self.start : self.stop] 174 175 if bool(self.dependent_vars): 176 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 177 178 return self.split.label, Data( 179 X=data.select(x_cols), 180 y=data.select(self.dependent_vars), 181 ) 182 183 return self.split.label, Data(X=data) 184 185 @staticmethod 186 def validate(data_type: str, split: DataSplit) -> bool: 187 return data_type == AllowedDataType.POLARS and split.start is not None
Split Polars DataFrame by rows slice
171 def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]: 172 # slice 173 data = data[self.start : self.stop] 174 175 if bool(self.dependent_vars): 176 x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars) 177 178 return self.split.label, Data( 179 X=data.select(x_cols), 180 y=data.select(self.dependent_vars), 181 ) 182 183 return self.split.label, Data(X=data)
Inherited Members
190class PandasIndexSplitter(DataSplitterBase): 191 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 192 data = data.iloc[self.indices] 193 194 if bool(self.dependent_vars): 195 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 196 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 197 198 return self.split.label, Data(X=x, y=y) 199 200 return self.split.label, Data(X=data) 201 202 @staticmethod 203 def validate(data_type: str, split: DataSplit) -> bool: 204 return data_type == AllowedDataType.PANDAS and split.indices is not None
191 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 192 data = data.iloc[self.indices] 193 194 if bool(self.dependent_vars): 195 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 196 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 197 198 return self.split.label, Data(X=x, y=y) 199 200 return self.split.label, Data(X=data)
Inherited Members
207class PandasRowSplitter(DataSplitterBase): 208 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 209 # slice 210 data = data[self.start : self.stop] 211 212 if bool(self.dependent_vars): 213 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 214 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 215 216 return self.split.label, Data(X=x, y=y) 217 218 return self.split.label, Data(X=data) 219 220 @staticmethod 221 def validate(data_type: str, split: DataSplit) -> bool: 222 return data_type == AllowedDataType.PANDAS and split.start is not None
208 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 209 # slice 210 data = data[self.start : self.stop] 211 212 if bool(self.dependent_vars): 213 x = data[data.columns[~data.columns.isin(self.dependent_vars)]] 214 y = data[data.columns[data.columns.isin(self.dependent_vars)]] 215 216 return self.split.label, Data(X=x, y=y) 217 218 return self.split.label, Data(X=data)
Inherited Members
225class PandasColumnSplitter(DataSplitterBase): 226 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 227 if self.split.inequality is None: 228 data = data[data[self.column_name] == self.column_value] 229 230 elif self.split.inequality == ">": 231 data = data[data[self.column_name] > self.column_value] 232 233 elif self.split.inequality == ">=": 234 data = data[data[self.column_name] >= self.column_value] 235 236 elif self.split.inequality == "<": 237 data = data[data[self.column_name] < self.column_value] 238 239 else: 240 data = data[data[self.column_name] <= self.column_value] 241 242 if bool(self.dependent_vars): 243 return self.split.label, Data( 244 X=data[data.columns[~data.columns.isin(self.dependent_vars)]], 245 y=data[data.columns[data.columns.isin(self.dependent_vars)]], 246 ) 247 248 data_split = Data(X=data) 249 return self.split.label, data_split 250 251 @staticmethod 252 def validate(data_type: str, split: DataSplit) -> bool: 253 return data_type == AllowedDataType.PANDAS and split.column_name is not None
226 def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]: 227 if self.split.inequality is None: 228 data = data[data[self.column_name] == self.column_value] 229 230 elif self.split.inequality == ">": 231 data = data[data[self.column_name] > self.column_value] 232 233 elif self.split.inequality == ">=": 234 data = data[data[self.column_name] >= self.column_value] 235 236 elif self.split.inequality == "<": 237 data = data[data[self.column_name] < self.column_value] 238 239 else: 240 data = data[data[self.column_name] <= self.column_value] 241 242 if bool(self.dependent_vars): 243 return self.split.label, Data( 244 X=data[data.columns[~data.columns.isin(self.dependent_vars)]], 245 y=data[data.columns[data.columns.isin(self.dependent_vars)]], 246 ) 247 248 data_split = Data(X=data) 249 return self.split.label, data_split
Inherited Members
256class PyArrowIndexSplitter(DataSplitterBase): 257 def create_split(self, data: pa.Table) -> Tuple[str, Data]: 258 return self.split.label, Data(X=data.take(self.indices)) 259 260 @staticmethod 261 def validate(data_type: str, split: DataSplit) -> bool: 262 return data_type == AllowedDataType.PYARROW and split.indices is not None
Inherited Members
265class NumpyIndexSplitter(DataSplitterBase): 266 def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]: 267 return self.split.label, Data(X=data[self.indices]) 268 269 @staticmethod 270 def validate(data_type: str, split: DataSplit) -> bool: 271 return data_type == AllowedDataType.NUMPY and split.indices is not None
def
create_split( self, data: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Tuple[str, Data]:
Inherited Members
274class NumpyRowSplitter(DataSplitterBase): 275 def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]: 276 data_split = data[self.start : self.stop] 277 return self.split.label, Data(X=data_split) 278 279 @staticmethod 280 def validate(data_type: str, split: DataSplit) -> bool: 281 return data_type == AllowedDataType.NUMPY and split.start is not None
def
create_split( self, data: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Tuple[str, Data]:
Inherited Members
class
DataSplitter:
284class DataSplitter: 285 @staticmethod 286 def split( 287 split: DataSplit, 288 data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame], 289 data_type: str, 290 dependent_vars: List[Union[int, str]], 291 ) -> Tuple[str, Data]: 292 data_splitter = next( 293 ( 294 data_splitter 295 for data_splitter in DataSplitterBase.__subclasses__() 296 if data_splitter.validate( 297 data_type=data_type, 298 split=split, 299 ) 300 ), 301 None, 302 ) 303 304 if data_splitter is not None: 305 return data_splitter( 306 split=split, 307 dependent_vars=dependent_vars, 308 ).create_split(data=data) 309 310 raise ValueError("Failed to find data supporter that supports provided logic")
@staticmethod
def
split( split: DataSplit, data: Union[pandas.core.frame.DataFrame, numpy.ndarray[Any, numpy.dtype[Any]], polars.dataframe.frame.DataFrame], data_type: str, dependent_vars: List[Union[int, str]]) -> Tuple[str, Data]:
285 @staticmethod 286 def split( 287 split: DataSplit, 288 data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame], 289 data_type: str, 290 dependent_vars: List[Union[int, str]], 291 ) -> Tuple[str, Data]: 292 data_splitter = next( 293 ( 294 data_splitter 295 for data_splitter in DataSplitterBase.__subclasses__() 296 if data_splitter.validate( 297 data_type=data_type, 298 split=split, 299 ) 300 ), 301 None, 302 ) 303 304 if data_splitter is not None: 305 return data_splitter( 306 split=split, 307 dependent_vars=dependent_vars, 308 ).create_split(data=data) 309 310 raise ValueError("Failed to find data supporter that supports provided logic")