diff --git a/catterplot/core.py b/catterplot/core.py index 498ed41..ca3ba87 100644 --- a/catterplot/core.py +++ b/catterplot/core.py @@ -80,7 +80,7 @@ def catter(x, y, s=40, c=None, cat='random', alpha=1, ax=None, cmap=None, cats = np.random.randint(n_cats(), size=len(x)) else: try: - cats = np.ones(len(x)) * cat + cats = [cat] * len(x) except TypeError as e: raise TypeError('`cat` argument needs to be "random", a scalar, or match the input.', e)