opsml.data.splitter

  1# pylint: disable=invalid-name
  2# Copyright (c) Shipt, Inc.
  3# This source code is licensed under the MIT license found in the
  4# LICENSE file in the root directory of this source tree.
  5from dataclasses import dataclass
  6from typing import Any, List, Optional, Tuple, Union
  7
  8import pandas as pd
  9import polars as pl
 10import pyarrow as pa
 11from numpy.typing import NDArray
 12from pydantic import BaseModel, ConfigDict, field_validator
 13
 14from opsml.types import AllowedDataType
 15
 16
 17@dataclass
 18class Data:
 19    X: Any
 20    y: Optional[Any] = None
 21
 22
 23class DataSplit(BaseModel):
 24    model_config = ConfigDict(arbitrary_types_allowed=True)
 25
 26    label: str
 27    column_name: Optional[str] = None
 28    column_value: Optional[Union[str, float, int, pd.Timestamp]] = None
 29    inequality: Optional[str] = None
 30    start: Optional[int] = None
 31    stop: Optional[int] = None
 32    indices: Optional[List[int]] = None
 33
 34    @field_validator("indices", mode="before")
 35    @classmethod
 36    def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
 37        """Pre to convert indices to list if not None"""
 38
 39        if value is not None and not isinstance(value, list):
 40            value = list(value)
 41
 42        return value
 43
 44    @field_validator("inequality", mode="before")
 45    @classmethod
 46    def trim_whitespace(cls, value: str) -> str:
 47        """Trims whitespace from inequality signs"""
 48
 49        if value is not None:
 50            value = value.strip()
 51
 52        return value
 53
 54
 55class DataSplitterBase:
 56    def __init__(
 57        self,
 58        split: DataSplit,
 59        dependent_vars: List[Union[int, str]],
 60    ):
 61        self.split = split
 62        self.dependent_vars = dependent_vars
 63
 64    @property
 65    def column_name(self) -> str:
 66        if self.split.column_name is not None:
 67            return self.split.column_name
 68
 69        raise ValueError("Column name was not provided")
 70
 71    @property
 72    def column_value(self) -> Any:
 73        if self.split.column_value is not None:
 74            return self.split.column_value
 75
 76        raise ValueError("Column value was not provided")
 77
 78    @property
 79    def indices(self) -> List[int]:
 80        if self.split.indices is not None:
 81            return self.split.indices
 82        raise ValueError("List of indices was not provided")
 83
 84    @property
 85    def start(self) -> int:
 86        if self.split.start is not None:
 87            return self.split.start
 88        raise ValueError("Start index was not provided")
 89
 90    @property
 91    def stop(self) -> int:
 92        if self.split.stop is not None:
 93            return self.split.stop
 94        raise ValueError("Stop index was not provided")
 95
 96    def get_x_cols(self, columns: List[str], dependent_vars: List[Union[str, int]]) -> List[str]:
 97        for var in dependent_vars:
 98            if isinstance(var, str):
 99                columns.remove(var)
100
101        return columns
102
103    def create_split(self, data: Any) -> Tuple[str, Data]:
104        raise NotImplementedError
105
106    @staticmethod
107    def validate(data_type: str, split: DataSplit) -> bool:
108        raise NotImplementedError
109
110
111class PolarsColumnSplitter(DataSplitterBase):
112    """Column splitter for Polars dataframe"""
113
114    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
115        if self.split.inequality is None:
116            data = data.filter(pl.col(self.column_name) == self.column_value)
117
118        elif self.split.inequality == ">":
119            data = data.filter(pl.col(self.column_name) > self.column_value)
120
121        elif self.split.inequality == ">=":
122            data = data.filter(pl.col(self.column_name) >= self.column_value)
123
124        elif self.split.inequality == "<":
125            data = data.filter(pl.col(self.column_name) < self.column_value)
126
127        else:
128            data = data.filter(pl.col(self.column_name) <= self.column_value)
129
130        if bool(self.dependent_vars):
131            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
132
133            return self.split.label, Data(
134                X=data.select(x_cols),
135                y=data.select(self.dependent_vars),
136            )
137
138        return self.split.label, Data(X=data)
139
140    @staticmethod
141    def validate(data_type: str, split: DataSplit) -> bool:
142        return data_type == AllowedDataType.POLARS and split.column_name is not None
143
144
145class PolarsIndexSplitter(DataSplitterBase):
146    """Split Polars DataFrame by rows index"""
147
148    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
149        # slice
150        data = data[self.indices]
151
152        if bool(self.dependent_vars):
153            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
154
155            return self.split.label, Data(
156                X=data.select(x_cols),
157                y=data.select(self.dependent_vars),
158            )
159
160        return self.split.label, Data(X=data)
161
162    @staticmethod
163    def validate(data_type: str, split: DataSplit) -> bool:
164        return data_type == AllowedDataType.POLARS and split.indices is not None
165
166
167class PolarsRowsSplitter(DataSplitterBase):
168    """Split Polars DataFrame by rows slice"""
169
170    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
171        # slice
172        data = data[self.start : self.stop]
173
174        if bool(self.dependent_vars):
175            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
176
177            return self.split.label, Data(
178                X=data.select(x_cols),
179                y=data.select(self.dependent_vars),
180            )
181
182        return self.split.label, Data(X=data)
183
184    @staticmethod
185    def validate(data_type: str, split: DataSplit) -> bool:
186        return data_type == AllowedDataType.POLARS and split.start is not None
187
188
189class PandasIndexSplitter(DataSplitterBase):
190    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
191        data = data.iloc[self.indices]
192
193        if bool(self.dependent_vars):
194            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
195            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
196
197            return self.split.label, Data(X=x, y=y)
198
199        return self.split.label, Data(X=data)
200
201    @staticmethod
202    def validate(data_type: str, split: DataSplit) -> bool:
203        return data_type == AllowedDataType.PANDAS and split.indices is not None
204
205
206class PandasRowSplitter(DataSplitterBase):
207    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
208        # slice
209        data = data[self.start : self.stop]
210
211        if bool(self.dependent_vars):
212            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
213            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
214
215            return self.split.label, Data(X=x, y=y)
216
217        return self.split.label, Data(X=data)
218
219    @staticmethod
220    def validate(data_type: str, split: DataSplit) -> bool:
221        return data_type == AllowedDataType.PANDAS and split.start is not None
222
223
224class PandasColumnSplitter(DataSplitterBase):
225    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
226        if self.split.inequality is None:
227            data = data[data[self.column_name] == self.column_value]
228
229        elif self.split.inequality == ">":
230            data = data[data[self.column_name] > self.column_value]
231
232        elif self.split.inequality == ">=":
233            data = data[data[self.column_name] >= self.column_value]
234
235        elif self.split.inequality == "<":
236            data = data[data[self.column_name] < self.column_value]
237
238        else:
239            data = data[data[self.column_name] <= self.column_value]
240
241        if bool(self.dependent_vars):
242            return self.split.label, Data(
243                X=data[data.columns[~data.columns.isin(self.dependent_vars)]],
244                y=data[data.columns[data.columns.isin(self.dependent_vars)]],
245            )
246
247        data_split = Data(X=data)
248        return self.split.label, data_split
249
250    @staticmethod
251    def validate(data_type: str, split: DataSplit) -> bool:
252        return data_type == AllowedDataType.PANDAS and split.column_name is not None
253
254
255class PyArrowIndexSplitter(DataSplitterBase):
256    def create_split(self, data: pa.Table) -> Tuple[str, Data]:
257        return self.split.label, Data(X=data.take(self.indices))
258
259    @staticmethod
260    def validate(data_type: str, split: DataSplit) -> bool:
261        return data_type == AllowedDataType.PYARROW and split.indices is not None
262
263
264class NumpyIndexSplitter(DataSplitterBase):
265    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
266        return self.split.label, Data(X=data[self.indices])
267
268    @staticmethod
269    def validate(data_type: str, split: DataSplit) -> bool:
270        return data_type == AllowedDataType.NUMPY and split.indices is not None
271
272
273class NumpyRowSplitter(DataSplitterBase):
274    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
275        data_split = data[self.start : self.stop]
276        return self.split.label, Data(X=data_split)
277
278    @staticmethod
279    def validate(data_type: str, split: DataSplit) -> bool:
280        return data_type == AllowedDataType.NUMPY and split.start is not None
281
282
283class DataSplitter:
284    @staticmethod
285    def split(
286        split: DataSplit,
287        data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame],
288        data_type: str,
289        dependent_vars: List[Union[int, str]],
290    ) -> Tuple[str, Data]:
291        data_splitter = next(
292            (
293                data_splitter
294                for data_splitter in DataSplitterBase.__subclasses__()
295                if data_splitter.validate(
296                    data_type=data_type,
297                    split=split,
298                )
299            ),
300            None,
301        )
302
303        if data_splitter is not None:
304            return data_splitter(
305                split=split,
306                dependent_vars=dependent_vars,
307            ).create_split(data=data)
308
309        raise ValueError("Failed to find data supporter that supports provided logic")
@dataclass
class Data:
18@dataclass
19class Data:
20    X: Any
21    y: Optional[Any] = None
Data(X: Any, y: Optional[Any] = None)
X: Any
y: Optional[Any] = None
class DataSplit(pydantic.main.BaseModel):
24class DataSplit(BaseModel):
25    model_config = ConfigDict(arbitrary_types_allowed=True)
26
27    label: str
28    column_name: Optional[str] = None
29    column_value: Optional[Union[str, float, int, pd.Timestamp]] = None
30    inequality: Optional[str] = None
31    start: Optional[int] = None
32    stop: Optional[int] = None
33    indices: Optional[List[int]] = None
34
35    @field_validator("indices", mode="before")
36    @classmethod
37    def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
38        """Pre to convert indices to list if not None"""
39
40        if value is not None and not isinstance(value, list):
41            value = list(value)
42
43        return value
44
45    @field_validator("inequality", mode="before")
46    @classmethod
47    def trim_whitespace(cls, value: str) -> str:
48        """Trims whitespace from inequality signs"""
49
50        if value is not None:
51            value = value.strip()
52
53        return value

Usage docs: https://docs.pydantic.dev/2.6/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of classvars defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The signature for instantiating the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a RootModel.
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_extra__: An instance attribute with the values of extra fields from validation when model_config['extra'] == 'allow'.
  • __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
  • __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
model_config = {'arbitrary_types_allowed': True}
label: str
column_name: Optional[str]
column_value: Union[str, float, int, pandas._libs.tslibs.timestamps.Timestamp, NoneType]
inequality: Optional[str]
start: Optional[int]
stop: Optional[int]
indices: Optional[List[int]]
@field_validator('indices', mode='before')
@classmethod
def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
35    @field_validator("indices", mode="before")
36    @classmethod
37    def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
38        """Pre to convert indices to list if not None"""
39
40        if value is not None and not isinstance(value, list):
41            value = list(value)
42
43        return value

Pre to convert indices to list if not None

@field_validator('inequality', mode='before')
@classmethod
def trim_whitespace(cls, value: str) -> str:
45    @field_validator("inequality", mode="before")
46    @classmethod
47    def trim_whitespace(cls, value: str) -> str:
48        """Trims whitespace from inequality signs"""
49
50        if value is not None:
51            value = value.strip()
52
53        return value

Trims whitespace from inequality signs

model_fields = {'label': FieldInfo(annotation=str, required=True), 'column_name': FieldInfo(annotation=Union[str, NoneType], required=False), 'column_value': FieldInfo(annotation=Union[str, float, int, Timestamp, NoneType], required=False), 'inequality': FieldInfo(annotation=Union[str, NoneType], required=False), 'start': FieldInfo(annotation=Union[int, NoneType], required=False), 'stop': FieldInfo(annotation=Union[int, NoneType], required=False), 'indices': FieldInfo(annotation=Union[List[int], NoneType], required=False)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
class DataSplitterBase:
 56class DataSplitterBase:
 57    def __init__(
 58        self,
 59        split: DataSplit,
 60        dependent_vars: List[Union[int, str]],
 61    ):
 62        self.split = split
 63        self.dependent_vars = dependent_vars
 64
 65    @property
 66    def column_name(self) -> str:
 67        if self.split.column_name is not None:
 68            return self.split.column_name
 69
 70        raise ValueError("Column name was not provided")
 71
 72    @property
 73    def column_value(self) -> Any:
 74        if self.split.column_value is not None:
 75            return self.split.column_value
 76
 77        raise ValueError("Column value was not provided")
 78
 79    @property
 80    def indices(self) -> List[int]:
 81        if self.split.indices is not None:
 82            return self.split.indices
 83        raise ValueError("List of indices was not provided")
 84
 85    @property
 86    def start(self) -> int:
 87        if self.split.start is not None:
 88            return self.split.start
 89        raise ValueError("Start index was not provided")
 90
 91    @property
 92    def stop(self) -> int:
 93        if self.split.stop is not None:
 94            return self.split.stop
 95        raise ValueError("Stop index was not provided")
 96
 97    def get_x_cols(self, columns: List[str], dependent_vars: List[Union[str, int]]) -> List[str]:
 98        for var in dependent_vars:
 99            if isinstance(var, str):
100                columns.remove(var)
101
102        return columns
103
104    def create_split(self, data: Any) -> Tuple[str, Data]:
105        raise NotImplementedError
106
107    @staticmethod
108    def validate(data_type: str, split: DataSplit) -> bool:
109        raise NotImplementedError
DataSplitterBase( split: DataSplit, dependent_vars: List[Union[int, str]])
57    def __init__(
58        self,
59        split: DataSplit,
60        dependent_vars: List[Union[int, str]],
61    ):
62        self.split = split
63        self.dependent_vars = dependent_vars
split
dependent_vars
column_name: str
65    @property
66    def column_name(self) -> str:
67        if self.split.column_name is not None:
68            return self.split.column_name
69
70        raise ValueError("Column name was not provided")
column_value: Any
72    @property
73    def column_value(self) -> Any:
74        if self.split.column_value is not None:
75            return self.split.column_value
76
77        raise ValueError("Column value was not provided")
indices: List[int]
79    @property
80    def indices(self) -> List[int]:
81        if self.split.indices is not None:
82            return self.split.indices
83        raise ValueError("List of indices was not provided")
start: int
85    @property
86    def start(self) -> int:
87        if self.split.start is not None:
88            return self.split.start
89        raise ValueError("Start index was not provided")
stop: int
91    @property
92    def stop(self) -> int:
93        if self.split.stop is not None:
94            return self.split.stop
95        raise ValueError("Stop index was not provided")
def get_x_cols( self, columns: List[str], dependent_vars: List[Union[int, str]]) -> List[str]:
 97    def get_x_cols(self, columns: List[str], dependent_vars: List[Union[str, int]]) -> List[str]:
 98        for var in dependent_vars:
 99            if isinstance(var, str):
100                columns.remove(var)
101
102        return columns
def create_split(self, data: Any) -> Tuple[str, Data]:
104    def create_split(self, data: Any) -> Tuple[str, Data]:
105        raise NotImplementedError
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
107    @staticmethod
108    def validate(data_type: str, split: DataSplit) -> bool:
109        raise NotImplementedError
class PolarsColumnSplitter(DataSplitterBase):
112class PolarsColumnSplitter(DataSplitterBase):
113    """Column splitter for Polars dataframe"""
114
115    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
116        if self.split.inequality is None:
117            data = data.filter(pl.col(self.column_name) == self.column_value)
118
119        elif self.split.inequality == ">":
120            data = data.filter(pl.col(self.column_name) > self.column_value)
121
122        elif self.split.inequality == ">=":
123            data = data.filter(pl.col(self.column_name) >= self.column_value)
124
125        elif self.split.inequality == "<":
126            data = data.filter(pl.col(self.column_name) < self.column_value)
127
128        else:
129            data = data.filter(pl.col(self.column_name) <= self.column_value)
130
131        if bool(self.dependent_vars):
132            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
133
134            return self.split.label, Data(
135                X=data.select(x_cols),
136                y=data.select(self.dependent_vars),
137            )
138
139        return self.split.label, Data(X=data)
140
141    @staticmethod
142    def validate(data_type: str, split: DataSplit) -> bool:
143        return data_type == AllowedDataType.POLARS and split.column_name is not None

Column splitter for Polars dataframe

def create_split( self, data: polars.dataframe.frame.DataFrame) -> Tuple[str, Data]:
115    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
116        if self.split.inequality is None:
117            data = data.filter(pl.col(self.column_name) == self.column_value)
118
119        elif self.split.inequality == ">":
120            data = data.filter(pl.col(self.column_name) > self.column_value)
121
122        elif self.split.inequality == ">=":
123            data = data.filter(pl.col(self.column_name) >= self.column_value)
124
125        elif self.split.inequality == "<":
126            data = data.filter(pl.col(self.column_name) < self.column_value)
127
128        else:
129            data = data.filter(pl.col(self.column_name) <= self.column_value)
130
131        if bool(self.dependent_vars):
132            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
133
134            return self.split.label, Data(
135                X=data.select(x_cols),
136                y=data.select(self.dependent_vars),
137            )
138
139        return self.split.label, Data(X=data)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
141    @staticmethod
142    def validate(data_type: str, split: DataSplit) -> bool:
143        return data_type == AllowedDataType.POLARS and split.column_name is not None
class PolarsIndexSplitter(DataSplitterBase):
146class PolarsIndexSplitter(DataSplitterBase):
147    """Split Polars DataFrame by rows index"""
148
149    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
150        # slice
151        data = data[self.indices]
152
153        if bool(self.dependent_vars):
154            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
155
156            return self.split.label, Data(
157                X=data.select(x_cols),
158                y=data.select(self.dependent_vars),
159            )
160
161        return self.split.label, Data(X=data)
162
163    @staticmethod
164    def validate(data_type: str, split: DataSplit) -> bool:
165        return data_type == AllowedDataType.POLARS and split.indices is not None

Split Polars DataFrame by rows index

def create_split( self, data: polars.dataframe.frame.DataFrame) -> Tuple[str, Data]:
149    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
150        # slice
151        data = data[self.indices]
152
153        if bool(self.dependent_vars):
154            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
155
156            return self.split.label, Data(
157                X=data.select(x_cols),
158                y=data.select(self.dependent_vars),
159            )
160
161        return self.split.label, Data(X=data)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
163    @staticmethod
164    def validate(data_type: str, split: DataSplit) -> bool:
165        return data_type == AllowedDataType.POLARS and split.indices is not None
class PolarsRowsSplitter(DataSplitterBase):
168class PolarsRowsSplitter(DataSplitterBase):
169    """Split Polars DataFrame by rows slice"""
170
171    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
172        # slice
173        data = data[self.start : self.stop]
174
175        if bool(self.dependent_vars):
176            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
177
178            return self.split.label, Data(
179                X=data.select(x_cols),
180                y=data.select(self.dependent_vars),
181            )
182
183        return self.split.label, Data(X=data)
184
185    @staticmethod
186    def validate(data_type: str, split: DataSplit) -> bool:
187        return data_type == AllowedDataType.POLARS and split.start is not None

Split Polars DataFrame by rows slice

def create_split( self, data: polars.dataframe.frame.DataFrame) -> Tuple[str, Data]:
171    def create_split(self, data: pl.DataFrame) -> Tuple[str, Data]:
172        # slice
173        data = data[self.start : self.stop]
174
175        if bool(self.dependent_vars):
176            x_cols = self.get_x_cols(columns=data.columns, dependent_vars=self.dependent_vars)
177
178            return self.split.label, Data(
179                X=data.select(x_cols),
180                y=data.select(self.dependent_vars),
181            )
182
183        return self.split.label, Data(X=data)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
185    @staticmethod
186    def validate(data_type: str, split: DataSplit) -> bool:
187        return data_type == AllowedDataType.POLARS and split.start is not None
class PandasIndexSplitter(DataSplitterBase):
190class PandasIndexSplitter(DataSplitterBase):
191    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
192        data = data.iloc[self.indices]
193
194        if bool(self.dependent_vars):
195            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
196            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
197
198            return self.split.label, Data(X=x, y=y)
199
200        return self.split.label, Data(X=data)
201
202    @staticmethod
203    def validate(data_type: str, split: DataSplit) -> bool:
204        return data_type == AllowedDataType.PANDAS and split.indices is not None
def create_split( self, data: pandas.core.frame.DataFrame) -> Tuple[str, Data]:
191    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
192        data = data.iloc[self.indices]
193
194        if bool(self.dependent_vars):
195            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
196            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
197
198            return self.split.label, Data(X=x, y=y)
199
200        return self.split.label, Data(X=data)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
202    @staticmethod
203    def validate(data_type: str, split: DataSplit) -> bool:
204        return data_type == AllowedDataType.PANDAS and split.indices is not None
class PandasRowSplitter(DataSplitterBase):
207class PandasRowSplitter(DataSplitterBase):
208    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
209        # slice
210        data = data[self.start : self.stop]
211
212        if bool(self.dependent_vars):
213            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
214            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
215
216            return self.split.label, Data(X=x, y=y)
217
218        return self.split.label, Data(X=data)
219
220    @staticmethod
221    def validate(data_type: str, split: DataSplit) -> bool:
222        return data_type == AllowedDataType.PANDAS and split.start is not None
def create_split( self, data: pandas.core.frame.DataFrame) -> Tuple[str, Data]:
208    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
209        # slice
210        data = data[self.start : self.stop]
211
212        if bool(self.dependent_vars):
213            x = data[data.columns[~data.columns.isin(self.dependent_vars)]]
214            y = data[data.columns[data.columns.isin(self.dependent_vars)]]
215
216            return self.split.label, Data(X=x, y=y)
217
218        return self.split.label, Data(X=data)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
220    @staticmethod
221    def validate(data_type: str, split: DataSplit) -> bool:
222        return data_type == AllowedDataType.PANDAS and split.start is not None
class PandasColumnSplitter(DataSplitterBase):
225class PandasColumnSplitter(DataSplitterBase):
226    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
227        if self.split.inequality is None:
228            data = data[data[self.column_name] == self.column_value]
229
230        elif self.split.inequality == ">":
231            data = data[data[self.column_name] > self.column_value]
232
233        elif self.split.inequality == ">=":
234            data = data[data[self.column_name] >= self.column_value]
235
236        elif self.split.inequality == "<":
237            data = data[data[self.column_name] < self.column_value]
238
239        else:
240            data = data[data[self.column_name] <= self.column_value]
241
242        if bool(self.dependent_vars):
243            return self.split.label, Data(
244                X=data[data.columns[~data.columns.isin(self.dependent_vars)]],
245                y=data[data.columns[data.columns.isin(self.dependent_vars)]],
246            )
247
248        data_split = Data(X=data)
249        return self.split.label, data_split
250
251    @staticmethod
252    def validate(data_type: str, split: DataSplit) -> bool:
253        return data_type == AllowedDataType.PANDAS and split.column_name is not None
def create_split( self, data: pandas.core.frame.DataFrame) -> Tuple[str, Data]:
226    def create_split(self, data: pd.DataFrame) -> Tuple[str, Data]:
227        if self.split.inequality is None:
228            data = data[data[self.column_name] == self.column_value]
229
230        elif self.split.inequality == ">":
231            data = data[data[self.column_name] > self.column_value]
232
233        elif self.split.inequality == ">=":
234            data = data[data[self.column_name] >= self.column_value]
235
236        elif self.split.inequality == "<":
237            data = data[data[self.column_name] < self.column_value]
238
239        else:
240            data = data[data[self.column_name] <= self.column_value]
241
242        if bool(self.dependent_vars):
243            return self.split.label, Data(
244                X=data[data.columns[~data.columns.isin(self.dependent_vars)]],
245                y=data[data.columns[data.columns.isin(self.dependent_vars)]],
246            )
247
248        data_split = Data(X=data)
249        return self.split.label, data_split
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
251    @staticmethod
252    def validate(data_type: str, split: DataSplit) -> bool:
253        return data_type == AllowedDataType.PANDAS and split.column_name is not None
class PyArrowIndexSplitter(DataSplitterBase):
256class PyArrowIndexSplitter(DataSplitterBase):
257    def create_split(self, data: pa.Table) -> Tuple[str, Data]:
258        return self.split.label, Data(X=data.take(self.indices))
259
260    @staticmethod
261    def validate(data_type: str, split: DataSplit) -> bool:
262        return data_type == AllowedDataType.PYARROW and split.indices is not None
def create_split(self, data: pyarrow.lib.Table) -> Tuple[str, Data]:
257    def create_split(self, data: pa.Table) -> Tuple[str, Data]:
258        return self.split.label, Data(X=data.take(self.indices))
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
260    @staticmethod
261    def validate(data_type: str, split: DataSplit) -> bool:
262        return data_type == AllowedDataType.PYARROW and split.indices is not None
class NumpyIndexSplitter(DataSplitterBase):
265class NumpyIndexSplitter(DataSplitterBase):
266    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
267        return self.split.label, Data(X=data[self.indices])
268
269    @staticmethod
270    def validate(data_type: str, split: DataSplit) -> bool:
271        return data_type == AllowedDataType.NUMPY and split.indices is not None
def create_split( self, data: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Tuple[str, Data]:
266    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
267        return self.split.label, Data(X=data[self.indices])
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
269    @staticmethod
270    def validate(data_type: str, split: DataSplit) -> bool:
271        return data_type == AllowedDataType.NUMPY and split.indices is not None
class NumpyRowSplitter(DataSplitterBase):
274class NumpyRowSplitter(DataSplitterBase):
275    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
276        data_split = data[self.start : self.stop]
277        return self.split.label, Data(X=data_split)
278
279    @staticmethod
280    def validate(data_type: str, split: DataSplit) -> bool:
281        return data_type == AllowedDataType.NUMPY and split.start is not None
def create_split( self, data: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Tuple[str, Data]:
275    def create_split(self, data: NDArray[Any]) -> Tuple[str, Data]:
276        data_split = data[self.start : self.stop]
277        return self.split.label, Data(X=data_split)
@staticmethod
def validate(data_type: str, split: DataSplit) -> bool:
279    @staticmethod
280    def validate(data_type: str, split: DataSplit) -> bool:
281        return data_type == AllowedDataType.NUMPY and split.start is not None
class DataSplitter:
284class DataSplitter:
285    @staticmethod
286    def split(
287        split: DataSplit,
288        data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame],
289        data_type: str,
290        dependent_vars: List[Union[int, str]],
291    ) -> Tuple[str, Data]:
292        data_splitter = next(
293            (
294                data_splitter
295                for data_splitter in DataSplitterBase.__subclasses__()
296                if data_splitter.validate(
297                    data_type=data_type,
298                    split=split,
299                )
300            ),
301            None,
302        )
303
304        if data_splitter is not None:
305            return data_splitter(
306                split=split,
307                dependent_vars=dependent_vars,
308            ).create_split(data=data)
309
310        raise ValueError("Failed to find data supporter that supports provided logic")
@staticmethod
def split( split: DataSplit, data: Union[pandas.core.frame.DataFrame, numpy.ndarray[Any, numpy.dtype[Any]], polars.dataframe.frame.DataFrame], data_type: str, dependent_vars: List[Union[int, str]]) -> Tuple[str, Data]:
285    @staticmethod
286    def split(
287        split: DataSplit,
288        data: Union[pd.DataFrame, NDArray[Any], pl.DataFrame],
289        data_type: str,
290        dependent_vars: List[Union[int, str]],
291    ) -> Tuple[str, Data]:
292        data_splitter = next(
293            (
294                data_splitter
295                for data_splitter in DataSplitterBase.__subclasses__()
296                if data_splitter.validate(
297                    data_type=data_type,
298                    split=split,
299                )
300            ),
301            None,
302        )
303
304        if data_splitter is not None:
305            return data_splitter(
306                split=split,
307                dependent_vars=dependent_vars,
308            ).create_split(data=data)
309
310        raise ValueError("Failed to find data supporter that supports provided logic")