diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go index c7a65e2a4b..0f7a4223fb 100644 --- a/core/state/snapshot/difflayer.go +++ b/core/state/snapshot/difflayer.go @@ -22,8 +22,6 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -169,125 +167,6 @@ func (dl *diffLayer) Update(blockRoot common.Hash, accounts map[common.Hash][]by return newDiffLayer(dl, dl.number+1, blockRoot, accounts, storage) } -// Cap traverses downwards the diff tree until the number of allowed layers are -// crossed. All diffs beyond the permitted number are flattened downwards. If -// the layer limit is reached, memory cap is also enforced (but not before). The -// block numbers for the disk layer and first diff layer are returned for GC. -func (dl *diffLayer) Cap(layers int, memory uint64) (uint64, uint64) { - // Dive until we run out of layers or reach the persistent database - if layers > 2 { - // If we still have diff layers below, recurse - if parent, ok := dl.parent.(*diffLayer); ok { - return parent.Cap(layers-1, memory) - } - // Diff stack too shallow, return block numbers without modifications - return dl.parent.(*diskLayer).number, dl.number - } - // We're out of layers, flatten anything below, stopping if it's the disk or if - // the memory limit is not yet exceeded. - switch parent := dl.parent.(type) { - case *diskLayer: - return parent.number, dl.number - case *diffLayer: - // Flatten the parent into the grandparent. The flattening internally obtains a - // write lock on grandparent. - flattened := parent.flatten().(*diffLayer) - - dl.lock.Lock() - defer dl.lock.Unlock() - - dl.parent = flattened - if flattened.memory < memory { - diskNumber, _ := flattened.parent.Info() - return diskNumber, flattened.number - } - default: - panic(fmt.Sprintf("unknown data layer: %T", parent)) - } - // If the bottommost layer is larger than our memory cap, persist to disk - var ( - parent = dl.parent.(*diffLayer) - base = parent.parent.(*diskLayer) - batch = base.db.NewBatch() - ) - parent.lock.RLock() - defer parent.lock.RUnlock() - - // Start by temporarily deleting the current snapshot block marker. This - // ensures that in the case of a crash, the entire snapshot is invalidated. - rawdb.DeleteSnapshotBlock(batch) - - // Mark the original base as stale as we're going to create a new wrapper - base.lock.Lock() - if base.stale { - panic("parent disk layer is stale") // we've committed into the same base from two children, boo - } - base.stale = true - base.lock.Unlock() - - // Push all the accounts into the database - for hash, data := range parent.accountData { - if len(data) > 0 { - // Account was updated, push to disk - rawdb.WriteAccountSnapshot(batch, hash, data) - base.cache.Set(string(hash[:]), data) - - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { - log.Crit("Failed to write account snapshot", "err", err) - } - batch.Reset() - } - } else { - // Account was deleted, remove all storage slots too - rawdb.DeleteAccountSnapshot(batch, hash) - base.cache.Set(string(hash[:]), nil) - - it := rawdb.IterateStorageSnapshots(base.db, hash) - for it.Next() { - if key := it.Key(); len(key) == 65 { // TODO(karalabe): Yuck, we should move this into the iterator - batch.Delete(key) - base.cache.Delete(string(key[1:])) - } - } - it.Release() - } - } - // Push all the storage slots into the database - for accountHash, storage := range parent.storageData { - for storageHash, data := range storage { - if len(data) > 0 { - rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data) - base.cache.Set(string(append(accountHash[:], storageHash[:]...)), data) - } else { - rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash) - base.cache.Set(string(append(accountHash[:], storageHash[:]...)), nil) - } - } - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { - log.Crit("Failed to write storage snapshot", "err", err) - } - batch.Reset() - } - } - // Update the snapshot block marker and write any remainder data - newBase := &diskLayer{ - root: parent.root, - number: parent.number, - cache: base.cache, - db: base.db, - journal: base.journal, - } - rawdb.WriteSnapshotBlock(batch, newBase.number, newBase.root) - if err := batch.Write(); err != nil { - log.Crit("Failed to write leftover snapshot", "err", err) - } - dl.parent = newBase - - return newBase.number, dl.number -} - // flatten pushes all data from this point downwards, flattening everything into // a single diff at the bottom. Since usually the lowermost diff is the largest, // the flattening bulds up from there in reverse. diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go index 5a718c6171..5499f20161 100644 --- a/core/state/snapshot/difflayer_test.go +++ b/core/state/snapshot/difflayer_test.go @@ -18,15 +18,11 @@ package snapshot import ( "bytes" - "fmt" "math/big" "math/rand" "testing" - "time" - "github.com/allegro/bigcache" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/rlp" ) @@ -192,113 +188,12 @@ func TestInsertAndMerge(t *testing.T) { } } -// TestCapTree tests some functionality regarding capping/flattening -func TestCapTree(t *testing.T) { - - var ( - storage = make(map[common.Hash]map[common.Hash][]byte) - ) - setAccount := func(accKey string) map[common.Hash][]byte { - return map[common.Hash][]byte{ - common.HexToHash(accKey): randomAccount(), - } - } - // the bottom-most layer, aside from the 'disk layer' - cache, _ := bigcache.NewBigCache(bigcache.Config{ // TODO(karalabe): dedup - Shards: 1, - LifeWindow: time.Hour, - MaxEntriesInWindow: 1 * 1024, - MaxEntrySize: 1, - HardMaxCacheSize: 1, - }) - - base := &diskLayer{ - journal: "", - db: rawdb.NewMemoryDatabase(), - cache: cache, - number: 0, - root: common.HexToHash("0x01"), - } - // The lowest difflayer - a1 := base.Update(common.HexToHash("0xa1"), setAccount("0xa1"), storage) - - a2 := a1.Update(common.HexToHash("0xa2"), setAccount("0xa2"), storage) - b2 := a1.Update(common.HexToHash("0xb2"), setAccount("0xb2"), storage) - - a3 := a2.Update(common.HexToHash("0xa3"), setAccount("0xa3"), storage) - b3 := b2.Update(common.HexToHash("0xb3"), setAccount("0xb3"), storage) - - checkExist := func(layer *diffLayer, key string) error { - accountKey := common.HexToHash(key) - data, _ := layer.Account(accountKey) - if data == nil { - return fmt.Errorf("expected %x to exist, got nil", accountKey) - } - return nil - } - shouldErr := func(layer *diffLayer, key string) error { - accountKey := common.HexToHash(key) - data, err := layer.Account(accountKey) - if err == nil { - return fmt.Errorf("expected error, got data %x", data) - } - return nil - } - - // check basics - if err := checkExist(b3, "0xa1"); err != nil { - t.Error(err) - } - if err := checkExist(b3, "0xb2"); err != nil { - t.Error(err) - } - if err := checkExist(b3, "0xb3"); err != nil { - t.Error(err) - } - // Now, merge the a-chain - diskNum, diffNum := a3.Cap(0, 1024) - if diskNum != 0 { - t.Errorf("disk layer err, got %d exp %d", diskNum, 0) - } - if diffNum != 2 { - t.Errorf("diff layer err, got %d exp %d", diffNum, 2) - } - // At this point, a2 got merged into a1. Thus, a1 is now modified, - // and as a1 is the parent of b2, b2 should no longer be able to iterate into parent - - // These should still be accessible - if err := checkExist(b3, "0xb2"); err != nil { - t.Error(err) - } - if err := checkExist(b3, "0xb3"); err != nil { - t.Error(err) - } - //b2ParentNum, _ := b2.parent.Info() - //if b2.parent.invalid == false - // t.Errorf("err, exp parent to be invalid, got %v", b2.parent, b2ParentNum) - //} - // But these would need iteration into the modified parent: - if err := shouldErr(b3, "0xa1"); err != nil { - t.Error(err) - } - if err := shouldErr(b3, "0xa2"); err != nil { - t.Error(err) - } - if err := shouldErr(b3, "0xa3"); err != nil { - t.Error(err) - } -} - type emptyLayer struct{} func (emptyLayer) Update(blockRoot common.Hash, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { panic("implement me") } -func (emptyLayer) Cap(layers int, memory uint64) (uint64, uint64) { - panic("implement me") -} - func (emptyLayer) Journal() error { panic("implement me") } @@ -403,7 +298,6 @@ func BenchmarkSearchSlot(b *testing.B) { // Without sorting and tracking accountlist // BenchmarkFlatten-6 300 5511511 ns/op func BenchmarkFlatten(b *testing.B) { - fill := func(parent snapshot, blocknum int) *diffLayer { accounts := make(map[common.Hash][]byte) storage := make(map[common.Hash]map[common.Hash][]byte) diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go index a9839f01a8..50321f1540 100644 --- a/core/state/snapshot/disklayer.go +++ b/core/state/snapshot/disklayer.go @@ -126,12 +126,6 @@ func (dl *diskLayer) Update(blockHash common.Hash, accounts map[common.Hash][]by return newDiffLayer(dl, dl.number+1, blockHash, accounts, storage) } -// Cap traverses downwards the diff tree until the number of allowed layers are -// crossed. All diffs beyond the permitted number are flattened downwards. -func (dl *diskLayer) Cap(layers int, memory uint64) (uint64, uint64) { - return dl.number, dl.number -} - // Journal commits an entire diff hierarchy to disk into a single journal file. func (dl *diskLayer) Journal() error { // There's no journalling a disk layer diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 6a21d57dcb..a181789779 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -73,11 +73,6 @@ type snapshot interface { // copying everything. Update(blockRoot common.Hash, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer - // Cap traverses downwards the diff tree until the number of allowed layers are - // crossed. All diffs beyond the permitted number are flattened downwards. The - // block numbers for the disk layer and first diff layer are returned for GC. - Cap(layers int, memory uint64) (uint64, uint64) - // Journal commits an entire diff hierarchy to disk into a single journal file. // This is meant to be used during shutdown to persist the snapshot without // flattening everything down (bad for reorgs). @@ -169,11 +164,56 @@ func (st *SnapshotTree) Cap(blockRoot common.Hash, layers int, memory uint64) er if snap == nil { return fmt.Errorf("snapshot [%#x] missing", blockRoot) } + diff, ok := snap.(*diffLayer) + if !ok { + return fmt.Errorf("snapshot [%#x] is base layer", blockRoot) + } // Run the internal capping and discard all stale layers st.lock.Lock() defer st.lock.Unlock() - diskNumber, diffNumber := snap.Cap(layers, memory) + var ( + diskNumber uint64 + diffNumber uint64 + ) + // Flattening the bottom-most diff layer requires special casing since there's + // no child to rewire to the grandparent. In that case we can fake a temporary + // child for the capping and then remove it. + switch layers { + case 0: + // If full commit was requested, flatten the diffs and merge onto disk + diff.lock.RLock() + base := diffToDisk(diff.flatten().(*diffLayer)) + diff.lock.RUnlock() + + st.layers[base.root] = base + diskNumber, diffNumber = base.number, base.number + + case 1: + // If full flattening was requested, flatten the diffs but only merge if the + // memory limit was reached + var ( + bottom *diffLayer + base *diskLayer + ) + diff.lock.RLock() + bottom = diff.flatten().(*diffLayer) + if bottom.memory >= memory { + base = diffToDisk(bottom) + } + diff.lock.RUnlock() + + if base != nil { + st.layers[base.root] = base + diskNumber, diffNumber = base.number, base.number + } else { + st.layers[bottom.root] = bottom + diskNumber, diffNumber = bottom.parent.(*diskLayer).number, bottom.number + } + + default: + diskNumber, diffNumber = st.cap(diff, layers, memory) + } for root, snap := range st.layers { if number, _ := snap.Info(); number != diskNumber && number < diffNumber { delete(st.layers, root) @@ -182,6 +222,135 @@ func (st *SnapshotTree) Cap(blockRoot common.Hash, layers int, memory uint64) er return nil } +// cap traverses downwards the diff tree until the number of allowed layers are +// crossed. All diffs beyond the permitted number are flattened downwards. If +// the layer limit is reached, memory cap is also enforced (but not before). The +// block numbers for the disk layer and first diff layer are returned for GC. +func (st *SnapshotTree) cap(diff *diffLayer, layers int, memory uint64) (uint64, uint64) { + // Dive until we run out of layers or reach the persistent database + if layers > 2 { + // If we still have diff layers below, recurse + if parent, ok := diff.parent.(*diffLayer); ok { + return st.cap(parent, layers-1, memory) + } + // Diff stack too shallow, return block numbers without modifications + return diff.parent.(*diskLayer).number, diff.number + } + // We're out of layers, flatten anything below, stopping if it's the disk or if + // the memory limit is not yet exceeded. + switch parent := diff.parent.(type) { + case *diskLayer: + return parent.number, diff.number + + case *diffLayer: + // Flatten the parent into the grandparent. The flattening internally obtains a + // write lock on grandparent. + flattened := parent.flatten().(*diffLayer) + st.layers[flattened.root] = flattened + + diff.lock.Lock() + defer diff.lock.Unlock() + + diff.parent = flattened + if flattened.memory < memory { + diskNumber, _ := flattened.parent.Info() + return diskNumber, flattened.number + } + default: + panic(fmt.Sprintf("unknown data layer: %T", parent)) + } + // If the bottom-most layer is larger than our memory cap, persist to disk + bottom := diff.parent.(*diffLayer) + + bottom.lock.RLock() + base := diffToDisk(bottom) + bottom.lock.RUnlock() + + st.layers[base.root] = base + diff.parent = base + + return base.number, diff.number +} + +// diffToDisk merges a bottom-most diff into the persistent disk layer underneath +// it. The method will panic if called onto a non-bottom-most diff layer. +func diffToDisk(bottom *diffLayer) *diskLayer { + var ( + base = bottom.parent.(*diskLayer) + batch = base.db.NewBatch() + ) + // Start by temporarily deleting the current snapshot block marker. This + // ensures that in the case of a crash, the entire snapshot is invalidated. + rawdb.DeleteSnapshotBlock(batch) + + // Mark the original base as stale as we're going to create a new wrapper + base.lock.Lock() + if base.stale { + panic("parent disk layer is stale") // we've committed into the same base from two children, boo + } + base.stale = true + base.lock.Unlock() + + // Push all the accounts into the database + for hash, data := range bottom.accountData { + if len(data) > 0 { + // Account was updated, push to disk + rawdb.WriteAccountSnapshot(batch, hash, data) + base.cache.Set(string(hash[:]), data) + + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write account snapshot", "err", err) + } + batch.Reset() + } + } else { + // Account was deleted, remove all storage slots too + rawdb.DeleteAccountSnapshot(batch, hash) + base.cache.Set(string(hash[:]), nil) + + it := rawdb.IterateStorageSnapshots(base.db, hash) + for it.Next() { + if key := it.Key(); len(key) == 65 { // TODO(karalabe): Yuck, we should move this into the iterator + batch.Delete(key) + base.cache.Delete(string(key[1:])) + } + } + it.Release() + } + } + // Push all the storage slots into the database + for accountHash, storage := range bottom.storageData { + for storageHash, data := range storage { + if len(data) > 0 { + rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data) + base.cache.Set(string(append(accountHash[:], storageHash[:]...)), data) + } else { + rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash) + base.cache.Set(string(append(accountHash[:], storageHash[:]...)), nil) + } + } + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write storage snapshot", "err", err) + } + batch.Reset() + } + } + // Update the snapshot block marker and write any remainder data + rawdb.WriteSnapshotBlock(batch, bottom.number, bottom.root) + if err := batch.Write(); err != nil { + log.Crit("Failed to write leftover snapshot", "err", err) + } + return &diskLayer{ + root: bottom.root, + number: bottom.number, + cache: base.cache, + db: base.db, + journal: base.journal, + } +} + // Journal commits an entire diff hierarchy to disk into a single journal file. // This is meant to be used during shutdown to persist the snapshot without // flattening everything down (bad for reorgs). diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go index 903bd4a6f6..ecd39bf3e8 100644 --- a/core/state/snapshot/snapshot_test.go +++ b/core/state/snapshot/snapshot_test.go @@ -15,3 +15,289 @@ // along with the go-ethereum library. If not, see . package snapshot + +import ( + "fmt" + "testing" + "time" + + "github.com/allegro/bigcache" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" +) + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(time.Minute)) + base := &diskLayer{ + db: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: cache, + } + snaps := &SnapshotTree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit a diff on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + storage := make(map[common.Hash]map[common.Hash][]byte) + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 2) + } + // Commit the diff layer onto the disk and ensure it's persisted + if err := snaps.Cap(common.HexToHash("0x02"), 0, 0); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 1 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 1) + fmt.Println(snaps.layers) + } +} + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(time.Minute)) + base := &diskLayer{ + db: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: cache, + } + snaps := &SnapshotTree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit two diffs on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + storage := make(map[common.Hash]map[common.Hash][]byte) + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + // Commit the diff layer onto the disk and ensure it's persisted + if err := snaps.Cap(common.HexToHash("0x03"), 2, 0); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(time.Minute)) + base := &diskLayer{ + db: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: cache, + } + snaps := &SnapshotTree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit two diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + storage := make(map[common.Hash]map[common.Hash][]byte) + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x03"), 1, 1024*1024); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(time.Minute)) + base := &diskLayer{ + db: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: cache, + } + snaps := &SnapshotTree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit three diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + storage := make(map[common.Hash]map[common.Hash][]byte) + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), accounts, storage); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 4 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 4) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x04"), 2, 1024*1024); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 3) + fmt.Println(snaps.layers) + } +} + +// TestPostCapBasicDataAccess tests some functionality regarding capping/flattening. +func TestPostCapBasicDataAccess(t *testing.T) { + // setAccount is a helper to construct a random account entry and assign it to + // an account slot in a snapshot + setAccount := func(accKey string) map[common.Hash][]byte { + return map[common.Hash][]byte{ + common.HexToHash(accKey): randomAccount(), + } + } + // Create a starting base layer and a snapshot tree out of it + cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(time.Minute)) + base := &diskLayer{ + db: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: cache, + } + snaps := &SnapshotTree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // The lowest difflayer + snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), setAccount("0xa1"), nil) + snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), setAccount("0xa2"), nil) + snaps.Update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), setAccount("0xb2"), nil) + + snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), setAccount("0xa3"), nil) + snaps.Update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), setAccount("0xb3"), nil) + + // checkExist verifies if an account exiss in a snapshot + checkExist := func(layer *diffLayer, key string) error { + if data, _ := layer.Account(common.HexToHash(key)); data == nil { + return fmt.Errorf("expected %x to exist, got nil", common.HexToHash(key)) + } + return nil + } + // shouldErr checks that an account access errors as expected + shouldErr := func(layer *diffLayer, key string) error { + if data, err := layer.Account(common.HexToHash(key)); err == nil { + return fmt.Errorf("expected error, got data %x", data) + } + return nil + } + // check basics + snap := snaps.Snapshot(common.HexToHash("0xb3")).(*diffLayer) + + if err := checkExist(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // Now, merge the a-chain + snaps.Cap(common.HexToHash("0xa3"), 0, 1024) + + // At this point, a2 got merged into a1. Thus, a1 is now modified, and as a1 is + // the parent of b2, b2 should no longer be able to iterate into parent. + + // These should still be accessible + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // But these would need iteration into the modified parent + if err := shouldErr(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa2"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa3"); err != nil { + t.Error(err) + } +}