import { KaminoObligation, KaminoReserve, Position } from '@kamino-finance/klend-sdk';
import Decimal from 'decimal.js';
import { PublicKey } from '@solana/web3.js';
import { lamportsToNumberDecimal } from '../utils';
import { useReserveApys } from '../../hooks/lending/useReserveApys';
import { isJlpTokenMint } from '../isJlp';
import { isLST } from '../isLST';
import { includeYieldsInAPYStore } from '../../../stores/includeYieldsInAPYStore';
import { isSUsdETokenMint } from '../isSUsdE';
import { isKTokenSymbol } from '../isKTokenSymbol';

function calcTotal({ positions, reserves }: { positions: Position[]; reserves: KaminoReserve[] }) {
  return positions.reduce((sum, position) => {
    const tokenMint = position.mintAddress;
    const reserve = reserves.find((r) => r.getLiquidityMint().equals(tokenMint));
    if (!reserve) {
      return sum;
    }
    const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
    const price = reserve.getOracleMarketPrice().toNumber() || 0;
    const positionAmountUsd = amount.mul(price);
    return sum.add(positionAmountUsd);
  }, new Decimal(0));
}

interface GetObligationNetApyProps {
  obligation: KaminoObligation;
  reserves: Map<PublicKey, KaminoReserve>;
  getReserveApys: ReturnType<typeof useReserveApys>['getReserveApys'];
}

export function getObligationNetApy({ obligation, reserves, getReserveApys }: GetObligationNetApyProps) {
  if (!obligation) {
    return {
      borrowNetApy: 0,
      supplyNetApy: 0,
      netApy: 0,
    };
  }

  const {
    includeLiquidStakingYieldInAPY,
    includeMarketMakingYieldInAPY,
    includeJLPYieldInAPY,
    includeSUsdEYieldInAPY,
  } = includeYieldsInAPYStore;

  const reservesArr = Array.from(reserves.values());
  // Calculate the total supply and borrow amounts
  const totalSupply = calcTotal({
    positions: Array.from(obligation.deposits.values()),
    reserves: reservesArr,
  });
  const totalBorrow = calcTotal({
    positions: Array.from(obligation.borrows.values()),
    reserves: reservesArr,
  });

  // Calculate the weighted average of the individual supply and borrow APYs
  const weightedAverageSupplyAPY = totalSupply.gt(0)
    ? Array.from(obligation.deposits.values()).reduce((sum, position) => {
        const tokenMint = position.mintAddress;

        const reserve = reservesArr.find((r) => r.getLiquidityMint().equals(tokenMint));
        if (!reserve) {
          return sum;
        }
        const reserveApys = getReserveApys(reserve.address);
        const price = reserve.getOracleMarketPrice().toNumber() || 0;
        const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
        const value = amount.mul(price);

        const supplyApy =
          (isLST(tokenMint) && includeLiquidStakingYieldInAPY) ||
          (isJlpTokenMint(tokenMint) && includeJLPYieldInAPY) ||
          (isSUsdETokenMint(tokenMint) && includeSUsdEYieldInAPY)
            ? reserveApys.totalDepositApy + reserveApys.extraSupplyApy
            : isKTokenSymbol(tokenMint.toString()) && includeMarketMakingYieldInAPY
            ? reserveApys.totalDepositApy +
              (typeof reserveApys.kTokenFeesApy === 'number'
                ? reserveApys.kTokenFeesApy
                : reserveApys.kTokenFeesApy.toNumber())
            : reserveApys.totalDepositApy;
        const proportion = value.div(totalSupply);
        return sum.add(proportion.mul(supplyApy));
      }, new Decimal(0))
    : new Decimal(0);

  const weightedAverageBorrowAPY = totalBorrow.gt(0)
    ? Array.from(obligation.borrows.values()).reduce((sum, position) => {
        const tokenMint = position.mintAddress;
        const reserve = reservesArr.find((r) => r.getLiquidityMint().equals(tokenMint));
        if (!reserve) {
          return sum;
        }
        const reserveApys = getReserveApys(reserve.address);
        const price = reserve.getOracleMarketPrice().toNumber() || 0;
        const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
        const value = amount.mul(price);
        const borrowApy = reserveApys.totalBorrowApy || 0;
        const proportion = value.div(totalBorrow);
        return sum.add(proportion.mul(borrowApy));
      }, new Decimal(0))
    : new Decimal(0);

  // Net APY = (total supply * supply APY) - (total borrow * borrow APY) / (total supply - total borrow)
  // Which is how much interest you earn as a % of your net value
  const netAPY = totalSupply
    .mul(weightedAverageSupplyAPY)
    .minus(totalBorrow.mul(weightedAverageBorrowAPY))
    .div(totalSupply.minus(totalBorrow));

  return {
    borrowNetApy: weightedAverageBorrowAPY.toNumber(),
    supplyNetApy: weightedAverageSupplyAPY.toNumber(),
    netApy: netAPY.toNumber(),
  };
}

export function getNetApy({
  deposits,
  borrows,
  reserves,
  extraSupplyApy = 0,
  extraBorrowApy = 0,
  getReserveApys,
}: {
  deposits: Position[];
  borrows: Position[];
  reserves: KaminoReserve[];
  extraSupplyApy?: number;
  extraBorrowApy?: number;
  getReserveApys: ReturnType<typeof useReserveApys>['getReserveApys'];
}) {
  const totalSupply = calcTotal({ positions: deposits, reserves });
  const totalBorrow = calcTotal({ positions: borrows, reserves });
  const netValue = totalSupply.minus(totalBorrow);

  // Calculate the weighted average of the individual supply and borrow APYs
  const weightedAverageSupplyAPYIncludingStakingApy = totalSupply.gt(0)
    ? deposits
        .filter((x) => new Decimal(x.marketValueRefreshed).gt(new Decimal(0)))
        .reduce((sum, position) => {
          const tokenMint = position.mintAddress;
          const reserve = reserves.find((r) => r.getLiquidityMint().equals(tokenMint));
          if (!reserve) {
            return sum;
          }
          const reserveApys = getReserveApys(reserve.address);
          const price = reserve.getOracleMarketPrice().toNumber() || 0;
          const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
          const value = amount.mul(price);
          const supplyApy = (reserveApys.totalDepositApy || 0) + extraSupplyApy;
          const proportion = value.div(totalSupply);
          return sum.add(proportion.mul(supplyApy));
        }, new Decimal(0))
    : new Decimal(0);

  const weightedAverageSupplyAPY = totalSupply.gt(0)
    ? deposits
        .filter((x) => new Decimal(x.marketValueRefreshed).gt(new Decimal(0)))
        .reduce((sum, position) => {
          const tokenMint = position.mintAddress;
          const reserve = reserves.find((r) => r.getLiquidityMint().equals(tokenMint));
          if (!reserve) {
            return sum;
          }
          const reserveApys = getReserveApys(reserve.address);
          const price = reserve.getOracleMarketPrice().toNumber() || 0;
          const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
          const value = amount.mul(price);
          const supplyApy = reserveApys.totalDepositApy || 0;
          const proportion = value.div(totalSupply);
          return sum.add(proportion.mul(supplyApy));
        }, new Decimal(0))
    : new Decimal(0);

  const weightedAverageBorrowAPY = totalBorrow.gt(0)
    ? borrows.reduce((sum, position) => {
        const tokenMint = position.mintAddress;
        const reserve = reserves.find((r) => r.getLiquidityMint().equals(tokenMint));
        if (!reserve) {
          return sum;
        }
        const reserveApys = getReserveApys(reserve.address);
        const price = reserve.getOracleMarketPrice().toNumber() || 0;
        const amount = lamportsToNumberDecimal(position.amount.toString(), reserve.stats.decimals);
        const value = amount.mul(price);
        const borrowApy = (reserveApys.totalBorrowApy || 0) + extraBorrowApy;
        const proportion = value.div(totalBorrow);
        return sum.add(proportion.mul(borrowApy));
      }, new Decimal(0))
    : new Decimal(0);

  const totalIncome = totalSupply.mul(weightedAverageSupplyAPYIncludingStakingApy);
  const totalInterest = totalBorrow.mul(weightedAverageBorrowAPY);
  // const netAPY = weightedAverageSupplyAPY.minus(weightedAverageBorrowAPY);
  const netAPY = netValue.eq(new Decimal(0)) ? new Decimal(0) : totalIncome.minus(totalInterest).div(netValue);

  return {
    borrowNetApy: weightedAverageBorrowAPY.toNumber(),
    supplyNetApy: weightedAverageSupplyAPYIncludingStakingApy.toNumber(),
    netApy: netAPY.toNumber(),
    supplyApy: weightedAverageSupplyAPY.toNumber(),
    totalSupply,
    totalBorrow,
  };
}
