package core

import (
	"math/big"
	"testing"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/rawdb"
)

func TestXdcStateRootCache_BasicOperations(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	// Reset cache state
	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Test caching state roots
	block1 := uint64(1800)
	remote1 := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111")
	local1 := common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222")

	if err := XdcCacheStateRoot(block1, local1, remote1); err != nil {
		t.Fatal(err)
	}

	// Verify retrieval by block number
	if root, ok := XdcGetCachedStateRoot(block1); !ok || root != local1 {
		t.Errorf("Failed to retrieve cached root by block number: got %v, want %v", root, local1)
	}

	// Verify retrieval by remote root
	if root, ok := XdcFindCachedRootForRemote(remote1); !ok || root != local1 {
		t.Errorf("Failed to retrieve cached root by remote root: got %v, want %v", root, local1)
	}

	// Verify DB fallback by clearing memory and reading again
	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.Unlock()

	if root, ok := XdcGetCachedStateRoot(block1); !ok || root != local1 {
		t.Errorf("DB fallback failed for block root: got %v, want %v", root, local1)
	}
	if root, ok := XdcFindCachedRootForRemote(remote1); !ok || root != local1 {
		t.Errorf("DB fallback failed for remote mapping: got %v, want %v", root, local1)
	}
}

func TestXdcStateRootCache_SkipIdenticalRoots(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Cache identical roots (should not be stored)
	sameRoot := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111")
	if err := XdcCacheStateRoot(100, sameRoot, sameRoot); err != nil {
		t.Fatal(err)
	}

	// Should not be in cache
	if _, ok := XdcGetCachedStateRoot(100); ok {
		t.Error("Identical roots should not be cached")
	}
}

func TestXdcStateRootCache_Persistence(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	// Initialize cache
	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Add some entries
	entries := []struct {
		block  uint64
		remote common.Hash
		local  common.Hash
	}{
		{1800, common.HexToHash("0x1111"), common.HexToHash("0x2222")},
		{2700, common.HexToHash("0x3333"), common.HexToHash("0x4444")},
		{3600, common.HexToHash("0x5555"), common.HexToHash("0x6666")},
	}

	for _, e := range entries {
		if err := XdcCacheStateRoot(e.block, e.local, e.remote); err != nil {
			t.Fatal(err)
		}
	}

	// Clear memory
	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.Unlock()

	// Verify all entries are restored via DB fallback
	for _, e := range entries {
		if root, ok := XdcGetCachedStateRoot(e.block); !ok || root != e.local {
			t.Errorf("Failed to load entry for block %d: got %v, want %v", e.block, root, e.local)
		}
		if root, ok := XdcFindCachedRootForRemote(e.remote); !ok || root != e.local {
			t.Errorf("Failed to load remote→local mapping for block %d: got %v, want %v", e.block, root, e.local)
		}
	}
}

func TestXdcStateRootCache_BackwardScan(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Add some cached entries
	cachedBlocks := []uint64{1800, 2700, 3600, 4500}
	for _, block := range cachedBlocks {
		remote := common.HexToHash("0x1111")
		local := common.HexToHash("0x2222")
		if err := XdcCacheStateRoot(block, local, remote); err != nil {
			t.Fatal(err)
		}
	}

	// Scan backward from block 5000 - should find 4500
	if block, root, found := XdcBackwardScanForValidRoot(5000, 2000); !found {
		t.Error("Backward scan should have found cached block")
	} else if block != 4500 {
		t.Errorf("Backward scan found wrong block: got %d, want 4500", block)
	} else if root == (common.Hash{}) {
		t.Error("Backward scan returned empty root")
	}

	// Scan from block 3000 - should find 2700
	if block, _, found := XdcBackwardScanForValidRoot(3000, 1500); !found {
		t.Error("Backward scan should have found cached block")
	} else if block != 2700 {
		t.Errorf("Backward scan found wrong block: got %d, want 2700", block)
	}

	// Scan with no results
	if _, _, found := XdcBackwardScanForValidRoot(1500, 500); found {
		t.Error("Backward scan should not have found any cached block")
	}
}

func TestXdcStateRootCache_Eviction(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Add 150 entries
	for i := uint64(1); i <= 150; i++ {
		remote := common.BigToHash(common.Big1)
		local := common.BigToHash(common.Big2)
		if err := XdcCacheStateRoot(i, local, remote); err != nil {
			t.Fatal(err)
		}
	}

	// Manually trigger eviction (normally happens at xdcStateRootCacheSize)
	xdcStateRootCache.Lock()
	evictOldest(50) // Evict oldest 50
	count := len(xdcStateRootCache.blockRoots)
	xdcStateRootCache.Unlock()

	// Should have 100 entries left
	if count != 100 {
		t.Errorf("After eviction, cache should have 100 entries, got %d", count)
	}

	// Oldest entries (1-50) should be gone from memory (DB fallback will still find them)
	xdcStateRootCache.RLock()
	_, inMem1 := xdcStateRootCache.blockRoots[1]
	_, inMem50 := xdcStateRootCache.blockRoots[50]
	xdcStateRootCache.RUnlock()
	if inMem1 {
		t.Error("Block 1 should have been evicted from memory")
	}
	if inMem50 {
		t.Error("Block 50 should have been evicted from memory")
	}

	// Newer entries (51-150) should remain in memory
	xdcStateRootCache.RLock()
	_, inMem51 := xdcStateRootCache.blockRoots[51]
	_, inMem150 := xdcStateRootCache.blockRoots[150]
	xdcStateRootCache.RUnlock()
	if !inMem51 {
		t.Error("Block 51 should still be cached in memory")
	}
	if !inMem150 {
		t.Error("Block 150 should still be cached in memory")
	}

	// Verify DB fallback still works for evicted entries
	if root, ok := XdcGetCachedStateRoot(1); !ok {
		t.Error("Block 1 should still be available via DB fallback")
	} else if root != common.BigToHash(common.Big2) {
		t.Errorf("DB fallback returned wrong root for block 1: got %v", root)
	}
}

func TestXdcStateRootCache_Stats(t *testing.T) {
	db := rawdb.NewMemoryDatabase()

	xdcStateRootCache.Lock()
	xdcStateRootCache.blockRoots = make(map[uint64]common.Hash)
	xdcStateRootCache.remoteToLocal = make(map[common.Hash]common.Hash)
	xdcStateRootCache.blockToRemote = make(map[uint64]common.Hash)
	xdcStateRootCache.db = db
	xdcStateRootCache.initialized = true
	xdcStateRootCache.Unlock()

	// Add some entries with unique remote roots
	for i := uint64(1800); i < 1810; i++ {
		remote := common.BigToHash(big.NewInt(int64(i)))
		local := common.HexToHash("0x2222")
		if err := XdcCacheStateRoot(i, local, remote); err != nil {
			t.Fatal(err)
		}
	}

	stats := XdcCacheStats()

	if blockRoots, ok := stats["blockRoots"].(int); !ok || blockRoots != 10 {
		t.Errorf("Stats blockRoots incorrect: got %v, want 10", stats["blockRoots"])
	}

	if remoteToLocal, ok := stats["remoteToLocal"].(int); !ok || remoteToLocal != 10 {
		t.Errorf("Stats remoteToLocal incorrect: got %v, want 10", stats["remoteToLocal"])
	}
}
