import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from .._explainer import Explainer


class Maple(Explainer):
    """Simply wraps MAPLE into the common SHAP interface.

    Parameters
    ----------
    model : function
        User supplied function that takes a matrix of samples (# samples x # features) and
        computes the output of the model for those samples. The output can be a vector
        (# samples) or a matrix (# samples x # model outputs).

    data : numpy.array
        The background dataset.
    """

    def __init__(self, model, data):
        self.model = model

        if isinstance(data, pd.DataFrame):
            data = data.values
        self.data = data
        self.data_mean = self.data.mean(0)

        out = self.model(data)
        if len(out.shape) == 1:
            self.out_dim = 1
            self.flat_out = True
        else:
            self.out_dim = out.shape[1]
            self.flat_out = False

        X_train, X_valid, y_train, y_valid = train_test_split(data, out, test_size=0.2, random_state=0)
        self.explainer = MAPLE(X_train, y_train, X_valid, y_valid)

    def attributions(self, X, multiply_by_input=False):
        """Compute the MAPLE coef attributions.

        Parameters
        ----------
        multiply_by_input : bool
            If true, this multiplies the learned coefficients by the mean-centered input. This makes these
            values roughly comparable to SHAP values.
        """
        if isinstance(X, pd.DataFrame):
            X = X.values

        out = [np.zeros(X.shape) for j in range(self.out_dim)]
        for i in range(X.shape[0]):
            exp = self.explainer.explain(X[i])["coefs"]
            out[0][i, :] = exp[1:]
            if multiply_by_input:
                out[0][i, :] = out[0][i, :] * (X[i] - self.data_mean)

        return out[0] if self.flat_out else out


class TreeMaple(Explainer):
    """Simply tree MAPLE into the common SHAP interface.

    Parameters
    ----------
    model : function
        User supplied function that takes a matrix of samples (# samples x # features) and
        computes the output of the model for those samples. The output can be a vector
        (# samples) or a matrix (# samples x # model outputs).

    data : numpy.array
        The background dataset.
    """

    def __init__(self, model, data):
        self.model = model

        if str(type(model)).endswith("sklearn.ensemble.gradient_boosting.GradientBoostingRegressor'>"):
            fe_type = "gbdt"
        # elif str(type(model)).endswith("sklearn.tree.tree.DecisionTreeClassifier'>"):
        # pass
        elif str(type(model)).endswith("sklearn.ensemble.forest.RandomForestRegressor'>"):
            fe_type = "rf"
        # elif str(type(model)).endswith("sklearn.ensemble.forest.RandomForestClassifier'>"):
        #     pass
        # elif str(type(model)).endswith("xgboost.sklearn.XGBRegressor'>"):
        #     pass
        # elif str(type(model)).endswith("xgboost.sklearn.XGBClassifier'>"):
        #     pass
        else:
            raise NotImplementedError(
                "The passed model is not yet supported by TreeMapleExplainer: " + str(type(model))
            )

        if isinstance(data, pd.DataFrame):
            data = data.values
        self.data = data
        self.data_mean = self.data.mean(0)

        out = self.model.predict(data[0:1])
        if len(out.shape) == 1:
            self.out_dim = 1
            self.flat_out = True
        else:
            self.out_dim = self.model.predict(data[0:1]).shape[1]
            self.flat_out = False

        # _, X_valid, _, y_valid = train_test_split(data, self.model.predict(data), test_size=0.2, random_state=0)
        preds = self.model.predict(data)
        self.explainer = MAPLE(data, preds, data, preds, fe=self.model, fe_type=fe_type)

    def attributions(self, X, multiply_by_input=False):
        """Compute the MAPLE coef attributions.

        Parameters
        ----------
        multiply_by_input : bool
            If true, this multiplies the learned coefficients by the mean-centered input. This makes these
            values roughly comparable to SHAP values.
        """
        if isinstance(X, pd.DataFrame):
            X = X.values

        out = [np.zeros(X.shape) for j in range(self.out_dim)]
        for i in range(X.shape[0]):
            exp = self.explainer.explain(X[i])["coefs"]
            out[0][i, :] = exp[1:]
            if multiply_by_input:
                out[0][i, :] = out[0][i, :] * (X[i] - self.data_mean)

        return out[0] if self.flat_out else out


#################################################
# The code below was authored by Gregory Plumb and is
# from: https://github.com/GDPlumb/MAPLE/blob/master/Code/MAPLE.py
# It has by copied here to allow for benchmark comparisons. Please see
# the original repo for the latest version, supporting material, and citations.
#################################################

# Notes:
# -  Assumes any required data normalization has already been done
# -  Can pass Y (desired response) instead of MR (model fit to Y) to make fitting MAPLE to datasets easy

import numpy as np
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error


class MAPLE:
    def __init__(
        self,
        X_train,
        MR_train,
        X_val,
        MR_val,
        fe_type="rf",
        fe=None,
        n_estimators=200,
        max_features=0.5,
        min_samples_leaf=10,
        regularization=0.001,
    ):
        # Features and the target model response
        self.X_train = X_train
        self.MR_train = MR_train
        self.X_val = X_val
        self.MR_val = MR_val

        # Forest Ensemble Parameters
        self.n_estimators = n_estimators
        self.max_features = max_features
        self.min_samples_leaf = min_samples_leaf

        # Local Linear Model Parameters
        self.regularization = regularization

        # Data parameters
        num_features = X_train.shape[1]
        self.num_features = num_features
        num_train = X_train.shape[0]
        self.num_train = num_train
        num_val = X_val.shape[0]

        # Fit a Forest Ensemble to the model response
        if fe is None:
            if fe_type == "rf":
                fe = RandomForestRegressor(
                    n_estimators=n_estimators, min_samples_leaf=min_samples_leaf, max_features=max_features
                )
            elif fe_type == "gbrt":
                fe = GradientBoostingRegressor(
                    n_estimators=n_estimators,
                    min_samples_leaf=min_samples_leaf,
                    max_features=max_features,
                    max_depth=None,
                )
            else:
                print("Unknown FE type ", fe)
                import sys

                sys.exit(0)
            fe.fit(X_train, MR_train)
        else:
            self.n_estimators = n_estimators = len(fe.estimators_)
        self.fe = fe

        train_leaf_ids = fe.apply(X_train)
        self.train_leaf_ids = train_leaf_ids

        val_leaf_ids_list = fe.apply(X_val)

        # Compute the feature importances: Non-normalized @ Root
        scores = np.zeros(num_features)
        if fe_type == "rf":
            for i in range(n_estimators):
                splits = fe[i].tree_.feature  # -2 indicates leaf, index 0 is root
                if splits[0] != -2:
                    scores[splits[0]] += fe[i].tree_.impurity[0]  # impurity reduction not normalized per tree
        elif fe_type == "gbrt":
            for i in range(n_estimators):
                splits = fe[i, 0].tree_.feature  # -2 indicates leaf, index 0 is root
                if splits[0] != -2:
                    scores[splits[0]] += fe[i, 0].tree_.impurity[0]  # impurity reduction not normalized per tree
        self.feature_scores = scores
        mostImpFeats = np.argsort(-scores)

        # Find the number of features to use for MAPLE
        retain_best = 0
        rmse_best = np.inf
        for retain in range(1, num_features + 1):
            # Drop less important features for local regression
            X_train_p = np.delete(X_train, mostImpFeats[retain:], axis=1)
            X_val_p = np.delete(X_val, mostImpFeats[retain:], axis=1)

            lr_predictions = np.empty([num_val], dtype=float)

            for i in range(num_val):
                weights = self.training_point_weights(val_leaf_ids_list[i])

                # Local linear model
                lr_model = Ridge(alpha=regularization)
                lr_model.fit(X_train_p, MR_train, weights)
                lr_predictions[i] = lr_model.predict(X_val_p[i].reshape(1, -1))

            rmse_curr = np.sqrt(mean_squared_error(lr_predictions, MR_val))

            if rmse_curr < rmse_best:
                rmse_best = rmse_curr
                retain_best = retain

        self.retain = retain_best
        self.X = np.delete(X_train, mostImpFeats[retain_best:], axis=1)

    def training_point_weights(self, instance_leaf_ids):
        weights = np.zeros(self.num_train)
        for i in range(self.n_estimators):
            # Get the PNNs for each tree (ones with the same leaf_id)
            PNNs_Leaf_Node = np.where(self.train_leaf_ids[:, i] == instance_leaf_ids[i])[0]
            if len(PNNs_Leaf_Node) > 0:  # SML: added this to fix degenerate cases
                weights[PNNs_Leaf_Node] += 1.0 / len(PNNs_Leaf_Node)
        return weights

    def explain(self, x):
        x = x.reshape(1, -1)

        mostImpFeats = np.argsort(-self.feature_scores)
        x_p = np.delete(x, mostImpFeats[self.retain :], axis=1)

        curr_leaf_ids = self.fe.apply(x)[0]
        weights = self.training_point_weights(curr_leaf_ids)

        # Local linear model
        lr_model = Ridge(alpha=self.regularization)
        lr_model.fit(self.X, self.MR_train, weights)

        # Get the model coefficients
        coefs = np.zeros(self.num_features + 1)
        coefs[0] = lr_model.intercept_
        coefs[np.sort(mostImpFeats[0 : self.retain]) + 1] = lr_model.coef_

        # Get the prediction at this point
        prediction = lr_model.predict(x_p.reshape(1, -1))

        out = {}
        out["weights"] = weights
        out["coefs"] = coefs
        out["pred"] = prediction

        return out

    def predict(self, X):
        n = X.shape[0]
        pred = np.zeros(n)
        for i in range(n):
            exp = self.explain(X[i, :])
            pred[i] = exp["pred"][0]
        return pred

    # Make the predictions based on the forest ensemble (either random forest or gradient boosted regression tree) instead of MAPLE
    def predict_fe(self, X):
        return self.fe.predict(X)

    # Make the predictions based on SILO (no feature selection) instead of MAPLE
    def predict_silo(self, X):
        n = X.shape[0]
        pred = np.zeros(n)
        for i in range(
            n
        ):  # The contents of this inner loop are similar to explain(): doesn't use the features selected by MAPLE or return as much information
            x = X[i, :].reshape(1, -1)

            curr_leaf_ids = self.fe.apply(x)[0]
            weights = self.training_point_weights(curr_leaf_ids)

            # Local linear model
            lr_model = Ridge(alpha=self.regularization)
            lr_model.fit(self.X_train, self.MR_train, weights)

            pred[i] = lr_model.predict(x)[0]

        return pred
