import numpy as np

from .._serializable import Serializable


class Masker(Serializable):
    """This is the superclass of all maskers."""

    def __call__(self, mask, *args):
        """Maskers are callable objects that accept the same inputs as the model plus a binary mask."""

    def _standardize_mask(self, mask, *args):
        """This allows users to pass True/False as short hand masks."""
        if mask is True or mask is False:
            if callable(self.shape):
                shape = self.shape(*args)
            else:
                shape = self.shape

            if mask is True:
                return np.ones(shape[1], dtype=bool)
            return np.zeros(shape[1], dtype=bool)
        return mask
