import { useCallback } from 'react';
import Decimal from 'decimal.js';

import { useQueryClient } from '@tanstack/react-query';
import useEnv from './useEnv';
import { QUERYKEYS } from '../constants/queryKeys';
import { useWeb3Client } from './useWeb3Client';
import { isNativeSolMint, isSOLMint } from '../utils/tokens';
import { PublicKeyAddress } from '../types/strategies';
import { useNativeBalanceQuery } from '../../features/UnwrapSolNotification/hooks/useNativeBalanceQuery';
import { SOL_MINTS } from '../constants/tokens';
import { DECIMALS_SOL } from '../constants/math';
import { lamportsToNumberDecimal } from '../utils';

export interface UseBalancesOptions {
  // if true do not combine sol and wSol balances
  isSolWsolSeparated?: boolean;
  liquidityMintTokenProgramIds?: string[];
}

const DEFAULT_PROPS: UseBalancesOptions = {};

export function useBalances(options: UseBalancesOptions = DEFAULT_PROPS) {
  const { walletPublicKey } = useEnv();
  const { web3client } = useWeb3Client();
  const queryClient = useQueryClient();
  const { data: nativeSolBalanceLamports } = useNativeBalanceQuery();

  const getSolWsolValue = useCallback(
    (tokenMint: PublicKeyAddress) => {
      if (!web3client || !walletPublicKey) {
        return;
      }
      const isSolWsolCombined = !options?.isSolWsolSeparated;
      const wSolBalanceKey = QUERYKEYS.getTokenBalance(walletPublicKey, SOL_MINTS[1]);
      const wSolBalance = queryClient.getQueryData<Decimal>(wSolBalanceKey);
      if (isSolWsolCombined) {
        return new Decimal(wSolBalance || 0).add(lamportsToNumberDecimal(nativeSolBalanceLamports, DECIMALS_SOL));
      }

      return isNativeSolMint(tokenMint) ? lamportsToNumberDecimal(nativeSolBalanceLamports, DECIMALS_SOL) : wSolBalance;
    },
    [nativeSolBalanceLamports, options.isSolWsolSeparated, queryClient, walletPublicKey, web3client]
  );

  const getBalanceByTokenMint = useCallback(
    (tokenMint: PublicKeyAddress): Decimal | undefined => {
      if (!web3client || !walletPublicKey) {
        return;
      }

      // if wSOL or SOL token,
      const isSolMint = isSOLMint(tokenMint.toString());
      if (isSolMint) {
        return getSolWsolValue(tokenMint);
      }
      // const mint = isSolMint ? SOL_MINTS[1] : tokenMint.toString();
      const queryKey = QUERYKEYS.getTokenBalance(walletPublicKey, tokenMint.toString());
      return queryClient.getQueryData<Decimal>(queryKey);
    },
    [getSolWsolValue, queryClient, walletPublicKey, web3client]
  );

  const getBalancesByTokenMint = useCallback(
    (mints: Array<PublicKeyAddress>): (Decimal | undefined)[] => {
      return mints.map(getBalanceByTokenMint);
    },
    [getBalanceByTokenMint]
  );

  return {
    getBalanceByTokenMint,
    getBalancesByTokenMint,
  };
}
