import React, { useEffect, useMemo, useRef } from 'react';
import * as d3 from 'd3';

const MARGIN = { top: 30, right: 30, bottom: 40, left: 50 };
const BUCKET_COUNT = 70;
const BUCKET_PADDING = 1;

type HistogramProps = {
  width: number;
  height: number;
  data: number[];
  fillColor?: string;
  dataBucketing?: {
    hasUniformBuckets: boolean;
    desiredBucketCount: number;
  };
};

export const Histogram = ({
  width,
  height,
  data,
  fillColor = 'dodgerblue',
  dataBucketing = {
    hasUniformBuckets: true,
    desiredBucketCount: BUCKET_COUNT,
  },
}: HistogramProps) => {
  const axesRef = useRef(null);
  const boundsWidth = width - MARGIN.right - MARGIN.left;
  const boundsHeight = height - MARGIN.top - MARGIN.bottom;

  const xScale = useMemo(() => {
    const max = d3.max(data) ?? 0;
    return d3.scaleLinear().domain([0, max]).range([10, boundsWidth]);
  }, [
    data,
    width,
    dataBucketing.hasUniformBuckets,
    dataBucketing.desiredBucketCount,
  ]);

  const buckets = useMemo(() => {
    const bucketGenerator = d3
      .bin()
      .value((d) => d)
      .domain(xScale.domain() as [number, number])
      .thresholds(() => {
        if (dataBucketing.hasUniformBuckets) {
          return xScale.ticks(dataBucketing.desiredBucketCount);
        }
        const maxBucketBound = Math.round(xScale.domain()[1]);
        return Array.from(Array(maxBucketBound).keys()).filter(
          (x) => x % dataBucketing.desiredBucketCount === 0
        );
      });
    return bucketGenerator(data);
  }, [xScale]);

  const yScale = useMemo(() => {
    const max = d3.max(buckets.map((bucket) => bucket?.length)) ?? 0;
    return d3.scaleLinear().domain([0, max]).range([boundsHeight, 0]).nice();
  }, [
    data,
    height,
    dataBucketing.hasUniformBuckets,
    dataBucketing.desiredBucketCount,
  ]);

  // Render the X axis using d3.js, not react
  useEffect(() => {
    const svgElement = d3.select(axesRef.current);
    svgElement.selectAll('*').remove();

    const xAxisGenerator = d3.axisBottom(xScale);
    svgElement
      .append('g')
      .attr('transform', 'translate(0,' + boundsHeight + ')')
      .call(xAxisGenerator);

    const yAxisGenerator = d3.axisLeft(yScale);
    svgElement.append('g').call(yAxisGenerator);
  }, [xScale, yScale, boundsHeight]);

  const allRects = buckets.map((bucket, i) => {
    if (bucket.x0 === undefined || bucket.x1 === undefined) {
      return <></>;
    }
    return (
      <rect
        key={i}
        fill={fillColor}
        x={xScale(bucket.x0) + BUCKET_PADDING / 2}
        width={xScale(bucket.x1) - xScale(bucket.x0) - BUCKET_PADDING}
        y={yScale(bucket.length)}
        height={boundsHeight - yScale(bucket.length)}
      />
    );
  });

  return (
    <svg width={width} height={height}>
      <g
        width={boundsWidth}
        height={boundsHeight}
        transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
      >
        {allRects}
      </g>
      <g
        width={boundsWidth}
        height={boundsHeight}
        ref={axesRef}
        transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
      />
    </svg>
  );
};
