import numpy as np
import pandas as pd
from pyqtgraph import mkPen

from openlog.datamodel.assay.generic_assay import GenericAssay
from openlog.gui.assay_visualization.extended.bar_symbology import BarSymbology
from openlog.gui.pyqtgraph.SvgBarGraphItem import SvgBarGraphItem


class CategoricalBarGraphItem(SvgBarGraphItem):

    WIDTH = 10

    def __init__(self, assay: GenericAssay, column: str, symbology: BarSymbology):
        """
        Override SvgBarGraphItem to display assay data as categorical series

        Args:
            assay: GenericAssay
            column: assay column name (should have an extended data extent and categorical series type)
            symbology: BarSymbology for category display
        """
        self._assay = assay
        self._column = column
        self.symbology_config = symbology

        # For now display all values available
        (x, y) = assay.get_all_values(column, remove_none=True)
        x = x.astype(float)

        if assay.get_dimension(x.shape) == 2:
            y0_val = x[:, 0]
            height_val = [interval[1] - interval[0] for interval in x]

        else:
            y0_val = []
            height_val = []

        x0_val = np.full(len(x), 0)
        width_val = np.full(len(x), self.WIDTH)

        super().__init__(height=height_val, width=width_val, x0=x0_val, y0=y0_val)
        self.set_symbology(self.symbology_config)

    def set_symbology(self, symbology: BarSymbology):
        """
        Define used bar symbology map

        Args:
            symbology_map: (BarSymbologyMap)
        """
        self.symbology_config = symbology

        # For now display all values available
        val_dict = {}
        if self.symbology_config.pattern_col not in val_dict:
            (_, y) = self._assay.get_all_values(
                self.symbology_config.pattern_col, remove_none=True
            )
            val_dict[self.symbology_config.pattern_col] = y
        self.svg_files = [
            self.symbology_config.get_pattern_file(str(cat))
            for cat in val_dict[self.symbology_config.pattern_col]
        ]

        if self.symbology_config.scale_col not in val_dict:
            (_, y) = self._assay.get_all_values(
                self.symbology_config.scale_col, remove_none=True
            )
            val_dict[self.symbology_config.scale_col] = y
        self.svg_sizes = [
            self.symbology_config.get_scale(str(cat))
            for cat in val_dict[self.symbology_config.scale_col]
        ]

        if self.symbology_config.color_col not in val_dict:
            (_, y) = self._assay.get_all_values(
                self.symbology_config.color_col, remove_none=True
            )
            val_dict[self.symbology_config.color_col] = y
        brushes = [
            self.symbology_config.get_color(str(cat))
            for cat in val_dict[self.symbology_config.color_col]
        ]
        self.setOpts(brushes=brushes)
        self.update_brush_texture()

    def setPen(self, *args, **kargs):
        """
        Sets the pen used to draw graph line.
        The argument can be a :class:`QtGui.QPen` or any combination of arguments accepted by
        :func:`pyqtgraph.mkPen() <pyqtgraph.mkPen>`.
        """
        self.setOpts(pen=None)
