from matplotlib.figure import Figure
from generative.core import Asset, FileAsset, generative_function, GenerativeType
from enum import Enum

import neuralfoil
import numpy as np

from pydantic import Field


class ParsecParams(GenerativeType):
    """PARSEC airfoil parametrisation: http://pubs.sciepub.com/ajme/2/4/1. Defaults to NACA0012."""

    leading_edge_radius: float = Field(ge=0.001, le=0.999, default=0.0155)
    upper_crest_x_location: float = Field(ge=0.001, le=0.999, default=0.2966)
    upper_crest_y_location: float = Field(ge=0, le=0.4, default=0.06002)
    upper_crest_curvature: float = Field(ge=-30, le=30, default=-0.4515)
    lower_crest_x_location: float = Field(ge=0.001, le=0.999, default=0.2966)
    lower_crest_y_location: float = Field(ge=-0.4, le=0, default=-0.06002)
    lower_crest_curvature: float = Field(ge=-30, le=30, default=0.4515)
    trailing_edge_vertical_offset: float = Field(ge=-0.3, le=0.3, default=0)
    trailing_edge_thickness: float = Field(ge=0.0, le=0.2, default=0.0025)
    trailing_edge_angle_degrees: float = Field(ge=-30, le=30, default=0)
    trailing_edge_wedge_angle_degrees: float = Field(ge=0.001, lt=90, default=12.89)


class FlowConditions(GenerativeType):
    angle_of_attack: float = 0.0
    reynolds_number: float = Field(gt=0.0, default=1_000_000)
    mach: float = Field(gt=0.0, default=0.4)


class Config(GenerativeType):
    n_points_per_side: int = 100
    model_size: str = "large"


class AnalysisOutputs(GenerativeType):
    lift_coefficient: float
    drag_coefficient: float
    moment_coefficient: float
    lift_drag_ratio: float
    analysis_confidence: float


class Outputs(GenerativeType):
    plot: Asset
    analysis: AnalysisOutputs


@generative_function
def airfoil(airfoil_params: ParsecParams, conditions: FlowConditions, config: Config) -> Outputs:
    airfoil_coords = parsec_airfoil(airfoil_params, config.n_points_per_side)
    check_surface_overlap(airfoil_coords)
    plot = plot_airfoil(airfoil_coords)
    analysis_output = neuralfoil_analysis(airfoil_coords, conditions, model_size=config.model_size)
    return Outputs(plot=plot, analysis=analysis_output)


def plot_airfoil(airfoil_coords):
    # Close the trailing edge before plotting
    x_vals = [p[0] for p in airfoil_coords] + [airfoil_coords[0][0]]
    y_vals = [p[1] for p in airfoil_coords] + [airfoil_coords[0][1]]

    fig = Figure()
    ax = fig.subplots()
    ax.plot(x_vals, y_vals, linewidth=2.5, solid_capstyle="round", color="#ECEEED")
    ax.set_aspect("equal", "box")
    ax.axis("off")

    asset = FileAsset(extension="svg")
    fig.savefig(asset.path, transparent=True)

    return asset


def degrees_to_radians(value: float) -> float:
    return value * np.pi / 180.0


class Surface(Enum):
    UPPER = 0
    LOWER = 1


def _parsec_airfoil_surface_coords(params, surface, n_points):
    """Produce coordinates of a surface running from LE to TE. All angles in radians."""

    # Use cosine spacing to group points at the curvy bits of the airfoil
    x_coords = 0.5 * (1 - np.cos(np.linspace(0, np.pi, n_points)))

    x_matrix = np.zeros([n_points, 6], dtype=np.float64)
    for i in range(n_points):
        for j in range(6):
            x_matrix[i, j] = x_coords[i] ** (j + 0.5)

    if surface == Surface.UPPER:
        crest_x_loc = params.upper_crest_x_location
        crest_y_loc = params.upper_crest_y_location
        crest_curv = params.upper_crest_curvature
        te_y_coord = params.trailing_edge_vertical_offset + 0.5 * params.trailing_edge_thickness
        te_angle = degrees_to_radians(
            params.trailing_edge_angle_degrees - 0.5 * params.trailing_edge_wedge_angle_degrees
        )
        le_fac = 1.0
    else:
        crest_x_loc = params.lower_crest_x_location
        crest_y_loc = params.lower_crest_y_location
        crest_curv = params.lower_crest_curvature
        te_y_coord = params.trailing_edge_vertical_offset - 0.5 * params.trailing_edge_thickness
        te_angle = degrees_to_radians(
            params.trailing_edge_angle_degrees + 0.5 * params.trailing_edge_wedge_angle_degrees
        )
        le_fac = -1.0

    c = np.ones([6, 6], dtype=np.float64)
    for i in range(6):
        c[1, i] = crest_x_loc ** (0.5 + i)
        c[2, i] = 0.5 + i
        c[3, i] = crest_x_loc ** (-0.5 + i)
        c[4, i] = crest_x_loc ** (-1.5 + i)
        if i != 0:
            c[5, i] = 0.0
    c[3, :] = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5] * c[3, :]
    c[4, :] = [-0.25, 0.75, 3.75, 3.75, 15.75, 24.75] * c[4, :]

    le_term = le_fac * (2 * params.leading_edge_radius) ** 0.5

    b = np.array(
        [te_y_coord, crest_y_loc, np.tan(te_angle), 0.0, crest_curv, le_term], dtype=np.float64
    )

    a = np.linalg.solve(c, b)
    y_coords = np.dot(x_matrix, a)

    return list(zip(x_coords, y_coords, strict=True))


def parsec_airfoil(params, n_points_per_side):
    """Produce airfoil from PARSEC parameters (http://pubs.sciepub.com/ajme/2/4/1)"""
    upper = _parsec_airfoil_surface_coords(params, Surface.UPPER, n_points_per_side)
    lower = _parsec_airfoil_surface_coords(params, Surface.LOWER, n_points_per_side)
    # Build coords anti-clockwise from upper trailing edge
    return list(reversed(upper)) + lower[1:]


def neuralfoil_analysis(airfoil_coords, flow_conditions, model_size):
    aero = neuralfoil.get_aero_from_coordinates(
        coordinates=np.array(airfoil_coords),
        alpha=flow_conditions.angle_of_attack,
        Re=flow_conditions.reynolds_number,
        model_size=model_size,
    )
    return AnalysisOutputs(
        lift_coefficient=aero.get("CL"),
        drag_coefficient=aero.get("CD"),
        moment_coefficient=aero.get("CM"),
        lift_drag_ratio=aero.get("CL") / aero.get("CD"),
        analysis_confidence=aero.get("analysis_confidence"),
    )


def _separate_surfaces(coordinates):
    """Separate upper and lower surfaces, returning tuple of (upper surface, lower surface)"""
    try:
        leading_edge_idx = next(i for i, (x, _) in enumerate(coordinates) if x == 0.0)
    except StopIteration as e:
        raise ValueError("No leading edge found") from e

    s1 = coordinates[leading_edge_idx:]
    s2 = list(reversed(coordinates[: leading_edge_idx + 1]))
    return (s1, s2) if s1[1][1] > s2[1][1] else (s2, s1)


def check_surface_overlap(airfoil_coords):
    """Raise `ValueError` if the upper and lower surfaces overlap at any point. Return `None`."""
    tol = 1e-6
    upper_surface, lower_surface = _separate_surfaces(airfoil_coords)
    upper_x_vals = [x for x, _ in upper_surface]
    upper_y_vals = [y for _, y in upper_surface]
    for x_low, y_low in lower_surface:
        y_up = np.interp(x_low, upper_x_vals, upper_y_vals)
        if y_low - tol > y_up:
            raise ValueError("Lower and upper surfaces overlap")