import { useMemo } from 'react';
import { useAtom } from 'jotai';
import * as d3 from 'rockerbox_d3_legacy_clone';
import { IndexGridTree } from '@rockerbox/styleguide';

import { spendFormatter } from '../../../utils/valueFormatter';
import { tiersAtom, columnsAtom } from '../../../atoms';

const f = values => values.filter(x => !x.depth);

export const getTreeGridColumns = (tierColorMap, conversionKey, currencyCode, revenueKey, currentSegments, tiersSummaryTotals) => {
  const dataColumns = totals => {
    if (!currentSegments) return [];

    return currentSegments.flatMap(({ filter_id, action_name, include_revenue }, i) => {
      const conversionKeyName = `${filter_id}_${conversionKey}`;
      const spendKey = `${filter_id}_spend`;
      const columnObj = [
        {
          id: `${filter_id}_conversion_key`,
          display: `${action_name}`,
          key: conversionKeyName,
          as: IndexGridTree.NumberCellTwoDecimals,
          reducer: values => d3.sum(f(values), x => x[conversionKeyName]),
          summaryLabel: 'Total',
          style: { textAlign: 'right', whiteSpace: 'normal' },
        },
        {
          id: `${filter_id}_spend`,
          display: `${action_name} Spend`,
          key: spendKey,
          as: IndexGridTree.SpendCell(spendFormatter(currencyCode)),
          reducer: values => d3.sum(f(values), x => x[spendKey]),
          summaryLabel: 'Total',
          style: { textAlign: 'right', whiteSpace: 'normal' },
        },
        {
          id: `${filter_id}_cpa`,
          display: `${action_name} CPA`,
          key: `${filter_id}_cpa`,
          as: IndexGridTree.CpaCell(conversionKeyName, spendFormatter(currencyCode), spendKey),
          reducer: values => (d3.sum(f(values), x => x[spendKey]) ? d3.sum(f(values), x => x[spendKey]) / d3.sum(f(values), x => x[conversionKeyName]) : 0),
          style: { textAlign: 'right', whiteSpace: 'normal' },
        },
      ];

      if (include_revenue) {
        const revenueKeyName = `${filter_id}_${revenueKey}`;
        columnObj.push(
          {
            id: `${filter_id}_revenue_key`,
            display: `${action_name} Revenue`,
            key: revenueKeyName,
            as: IndexGridTree.SpendCell(spendFormatter(currencyCode)),
            reducer: values => d3.sum(f(values), x => x[revenueKeyName]),
            summaryLabel: 'Total',
            style: { textAlign: 'right', whiteSpace: 'normal' },
          },
          {
            id: `${filter_id}_roas`,
            display: `${action_name} ROAS`,
            key: 'roas',
            as: IndexGridTree.RoasCell(revenueKeyName, spendFormatter(currencyCode), spendKey),
            reducer: values => (d3.sum(f(values), x => x[spendKey]) ? d3.sum(f(values), x => x[revenueKeyName]) / d3.sum(f(values), x => x[spendKey]) : 0),
            style: { textAlign: 'right', whiteSpace: 'normal' },
          },
        );
      }

      if (i === 0) return columnObj;

      const lastSegmentconversionKeyName = `${currentSegments[i - 1].filter_id}_${conversionKey}`;
      const lastSegmentActionName = currentSegments[i - 1].action_name;
      const rateColumnObj = [
        {
          id: `${filter_id}_conversion_key_rate`,
          display: `${lastSegmentActionName} to ${action_name} Conv Rate`,
          key: conversionKeyName,
          denominatorKey: lastSegmentconversionKeyName,
          totals,
          as: IndexGridTree.PercentageCellConversionRate,
          reducer: values => d3.sum(f(values), x => x[conversionKeyName]),
          style: { textAlign: 'right', whiteSpace: 'normal' },
          showPercent: true,
          sortable: 0,
        },
      ];

      return [...columnObj, ...rateColumnObj];
    });
  };

  const overallConvRate = [];
  if (currentSegments && currentSegments.length) {
    const firstConversionSegment = currentSegments[0];
    const lastConversionSegment = currentSegments[currentSegments.length - 1];
    overallConvRate.push(
      {
        id: 'overall_conversion_key_rate',
        display: 'Overall Conv Rate',
        key: `${lastConversionSegment.filter_id}_${conversionKey}`,
        denominatorKey: `${firstConversionSegment.filter_id}_${conversionKey}`,
        totals: tiersSummaryTotals,
        as: IndexGridTree.PercentageCellConversionRate,
        reducer: values => d3.sum(f(values), x => x[`${lastConversionSegment.filter_id}_${conversionKey}`]),
        style: { textAlign: 'right', whiteSpace: 'normal' },
        showPercent: true,
        sortable: 0,
      },
    );
  }

  return [
    {
      id: 'group',
      display: 'Channel',
      key: 'group',
      groupBy: ['tier_1', 'tier_2', 'tier_3', 'tier_4', 'tier_5'],
      as: IndexGridTree.NameCell(tierColorMap),
    },
    ...dataColumns(tiersSummaryTotals),
    ...overallConvRate,
  ];
};

export const useTreeColumns = (tiersData, tiersSummaryTotals, tierColorMap, conversionKey, currencyCode, revenueKey, currentSegments) => {
  const [savedColumnIds, setSavedColumnIds] = useAtom(columnsAtom);
  const [tiers, setTiers] = useAtom(tiersAtom);

  const allColumns = useMemo(() => (
    getTreeGridColumns(tierColorMap, conversionKey, currencyCode, revenueKey, currentSegments, tiersSummaryTotals)
  ), [tierColorMap, conversionKey, currencyCode, revenueKey, currentSegments, tiersSummaryTotals]);

  const defaultColumns = useMemo(() => {
    const defaultColsToExclude = ['spend', 'roas'];
    return allColumns.filter(({ id }) => !defaultColsToExclude.some(col => id.includes(col)));
  }, [allColumns]);

  const selectedColumns = useMemo(() => {
    if (!tiersData || tiersData.length === 0) return []; // if no data, return empty array
    if (savedColumnIds && savedColumnIds.length === 0) return defaultColumns;

    // since order matters... use the ordered that is saved
    const selectedCols = savedColumnIds.flatMap(id => allColumns.find(c => c.id === id) || []);

    // if there are no selected columns, return all columns
    if (selectedCols.length === 0) {
      return defaultColumns;
    }

    return selectedCols;
  }, [savedColumnIds, allColumns, conversionKey, tiersData]);

  const setColumns = selectedCols => {
    const keys = selectedCols.map(({ id }) => id);
    setSavedColumnIds(keys);
  };

  return {
    tiers: tiers.length ? tiers : ['tier_1', 'tier_2', 'tier_3', 'tier_4', 'tier_5'],
    setTiers,
    allColumns,
    selectedColumns,
    setColumns,
  };
};
