"""Summary plots of SHAP values across a whole dataset."""

from __future__ import annotations

import warnings
from typing import Literal

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.cluster
import scipy.sparse
import scipy.spatial
from matplotlib.figure import Figure
from packaging import version
from scipy.stats import gaussian_kde

from .. import Explanation
from ..utils import safe_isinstance
from ..utils._exceptions import DimensionError
from . import colors
from ._labels import labels
from ._utils import (
    convert_color,
    convert_ordering,
    get_sort_order,
    merge_nodes,
    sort_inds,
)

# TODO: simplify this when we drop support for matplotlib 3.9
if version.parse(matplotlib.__version__) >= version.parse("3.10"):
    ORIENTATION_KWARG = dict(orientation="horizontal")
else:
    ORIENTATION_KWARG = dict(vert=False)  # type: ignore[dict-item]


# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
def beeswarm(
    shap_values: Explanation,
    max_display: int | None = 10,
    order=Explanation.abs.mean(0),  # type: ignore
    clustering=None,
    cluster_threshold=0.5,
    color=None,
    axis_color="#333333",
    alpha: float = 1.0,
    ax: plt.Axes | None = None,
    show: bool = True,
    log_scale: bool = False,
    color_bar: bool = True,
    s: float = 16,
    plot_size: Literal["auto"] | float | tuple[float, float] | None = "auto",
    color_bar_label: str = labels["FEATURE_VALUE"],
    group_remaining_features: bool = True,
):
    """Create a SHAP beeswarm plot, colored by feature values when they are provided.

    Parameters
    ----------
    shap_values : Explanation
        This is an :class:`.Explanation` object containing a matrix of SHAP values
        (# samples x # features).

    max_display : int
        How many top features to include in the plot (default is 10, or 7 for
        interaction plots).

    ax: matplotlib Axes
        Axes object to draw the plot onto, otherwise uses the current Axes.

    show : bool
        Whether :external+mpl:func:`matplotlib.pyplot.show()` is called before returning.
        Setting this to ``False`` allows the plot to be customized further
        after it has been created, returning the current axis via
        :external+mpl:func:`matplotlib.pyplot.gca()`.

    color_bar : bool
        Whether to draw the color bar (legend).

    s : float
        What size to make the markers. For further information, see ``s`` in
        :external+mpl:func:`matplotlib.pyplot.scatter`.

    plot_size : "auto" (default), float, (float, float), or None
        What size to make the plot. By default, the size is auto-scaled based on the
        number of features that are being displayed. Passing a single float will cause
        each row to be that many inches high. Passing a pair of floats will scale the
        plot by that number of inches. If ``None`` is passed, then the size of the
        current figure will be left unchanged. If ``ax`` is not ``None``, then passing
        ``plot_size`` will raise a :exc:`ValueError`.

    group_remaining_features: bool
        If there are more features than ``max_display``, then plot a row representing
        the sum of SHAP values of all remaining features. Default True.

    Returns
    -------
    ax: matplotlib Axes
        Returns the :external+mpl:class:`~matplotlib.axes.Axes` object with the plot drawn onto it. Only
        returned if ``show=False``.

    Examples
    --------
    See `beeswarm plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html>`_.

    """
    if not isinstance(shap_values, Explanation):
        emsg = "The beeswarm plot requires an `Explanation` object as the `shap_values` argument."
        raise TypeError(emsg)

    sv_shape = shap_values.shape
    if len(sv_shape) == 1:
        emsg = (
            "The beeswarm plot does not support plotting a single instance, please pass "
            "an explanation matrix with many instances!"
        )
        raise ValueError(emsg)
    elif len(sv_shape) > 2:
        emsg = (
            "The beeswarm plot does not support plotting explanations with instances that have more than one dimension!"
        )
        raise ValueError(emsg)

    if ax and plot_size:
        emsg = (
            "The beeswarm plot does not support passing an axis and adjusting the plot size. "
            "To adjust the size of the plot, set plot_size to None and adjust the size on the original figure the axes was part of"
        )
        raise ValueError(emsg)

    shap_exp = shap_values
    # we make a copy here, because later there are places that might modify this array
    values = np.copy(shap_exp.values)
    features = shap_exp.data
    if scipy.sparse.issparse(features):
        features = features.toarray()
    feature_names = shap_exp.feature_names
    # if out_names is None: # TODO: waiting for slicer support
    #     out_names = shap_exp.output_names

    order = convert_ordering(order, values)

    # multi_class = False
    # if isinstance(values, list):
    #     multi_class = True
    #     if plot_type is None:
    #         plot_type = "bar" # default for multi-output explanations
    #     assert plot_type == "bar", "Only plot_type = 'bar' is supported for multi-output explanations!"
    # else:
    #     if plot_type is None:
    #         plot_type = "dot" # default for single output explanations
    #     assert len(values.shape) != 1, "Summary plots need a matrix of values, not a vector."

    # default color:
    if color is None:
        if features is not None:
            color = colors.red_blue
        else:
            color = colors.blue_rgb
    color = convert_color(color)

    idx2cat = None
    # convert from a DataFrame or other types
    if isinstance(features, pd.DataFrame):
        if feature_names is None:
            feature_names = features.columns
        # feature index to category flag
        idx2cat = features.dtypes.astype(str).isin(["object", "category"]).tolist()
        features = features.values
    elif isinstance(features, list):
        if feature_names is None:
            feature_names = features
        features = None
    elif (features is not None) and len(features.shape) == 1 and feature_names is None:
        feature_names = features
        features = None

    num_features = values.shape[1]

    if features is not None:
        shape_msg = "The shape of the shap_values matrix does not match the shape of the provided data matrix."
        if num_features - 1 == features.shape[1]:
            shape_msg += (
                " Perhaps the extra column in the shap_values matrix is the "
                "constant offset? If so, just pass shap_values[:,:-1]."
            )
            raise DimensionError(shape_msg)
        if num_features != features.shape[1]:
            raise DimensionError(shape_msg)

    if feature_names is None:
        feature_names = np.array([labels["FEATURE"] % str(i) for i in range(num_features)])

    if ax is None:
        ax = plt.gca()
    fig = ax.get_figure()
    assert isinstance(fig, Figure)  # type narrowing for mypy

    if log_scale:
        ax.set_xscale("symlog")

    if clustering is None:
        partition_tree = getattr(shap_values, "clustering", None)
        if partition_tree is not None and partition_tree.var(0).sum() == 0:
            partition_tree = partition_tree[0]
        else:
            partition_tree = None
    elif clustering is False:
        partition_tree = None
    else:
        partition_tree = clustering

    if partition_tree is not None:
        if partition_tree.shape[1] != 4:
            emsg = (
                "The clustering provided by the Explanation object does not seem to "
                "be a partition tree (which is all shap.plots.bar supports)!"
            )
            raise ValueError(emsg)

    # FIXME: introduce beeswarm interaction values as a separate function `beeswarm_interaction()` (?)
    #   In the meantime, users can use the `shap.summary_plot()` function.
    #
    # # plotting SHAP interaction values
    # if len(values.shape) == 3:
    #
    #     if plot_type == "compact_dot":
    #         new_values = values.reshape(values.shape[0], -1)
    #         new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)
    #
    #         new_feature_names = []
    #         for c1 in feature_names:
    #             for c2 in feature_names:
    #                 if c1 == c2:
    #                     new_feature_names.append(c1)
    #                 else:
    #                     new_feature_names.append(c1 + "* - " + c2)
    #
    #         return beeswarm(
    #             new_values, new_features, new_feature_names,
    #             max_display=max_display, plot_type="dot", color=color, axis_color=axis_color,
    #             title=title, alpha=alpha, show=show, sort=sort,
    #             color_bar=color_bar, plot_size=plot_size, class_names=class_names,
    #             color_bar_label="*" + color_bar_label
    #         )
    #
    #     if max_display is None:
    #         max_display = 7
    #     else:
    #         max_display = min(len(feature_names), max_display)
    #
    #     interaction_sort_inds = order#np.argsort(-np.abs(values.sum(1)).sum(0))
    #
    #     # get plotting limits
    #     delta = 1.0 / (values.shape[1] ** 2)
    #     slow = np.nanpercentile(values, delta)
    #     shigh = np.nanpercentile(values, 100 - delta)
    #     v = max(abs(slow), abs(shigh))
    #     slow = -v
    #     shigh = v
    #
    #     plt.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
    #     plt.subplot(1, max_display, 1)
    #     proj_values = values[:, interaction_sort_inds[0], interaction_sort_inds]
    #     proj_values[:, 1:] *= 2  # because off diag effects are split in half
    #     beeswarm(
    #         proj_values, features[:, interaction_sort_inds] if features is not None else None,
    #         feature_names=feature_names[interaction_sort_inds],
    #         sort=False, show=False, color_bar=False,
    #         plot_size=None,
    #         max_display=max_display
    #     )
    #     plt.xlim((slow, shigh))
    #     plt.xlabel("")
    #     title_length_limit = 11
    #     plt.title(shorten_text(feature_names[interaction_sort_inds[0]], title_length_limit))
    #     for i in range(1, min(len(interaction_sort_inds), max_display)):
    #         ind = interaction_sort_inds[i]
    #         plt.subplot(1, max_display, i + 1)
    #         proj_values = values[:, ind, interaction_sort_inds]
    #         proj_values *= 2
    #         proj_values[:, i] /= 2  # because only off diag effects are split in half
    #         summary(
    #             proj_values, features[:, interaction_sort_inds] if features is not None else None,
    #             sort=False,
    #             feature_names=["" for i in range(len(feature_names))],
    #             show=False,
    #             color_bar=False,
    #             plot_size=None,
    #             max_display=max_display
    #         )
    #         plt.xlim((slow, shigh))
    #         plt.xlabel("")
    #         if i == min(len(interaction_sort_inds), max_display) // 2:
    #             plt.xlabel(labels['INTERACTION_VALUE'])
    #         plt.title(shorten_text(feature_names[ind], title_length_limit))
    #     plt.tight_layout(pad=0, w_pad=0, h_pad=0.0)
    #     plt.subplots_adjust(hspace=0, wspace=0.1)
    #     if show:
    #         plt.show()
    #     return

    # determine how many top features we will plot
    if max_display is None:
        max_display = len(feature_names)
    num_features = min(max_display, len(feature_names))

    # iteratively merge nodes until we can cut off the smallest feature values to stay within
    # num_features without breaking a cluster tree
    orig_inds = [[i] for i in range(len(feature_names))]
    orig_values = values.copy()
    while True:
        feature_order = convert_ordering(order, Explanation(np.abs(values)))
        if partition_tree is not None:
            # compute the leaf order if we were to show (and so have the ordering respect) the whole partition tree
            clust_order = sort_inds(partition_tree, np.abs(values))

            # now relax the requirement to match the partition tree ordering for connections above cluster_threshold
            dist = scipy.spatial.distance.squareform(scipy.cluster.hierarchy.cophenet(partition_tree))
            feature_order = get_sort_order(dist, clust_order, cluster_threshold, feature_order)

            # if the last feature we can display is connected in a tree the next feature then we can't just cut
            # off the feature ordering, so we need to merge some tree nodes and then try again.
            if (
                max_display < len(feature_order)
                and dist[feature_order[max_display - 1], feature_order[max_display - 2]] <= cluster_threshold
            ):
                # values, partition_tree, orig_inds = merge_nodes(values, partition_tree, orig_inds)
                partition_tree, ind1, ind2 = merge_nodes(np.abs(values), partition_tree)
                for _ in range(len(values)):
                    values[:, ind1] += values[:, ind2]
                    values = np.delete(values, ind2, 1)
                    orig_inds[ind1] += orig_inds[ind2]
                    del orig_inds[ind2]
            else:
                break
        else:
            break

    # here we build our feature names, accounting for the fact that some features might be merged together
    feature_inds = feature_order[:max_display]
    feature_names_new = []
    for inds in orig_inds:
        if len(inds) == 1:
            feature_names_new.append(feature_names[inds[0]])
        elif len(inds) <= 2:
            feature_names_new.append(" + ".join([feature_names[i] for i in inds]))
        else:
            max_ind = np.argmax(np.abs(orig_values).mean(0)[inds])
            feature_names_new.append(f"{feature_names[inds[max_ind]]} + {len(inds) - 1} other features")
    feature_names = feature_names_new

    # see how many individual (vs. grouped at the end) features we are plotting
    include_grouped_remaining = num_features < len(values[0]) and group_remaining_features
    if include_grouped_remaining:
        num_cut = np.sum([len(orig_inds[feature_order[i]]) for i in range(num_features - 1, len(values[0]))])
        values[:, feature_order[num_features - 1]] = np.sum(
            [values[:, feature_order[i]] for i in range(num_features - 1, len(values[0]))], 0
        )

    # build our y-tick labels
    yticklabels = [feature_names[i] for i in feature_inds]
    if include_grouped_remaining:
        yticklabels[-1] = f"Sum of {num_cut} other features"

    row_height = 0.4
    if plot_size == "auto":
        fig.set_size_inches(8, min(len(feature_order), max_display) * row_height + 1.5)
    elif isinstance(plot_size, (list, tuple)):
        fig.set_size_inches(plot_size[0], plot_size[1])
    elif plot_size is not None:
        fig.set_size_inches(8, min(len(feature_order), max_display) * plot_size + 1.5)
    ax.axvline(x=0, color="#999999", zorder=-1)

    # make the beeswarm dots
    for pos, i in enumerate(reversed(feature_inds)):
        ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
        shaps = values[:, i]
        fvalues = None if features is None else features[:, i]
        f_inds = np.arange(len(shaps))
        np.random.shuffle(f_inds)
        if fvalues is not None:
            fvalues = fvalues[f_inds]
        shaps = shaps[f_inds]
        colored_feature = True
        try:
            if idx2cat is not None and idx2cat[i]:  # check categorical feature
                colored_feature = False
            else:
                fvalues = np.array(fvalues, dtype=np.float64)  # make sure this can be numeric
        except Exception:
            colored_feature = False
        N = len(shaps)
        # hspacing = (np.max(shaps) - np.min(shaps)) / 200
        # curr_bin = []
        nbins = 100
        quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
        inds_ = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
        ys = np.zeros(N)
        for ind in inds_:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
            layer += 1
            last_bin = quant[ind]
        ys *= 0.9 * (row_height / np.max(ys + 1))

        if safe_isinstance(color, "matplotlib.colors.Colormap") and fvalues is not None and colored_feature is True:
            # trim the color range, but prevent the color range from collapsing
            vmin = np.nanpercentile(fvalues, 5)
            vmax = np.nanpercentile(fvalues, 95)
            if vmin == vmax:
                vmin = np.nanpercentile(fvalues, 1)
                vmax = np.nanpercentile(fvalues, 99)
                if vmin == vmax:
                    vmin = np.min(fvalues)
                    vmax = np.max(fvalues)
            if vmin > vmax:  # fixes rare numerical precision issues
                vmin = vmax

            if features is not None and features.shape[0] != len(shaps):
                emsg = "Feature and SHAP matrices must have the same number of rows!"
                raise DimensionError(emsg)

            # plot the nan fvalues in the interaction feature as grey
            nan_mask = np.isnan(fvalues)
            ax.scatter(
                shaps[nan_mask],
                pos + ys[nan_mask],
                color="#777777",
                s=s,
                alpha=alpha,
                linewidth=0,
                zorder=3,
                rasterized=len(shaps) > 500,
            )

            # plot the non-nan fvalues colored by the trimmed feature value
            cvals = fvalues[np.invert(nan_mask)].astype(np.float64)
            cvals_imp = cvals.copy()
            cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
            cvals[cvals_imp > vmax] = vmax
            cvals[cvals_imp < vmin] = vmin
            ax.scatter(
                shaps[np.invert(nan_mask)],
                pos + ys[np.invert(nan_mask)],
                cmap=color,
                vmin=vmin,
                vmax=vmax,
                s=s,
                c=cvals,
                alpha=alpha,
                linewidth=0,
                zorder=3,
                rasterized=len(shaps) > 500,
            )
        else:
            if safe_isinstance(color, "matplotlib.colors.Colormap") and hasattr(color, "colors"):
                color = color.colors
            ax.scatter(
                shaps,
                pos + ys,
                s=s,
                alpha=alpha,
                linewidth=0,
                zorder=3,
                color=color if colored_feature else "#777777",
                rasterized=len(shaps) > 500,
            )

    # draw the color bar
    if safe_isinstance(color, "matplotlib.colors.Colormap") and color_bar and features is not None:
        import matplotlib.cm as cm

        m = cm.ScalarMappable(cmap=color)
        m.set_array([0, 1])
        cb = fig.colorbar(m, ax=ax, ticks=[0, 1], aspect=80)
        cb.set_ticklabels([labels["FEATURE_VALUE_LOW"], labels["FEATURE_VALUE_HIGH"]])
        cb.set_label(color_bar_label, size=12, labelpad=0)
        cb.ax.tick_params(labelsize=11, length=0)
        cb.set_alpha(1)
        cb.outline.set_visible(False)  # type: ignore
    #         bbox = cb.ax.get_window_extent().transformed(plt.gcf().dpi_scale_trans.inverted())
    #         cb.ax.set_aspect((bbox.height - 0.9) * 20)
    # cb.draw_all()

    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("none")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.tick_params(color=axis_color, labelcolor=axis_color)
    ax.set_yticks(range(len(feature_inds)), list(reversed(yticklabels)), fontsize=13)
    ax.tick_params("y", length=20, width=0.5, which="major")
    ax.tick_params("x", labelsize=11)
    ax.set_ylim(-1, len(feature_inds))
    ax.set_xlabel(labels["VALUE"], fontsize=13)
    if show:
        plt.show()
    else:
        return ax


def shorten_text(text, length_limit):
    if len(text) > length_limit:
        return text[: length_limit - 3] + "..."
    else:
        return text


def is_color_map(color):
    safe_isinstance(color, "matplotlib.colors.Colormap")


# TODO: remove unused title argument / use title argument
# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
def summary_legacy(
    shap_values,
    features=None,
    feature_names=None,
    max_display=None,
    plot_type=None,
    color=None,
    axis_color="#333333",
    title=None,
    alpha=1,
    show=True,
    sort=True,
    color_bar=True,
    plot_size="auto",
    layered_violin_max_num_bins=20,
    class_names=None,
    class_inds=None,
    color_bar_label=labels["FEATURE_VALUE"],
    cmap=colors.red_blue,
    show_values_in_legend: bool = False,
    use_log_scale: bool = False,
    rng: np.random.Generator | None = None,
):
    """Create a SHAP beeswarm plot, colored by feature values when they are provided.

    Parameters
    ----------
    shap_values : numpy.array
        For single output explanations this is a matrix of SHAP values (# samples x # features).
        For multi-output explanations this is a list of such matrices of SHAP values.

    features : numpy.array or pandas.DataFrame or list
        Matrix of feature values (# samples x # features) or a feature_names list as shorthand

    feature_names : list
        Names of the features (length # features)

    max_display : int
        How many top features to include in the plot (default is 20, or 7 for interaction plots)

    plot_type : "dot" (default for single output), "bar" (default for multi-output), "violin",
        or "compact_dot".
        What type of summary plot to produce. Note that "compact_dot" is only used for
        SHAP interaction values.

    plot_size : "auto" (default), float, (float, float), or None
        What size to make the plot. By default the size is auto-scaled based on the number of
        features that are being displayed. Passing a single float will cause each row to be that
        many inches high. Passing a pair of floats will scale the plot by that
        number of inches. If None is passed then the size of the current figure will be left
        unchanged.

    show_values_in_legend: bool
        Flag to print the mean of the SHAP values in the multi-output bar plot. Set to False
        by default.
    rng : `numpy.random.Generator`, optional
        Pseudorandom number generator state. When `rng` is None,
        the legacy behavior of using global NumPy random state will be
        used. Types other than `numpy.random.Generator` are
        passed to `numpy.random.default_rng` to instantiate a ``Generator``.

    """
    # handle randomization machinery in conformance with SPEC 7
    if rng is not None:
        rng = np.random.default_rng(rng)
    else:
        global_seed_set = np.random.mtrand._rand._bit_generator._seed_seq is None  # type: ignore
        if global_seed_set:
            msg = (
                "The NumPy global RNG was seeded by calling `np.random.seed`. "
                "In a future version this function will no longer use the global RNG. "
                "Pass `rng` explicitly to opt-in to the new behaviour and silence this warning."
            )
            warnings.warn(msg, FutureWarning, stacklevel=2)

    # support passing an explanation object
    if str(type(shap_values)).endswith("Explanation'>"):
        shap_exp = shap_values
        shap_values = shap_exp.values
        if features is None:
            features = shap_exp.data
        if feature_names is None:
            feature_names = shap_exp.feature_names

        # Revert back to list for multi-output explanations.
        if len(shap_exp.base_values.shape) == 2 and shap_exp.base_values.shape[1] > 2:
            shap_values = [shap_values[:, :, i] for i in range(shap_exp.base_values.shape[1])]

        # if out_names is None: # TODO: waiting for slicer support of this
        #     out_names = shap_exp.output_names

    multi_class = False
    if isinstance(shap_values, list):
        multi_class = True
        if plot_type is None:
            plot_type = "bar"  # default for multi-output explanations
        assert plot_type == "bar", "Only plot_type = 'bar' is supported for multi-output explanations!"
    else:
        if plot_type is None:
            plot_type = "dot"  # default for single output explanations
        assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
        # revert the shape of the shap_values matrix for multi-output explanations to list of matrices
        if len(shap_values.shape) == 3 and shap_values.shape[2] > 2 and plot_type == "bar":
            shap_values = [shap_values[:, :, i] for i in range(shap_values.shape[2])]
            multi_class = True

    # default color:
    if color is None:
        if plot_type == "layered_violin":
            color = "coolwarm"
        elif multi_class:

            def color(i):
                return colors.red_blue_circle(i / len(shap_values))
        else:
            color = colors.blue_rgb

    idx2cat = None
    # convert from a DataFrame or other types
    if isinstance(features, pd.DataFrame):
        if feature_names is None:
            feature_names = features.columns
        # feature index to category flag
        idx2cat = features.dtypes.astype(str).isin(["object", "category"]).tolist()
        features = features.values
    elif isinstance(features, list):
        if feature_names is None:
            feature_names = features
        features = None
    elif (features is not None) and len(features.shape) == 1 and feature_names is None:
        feature_names = features
        features = None

    num_features = shap_values[0].shape[1] if multi_class else shap_values.shape[1]

    if features is not None:
        shape_msg = "The shape of the shap_values matrix does not match the shape of the provided data matrix."
        if num_features - 1 == features.shape[1]:
            raise ValueError(
                shape_msg + " Perhaps the extra column in the shap_values matrix is the "
                "constant offset? Of so just pass shap_values[:,:-1]."
            )
        else:
            assert num_features == features.shape[1], shape_msg

    if feature_names is None:
        feature_names = np.array([labels["FEATURE"] % str(i) for i in range(num_features)])

    if use_log_scale:
        plt.xscale("symlog")

    # plotting SHAP interaction values
    if not multi_class and len(shap_values.shape) == 3:
        if plot_type == "compact_dot":
            new_shap_values = shap_values.reshape(shap_values.shape[0], -1)
            new_features = np.tile(features, (1, 1, features.shape[1])).reshape(features.shape[0], -1)

            new_feature_names = []
            for c1 in feature_names:
                for c2 in feature_names:
                    if c1 == c2:
                        new_feature_names.append(c1)
                    else:
                        new_feature_names.append(c1 + "* - " + c2)

            return summary_legacy(
                new_shap_values,
                new_features,
                new_feature_names,
                max_display=max_display,
                plot_type="dot",
                color=color,
                axis_color=axis_color,
                title=title,
                alpha=alpha,
                show=show,
                sort=sort,
                color_bar=color_bar,
                plot_size=plot_size,
                class_names=class_names,
                color_bar_label="*" + color_bar_label,
            )

        if max_display is None:
            max_display = 7
        else:
            max_display = min(len(feature_names), max_display)

        sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0))

        # get plotting limits
        delta = 1.0 / (shap_values.shape[1] ** 2)
        slow = np.nanpercentile(shap_values, delta)
        shigh = np.nanpercentile(shap_values, 100 - delta)
        v = max(abs(slow), abs(shigh))
        slow = -v
        shigh = v

        plt.figure(figsize=(1.5 * max_display + 1, 0.8 * max_display + 1))
        plt.subplot(1, max_display, 1)
        proj_shap_values = shap_values[:, sort_inds[0], sort_inds]
        proj_shap_values[:, 1:] *= 2  # because off diag effects are split in half
        summary_legacy(
            proj_shap_values,
            features[:, sort_inds] if features is not None else None,
            feature_names=np.array(feature_names)[sort_inds].tolist(),
            sort=False,
            show=False,
            color_bar=False,
            plot_size=None,
            max_display=max_display,
        )
        plt.xlim((slow, shigh))
        plt.xlabel("")
        title_length_limit = 11
        plt.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
        for i in range(1, min(len(sort_inds), max_display)):
            ind = sort_inds[i]
            plt.subplot(1, max_display, i + 1)
            proj_shap_values = shap_values[:, ind, sort_inds]
            proj_shap_values *= 2
            proj_shap_values[:, i] /= 2  # because only off diag effects are split in half
            summary_legacy(
                proj_shap_values,
                features[:, sort_inds] if features is not None else None,
                sort=False,
                feature_names=["" for i in range(len(feature_names))],
                show=False,
                color_bar=False,
                plot_size=None,
                max_display=max_display,
            )
            plt.xlim((slow, shigh))
            plt.xlabel("")
            if i == min(len(sort_inds), max_display) // 2:
                plt.xlabel(labels["INTERACTION_VALUE"])
            plt.title(shorten_text(feature_names[ind], title_length_limit))
        plt.tight_layout(pad=0, w_pad=0, h_pad=0.0)
        plt.subplots_adjust(hspace=0, wspace=0.1)
        if show:
            plt.show()
        return

    if max_display is None:
        max_display = 20

    if sort:
        # order features by the sum of their effect magnitudes
        if multi_class:
            feature_order = np.argsort(np.sum(np.mean(np.abs(shap_values), axis=1), axis=0))
        else:
            feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0))
        feature_order = feature_order[-min(max_display, len(feature_order)) :]
    else:
        feature_order = np.flip(np.arange(min(max_display, num_features)), 0)

    row_height = 0.4
    if plot_size == "auto":
        plt.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
    elif type(plot_size) in (list, tuple):
        plt.gcf().set_size_inches(plot_size[0], plot_size[1])
    elif plot_size is not None:
        plt.gcf().set_size_inches(8, len(feature_order) * plot_size + 1.5)
    plt.axvline(x=0, color="#999999", zorder=-1)

    if plot_type == "dot":
        for pos, i in enumerate(feature_order):
            plt.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
            shaps = shap_values[:, i]
            values = None if features is None else features[:, i]
            inds = np.arange(len(shaps))
            if rng is None:
                np.random.shuffle(inds)
            else:
                rng.shuffle(inds)
            if values is not None:
                values = values[inds]
            shaps = shaps[inds]
            colored_feature = True
            try:
                if idx2cat is not None and idx2cat[i]:  # check categorical feature
                    colored_feature = False
                else:
                    values = np.array(values, dtype=np.float64)  # make sure this can be numeric
            except Exception:
                colored_feature = False
            N = len(shaps)
            # hspacing = (np.max(shaps) - np.min(shaps)) / 200
            # curr_bin = []
            nbins = 100
            quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
            if rng is None:
                tmp_x = np.random.randn(N)
            else:
                tmp_x = rng.standard_normal(N)
            inds = np.argsort(quant + tmp_x * 1e-6)
            layer = 0
            last_bin = -1
            ys = np.zeros(N)
            for ind in inds:
                if quant[ind] != last_bin:
                    layer = 0
                ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
                layer += 1
                last_bin = quant[ind]
            ys *= 0.9 * (row_height / np.max(ys + 1))

            if features is not None and colored_feature:
                # trim the color range, but prevent the color range from collapsing
                vmin = np.nanpercentile(values, 5)
                vmax = np.nanpercentile(values, 95)
                if vmin == vmax:
                    vmin = np.nanpercentile(values, 1)
                    vmax = np.nanpercentile(values, 99)
                    if vmin == vmax:
                        vmin = np.min(values)
                        vmax = np.max(values)
                if vmin > vmax:  # fixes rare numerical precision issues
                    vmin = vmax

                assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"

                # plot the nan values in the interaction feature as grey
                nan_mask = np.isnan(values)
                plt.scatter(
                    shaps[nan_mask],
                    pos + ys[nan_mask],
                    color="#777777",
                    s=16,
                    alpha=alpha,
                    linewidth=0,
                    zorder=3,
                    rasterized=len(shaps) > 500,
                )

                # plot the non-nan values colored by the trimmed feature value
                cvals = values[np.invert(nan_mask)].astype(np.float64)
                cvals_imp = cvals.copy()
                cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
                cvals[cvals_imp > vmax] = vmax
                cvals[cvals_imp < vmin] = vmin
                plt.scatter(
                    shaps[np.invert(nan_mask)],
                    pos + ys[np.invert(nan_mask)],
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    s=16,
                    c=cvals,
                    alpha=alpha,
                    linewidth=0,
                    zorder=3,
                    rasterized=len(shaps) > 500,
                )
            else:
                plt.scatter(
                    shaps,
                    pos + ys,
                    s=16,
                    alpha=alpha,
                    linewidth=0,
                    zorder=3,
                    color=color if colored_feature else "#777777",
                    rasterized=len(shaps) > 500,
                )

    elif plot_type == "violin":
        for pos in range(len(feature_order)):
            plt.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

        if features is not None:
            global_low = np.nanpercentile(shap_values[:, : len(feature_names)].flatten(), 1)
            global_high = np.nanpercentile(shap_values[:, : len(feature_names)].flatten(), 99)
            for pos, i in enumerate(feature_order):
                shaps = shap_values[:, i]
                shap_min, shap_max = np.min(shaps), np.max(shaps)
                shap_max_min = shap_max - shap_min
                xs = np.linspace(np.min(shaps) - shap_max_min * 0.2, np.max(shaps) + shap_max_min * 0.2, 100)
                if np.std(shaps) < (global_high - global_low) / 100:
                    if rng is None:
                        tmp_y = np.random.randn(len(shaps))
                    else:
                        tmp_y = rng.standard_normal(len(shaps))
                    ds = gaussian_kde(shaps + tmp_y * (global_high - global_low) / 100)(xs)
                else:
                    ds = gaussian_kde(shaps)(xs)
                ds /= np.max(ds) * 3

                values = features[:, i]
                # window_size = max(10, len(values) // 20)
                smooth_values = np.zeros(len(xs) - 1)
                sort_inds = np.argsort(shaps)
                trailing_pos = 0
                leading_pos = 0
                running_sum = 0
                back_fill = 0
                for j in range(len(xs) - 1):
                    while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
                        running_sum += values[sort_inds[leading_pos]]
                        leading_pos += 1
                        if leading_pos - trailing_pos > 20:
                            running_sum -= values[sort_inds[trailing_pos]]
                            trailing_pos += 1
                    if leading_pos - trailing_pos > 0:
                        smooth_values[j] = running_sum / (leading_pos - trailing_pos)
                        for k in range(back_fill):
                            smooth_values[j - k - 1] = smooth_values[j]
                    else:
                        back_fill += 1

                vmin = np.nanpercentile(values, 5)
                vmax = np.nanpercentile(values, 95)
                if vmin == vmax:
                    vmin = np.nanpercentile(values, 1)
                    vmax = np.nanpercentile(values, 99)
                    if vmin == vmax:
                        vmin = np.min(values)
                        vmax = np.max(values)

                # plot the nan values in the interaction feature as grey
                nan_mask = np.isnan(values)
                plt.scatter(
                    shaps[nan_mask],
                    np.ones(shap_values[nan_mask].shape[0]) * pos,
                    color="#777777",
                    s=9,
                    alpha=alpha,
                    linewidth=0,
                    zorder=1,
                )
                # plot the non-nan values colored by the trimmed feature value
                cvals = values[np.invert(nan_mask)].astype(np.float64)
                cvals_imp = cvals.copy()
                cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
                cvals[cvals_imp > vmax] = vmax
                cvals[cvals_imp < vmin] = vmin
                plt.scatter(
                    shaps[np.invert(nan_mask)],
                    np.ones(shap_values[np.invert(nan_mask)].shape[0]) * pos,
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    s=9,
                    c=cvals,
                    alpha=alpha,
                    linewidth=0,
                    zorder=1,
                )
                # smooth_values -= nxp.nanpercentile(smooth_values, 5)
                # smooth_values /= np.nanpercentile(smooth_values, 95)
                smooth_values -= vmin
                if vmax - vmin > 0:
                    smooth_values /= vmax - vmin
                for i in range(len(xs) - 1):
                    if ds[i] > 0.05 or ds[i + 1] > 0.05:
                        plt.fill_between(
                            [xs[i], xs[i + 1]],
                            [pos + ds[i], pos + ds[i + 1]],
                            [pos - ds[i], pos - ds[i + 1]],
                            color=colors.red_blue_no_bounds(smooth_values[i]),
                            zorder=2,
                        )

        else:
            parts = plt.violinplot(
                shap_values[:, feature_order],
                range(len(feature_order)),
                points=200,
                **ORIENTATION_KWARG,  # type: ignore[arg-type]
                widths=0.7,
                showmeans=False,
                showextrema=False,
                showmedians=False,
            )

            for pc in parts["bodies"]:  # type:ignore
                pc.set_facecolor(color)
                pc.set_edgecolor("none")
                pc.set_alpha(alpha)

    elif plot_type == "layered_violin":  # courtesy of @kodonnell
        num_x_points = 200
        bins = (
            np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype("int")
        )  # the indices of the feature data corresponding to each bin
        shap_min, shap_max = np.min(shap_values), np.max(shap_values)
        x_points = np.linspace(shap_min, shap_max, num_x_points)

        # loop through each feature and plot:
        for pos, ind in enumerate(feature_order):
            # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
            # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
            feature = features[:, ind]
            unique, counts = np.unique(feature, return_counts=True)
            if unique.shape[0] <= layered_violin_max_num_bins:
                order = np.argsort(unique)
                thesebins = np.cumsum(counts[order])
                thesebins = np.insert(thesebins, 0, 0)
            else:
                thesebins = bins
            nbins = thesebins.shape[0] - 1
            # order the feature data so we can apply percentiling
            order = np.argsort(feature)
            # x axis is located at y0 = pos, with pos being there for offset
            # y0 = np.ones(num_x_points) * pos
            # calculate kdes:
            ys = np.zeros((nbins, num_x_points))
            for i in range(nbins):
                # get shap values in this bin:
                shaps = shap_values[order[thesebins[i] : thesebins[i + 1]], ind]
                # if there's only one element, then we can't
                if shaps.shape[0] == 1:
                    warnings.warn(
                        f"Not enough data in bin #{i} for feature {feature_names[ind]}, so it'll be ignored."
                        " Try increasing the number of records to plot."
                    )
                    # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
                    # nothing to do if i == 0
                    if i > 0:
                        ys[i, :] = ys[i - 1, :]
                    continue
                # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
                if rng is None:
                    tmp_z = np.random.normal(loc=0, scale=0.001, size=shaps.shape[0])
                else:
                    tmp_z = rng.normal(loc=0, scale=0.001, size=shaps.shape[0])
                ys[i, :] = gaussian_kde(shaps + tmp_z)(x_points)
                # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
                # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
                # female, we want the 1% to appear a lot smaller.
                size = thesebins[i + 1] - thesebins[i]
                bin_size_if_even = features.shape[0] / nbins
                relative_bin_size = size / bin_size_if_even
                ys[i, :] *= relative_bin_size
            # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
            # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
            # whitespace
            ys = np.cumsum(ys, axis=0)
            width = 0.8
            scale = ys.max() * 2 / width  # 2 is here as we plot both sides of x axis
            for i in range(nbins - 1, -1, -1):
                y = ys[i, :] / scale
                c = (
                    plt.get_cmap(color)(i / (nbins - 1)) if color in plt.colormaps else color
                )  # if color is a cmap, use it, otherwise use a color
                plt.fill_between(x_points, pos - y, pos + y, facecolor=c, edgecolor="face")
        plt.xlim(shap_min, shap_max)

    elif not multi_class and plot_type == "bar":
        feature_inds = feature_order[:max_display]
        y_pos = np.arange(len(feature_inds))
        global_shap_values = np.abs(shap_values).mean(0)
        plt.barh(y_pos, global_shap_values[feature_inds], 0.7, align="center", color=color)
        plt.yticks(y_pos, fontsize=13)
        plt.gca().set_yticklabels([feature_names[i] for i in feature_inds])

    elif multi_class and plot_type == "bar":
        if class_names is None:
            class_names = ["Class " + str(i) for i in range(len(shap_values))]
        feature_inds = feature_order[:max_display]
        y_pos = np.arange(len(feature_inds))
        left_pos = np.zeros(len(feature_inds))

        if class_inds is None:
            class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])
        elif class_inds == "original":
            class_inds = range(len(shap_values))

        if show_values_in_legend:
            # Get the smallest decimal place of the first significant digit
            # to print on the legend. The legend will print ('n_decimal'+1)
            # decimal places.
            # Set to 1 if the smallest number is bigger than 1.
            smallest_shap = np.min(np.abs(shap_values).mean((1, 2)))
            if smallest_shap > 1:
                n_decimals = 1
            else:
                n_decimals = int(-np.floor(np.log10(smallest_shap)))

        for i, ind in enumerate(class_inds):
            global_shap_values = np.abs(shap_values[ind]).mean(0)
            if show_values_in_legend:
                label = f"{class_names[ind]} ({np.round(np.mean(global_shap_values), (n_decimals + 1))})"
            else:
                label = class_names[ind]
            plt.barh(
                y_pos, global_shap_values[feature_inds], 0.7, left=left_pos, align="center", color=color(i), label=label
            )
            left_pos += global_shap_values[feature_inds]
        plt.yticks(y_pos, fontsize=13)
        plt.gca().set_yticklabels([feature_names[i] for i in feature_inds])
        plt.legend(frameon=False, fontsize=12)

    # draw the color bar
    if (
        color_bar
        and features is not None
        and plot_type != "bar"
        and (plot_type != "layered_violin" or color in plt.colormaps)
    ):
        import matplotlib.cm as cm

        m = cm.ScalarMappable(cmap=cmap if plot_type != "layered_violin" else plt.get_cmap(color))
        m.set_array([0, 1])
        cb = plt.colorbar(m, ax=plt.gca(), ticks=[0, 1], aspect=80)
        cb.set_ticklabels([labels["FEATURE_VALUE_LOW"], labels["FEATURE_VALUE_HIGH"]])
        cb.set_label(color_bar_label, size=12, labelpad=0)
        cb.ax.tick_params(labelsize=11, length=0)
        cb.set_alpha(1)
        cb.outline.set_visible(False)  # type: ignore
    #         bbox = cb.ax.get_window_extent().transformed(plt.gcf().dpi_scale_trans.inverted())
    #         cb.ax.set_aspect((bbox.height - 0.9) * 20)
    # cb.draw_all()

    plt.gca().xaxis.set_ticks_position("bottom")
    plt.gca().yaxis.set_ticks_position("none")
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["left"].set_visible(False)
    plt.gca().tick_params(color=axis_color, labelcolor=axis_color)
    plt.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
    if plot_type != "bar":
        plt.gca().tick_params("y", length=20, width=0.5, which="major")
    plt.gca().tick_params("x", labelsize=11)
    plt.ylim(-1, len(feature_order))
    if plot_type == "bar":
        plt.xlabel(labels["GLOBAL_VALUE"], fontsize=13)
    else:
        plt.xlabel(labels["VALUE"], fontsize=13)
    plt.tight_layout()
    if show:
        plt.show()
