src package
Subpackages
- src.rumboost package
- Submodules
- src.rumboost.datasets module
load_preprocess_Airplane()load_preprocess_LPMC()load_preprocess_MTMC()load_preprocess_MTMC_all()load_preprocess_Netherlands()load_preprocess_Optima()load_preprocess_Parking()load_preprocess_SwissMetro()load_preprocess_Telephone()load_preprocess_Vaccines()prepare_dataset()stratified_group_k_fold()
- src.rumboost.metrics module
- src.rumboost.nested_cross_nested module
- src.rumboost.ordinal module
- src.rumboost.post_process module
- src.rumboost.rumboost module
CVRUMBoostRUMBoostRUMBoost.boostersRUMBoost.valid_setsRUMBoost.f_obj()RUMBoost.f_obj_binary()RUMBoost.f_obj_coral()RUMBoost.f_obj_cross_nested()RUMBoost.f_obj_full_hessian()RUMBoost.f_obj_mse()RUMBoost.f_obj_nest()RUMBoost.f_obj_proportional_odds()RUMBoost.model_from_string()RUMBoost.model_to_string()RUMBoost.multiply_grad_hess_by_data()RUMBoost.predict()RUMBoost.save_model()
rum_cv()rum_train()
- src.rumboost.torch_functions module
- src.rumboost.utility_plotting module
- src.rumboost.utility_smoothing module
- src.rumboost.utils module
- Module contents
RUMBoostRUMBoost.boostersRUMBoost.valid_setsRUMBoost.f_obj()RUMBoost.f_obj_binary()RUMBoost.f_obj_coral()RUMBoost.f_obj_cross_nested()RUMBoost.f_obj_full_hessian()RUMBoost.f_obj_mse()RUMBoost.f_obj_nest()RUMBoost.f_obj_proportional_odds()RUMBoost.model_from_string()RUMBoost.model_to_string()RUMBoost.multiply_grad_hess_by_data()RUMBoost.predict()RUMBoost.save_model()
rum_train()
- src.statistical_models package
- src.tastenet package
Submodules
src.constants module
src.helper module
src.hyperparameter_search module
- src.hyperparameter_search.objective(trial, model, func_int, func_params, dataset)[source]
Optuna objective function for the hyperparameter search.
- Parameters:
trial (optuna.Trial) – The current trial object.
model (str) – The model to train.
func_int (bool) – Whether to use functional intercept.
func_params (bool) – Whether to use functional parameters.
dataset (str) – The dataset to use.
src.main module
src.models_wrapper module
- class src.models_wrapper.DNN(**kwargs)[source]
Bases:
objectWrapper class for DNN model. Only implemented for classification and regression.
- build_dataloader(X_train: DataFrame, y_train: Series, X_valid: DataFrame = None, y_valid: Series = None)[source]
Builds and stores the LightGBM dataset.
- Parameters:
X_train (pd.DataFrame) – Training features.
y_train (pd.Series) – Training target variable.
X_valid (pd.DataFrame, optional) – Validation features. The default is None.
y_valid (pd.Series, optional) – Validation target variable. The default is None.
- fit(**kwargs) tuple[float, float] | tuple[float, None][source]
Fits the model to the training data.
- load_model(path: str)[source]
Loads the model from the specified path.
- Parameters:
path (str) – Path to load the model from.
- predict(X_test: DataFrame, utilities: bool = False, **kwargs) array[source]
Predicts the target variable for the test set.
- Parameters:
X_test (pd.DataFrame) – Test set.
utilities (bool) – Whether to predict utilities or not.
kwargs (dict) – Additional arguments.
- Returns:
preds (np.array) – Predicted target variable, as probabilities.
binary_preds (np.array) – The binary probabilities of the target being bigger than each level.
label_pred (np.array) – Predicted target variable, as labels.
- class src.models_wrapper.GBDT(**kwargs)[source]
Bases:
objectWrapper class for GBDT model. Only implemented for classification and regression.
- build_dataloader(X_train: DataFrame, y_train: Series, X_valid: DataFrame = None, y_valid: Series = None)[source]
Builds and stores the LightGBM dataset. There is no specific dataloader for GBDT.
- Parameters:
X_train (pd.DataFrame) – Training features.
y_train (pd.Series) – Training target variable.
X_valid (pd.DataFrame, optional) – Validation features. The default is None.
y_valid (pd.Series, optional) – Validation target variable. The default is None.
- fit(**kwargs) tuple[float, float] | tuple[float, None][source]
Fits the model to the training data.
- load_model(path: str)[source]
Loads the model from the specified path.
- Parameters:
path (str) – Path to load the model from.
- predict(X_test: DataFrame, utilities: bool = False, **kwargs) array[source]
” Predicts the target variable for the test set.”
- Parameters:
X_test (pd.DataFrame) – Test set.
utilities (bool) – Whether to predict utilities or probabilities. If True, returns raw utility values.
kwargs (dict) – Additional arguments.
- Returns:
preds (np.array) – Predicted target variable, as probabilities.
binary_preds (np.array) – The binary probabilities of the target being bigger than each level.
label_pred (np.array) – Predicted target variable, as labels.
- class src.models_wrapper.MixedEffect(**kwargs)[source]
Bases:
objectWrapper class for Mixed Effect model.
- build_dataloader(X_train: DataFrame, y_train: Series, X_val: DataFrame = None, y_val: Series = None) None[source]
Builds and stores the Biogeme database.
- Parameters:
X_train (pd.DataFrame) – Training features.
y_train (pd.Series) – Training target variable.
X_val (pd.DataFrame, optional) – Validation features. The default is None.
y_val (pd.Series, optional) – Validation target variable. The default is None.
- get_individual_parameters(on_train_set: bool) array[source]
Returns the individual-specific parameters.
- Parameters:
on_train_set (bool) – Whether to get the parameters for the training set or test set.
- Returns:
individual_params – DataFrame containing individual-specific parameters.
- Return type:
pd.DataFrame
- load_model(path: str, **kwargs) None[source]
Loads the model from the specified path.
- Parameters:
path (str) – Path to load the model from.
kwargs (dict) –
Additional arguments, including: - alt_spec_features: dict[int, list[str]]
Dictionary mapping alternative IDs to lists of variable names.
- socio_demo_chars: list[str]
List of socio-demographic characteristic names.
- num_classes: int
Number of classes.
- predict(X_test: DataFrame, utilities: bool = False, **kwargs) array[source]
Predicts the target variable for the test set.
- Parameters:
X_test (pd.DataFrame) – Test set.
utilities (bool) – Whether to predict utilities or not.
kwargs (dict) – Additional arguments.
- Returns:
cel (np.array) – Cross-entropy loss on the test set.
binary_preds (np.array) – The binary probabilities of the target being bigger than each level.
label_pred (np.array) – Predicted target variable, as labels.
- class src.models_wrapper.RUMBoost(**kwargs)[source]
Bases:
objectWrapper class for RUMBoost model.
- build_dataloader(X_train: DataFrame, y_train: Series, X_valid: DataFrame = None, y_valid: Series = None)[source]
Builds and stores the LightGBM dataset. There is no specific dataloader for RUMBoost.
- Parameters:
X_train (pd.DataFrame) – Training features.
y_train (pd.Series) – Training target variable.
X_valid (pd.DataFrame, optional) – Validation features. The default is None.
y_valid (pd.Series, optional) – Validation target variable. The default is None.
- fit(**kwargs) tuple[float, float] | tuple[float, None][source]
Fits the model to the training data.
- load_model(path: str)[source]
Loads the model from the specified path.
- Parameters:
path (str) – Path to load the model from.
- predict(X_test: DataFrame, utilities: bool = False, **kwargs) array[source]
” Predicts the target variable for the test set.”
- Parameters:
X_test (pd.DataFrame) – Test set.
utilities (bool) – Whether to predict utilities or probabilities. If True, returns raw utility values.
kwargs (dict) – Additional arguments.
- Returns:
preds (np.array) – Predicted target variable, as probabilities.
binary_preds (np.array) – The binary probabilities of the target being bigger than each level.
label_pred (np.array) – Predicted target variable, as labels.
- class src.models_wrapper.TasteNet(**kwargs)[source]
Bases:
objectWrapper class for TasteNet model.
- build_dataloader(X_train: DataFrame, y_train: Series, X_valid: DataFrame = None, y_valid: Series = None)[source]
Builds and stores the LightGBM dataset.
- Parameters:
X_train (pd.DataFrame) – Training features.
y_train (pd.Series) – Training target variable.
X_valid (pd.DataFrame, optional) – Validation features. The default is None.
y_valid (pd.Series, optional) – Validation target variable. The default is None.
- fit(**kwargs) tuple[float, float] | tuple[float, None][source]
Fits the model to the training data.
- load_model(path: str)[source]
Loads the model from the specified path.
- Parameters:
path (str) – Path to load the model from.
- predict(X_test: DataFrame, utilities: bool = False, **kwargs) array[source]
Predicts the target variable for the test set.
- Parameters:
X_test (pd.DataFrame) – Test set.
utilities (bool) – Whether to predict utilities or not.
kwargs (dict) – Additional arguments.
- Returns:
preds (np.array) – Predicted target variable, as probabilities.
binary_preds (np.array) – The binary probabilities of the target being bigger than each level.
label_pred (np.array) – Predicted target variable, as labels.
src.parser module
src.plotter module
- src.plotter.plot_alt_spec_features(alt_spec_features: ~typing.List = {'LPMC': {0: ['dur_walking', 'distance', 'day_of_week', 'start_time_linear'], 1: ['dur_cycling', 'distance', 'day_of_week', 'start_time_linear'], 2: ['dur_pt_access', 'dur_pt_rail', 'dur_pt_bus', 'dur_pt_int_waiting', 'dur_pt_int_walking', 'pt_n_interchanges', 'cost_transit', 'distance', 'day_of_week', 'start_time_linear'], 3: ['dur_driving', 'cost_driving_fuel', 'congestion_charge', 'distance', 'driving_traffic_percent', 'day_of_week', 'start_time_linear']}, 'SwissMetro': {0: ['TRAIN_TT', 'TRAIN_HE', 'TRAIN_CO'], 1: ['SM_TT', 'SM_HE', 'SM_CO', 'SM_SEATS'], 2: ['CAR_TT', 'CAR_CO']}, 'easySHARE': {0: ['chronic_mod', 'nb_doctor_visits', 'maxgrip', 'daily_activities_index', 'instrumental_activities_index', 'mobilityind', 'lgmuscle', 'grossmotor', 'finemotor', 'recall_1', 'recall_2', 'bmi', 'sphus_fair', 'sphus_good', 'sphus_poor', 'sphus_very_good', 'hospitalised_last_year_yes', 'nursing_home_last_year_yes_permanently', 'nursing_home_last_year_yes_temporarily']}}, all_models: ~typing.Dict = {'RUMBoost': <class 'models_wrapper.RUMBoost'>, 'TasteNet': <class 'models_wrapper.TasteNet'>}, path_to_data: str = {'LPMC': '../data/LPMC/', 'SwissMetro': '../data/SwissMetro/train.pkl', 'easySHARE': '../data/easySHARE/easySHARE_preprocessed.csv'}, save_fig: bool = False, dataset: str = 'SwissMetro')[source]
Plot the alternative-specific features for the models, if trained without functional parameters.
- Parameters:
alt_spec_features (List) – List of alternative-specific features. They must be in the same order as for the training.
all_models (Dict) – Dictionary of all models.
path_to_data (str) – Path to the data folder.
save_fig (bool) – Whether to save the figure or not.
dataset (str) – Dataset to use. Default is “SwissMetro”.
- src.plotter.plot_ind_spec_constant(alt_spec_features: ~typing.List = {'LPMC': {0: ['dur_walking', 'distance', 'day_of_week', 'start_time_linear'], 1: ['dur_cycling', 'distance', 'day_of_week', 'start_time_linear'], 2: ['dur_pt_access', 'dur_pt_rail', 'dur_pt_bus', 'dur_pt_int_waiting', 'dur_pt_int_walking', 'pt_n_interchanges', 'cost_transit', 'distance', 'day_of_week', 'start_time_linear'], 3: ['dur_driving', 'cost_driving_fuel', 'congestion_charge', 'distance', 'driving_traffic_percent', 'day_of_week', 'start_time_linear']}, 'SwissMetro': {0: ['TRAIN_TT', 'TRAIN_HE', 'TRAIN_CO'], 1: ['SM_TT', 'SM_HE', 'SM_CO', 'SM_SEATS'], 2: ['CAR_TT', 'CAR_CO']}, 'easySHARE': {0: ['chronic_mod', 'nb_doctor_visits', 'maxgrip', 'daily_activities_index', 'instrumental_activities_index', 'mobilityind', 'lgmuscle', 'grossmotor', 'finemotor', 'recall_1', 'recall_2', 'bmi', 'sphus_fair', 'sphus_good', 'sphus_poor', 'sphus_very_good', 'hospitalised_last_year_yes', 'nursing_home_last_year_yes_permanently', 'nursing_home_last_year_yes_temporarily']}}, all_models: ~typing.Dict = {'RUMBoost': <class 'models_wrapper.RUMBoost'>, 'TasteNet': <class 'models_wrapper.TasteNet'>}, path_to_data: str = {'LPMC': '../data/LPMC/', 'SwissMetro': '../data/SwissMetro/train.pkl', 'easySHARE': '../data/easySHARE/easySHARE_preprocessed.csv'}, path_to_data_train: str = {'LPMC': '../data/LPMC/LPMC_train.csv', 'SwissMetro': '../data/SwissMetro/train.pkl', 'easySHARE': '../data/easySHARE/easySHARE_preprocessed_train.csv'}, save_fig: bool = False, feature_to_highlight: str = None, functional_params: bool = True, functional_intercept: bool = True, dataset: str = 'SwissMetro')[source]
Plot the individual-specific constant for the models. The model needs to be trained with functional parameters or functional intercept.
- Parameters:
alt_spec_features (List) – List of alternative-specific features. They must be in the same order as for the training.
all_models (Dict) – Dictionary of all models.
path_to_data (str) – Path to the data folder.
path_to_data_train (str) – Path to the training data folder.
save_fig (bool) – Whether to save the figure or not.
feature_to_highlight (str) – Feature to highlight in the plot. If None, no feature is highlighted.
functional_params (bool) – If the model is trained with functional parameters.
functional_intercept (bool) – If the model is trained with functional intercept.
dataset (str) – Dataset to use. Default is “SwissMetro”.
src.run_models module
src.split_save_dataset module
src.synthetic_experiment module
- src.synthetic_experiment.add_simulated_choices(data: DataFrame, with_noise: bool = False) DataFrame[source]
Add simulated choices to the data based on the utility function.
- Parameters:
data (pd.DataFrame) – Data used for the synthetic experiment
with_noise (bool) – Whether to add noise to the utility values
- Returns:
data – Data with the simulated choices added.
- Return type:
pd.DataFrame
- src.synthetic_experiment.compute_prob(V: ndarray) ndarray[source]
Compute the probabilities for each alternative using the softmax function.
- Parameters:
V (np.ndarray) – The utility values for each alternative.
- Returns:
probs – The probabilities for each alternative.
- Return type:
np.ndarray
- src.synthetic_experiment.create_dataset() DataFrame[source]
Create a pandas DataFrame from the synthetic data array.
- Returns:
The created DataFrame.
- Return type:
pd.DataFrame
- src.synthetic_experiment.create_functional_effects(x: ndarray, n_utility: int, n_socio_dem: int) ndarray[source]
Create functional effects for a given number of utilities and features per utility. The functional effects are bounded by [0,1] and use all socio-demographic characteristics. This function assumes that the socio-demographic characteristics are in the first columns of the input array.
- Parameters:
x (np.ndarray) – The input array containing the features.
n_utility (int) – The number of utility functions to create.
n_socio_dem (int) – The number of socio-demographic features.
- Returns:
The created functional effects.
- Return type:
np.ndarray
- src.synthetic_experiment.gather_functional_intercepts(data: DataFrame, model: RUMBoost | TasteNet | MixedEffect, socio_demo_characts: list[str] = ['f0', 'f1', 'f2', 'f3'], n_classes: int = 4, alt_normalised: int = 0, on_train_set: bool = True) ndarray[source]
Gather the learnt functional intercepts for the given model.
- Parameters:
data (pd.DataFrame) – Data used for the synthetic experiment
model (RUMBoost or TasteNet) – The model used for the synthetic experiment.
socio_demo_characts (list[str], optional (default: socio_demo_chars)) – The socio-demographic characteristics used for the functional intercepts.
n_classes (int, optional (default: n_alternatives)) – The number of alternatives (classes) in the model.
alt_normalised (int, optional (default: 0)) – The alternative index for the normalised functional intercepts.
on_train_set (bool, optional (default: True)) – Whether to compute the functional intercepts on the training set or not. Only used for MixedEffect model.
- Returns:
functional_intercepts – The functional intercepts for the given features.
- Return type:
np.ndarray
- src.synthetic_experiment.generate_labels(probs: ndarray) ndarray[source]
Generate labels based on the probabilities.
- Parameters:
probs (np.ndarray) – The probabilities for each alternative and each observation.
- Returns:
labels – The generated labels for each alternative and each observation.
- Return type:
np.ndarray
- src.synthetic_experiment.generate_noise(mean: float, sd: float, n: tuple[int, ...]) ndarray[source]
Generate noise from a Gumbel distribution.
- Parameters:
mean (float) – The mean of the Gumbel distribution.
sd (float) – The standard deviation of the Gumbel distribution.
n (tuple) – The shape of the noise to generate.
- Returns:
noise – The generated noise.
- Return type:
np.ndarray
- src.synthetic_experiment.generate_x(n: int, k: int, n_socio_dem: int = 0, panel_factor: int = 1) ndarray[source]
Generate synthetic data.
- Parameters:
n (int) – The total number of samples.
k (int) – The total number of features.
n_socio_dem (int) – The number of socio-demographic features.
panel_factor (int) – The panel factor, i.e. the number of repeated trips per observation.
- Returns:
The generated synthetic data.
- Return type:
np.ndarray
- src.synthetic_experiment.hyperparameter_search(model: str = 'RUMBoost') None[source]
Perform hyperparameter search for the models. This function is not implemented yet.
- Parameters:
model (str) – The model to train. Can be “RUMBoost” or “TasteNet”.
- src.synthetic_experiment.l1_distance(true_fct_intercept: ndarray, learnt_fct_intercept: ndarray) float[source]
Compute the L1 distance between the true and learnt functional intercepts.
- Parameters:
true_fct_intercept (np.ndarray) – The true functional intercepts.
learnt_fct_intercept (np.ndarray) – The learnt functional intercepts.
- Returns:
l1_distance – The L1 distance between the true and learnt functional intercepts.
- Return type:
float
- src.synthetic_experiment.plot_alt_spec_features(alt_spec_features_list: list[str], save_fig: bool = True) None[source]
Plot the alternative-specific features for the models, if trained without functional parameters.
- Parameters:
alt_spec_features_list (list[str]) – List of alternative-specific features. They must be in the same order as for the training.
save_fig (bool) – Whether to save the figure or not.
- src.synthetic_experiment.plot_ind_spec_constant(save_fig: bool = True, functional_params: bool = False, functional_intercept: bool = True) None[source]
Plot the individual-specific constant for the models. The model needs to be trained with functional parameters or functional intercept.
- Parameters:
save_fig (bool) – Whether to save the figure or not.
functional_params (bool) – If the model is trained with functional parameters.
functional_intercept (bool) – If the model is trained with functional intercept.
- src.synthetic_experiment.run_experiment(args: Namespace) None[source]
Run the synthetic experiment with the given arguments.
- Parameters:
args (argparse.Namespace) – The command line arguments parsed by the parser.
- src.synthetic_experiment.utility_function(data: DataFrame, with_noise: bool = False) ndarray[source]
Create the utility function for the synthetic dataset.
- Parameters:
data (pd.DataFrame) – Data used for the synthetic experiment
with_noise (bool) – Whether to add noise to the utility values
- Returns:
V – The utility values for each alternative.
- Return type:
np.ndarray
src.train module
src.utils module
- src.utils.add_hyperparameters(rum_struct: List[Dict[str, Any]], hyperparameters: Dict[str, Any]) Dict[str, Any][source]
Add hyperparameters to a specific dict of rum structure.
- Parameters:
rum_struct (List[Dict[str, Any]]) – The rum structure to be modified.
hyperparameters (Dict[str, Any]) – The hyperparameters to be added to the rum structure.
- Returns:
rum_structure – The modified rum structure with the hyperparameters added.
- Return type:
List[Dict[Any]]
- src.utils.build_lgb_dataset(X: DataFrame, y: Series) Dataset[source]
Build the LightGBM dataset from the dataframe.
- Parameters:
X (pd.DataFrame) – The dataframe to be used.
y (pd.Series) – The target variable.
- Returns:
lgb_dataset – The LightGBM dataset.
- Return type:
Any
- src.utils.compute_metrics(preds: ndarray, binary_preds: ndarray, labels: ndarray, y_test: Series) tuple[float, float, float][source]
Compute the metrics for the model.
- Parameters:
preds (np.ndarray) – The predictions of the model.
binary_preds (np.ndarray) – The binary predictions of the model.
labels (np.ndarray) – The labels of the model.
y_test (pd.Series) – The test set.
- Returns:
mae_test (float) – The mean absolute error of the model.
loss_test (float) – The loss of the model.
emae_test (float) – The expected mean absolute error of the model.
- src.utils.cross_entropy(y_true: ndarray, y_pred: ndarray) float[source]
Compute the cross entropy loss.
- Parameters:
y_true (np.ndarray) – The true labels.
y_pred (np.ndarray) – The predicted labels.
- Returns:
loss – The cross entropy loss.
- Return type:
float
- src.utils.generate_general_params(num_classes: int, **kwargs) Dict[str, Any][source]
” Generate the general parameters for the rumboost model.
- Parameters:
num_classes (int) – The number of classes in the dataset.
kwargs (Dict[str, Any]) – The additional parameters to be added to the general parameters. These parameters will be used to update the general parameters. It has to be parameters that are accepted by rumboost. See the rumboost documentation for more details.
- Returns:
general_params – The general parameters for the rumboost model.
- Return type:
Dict[str, Any]
- src.utils.generate_ordinal_spec(model_type: str | None = 'proportional_odds', optim_interval: int | None = 20) Dict[str, Any][source]
Generate the ordinal specification for the rumboost model.
- Parameters:
model_type (str) – The type of the model. It can be either ‘proportional_odds’, ‘coral’ or ‘corn’. The default is ‘proportional_odds’.
optim_interval (int) – The optimisation interval at which thresholds are updated with scipy. The default is 20.
- Returns:
ordinal_spec – The ordinal specification for the rumboost model.
- Return type:
Dict[str, Any]
- src.utils.generate_rum_structure(alt_spec_features: Dict[str, List[str]] | None = None, socio_demo_chars: List[str] | None = None, functional_intercept: bool | None = False, functional_params: bool | None = False) List[Dict[str, Any]][source]
Generate the rum structure for the given dataset. Note that this code is written for a single alternative (i.e. regression or ordinal regression problem).
- Parameters:
alt_spec_features (Optional[Dict[str, List[str]]]) – The alternative-specific features to be used in the rum structure. The dictionary keys are the utility indices and the values are the features to be used in the rum structure.
socio_demo_chars (Optional[List[str]]) – The socio-demographic characteristics to be used in the rum structure. They will represent the individual-specific constant learnt from the data.
functional_intercept (Optional[bool]) – Whether to use the functional intercept or not. The default is False.
functional_params (Optional[bool]) – Whether to use the functional parameters or not. The default is False.
- Returns:
rum_structure – The rum structure for the dataset.
- Return type:
List[Dict[Any]]
- src.utils.pkl_to_df(pkl_path: str) DataFrame[source]
Convert a pickle file to a pandas dataframe. :param pkl_path: The path to the pickle file. :type pkl_path: str
- Returns:
df – The dataframe containing the data from the pickle file.
- Return type:
pd.DataFrame
- src.utils.split_dataset(df: DataFrame, target: str, features: List[str], train_size: float = 0.8, val_size: float | None = None, random_state: int = 42, groups: Series | None = None, save_path: str | None = None) tuple[DataFrame, DataFrame, DataFrame, DataFrame] | tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame][source]
Split the dataset into train and test sets.
- Parameters:
df (pd.DataFrame) – The dataframe to be used.
target (str) – The target variable.
features (List[str]) – The features to be used.
train_size (float) – The size of the training set. The default is 0.7. This is the fraction of the total dataset.
val_size (Optional[float]) – The size of the validation set. The default is 0.1. This is the fraction of the total dataset.
random_state (int) – The random state to be used. The default is 42.
groups (Optional[pd.Series]) – Whether to use stratified sampling or not. The default is None. If None, the data will be split randomly. If not None, the data will be split using the groups provided.
save_path (Optional[str]) – The path to save the train and test sets. The default is None.
- Returns:
The train and test sets. If val_size is provided, the train set will be split into train and validation sets. If groups is provided, the train and test sets will be split using stratified sampling. If groups is not provided, the train and test sets will be split randomly.
- Return type:
TrainTestSplit
- src.utils.stratified_group_k_fold(X, y, groups, k, seed=None)[source]
Stratified Group K-Fold cross-validator Provides train/test indices to split data in train/test sets.
- Parameters:
X (array-like of shape (n_samples, n_features)) – The input samples.
y (array-like of shape (n_samples,)) – The target values.
groups (array-like of shape (n_samples,)) – Group labels for the samples used while splitting the dataset into train/test set.
k (int) – Number of folds. Must be at least 2.
seed (int, optional) – Random seed for shuffling the data.
- Yields:
train (ndarray) – The training set indices for that split.
test (ndarray) – The testing set indices for that split.