import pyqtgraph as pg

from openlog.datamodel.assay.generic_assay import GenericAssay
from openlog.gui.assay_visualization.categorical_symbology import CategoricalSymbology


class CategoricalScatterPlotItem(pg.ScatterPlotItem):
    def __init__(
        self,
        assay: GenericAssay,
        column: str,
        symbology: CategoricalSymbology,
        *args,
        **kargs,
    ):
        pg.ScatterPlotItem.__init__(self, *args, **kargs)

        self.assay = assay
        self.column = column
        self.symbology = symbology
        (x, cat) = self.assay.get_all_values(self.column)
        self.x = x.astype(float)
        self.cat = cat.astype(str)
        self.legend_items = self.get_legend_items()

        self.set_symbology(self.symbology)

    def set_symbology(self, symbology: CategoricalSymbology):
        """
        Set a color by category.
        """

        self.symbology_config = symbology

        # For now display all values available
        val_dict = {}

        if self.symbology_config.color_col not in val_dict:
            (_, y) = self.assay.get_all_values(
                self.symbology_config.color_col, remove_none=True
            )
            val_dict[self.symbology_config.color_col] = y
        brushes = [
            self.symbology_config.get_color(str(cat))
            for cat in val_dict[self.symbology_config.color_col]
        ]

        self.setBrush(brushes)
        self.set_legend_brush(symbology)

    @classmethod
    def generate_single_brush_item(cls, brush: str = "") -> pg.ScatterPlotItem:
        """
        Given a color, return a pg.ScatterPlotItem with single brush
        """
        res = pg.ScatterPlotItem()
        res.setBrush(brush)
        res.setSize(15)
        return res

    def get_legend_items(self) -> list:
        """
        Return a list of item-category_name pairs for legend displaying.
        Only categories present in self are considered
        """
        result = []
        cats = set(self.cat)
        for cat in cats:
            brush = self.symbology.get_color(cat)
            obj = self.generate_single_brush_item(brush=brush)
            result.append((obj, cat))
        return result

    def set_legend_brush(self, symbology: CategoricalSymbology) -> None:
        """
        Update legend brushes according to self.symbology
        """
        for item, cat in self.legend_items:
            item.setBrush(symbology.get_color(cat))
