# Sebastian Raschka 2014-2024
# mlxtend Machine Learning Library Extensions
#
# A function for generating per class accuracy
#
# License: BSD 3 clause
import numpy as np


def _compute_metric(target_temp, predicted_temp, normalize=True):
    if normalize:
        return (target_temp == predicted_temp).mean()
    else:
        return (target_temp == predicted_temp).sum()


def accuracy_score(
    y_target, y_predicted, method="standard", pos_label=1, normalize=True
):
    """General accuracy function for supervised learning.
    Parameters
    ------------
    y_target : array-like, shape=[n_values]
        True class labels or target values.

    y_predicted : array-like, shape=[n_values]
        Predicted class labels or target values.

    method : str, 'standard' by default.
        The chosen method for accuracy computation.
        If set to 'standard', computes overall accuracy.
        If set to 'binary', computes accuracy for class pos_label.
        If set to 'average', computes average per-class (balanced) accuracy.
        If set to 'balanced', computes the scikit-learn-style balanced accuracy.

    pos_label : str or int, 1 by default.
        The class whose accuracy score is to be reported.
        Used only when `method` is set to 'binary'

    normalize : bool, True by default.
        If True, returns fraction of correctly classified samples.
        If False, returns number of correctly classified samples.

    Returns
    ------------
    score: float

    Examples
    -----------
    For usage examples, please see
    https://rasbt.github.io/mlxtend/user_guide/evaluate/accuracy_score/

    """

    target_temp = np.asarray(y_target)
    predicted_temp = np.asarray(y_predicted)
    unique_labels = np.unique(target_temp)

    if len(y_target) != len(y_predicted):
        raise AttributeError(
            "`y_target` and `y_predicted`" " don't have the same number of elements."
        )
    if method == "standard":
        return _compute_metric(target_temp, predicted_temp, normalize)

    elif method == "binary":
        if pos_label not in unique_labels:
            raise AttributeError("Chosen value of pos_label doesn't exist")

        target_temp = np.where(target_temp == pos_label, 1, 0)
        predicted_temp = np.where(predicted_temp == pos_label, 1, 0)
        return _compute_metric(target_temp, predicted_temp, normalize)

    elif method == "average":
        return sum(
            [
                _compute_metric(
                    np.where(target_temp != lab, 1, 0),
                    np.where(predicted_temp != lab, 1, 0),
                )
                for lab in unique_labels
            ]
        ) / float(unique_labels.shape[0])

    elif method == "balanced":
        all_class_acc = []
        for c in np.unique(y_target):
            positive_labels = y_target == c
            class_acc = np.mean((y_predicted == c)[positive_labels])
            all_class_acc.append(class_acc)

        return np.mean(all_class_acc)

    else:
        raise ValueError(
            '`method` must be "standard", "average",'
            ' "balanced", or "binary". Got "%s".' % method
        )
