# Sebastian Raschka 2014-2024
# mlxtend Machine Learning Library Extensions
#
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause


import numpy as np
from sklearn.model_selection import train_test_split


class RandomHoldoutSplit(object):
    """Train/Validation set splitter for sklearn's GridSearchCV etc.

    Provides train/validation set indices to split a dataset
    into train/validation sets using random indices.

    Parameters
    ----------
    valid_size : float (default: 0.5)
        Proportion of examples that being assigned as
        validation examples. 1-`valid_size` will then automatically
        be assigned as training set examples.
    random_seed : int (default: None)
        The random seed for splitting the data
        into training and validation set partitions.
    stratify : bool (default: False)
        True or False, whether to perform a stratified
        split or not

    Examples
    -----------
    For usage examples, please see
    https://rasbt.github.io/mlxtend/user_guide/evaluate/RandomHoldoutSplit/


    """

    def __init__(self, valid_size=0.5, random_seed=None, stratify=False):
        self.valid_size = valid_size
        self.random_seed = random_seed
        self.stratify = stratify

    def split(self, X, y, groups=None):
        """Generate indices to split data into training and test set.

        Parameters
        ----------
        X : array-like, shape (num_examples, num_features)
            Training data, where num_examples is the number of
            training examples and num_features is the number of features.

        y : array-like, shape (num_examples,)
            The target variable for supervised learning problems.
            Stratification is done based on the y labels.

        groups : object
            Always ignored, exists for compatibility.

        Yields
        ------
        train_index : ndarray
            The training set indices for that split.

        valid_index : ndarray
            The validation set indices for that split.
        """
        ind = np.arange(X.shape[0])
        if self.stratify:
            train_index, valid_index, _, _ = train_test_split(
                ind,
                y,
                test_size=self.valid_size,
                shuffle=True,
                stratify=y,
                random_state=self.random_seed,
            )

        else:
            train_index, valid_index, _, _ = train_test_split(
                ind,
                y,
                test_size=self.valid_size,
                shuffle=True,
                stratify=y,
                random_state=self.random_seed,
            )

        for i in range(1):
            yield train_index, valid_index

    def get_n_splits(self, X=None, y=None, groups=None):
        """Returns the number of splitting iterations in the cross-validator

        Parameters
        ----------
        X : object
            Always ignored, exists for compatibility.

        y : object
            Always ignored, exists for compatibility.

        groups : object
            Always ignored, exists for compatibility.

        Returns
        -------
        n_splits : 1
            Returns the number of splitting iterations in the cross-validator.
            Always returns 1.
        """
        return 1


class PredefinedHoldoutSplit(object):
    """Train/Validation set splitter for sklearn's GridSearchCV etc.

    Uses user-specified train/validation set indices to split a dataset
    into train/validation sets using user-defined or random
    indices.

    Parameters
    ----------
    valid_indices : array-like, shape (num_examples,)
        Indices of the training examples in the training set
        to be used for validation. All other indices in the
        training set are used to for a training subset
        for model fitting.

    Examples
    -----------
    For usage examples, please see
    https://rasbt.github.io/mlxtend/user_guide/evaluate/PredefinedHoldoutSplit/

    """

    def __init__(self, valid_indices):
        self.valid_indices = valid_indices

    def split(self, X, y, groups=None):
        """Generate indices to split data into training and test set.

        Parameters
        ----------
        X : array-like, shape (num_examples, num_features)
            Training data, where num_examples is the number of examples
            and num_features is the number of features.

        y : array-like, shape (num_examples,)
            The target variable for supervised learning problems.
            Stratification is done based on the y labels.

        groups : object
            Always ignored, exists for compatibility.

        Yields
        ------
        train_index : ndarray
            The training set indices for that split.

        valid_index : ndarray
            The validation set indices for that split.
        """

        ind = np.arange(X.shape[0])
        train_mask = np.ones(X.shape[0], dtype=np.bool_)
        train_mask[self.valid_indices] = False
        valid_mask = np.where(train_mask, False, True)

        for i in range(1):
            yield ind[train_mask], ind[valid_mask]

    def get_n_splits(self, X=None, y=None, groups=None):
        """Returns the number of splitting iterations in the cross-validator

        Parameters
        ----------
        X : object
            Always ignored, exists for compatibility.

        y : object
            Always ignored, exists for compatibility.

        groups : object
            Always ignored, exists for compatibility.

        Returns
        -------
        n_splits : 1
            Returns the number of splitting iterations in the cross-validator.
            Always returns 1.
        """
        return 1
