import ctypes
import math
from typing import Any

import numpy as np
from osgeo import gdal
from numba import njit, prange

from ..common import constants
from ...urban_sprawl.common.common import Common


@njit(parallel=True, fastmath=True) # type: ignore
def calculate_single_cell(
        matrix: 'np.ndarray[Any, Any]',
        rows: int,
        columns: int,
        center_x: int,
        center_y: int,
        pixel_size: int,
        wcc: float,
        offset: int,
        constant: int
) -> float:
    count = 0
    distance_sum = 0.0

    x_from = max(0, center_x - offset)
    x_to = min(rows, center_x + offset + 1)
    y_from = max(0, center_y - offset)
    y_to = min(columns, center_y + offset + 1)

    for x in prange(x_from, x_to):
        dx_sq = (center_x - x) ** 2

        for y in prange(y_from, y_to):
            if matrix[x, y] == constant:

                dy_sq = (center_y - y) ** 2

                distance = np.sqrt(dx_sq + dy_sq) * pixel_size

                if distance <= 2000:
                    count += 1
                    distance_sum += np.sqrt((distance * 2) + 1) - 1

    if count > 0:
        return (distance_sum / count) * wcc
    return constants.NO_DATA_VALUE


class SiCalculator:
    def __init__(
        self,
        raster: gdal.Dataset,
        clipped_raster_path: 'np.ndarray[Any, Any]',
        radius: int,
    ):
        self._matrix = Common.get_matrix(raster)
        self._clipped_matrix = clipped_raster_path

        self._radius = radius
        self._no_data_value = constants.NO_DATA_VALUE
        self._build_up_value = constants.BUILD_UP_VALUE

        self._pixel_size = Common.get_pixel_size(raster)
        self._wcc = self._calculate_wcc(self._pixel_size)

    @staticmethod
    def _calculate_wcc(pixel_size: float) -> float:
        return math.sqrt(0.97428 * pixel_size + 1.046) - 0.996249

    def calculate(
        self, x_index: int, y_index: int
    ) -> 'np.ndarray[Any, Any]':
        shape = Common.get_shape(self._clipped_matrix)
        matrix_shape = Common.get_shape(self._matrix)

        result_matrix = np.full(
            shape=(shape.rows, shape.columns),
            fill_value=self._no_data_value,
            dtype=float,
        )

        offset = round(self._radius / self._pixel_size)

        for x in prange(0, shape.rows):
            x_with_offset = x + x_index

            for y in prange(0, shape.columns):
                if self._clipped_matrix[x, y] == self._build_up_value:
                    y_with_offset = y + y_index

                    result = calculate_single_cell(
                        self._matrix,
                        matrix_shape.rows,
                        matrix_shape.columns,
                        x_with_offset,
                        y_with_offset,
                        int(self._pixel_size),
                        self._wcc,
                        offset,
                        self._build_up_value,
                    )
                    result_matrix[x, y] = result

        return result_matrix
