// Copyright (c) 2018 XDPoSChain
// Ported to go-ethereum for XDC compatibility

package hooks

import (
	"errors"
	"fmt"
	"math/big"
	"time"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/consensus"
	"github.com/ethereum/go-ethereum/consensus/XDPoS"
	"github.com/ethereum/go-ethereum/core"
	"github.com/ethereum/go-ethereum/core/rawdb"
	"github.com/ethereum/go-ethereum/core/state"
	"github.com/ethereum/go-ethereum/core/tracing"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/eth/util"
	"github.com/ethereum/go-ethereum/log"
	"github.com/ethereum/go-ethereum/params"
	"github.com/holiman/uint256"
)

// AttachConsensusV2Hooks attaches V2 consensus hooks to XDPoS engine.
//
// CC-4 cleanup (audit v3): the wrapper-level HookPenaltyTIPSigning hook that
// previously lived here was dead code for V2 blocks — the wrapper consults that
// field only on V1 paths (xdpos.go:724-729 and :963-970), and V2 blocks bypass
// those paths via verifyHeaderV2 → EngineV2.VerifyHeaderWithParents. The live
// V2 penalty hook is wired on adaptor.EngineV2.HookPenalty in eth/backend.go
// (~lines 315-433), matching v2.6.8 eth/hooks/engine_v2_hooks.go shape. Only the
// wrapper-level HookReward (consumed by the Finalize path) remains here.
func AttachConsensusV2Hooks(adaptor *XDPoS.XDPoS, bc *core.BlockChain, chainConfig *params.ChainConfig) {
	// Hook calculates reward for masternodes at epoch boundaries
	adaptor.HookReward = func(chain consensus.ChainHeaderReader, stateBlock *state.StateDB, parentState *state.StateDB, header *types.Header) (map[string]interface{}, error) {
		number := header.Number.Uint64()
		rCheckpoint := chainConfig.XDPoS.RewardCheckpoint
		if rCheckpoint == 0 {
			rCheckpoint = chainConfig.XDPoS.Epoch
		}
		foundationWalletAddr := chainConfig.XDPoS.FoudationWalletAddr
		if foundationWalletAddr == (common.Address{}) {
			log.Error("Foundation Wallet Address is empty", "error", foundationWalletAddr)
			return nil, errors.New("foundation wallet address is empty")
		}
		rewards := make(map[string]interface{})

		// Skip reward if this is the first v2 block — aligned with v2.6.8
		// Guard: V2 config may be nil for pre-v2 only networks or early sync
		if chainConfig.XDPoS.V2 != nil && chainConfig.XDPoS.V2.SwitchBlock != nil {
			if number == chainConfig.XDPoS.V2.SwitchBlock.Uint64()+1 {
				return rewards, nil
			}
		}

		if number > 0 && number > rCheckpoint && foundationWalletAddr != (common.Address{}) {
			start := time.Now()

			// Get chain reward with inflation
			chainReward := new(big.Int).Mul(new(big.Int).SetUint64(chainConfig.XDPoS.Reward), new(big.Int).SetUint64(params.Ether))
			chainReward = util.RewardInflation(nil, chainReward, number, common.BlocksPerYear)

			totalSigner := new(uint64)
			signers, err := GetSigningTxCount(adaptor, bc, header, chainConfig, totalSigner)

			log.Debug("Time Get Signers", "block", header.Number.Uint64(), "time", common.PrettyDuration(time.Since(start)))
			if err != nil {
				log.Crit("Fail to get signers for reward checkpoint", "error", err)
			}
			rewards["signers"] = signers

			rewardSigners, err := XDPoS.CalculateRewardForSigner(chainReward, signers, *totalSigner)
			if err != nil {
				log.Crit("Fail to calculate reward for signers", "error", err)
			}

			// Add reward for coin holders
			voterResults := make(map[common.Address]interface{})
			if len(signers) > 0 {
				for signer, calcReward := range rewardSigners {
					// Use parentState for reading owner/voter info (matches v2.6.8)
					holderRewards, err := XDPoS.GetRewardBalancesRate(foundationWalletAddr, parentState, signer, calcReward, number)
					if err != nil {
						log.Crit("Fail to calculate reward for holders.", "error", err)
					}
					if len(holderRewards) > 0 {
						for holder, reward := range holderRewards {
							// v2.6.8 uses AddBalance without tracing reason
							rewardU256 := uint256.MustFromBig(reward)
							log.Info("HookReward: AddBalance", "block", number, "holder", holder.Hex(), "reward", reward.String(), "rewardU256", rewardU256.String(), "prevBalance", stateBlock.GetBalance(holder).String())
							stateBlock.AddBalance(holder, rewardU256, tracing.BalanceChangeUnspecified)
							log.Info("HookReward: AddBalance done", "block", number, "holder", holder.Hex(), "newBalance", stateBlock.GetBalance(holder).String())
						}
					}
					voterResults[signer] = holderRewards
				}
			}
			rewards["rewards"] = voterResults
			log.Debug("Time Calculated HookReward", "block", header.Number.Uint64(), "time", common.PrettyDuration(time.Since(start)))
		}
		return rewards, nil
	}

	// Also wire the V2 engine's own reward hook. The wrapper's Finalize delegates
	// V2 blocks to EngineV2.Finalize (xdpos.go:1088), which checks x.HookReward
	// (engine.go:626) — *not* the wrapper's c.HookReward. Without this assignment,
	// V2 epoch-switch reward distribution is silently skipped at block 3600, 4500, …
	// (every V2 epoch boundary). Canonical: XDPoSChain eth/hooks/engine_v2_hooks.go:270.
	if adaptor.EngineV2 != nil {
		adaptor.EngineV2.SetHookReward(func(chain consensus.ChainReader, state *state.StateDB, parentState *state.StateDB, header *types.Header) (map[string]interface{}, error) {
			return adaptor.HookReward(chain, state, parentState, header)
		})
	}
}

// GetSigningTxCount gets signing transaction sender count for reward calculation.
// Uses V1 fixed-window logic for V1 blocks, V2 epoch-switch walk for V2 blocks.
func GetSigningTxCount(c *XDPoS.XDPoS, chain *core.BlockChain, header *types.Header, chainConfig *params.ChainConfig, totalSigner *uint64) (map[common.Address]*XDPoS.RewardLog, error) {
	number := header.Number.Uint64()
	signers := make(map[common.Address]*XDPoS.RewardLog)
	if number == 0 {
		return signers, nil
	}

	rCheckpoint := chainConfig.XDPoS.RewardCheckpoint
	if rCheckpoint == 0 {
		rCheckpoint = chainConfig.XDPoS.Epoch
	}

	// For V1 blocks (number <= V2.SwitchBlock), use original V1 fixed-window logic
	var (
		startBlockNumber, endBlockNumber uint64
		masternodes                      []common.Address
		data                             = make(map[common.Hash][]common.Address)
		mapBlkHash                       = make(map[uint64]common.Hash)
	)

	// Check if this is a V1 block
	isV1Block := true
	if chainConfig.XDPoS.V2 != nil && chainConfig.XDPoS.V2.SwitchBlock != nil {
		if number > chainConfig.XDPoS.V2.SwitchBlock.Uint64() {
			isV1Block = false
		}
	}

	if isV1Block {
		// V1 logic: fixed window
		prevCheckpoint := number - (rCheckpoint * 2)
		startBlockNumber = prevCheckpoint + 1
		endBlockNumber = startBlockNumber + rCheckpoint

		// V2 switch boundary guard
		var switchBlock uint64
		if chainConfig.XDPoS.V2 != nil && chainConfig.XDPoS.V2.SwitchBlock != nil {
			switchBlock = chainConfig.XDPoS.V2.SwitchBlock.Uint64()
		}
		if switchBlock > 0 && startBlockNumber <= switchBlock && number > switchBlock {
			startBlockNumber = switchBlock + 2
			if startBlockNumber > endBlockNumber {
				return signers, nil
			}
			log.Info("GetSigningTxCount: V2 boundary clamp applied", "originalStart", prevCheckpoint+1, "clampedStart", startBlockNumber, "switchBlock", switchBlock)
		}

		// Masternodes come from the checkpoint that *opened* the reward window —
		// prevCheckpoint = number - 2*rCheckpoint, NOT number - rCheckpoint.
		// V1 stores the *next* epoch's masternodes in each checkpoint's extraData,
		// so block 900 holds the masternodes for epoch 901..1800, not for the
		// epoch (1..900) we're rewarding at block 1800. Canonical:
		// XDPoSChain contracts/utils.go:357 — chain.GetHeader(parent, prevCheckpoint).
		// For block 1800: prevCheckpoint = 0 (genesis); for block 2700: prevCheckpoint = 900.
		checkpointHeader := chain.GetHeaderByNumber(prevCheckpoint)
		if checkpointHeader != nil {
			masternodes = c.GetMasternodesFromCheckpointHeader(checkpointHeader, prevCheckpoint, chainConfig.XDPoS.Epoch)
		}
	} else {
		// V2 logic: epoch-switch walk
		rewardEpochCount, signEpochCount := 2, 1
		switchBlockPlusOne := uint64(0)
		if chainConfig.XDPoS.V2 != nil && chainConfig.XDPoS.V2.SwitchBlock != nil {
			switchBlockPlusOne = chainConfig.XDPoS.V2.SwitchBlock.Uint64() + 1
		}

		h := header
		epochCount := 0
		for i := number - 1; ; i-- {
			parentHash := h.ParentHash
			h = chain.GetHeader(parentHash, i)
			if h == nil {
				return nil, fmt.Errorf("GetSigningTxCount: missing header at %d (%x)", i, parentHash)
			}

			isEpochSwitch, _, err := c.IsEpochSwitch(h)
			if err != nil {
				return nil, err
			}
			if isEpochSwitch && i != switchBlockPlusOne {
				epochCount++
				if epochCount == signEpochCount {
					endBlockNumber = h.Number.Uint64() - 1
				}
				if epochCount == rewardEpochCount {
					startBlockNumber = h.Number.Uint64() + 1
					// Dispatcher routes V1 headers to header.Extra parse and V2
					// headers to GetMasternodesFromEpochSwitchHeader. Required for
					// the first V2 reward block, whose 2-epoch walk-back lands on
					// a V1 checkpoint. Canonical: XDPoSChain XDPoS.go:449.
					masternodes = c.GetMasternodesFromCheckpointHeader(h, h.Number.Uint64(), chainConfig.XDPoS.Epoch)
					break
				}
			}
			if i == 0 {
				break
			}
		}
	}

	log.Info("GetSigningTxCount starting", "currentBlock", number, "scanFrom", startBlockNumber, "scanTo", endBlockNumber, "isV1", isV1Block)

	// Walk backwards from header collecting signing txs.
	// Start at number-1 (parent of current block, always available) and walk
	// backwards to startBlockNumber. The signing txs for block N are in block N+1,
	// so we need to scan blocks startBlockNumber..endBlockNumber+1.
	h := header
	for i := number - 1; i >= startBlockNumber; i-- {
		if h == nil {
			log.Error("GetSigningTxCount: header is nil at start of loop", "number", i)
			break
		}
		h = chain.GetHeader(h.ParentHash, i)
		if h == nil {
			log.Error("GetSigningTxCount: header not found", "number", i)
			break
		}
		if h.Hash() == (common.Hash{}) {
			log.Error("GetSigningTxCount: empty header hash", "number", i)
			break
		}
		mapBlkHash[i] = h.Hash()

		signingTxs, ok := c.GetCachedSigningTxs(h.Hash())
		cacheHit := ok
		var totalTxsInBlock int
		if !ok {
			block := rawdb.ReadBlock(c.GetDb(), h.Hash(), i)
			if block != nil {
				totalTxsInBlock = len(block.Transactions())
				if chainConfig.IsTIPSigning(h.Number) {
					signingTxs = c.CacheSigner(h.Hash(), block.Transactions())
				} else {
					receipts := rawdb.ReadRawReceipts(c.GetDb(), h.Hash(), i)
					signingTxs = c.CacheData(h, block.Transactions(), receipts)
				}
			} else {
				log.Warn("[V2-SIG] block not found in rawdb", "number", i, "hash", h.Hash())
			}
		}
		// V2-SIG diagnostic: log EVERY block in V2 era so we can see whether
		// blocks like 3616, 3631 have signing txs (potentially ~108 each).
		// Limit to V2 era (after switchBlock) to keep V1 logs sane.
		if chainConfig.XDPoS.V2 != nil && chainConfig.XDPoS.V2.SwitchBlock != nil &&
			i > chainConfig.XDPoS.V2.SwitchBlock.Uint64() && i <= endBlockNumber+200 {
			// Also log the FIRST referenced block hash so we can see WHAT is being signed.
			var firstRef string
			for _, tx := range signingTxs {
				if len(tx.Data()) >= 36 {
					firstRef = common.BytesToHash(tx.Data()[len(tx.Data())-32:]).Hex()[:10]
					break
				}
			}
			log.Warn("[V2-SIG]", "block", i, "totalTxs", totalTxsInBlock,
				"signingTxs", len(signingTxs), "firstRef", firstRef, "cacheHit", cacheHit, "hash", h.Hash().Hex()[:10])
		}
		// Use historical signer for the block being scanned
		signer := types.MakeSigner(chainConfig, h.Number, h.Time)
		for _, tx := range signingTxs {
			if len(tx.Data()) >= 36 {
				blkHash := common.BytesToHash(tx.Data()[len(tx.Data())-32:])
				from, err := types.Sender(signer, tx)
				if err != nil {
					log.Warn("[V2-SIG-RECOVER-FAIL]", "block", i, "txHash", tx.Hash().Hex()[:10], "err", err)
					continue
				}
				data[blkHash] = append(data[blkHash], from)
			} else if len(tx.Data()) > 0 {
				log.Warn("[V2-SIG-DATA-SHORT]", "block", i, "dataLen", len(tx.Data()), "txHash", tx.Hash().Hex()[:10])
			}
		}

		if i == 0 {
			break
		}
	}

	// V2-SIG diagnostic: dump the data map size and how many of the count-loop
	// blocks (multiples of MergeSignRange in window) actually have signing-tx
	// references in the data map. If 0/55, signing txs reference blocks NOT
	// matching the filter — that's the real bug.
	if !isV1Block {
		dataMapSize := len(data)
		matchedBlocks := 0
		for i := startBlockNumber; i <= endBlockNumber; i++ {
			if i%common.MergeSignRange == 0 {
				if len(data[mapBlkHash[i]]) > 0 {
					matchedBlocks++
				}
			}
		}
		// Count total signing-tx-referenced unique block hashes
		log.Warn("[V2-SIG-SUMMARY]",
			"startBlock", startBlockNumber, "endBlock", endBlockNumber,
			"data_map_entries", dataMapSize,
			"countloop_filter_matches", matchedBlocks,
			"total_filter_blocks", (endBlockNumber-startBlockNumber)/common.MergeSignRange+1,
			"masternodes_len", len(masternodes))
		// Print sample of data map keys (first 10) so we can see what blocks are referenced
		count := 0
		for refHash, addrs := range data {
			if count >= 10 {
				break
			}
			log.Warn("[V2-SIG-DATA]", "refBlockHash", refHash.Hex()[:10], "signers", len(addrs))
			count++
		}
	}

	mnDiagPrinted := false
	for i := startBlockNumber; i <= endBlockNumber; i++ {
		if i%common.MergeSignRange == 0 || !chainConfig.IsTIP2019(big.NewInt(int64(i))) {
			addrs := data[mapBlkHash[i]]
			if len(addrs) == 0 {
				continue
			}
			// V2-MN-DIAG: on first matched block, dump all masternodes + first 5
			// addrs so we can compare. This identifies whether:
			//   (a) masternodes is empty/garbage
			//   (b) masternodes has V1 addrs that don't match V2 signers
			if !isV1Block && !mnDiagPrinted {
				mnDiagPrinted = true
				mnHexes := make([]string, 0, len(masternodes))
				for _, mn := range masternodes {
					mnHexes = append(mnHexes, mn.Hex())
				}
				addrHexes := make([]string, 0, 5)
				for k, a := range addrs {
					if k >= 5 {
						break
					}
					addrHexes = append(addrHexes, a.Hex())
				}
				log.Warn("[V2-MN-DIAG]", "block", i, "mnCount", len(masternodes),
					"addrsCount", len(addrs), "masternodes", mnHexes, "addrsSample", addrHexes)
			}
			seen := make(map[common.Address]bool)
			for _, mn := range masternodes {
				for _, addr := range addrs {
					if addr == mn && !seen[addr] {
						seen[addr] = true
						break
					}
				}
			}
			if len(addrs) > 0 && len(seen) == 0 {
				log.Warn("[V2-SIG-MISMATCH]", "block", i, "addrs", len(addrs), "masternodes", len(masternodes), "matched", 0)
			}
			for addr := range seen {
				if rl, ok := signers[addr]; ok {
					rl.Sign++
				} else {
					signers[addr] = &XDPoS.RewardLog{Sign: 1, Reward: new(big.Int)}
				}
				*totalSigner++
			}
		}
	}

	log.Info("Calculate reward at checkpoint", "startBlock", startBlockNumber, "endBlock", endBlockNumber, "totalSigners", *totalSigner, "uniqueSigners", len(signers))
	return signers, nil
}
