from qgis.core import (
    QgsCoordinateTransformContext,
    QgsFeature,
    QgsGeometry,
    QgsPoint,
    QgsPointXY,
    QgsVectorFileWriter,
    QgsVectorLayer,
)

from .landxml_reader import LandXMLReader
from .mesh_elements import MeshVertex


class VectorLayerWriter:
    def __init__(self, landxml: LandXMLReader) -> None:
        self.landxml = landxml

        self.point_layer = QgsVectorLayer(
            "pointz?crs=epsg:3857&field=id:integer&field=z:double", "point_layer", "memory"
        )
        self.polygon_layer = QgsVectorLayer("polygon?crs=epsg:3857&field=id:integer", "polygon_layer", "memory")

        self._populate_layers()

    def _populate_layers(self):
        for i, point in enumerate(self.landxml.all_points):
            f = QgsFeature(self.point_layer.fields())
            f.setGeometry(QgsGeometry.fromPoint(QgsPoint(float(point.x), float(point.y), float(point.z))))
            f.setAttributes([i, point.z])
            self.point_layer.dataProvider().addFeature(f)

        for i, face in enumerate(self.landxml.all_faces):
            points_polygon = []

            for point_id in face.points_ids:
                point = self._find_point(int(point_id))
                points_polygon.append(QgsPointXY(float(point.x), float(point.y)))

            f = QgsFeature(self.polygon_layer.fields())
            f.setGeometry(QgsGeometry.fromPolygonXY([points_polygon]))
            f.setAttributes([i])

            self.polygon_layer.dataProvider().addFeature(f)

    def write(self, folder: str, prefix: str = None):
        options = QgsVectorFileWriter.SaveVectorOptions()
        options.driverName = "GPKG"

        if prefix is None:
            points_filename = "points.gpkg"
            polygons_filename = "polygons.gpkg"
        else:
            points_filename = f"{prefix}_points.gpkg"
            polygons_filename = f"{prefix}_polygons.gpkg"

        QgsVectorFileWriter.writeAsVectorFormatV3(
            self.point_layer, f"{folder}/{points_filename}", QgsCoordinateTransformContext(), options
        )
        QgsVectorFileWriter.writeAsVectorFormatV3(
            self.polygon_layer, f"{folder}/{polygons_filename}", QgsCoordinateTransformContext(), options
        )

    def _find_point(self, point_id: int) -> MeshVertex:
        point = [x for x in self.landxml.all_points if x.id == point_id]
        return point[0]
