// Copyright 2016-2024 Hitachi Energy. All rights reserved.

import SearchParams from "@pg/common/build/models/SearchParams";
import { routes } from "core/app/components/AppRoutes";
import { NavigateFn } from "core/app/components/RouterProvider";
import * as d3 from "d3";
import dayjs from "dayjs";
import { IntlShape } from "react-intl";
import CorrelationService, {
  ICorrelationData,
  ICorrelationLineData
} from "../services/CorrelationService";

export interface IScatterPlotPoint {
  name: string;
  valueX: number | Date;
  valueY: number | Date;
  class: string;
  risk: string;
}

export interface IScatterPlotDataPoints {
  DataPoints: IDataPoint[];
}

export interface IDataPoint {
  AssetId: string;
  Date: string;
  Value: number;
  Risk: string;
}

class ScatterPlot {
  private container: HTMLDivElement;
  private labelX: string;
  private labelY: string;

  private readonly minHeight = 300;
  private readonly minWidth = 40;

  private height = 700;
  private width = 200;
  private margin = { top: 40, right: 80, bottom: 80, left: 80 };

  private data: IScatterPlotPoint[];
  private correlationData: ICorrelationData;
  private xAxisHasNumbers = false;
  private yAxisHasNumbers = false;

  constructor(
    htmlContainer: HTMLDivElement,
    assetId: string,
    labelX: string,
    labelY: string,
    scatterPlotYData: IScatterPlotDataPoints,
    scatterPlotXData: IScatterPlotDataPoints,
    showCorrelation: boolean
  ) {
    this.container = htmlContainer;
    this.labelX = labelX;
    this.labelY = labelY;

    if (scatterPlotXData !== null && scatterPlotYData === null) {
      this.xAxisHasNumbers = true;
      this.data = this.formatDateToValue(scatterPlotXData, assetId);
    } else if (scatterPlotXData === null && scatterPlotYData !== null) {
      this.yAxisHasNumbers = true;
      this.data = this.formatValueToDate(scatterPlotYData, assetId);
    } else if (scatterPlotXData !== null && scatterPlotYData !== null) {
      this.xAxisHasNumbers = true;
      this.yAxisHasNumbers = true;
      this.data = this.formatValueToValue(
        scatterPlotYData.DataPoints,
        scatterPlotXData.DataPoints,
        assetId
      );
    } else {
      this.data = [];
    }

    if (this.data.length > 0 && showCorrelation) {
      this.correlationData = CorrelationService.compute(this.data);
    }
  }

  private formatValueToDate(
    dataPoints: IScatterPlotDataPoints,
    currentAssetId: string
  ): IScatterPlotPoint[] {
    return dataPoints.DataPoints.filter(
      (d) => d.Date != null && d.Value != null
    ).map<IScatterPlotPoint>((d) => ({
      name: d.AssetId,
      valueX: dayjs(d.Date).toDate(),
      valueY: +d.Value,
      class: d.AssetId === currentAssetId ? "current" : null,
      risk: d.Risk ?? ""
    }));
  }

  private formatDateToValue(
    dataPoints: IScatterPlotDataPoints,
    currentAssetId: string
  ): IScatterPlotPoint[] {
    return dataPoints.DataPoints.filter(
      (d) => d.Date != null && d.Value != null
    ).map<IScatterPlotPoint>((d) => ({
      name: d.AssetId,
      valueX: +d.Value,
      valueY: dayjs(d.Date).toDate(),
      class: d.AssetId === currentAssetId ? "current" : null,
      risk: d.Risk ?? ""
    }));
  }

  private formatValueToValue(
    pointsY: IDataPoint[],
    pointsX: IDataPoint[],
    currentAssetId: string
  ): IScatterPlotPoint[] {
    const mergedPoints = pointsY.map((yPoint) => {
      const haveEqualAssetId = (xPoint: IDataPoint) =>
        xPoint.AssetId === yPoint.AssetId;
      const xPoint = pointsX.find(haveEqualAssetId);
      return {
        name: yPoint.AssetId,
        valueX: xPoint ? +xPoint.Value : null,
        valueY: yPoint ? +yPoint.Value : null,
        class: yPoint.AssetId === currentAssetId ? "current" : null,
        risk: yPoint ? yPoint.Risk ?? "" : ""
      };
    });
    return mergedPoints.filter((p) => p.valueX != null && p.valueY != null);
  }

  private static valueOrDateFormatter(
    value: number | Date,
    isNumericValue: boolean,
    intl: IntlShape
  ): string {
    if (isNumericValue) return this.valueFormatter(value as number, intl);
    else return this.dateFormatter(value as Date, intl);
  }

  private static dateFormatter(date: Date, intl: IntlShape): string {
    return intl.formatDate(date);
  }

  private static valueFormatter(d: number, intl: IntlShape): string {
    const formattedValue = intl.formatNumber(
      Math.round(d * 100000000.0) / 100000000.0
    );
    return formattedValue;
  }

  private static rSquaredFormatter(
    rSquaredValue: number,
    intl: IntlShape
  ): string {
    if (!rSquaredValue) {
      return intl.formatMessage({
        defaultMessage: "N/A",
        id: "family_analytics.chart.r_squared_not_available_label"
      });
    }

    const formattedValue = intl.formatNumber(
      Math.round(rSquaredValue * 100000000.0) / 100000000.0
    );
    return formattedValue;
  }

  destroy() {
    d3.select(this.container).selectAll("*").remove();
  }

  public draw(intl: IntlShape, navigate: NavigateFn) {
    this.setChartDimensions();

    let y: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>;
    if (this.yAxisHasNumbers)
      y = d3
        .scaleLinear()
        .domain(
          d3.extent(this.data, (point: IScatterPlotPoint) => point.valueY)
        )
        .nice();
    else
      y = d3
        .scaleTime()
        .domain(
          d3.extent(this.data, (point: IScatterPlotPoint) => point.valueY)
        )
        .nice();

    y.range([this.height - this.margin.bottom, this.margin.top]);

    const svg = d3
      .select(this.container)
      .append("svg")
      .attr("width", this.width)
      .attr("height", this.height);

    svg.append("g").call(this.drawYAxisAndAdjustMarginLeft(y, intl));

    let x: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>;
    if (this.xAxisHasNumbers)
      x = d3
        .scaleLinear()
        .domain(
          d3.extent(this.data, (point: IScatterPlotPoint) => point.valueX)
        )
        .nice();
    else
      x = d3
        .scaleTime()
        .domain(
          d3.extent(this.data, (point: IScatterPlotPoint) => point.valueX)
        )
        .nice();

    x.range([this.margin.left, this.width - this.margin.right]);

    svg.append("g").call(this.drawXAxis(x, intl));

    svg.append("g").call(this.drawXGridLines(x));
    svg.append("g").call(this.drawYGridLines(y));

    svg.append("rect").attr("width", this.width).attr("height", this.height);

    svg
      .append("g")
      .attr("class", "data")
      .selectAll("circle")
      .data(this.data)
      .enter()
      .append("circle")
      .join("circle")
      .attr("class", (d) => "asset-risk-dot " + d.risk.toLowerCase())
      .attr("cx", (d) => x(d.valueX))
      .attr("cy", (d) => y(d.valueY))
      .attr("data-class", (d) => d.class)
      .attr("r", 6);

    if (this.correlationData) {
      this.drawCorrelationLine(svg, x, y);
    }

    svg.call(this.addTooltipBehavior(x, y, intl, navigate));
  }

  private drawCorrelationLine(
    svg: d3.Selection<SVGSVGElement, unknown, null, undefined>,
    x: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>,
    y: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>
  ) {
    const data = this.correlationData;
    svg
      .append("line")
      .attr("class", "correlation-line")
      .attr("x1", x(data.xMin))
      .attr("x2", x(data.xMax))
      .attr("y1", y(data.yMin))
      .attr("y2", y(data.yMax));
  }

  private setChartDimensions() {
    this.width = Math.max(this.container.clientWidth, this.minWidth);
    this.height = Math.max(this.container.clientHeight, this.minHeight);
  }

  private wrapTextToWidth(width: number) {
    return function () {
      const self = d3.select(this);
      let textLength = self.node().getComputedTextLength();
      let newText = self.text();
      while (textLength > width && newText.length > 0) {
        newText = newText.slice(0, -1);
        self.text(newText + "...");
        textLength = self.node().getComputedTextLength();
      }
    };
  }

  private drawXGridLines(
    x: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>
  ) {
    return (g: d3.Selection<SVGGElement, null, null, null>) =>
      g
        .attr("transform", `translate(0,${this.height - this.margin.bottom})`)
        .attr("class", "x grid")
        .call(
          d3
            .axisBottom(x)
            .ticks(6)
            .tickSize(-(this.height - this.margin.top - this.margin.bottom))
            .tickFormat((x: number | Date) => "")
        );
  }

  private drawYGridLines(
    y: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>
  ) {
    return (g: d3.Selection<SVGGElement, null, null, null>) =>
      g
        .attr("transform", `translate(${this.margin.left},0)`)
        .attr("class", "y grid")
        .call(
          d3
            .axisLeft(y)
            .ticks(5)
            .tickSize(-(this.width - this.margin.left - this.margin.right))
            .tickFormat((x: number | Date) => "")
        );
  }

  private drawXAxis(
    x: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>,
    intl: IntlShape
  ) {
    return (g: d3.Selection<SVGGElement, null, null, null>) => {
      g.attr("transform", `translate(0,${this.height - this.margin.bottom})`)
        .attr("class", "x axis")
        .call(
          d3
            .axisBottom(x)
            .ticks(6)
            .tickFormat((x: number | Date) =>
              ScatterPlot.valueOrDateFormatter(x, this.xAxisHasNumbers, intl)
            )
        );

      const axisWidth = g.node().getBBox().width;

      const labelX = g.append("text").attr("class", "axis-label");
      labelX.append("tspan").text(this.labelX);

      const labelHeight = labelX.node().getBBox().height;

      const labelXPosition = this.margin.left + axisWidth / 2;
      const labelYPosition = this.margin.bottom / 2 + labelHeight;

      labelX.attr(
        "transform",
        `translate(${labelXPosition}, ${labelYPosition})`
      );
    };
  }

  private drawYAxisAndAdjustMarginLeft(
    y: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>,
    intl: IntlShape
  ) {
    return (g: d3.Selection<SVGGElement, null, null, null>) => {
      g.attr("class", "y axis").call(
        d3
          .axisLeft(y)
          .ticks(5)
          .tickFormat((x: number) =>
            ScatterPlot.valueOrDateFormatter(x, this.yAxisHasNumbers, intl)
          )
      );

      const axisWidth = g.node().getBBox().width;
      const axisHeight = g.node().getBBox().height;

      const labelY = g.append("text").attr("class", "axis-label");

      labelY
        .append("tspan")
        .text(this.labelY)
        .each(this.wrapTextToWidth(axisHeight));

      const labelHeight = labelY.node().getBBox().height;
      const labelWidth = labelY.node().getBBox().width;

      const minMarginLeft = axisWidth + labelHeight + 40;
      if (minMarginLeft > this.margin.left) {
        this.margin.left = minMarginLeft;
      }

      const labelXPosition = (this.margin.left + axisWidth - labelHeight) / 2;
      const labelYPosition = (axisHeight - labelWidth) / 2 + this.margin.top;

      labelY.attr(
        "transform",
        `rotate(270)translate(${-labelYPosition}, ${-labelXPosition})`
      );

      g.attr("transform", `translate(${this.margin.left},0)`);
    };
  }

  private addTooltipBehavior(
    x: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>,
    y: d3.ScaleTime<number, number> | d3.ScaleLinear<number, number>,
    intl: IntlShape,
    navigate: NavigateFn
  ) {
    const xAxisHasNumbers = this.xAxisHasNumbers;
    const yAxisHasNumbers = this.yAxisHasNumbers;
    const rSquared = this.correlationData?.rSquared;

    return (svg: d3.Selection<SVGSVGElement, {}, null, undefined>) => {
      const getClosestPointElement = (
        tree: d3.Quadtree<IScatterPlotPoint>,
        pointEvent: MouseEvent
      ) => {
        const mouse = d3.pointer(pointEvent, pointEvent.target);
        return tree.find(mouse[0], mouse[1], 20);
      };

      const getClosestLineElement = (
        tree: d3.Quadtree<ICorrelationLineData>,
        pointEvent: MouseEvent
      ) => {
        if (!tree) return null;
        const mouse = d3.pointer(pointEvent, pointEvent.target);
        return tree.find(mouse[0], mouse[1], 5);
      };

      const show = (
        highlight: d3.Selection<SVGCircleElement, {}, null, undefined>,
        tooltip: d3.Selection<HTMLDivElement, {}, null, undefined>
      ) => {
        highlight?.classed("c-hidden", false);
        tooltip.transition().duration(150).style("opacity", 0.9);
      };

      const hide = (
        highlight: d3.Selection<SVGCircleElement, {}, null, undefined>,
        tooltip: d3.Selection<HTMLDivElement, {}, null, undefined>
      ) => {
        highlight?.classed("c-hidden", true);
        tooltip.transition().duration(100).style("opacity", 0);
      };

      const showTooltip = function (pointEvent: MouseEvent) {
        const midPoint = window.innerWidth / 2;
        let style: string;
        let styleToRemove: string;
        let dist: string;

        const setStyles = function () {
          if (pointEvent.pageX > midPoint) {
            style = "right";
            styleToRemove = "left";
            dist = window.innerWidth - pointEvent.pageX + 20 + "px";
          } else {
            style = "left";
            styleToRemove = "right";
            dist = pointEvent.pageX + 20 + "px";
          }
        };

        const closest = getClosestPointElement(tree, pointEvent);

        if (closest != null) {
          hide(null, correlationTooltip);
          setStyles();
          highlight
            .attr("cx", x(closest.valueX))
            .attr("cy", y(closest.valueY))
            .attr(
              "class",
              "highlight asset-risk-dot " + closest.risk.toLowerCase()
            );
          tooltip
            .html(
              `<span>${
                closest.name
              }</span><br/> (${ScatterPlot.valueOrDateFormatter(
                closest.valueX,
                xAxisHasNumbers,
                intl
              )}; ${ScatterPlot.valueOrDateFormatter(
                closest.valueY,
                yAxisHasNumbers,
                intl
              )})`
            )
            .style(style, dist)
            .style("top", pointEvent.pageY - 38 + "px")
            .style(styleToRemove, "");
          show(highlight, tooltip);
        } else if (getClosestLineElement(lineTree, pointEvent)) {
          hide(highlight, tooltip);
          setStyles();

          correlationTooltip
            .html(
              `<span>${intl.formatMessage({
                defaultMessage: "Correlation",
                id: "family_analytics.chart.correlation_label"
              })}</span></br>
              ${intl.formatMessage({
                defaultMessage: "R",
                id: "family_analytics.chart.r_squared_label"
              })}<sup>${intl.formatNumber(
                2
              )}</sup> = <span class="r-squared-value"> ${ScatterPlot.rSquaredFormatter(
                rSquared,
                intl
              )}</span>`
            )
            .style(style, dist)
            .style("top", pointEvent.pageY - 38 + "px")
            .style(styleToRemove, "");
          show(null, correlationTooltip);
        } else {
          hide(highlight, tooltip);
          hide(null, correlationTooltip);
        }
      };

      const handleClick = function (pointEvent: MouseEvent) {
        const closest = getClosestPointElement(tree, pointEvent);
        if (closest != null) {
          const searchParams = new SearchParams({ assetId: closest.name });
          navigate({
            pathname: routes.detailPage.pathname,
            search: searchParams.toString()
          });
        }
        pointEvent.stopPropagation();
      };

      const tooltip = d3
        .select(this.container)
        .append("div")
        .attr("class", "tooltip")
        .attr("data-qa", "point-tooltip")
        .style("opacity", 0);

      const correlationTooltip = d3
        .select(this.container)
        .append("div")
        .attr("class", "tooltip")
        .attr("data-qa", "correlation-tooltip")
        .style("opacity", 0);

      const highlight = svg
        .append("circle")
        .attr("r", 10)
        .attr("class", "highlight")
        .classed("c-hidden", true);

      const tree = d3
        .quadtree<IScatterPlotPoint>()
        .x(function (d) {
          return x(d.valueX);
        })
        .y(function (d) {
          return y(d.valueY);
        })
        .addAll(this.data);

      let lineTree: d3.Quadtree<ICorrelationLineData>;
      if (this.correlationData) {
        lineTree = d3
          .quadtree<ICorrelationLineData>()
          .x(function (d) {
            return x(d.x);
          })
          .y(function (d) {
            return y(d.y);
          })
          .addAll(this.correlationData.correlationLineData);
      }

      svg.on("mousemove", showTooltip);
      svg.on("mouseover", (pointEvent: MouseEvent) => showTooltip(pointEvent));
      svg.on("mouseout", () => hide(highlight, tooltip));

      svg.on("click", (pointEvent: MouseEvent) => handleClick(pointEvent));
    };
  }
}

export default ScatterPlot;
