import math

from qgis.core import (
    QgsPointXY,
    QgsMultiBandColorRenderer,
    QgsPalettedRasterRenderer,
    QgsSingleBandGrayRenderer,
    QgsSingleBandPseudoColorRenderer
)

from .vertex import Vertex, VertexList
from .utils import make_numeric


class MeshFactory():

    def __init__(self):
        self.clear()

    def __str__(self):
        result = f'MeshFactory[v: {len(self.vertices)}, f: {len(self.faces)}]'
        return result

    def clear(self):
        self.vertices = VertexList()
        self.faces = []
        self.faces_rgba = False
        self.x_min = self.y_min = 0
        self.x_scale = self.y_scale = 1
        self.decimals_xy = 0
        self.decimals_z = 3
    
    def generate_metadata_lines(self):
        lines = []
        lines.append('This output is generated with the 3D IO plugin for QGIS')
        lines.append(f'x_min: {self.x_min}')
        lines.append(f'y_min: {self.y_min}')
        lines.append(f'x_scale: {self.x_scale}')
        lines.append(f'y_scale: {self.y_scale}')
        return lines

    def write_obj(self, fn, decimals_xy=None, decimals_z=None, max_rgb=255):
        if decimals_xy is not None:
            self.decimals_xy = decimals_xy
        if self.decimals_xy == 0:
            dec_xy = None
        else:
            dec_xy = self.decimals_xy

        if decimals_z is not None:
            self.decimals_z = decimals_z
        if self.decimals_z == 0:
            dec_z = None
        else:
            dec_z = self.decimals_z

        with open(fn, 'w') as out_file:
            for line in self.generate_metadata_lines():
                out_file.write(f'# {line}\n')
            out_file.write(f'\n')

            for v in self.vertices:
                obj_x = round(v.x, dec_xy)
                obj_y = round(v.y, dec_xy)
                obj_z = round(v.z, dec_z)
                obj_line = f'v {obj_x} {obj_y} {obj_z}'
                if v.has_rgb():
                    obj_r = v.r
                    obj_g = v.g
                    obj_b = v.b
                    decimals_rgb = 0
                    if max_rgb == 1:
                        obj_r = round(obj_r / 255, 3)
                        obj_g = round(obj_g / 255, 3)
                        obj_b = round(obj_b / 255, 3)
                    obj_line = f'{obj_line} {obj_r} {obj_g} {obj_b}'
                out_file.write(f'{obj_line}\n')

            if len(self.faces) > 0:
                out_file.write(f'\n')
                for f in self.faces:
                    face_ids = ' '.join([str(v) for v in f])
                    out_file.write(f'f {face_ids}\n')

    def write_obj2(self, fn, scale_xy=1, scale_z=1, move_to_origin=False, decimals_xy=None, decimals_z=None):
        if len(self.vertices) == 0:
            return

        bounds = self.vertices.get_bounds()
        if move_to_origin:
            dx = bounds[0]
            dy = bounds[1]
            dz = bounds[2]
        else:
            dx = dy = dz = 0
        
        self.x_min = dx
        self.y_min = dy
        self.x_scale = self.y_scale = scale_xy

        with open(fn, 'w') as out_file:
            for line in self.generate_metadata_lines():
                out_file.write(f'# {line}\n')
            out_file.write(f'\n')

            for v in self.vertices:
                x = (v.x - dx) * scale_xy
                y = (v.y - dy) * scale_xy
                z = (v.z - dz) * scale_z

                obj_line = f'v {x} {y} {z}'
                out_file.write(f'{obj_line}\n')

    def from_raster(self, layer, normalize_xy=True, normalize_z=True, band_z=1, parse_rgb=True, face_type=1):
        # face_types: 0=NONE, 1=TRI, 2=QUAD, 3=TRI+QUAD
        self.clear()

        dx = layer.rasterUnitsPerPixelX()
        dy = layer.rasterUnitsPerPixelY()

        x_min = layer.extent().xMinimum()
        y_min = layer.extent().yMinimum()

        provider = layer.dataProvider()

        band_z_stats = provider.bandStatistics(band_z)
        z_min = band_z_stats.minimumValue
        z_max = band_z_stats.maximumValue

        # GET RGB renderer properties
        if parse_rgb:
            renderer = layer.renderer()
            if isinstance(renderer, QgsMultiBandColorRenderer):
                band_r = renderer.redBand()
                band_g = renderer.greenBand()
                band_b = renderer.blueBand()
                enhancement_r = renderer.redContrastEnhancement()
                enhancement_g = renderer.greenContrastEnhancement()
                enhancement_b = renderer.blueContrastEnhancement()
                if layer.bandCount() == 4:
                    band_a = 4
                else:
                    band_a = None
            elif isinstance(renderer, QgsPalettedRasterRenderer):
                band = renderer.band()
                palette_dict = {}
                for item in renderer.classes():
                    key = item.value
                    color = item.color.getRgb()
                    if len(color) == 4 and color[3] == 0:
                        continue
                    palette_dict[key] = [color[0], color[1], color[2]]
            elif isinstance(renderer, QgsSingleBandGrayRenderer):
                band = renderer.grayBand()
                enhancement = renderer.contrastEnhancement()
                gradient = renderer.gradient()
            elif isinstance(renderer, QgsSingleBandPseudoColorRenderer):
                band = renderer.band()
                shader = renderer.shader()
                shader_function = shader.rasterShaderFunction()

        l_width = layer.width()
        l_height = layer.height()

        vertex_ids = []
        vertex_id = 1

        for iy in range(l_height):
            vertex_id_row = []
            for ix in range(l_width):
                px = x_min + ((0.5 + ix) * dx)
                py = y_min + ((0.5 + iy) * dy)

                obj_z, valid = provider.sample(QgsPointXY(px, py), band_z)
                if not valid:
                    obj_z = None

                # Get values for Z and RGB
                r = g = b = None
                if parse_rgb:
                    if isinstance(renderer, QgsMultiBandColorRenderer):
                        has_rgb = False
                        val_r, valid = provider.sample(QgsPointXY(px, py), band_r)
                        if valid:
                            val_g, valid = provider.sample(QgsPointXY(px, py), band_g)
                            if valid:
                                val_b, valid = provider.sample(QgsPointXY(px, py), band_b)
                                if valid:
                                    has_rgb = True

                        if band_a is not None:
                            val_a, valid = provider.sample(QgsPointXY(px, py), band_a)
                            if valid and val_a > 0:
                                r = enhancement_r.enhanceContrast(val_r)
                                g = enhancement_g.enhanceContrast(val_g)
                                b = enhancement_b.enhanceContrast(val_b)
                            else:
                                obj_z = None
                    elif isinstance(renderer, QgsSingleBandPseudoColorRenderer):
                        val, valid = provider.sample(QgsPointXY(px, py), band_z)
                        if valid:
                            valid, val_r, val_g, val_b, a = shader.shade(val)
                            r = val_r
                            g = val_g
                            b = val_b
                    elif isinstance(renderer, QgsSingleBandGrayRenderer):
                        val, valid = provider.sample(QgsPointXY(px, py), band_z)
                        if valid:
                            val = enhancement.enhanceContrast(val)
                            if gradient: # WhiteToBlack
                                r = g = b = 255 - val
                            else:
                                r = g = b = val
                    elif isinstance(renderer, QgsPalettedRasterRenderer):
                        val, valid = provider.sample(QgsPointXY(px, py), band_z)
                        if valid:
                            color = palette_dict.get(val, None)
                            if color is not None:
                                r = color[0]
                                g = color[1]
                                b = color[2]

                if normalize_xy:
                    self.decimals_xy = 0
                    obj_x = ix
                    obj_y = iy
                    self.x_min = x_min
                    self.y_min = y_min
                    self.x_scale = 1 / dx
                    self.y_scale = 1 / dy
                else:
                    self.decimals_xy = 3
                    obj_x = px
                    obj_y = py

                self.decimals_z = 5
                if normalize_z and (obj_z is not None):
                    obj_z = (obj_z - z_min) / (z_max - z_min)

                if obj_z is None:
                    vertex_id_row.append(None)
                    continue

                vertex = Vertex(obj_x, obj_y, obj_z)
                if parse_rgb:
                    if r is not None and g is not None and b is not None:
                        vertex.set_rgb(r, g, b)
                    else:
                        vertex.set_rgb(0, 0, 0) # Black, but this should never happen
                iv = self.vertices.add_vertex(vertex)

                vertex_id_row.append(iv)

            vertex_ids.append(vertex_id_row)

        self.vertex_ids = vertex_ids

        if face_type > 0 and len(vertex_ids) > 1 and len(vertex_ids[0]) > 1:
            for iy in range(len(vertex_ids)-1):
                for ix in range(len(vertex_ids[0])-1):
                    ul_id = vertex_ids[iy+1][ix]
                    ur_id = vertex_ids[iy+1][ix+1]
                    lr_id = vertex_ids[iy][ix+1]
                    ll_id = vertex_ids[iy][ix]

                    if ul_id and ur_id and lr_id and ll_id:
                        if face_type == 1:
                            self.faces.append([ul_id, ur_id, ll_id])
                            self.faces.append([ur_id, lr_id, ll_id])
                        else:
                            self.faces.append([ul_id, ur_id, lr_id, ll_id])
                    else:
                        if face_type in [1, 3]: # TRI or TRI/QUAD
                            ids = []
                            if ul_id:
                                ids.append(ul_id)
                            if ur_id:
                                ids.append(ur_id)
                            if lr_id:
                                ids.append(lr_id)
                            if ll_id:
                                ids.append(ll_id)
                            if len(ids) == 3:
                                self.faces.append(ids)


    def from_point_layer(self, layer, z_field):
        self.clear()

        result = {'num_points_without_z': 0, 'num_points_exported': 0}

        for feat in layer.getFeatures():
            geom = feat.geometry()
            pnt = geom.get()
            x = pnt.x()
            y = pnt.y()

            if z_field is None:
                z = pnt.z()
            else:
                z = feat[z_field]

            z = make_numeric(z)

            if z is None:
                result['num_points_without_z'] += 1
                continue
            
            v = Vertex(x, y, z)
            i = self.vertices.add_vertex(v)
            result['num_points_exported'] += 1
        
        return result
