diff --git a/core/rawdb/table.go b/core/rawdb/table.go index 6610b7f5a2..092341fbfe 100644 --- a/core/rawdb/table.go +++ b/core/rawdb/table.go @@ -113,13 +113,21 @@ func (t *table) NewIterator() ethdb.Iterator { // database content starting at a particular initial key (or after, if it does // not exist). func (t *table) NewIteratorWithStart(start []byte) ethdb.Iterator { - return t.db.NewIteratorWithStart(start) + iter := t.db.NewIteratorWithStart(append([]byte(t.prefix), start...)) + return &tableIterator{ + iter: iter, + prefix: t.prefix, + } } // NewIteratorWithPrefix creates a binary-alphabetical iterator over a subset // of database content with a particular key prefix. func (t *table) NewIteratorWithPrefix(prefix []byte) ethdb.Iterator { - return t.db.NewIteratorWithPrefix(append([]byte(t.prefix), prefix...)) + iter := t.db.NewIteratorWithPrefix(append([]byte(t.prefix), prefix...)) + return &tableIterator{ + iter: iter, + prefix: t.prefix, + } } // Stat returns a particular internal stat of the database. @@ -198,7 +206,69 @@ func (b *tableBatch) Reset() { b.batch.Reset() } +// tableReplayer is a wrapper around a batch replayer which truncates +// the added prefix. +type tableReplayer struct { + w ethdb.KeyValueWriter + prefix string +} + +// Put implements the interface KeyValueWriter. +func (r *tableReplayer) Put(key []byte, value []byte) error { + trimmed := key[len(r.prefix):] + return r.w.Put(trimmed, value) +} + +// Delete implements the interface KeyValueWriter. +func (r *tableReplayer) Delete(key []byte) error { + trimmed := key[len(r.prefix):] + return r.w.Delete(trimmed) +} + // Replay replays the batch contents. func (b *tableBatch) Replay(w ethdb.KeyValueWriter) error { - return b.batch.Replay(w) + return b.batch.Replay(&tableReplayer{w: w, prefix: b.prefix}) +} + +// tableIterator is a wrapper around a database iterator that prefixes each key access +// with a pre-configured string. +type tableIterator struct { + iter ethdb.Iterator + prefix string +} + +// Next moves the iterator to the next key/value pair. It returns whether the +// iterator is exhausted. +func (iter *tableIterator) Next() bool { + return iter.iter.Next() +} + +// Error returns any accumulated error. Exhausting all the key/value pairs +// is not considered to be an error. +func (iter *tableIterator) Error() error { + return iter.iter.Error() +} + +// Key returns the key of the current key/value pair, or nil if done. The caller +// should not modify the contents of the returned slice, and its contents may +// change on the next call to Next. +func (iter *tableIterator) Key() []byte { + key := iter.iter.Key() + if key == nil { + return nil + } + return key[len(iter.prefix):] +} + +// Value returns the value of the current key/value pair, or nil if done. The +// caller should not modify the contents of the returned slice, and its contents +// may change on the next call to Next. +func (iter *tableIterator) Value() []byte { + return iter.iter.Value() +} + +// Release releases associated resources. Release should always succeed and can +// be called multiple times without causing error. +func (iter *tableIterator) Release() { + iter.iter.Release() } diff --git a/core/rawdb/table_test.go b/core/rawdb/table_test.go new file mode 100644 index 0000000000..977eb4bf05 --- /dev/null +++ b/core/rawdb/table_test.go @@ -0,0 +1,143 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "bytes" + "testing" +) + +func TestTableDatabase(t *testing.T) { testTableDatabase(t, "prefix") } +func TestEmptyPrefixTableDatabase(t *testing.T) { testTableDatabase(t, "") } + +type testReplayer struct { + puts [][]byte + dels [][]byte +} + +func (r *testReplayer) Put(key []byte, value []byte) error { + r.puts = append(r.puts, key) + return nil +} + +func (r *testReplayer) Delete(key []byte) error { + r.dels = append(r.dels, key) + return nil +} + +func testTableDatabase(t *testing.T, prefix string) { + db := NewTable(NewMemoryDatabase(), prefix) + + var entries = []struct { + key []byte + value []byte + }{ + {[]byte{0x01, 0x02}, []byte{0x0a, 0x0b}}, + {[]byte{0x03, 0x04}, []byte{0x0c, 0x0d}}, + {[]byte{0x05, 0x06}, []byte{0x0e, 0x0f}}, + + {[]byte{0xff, 0xff, 0x01}, []byte{0x1a, 0x1b}}, + {[]byte{0xff, 0xff, 0x02}, []byte{0x1c, 0x1d}}, + {[]byte{0xff, 0xff, 0x03}, []byte{0x1e, 0x1f}}, + } + + // Test Put/Get operation + for _, entry := range entries { + db.Put(entry.key, entry.value) + } + for _, entry := range entries { + got, err := db.Get(entry.key) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + if !bytes.Equal(got, entry.value) { + t.Fatalf("Value mismatch: want=%v, got=%v", entry.value, got) + } + } + + // Test batch operation + db = NewTable(NewMemoryDatabase(), prefix) + batch := db.NewBatch() + for _, entry := range entries { + batch.Put(entry.key, entry.value) + } + batch.Write() + for _, entry := range entries { + got, err := db.Get(entry.key) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + if !bytes.Equal(got, entry.value) { + t.Fatalf("Value mismatch: want=%v, got=%v", entry.value, got) + } + } + + // Test batch replayer + r := &testReplayer{} + batch.Replay(r) + for index, entry := range entries { + got := r.puts[index] + if !bytes.Equal(got, entry.key) { + t.Fatalf("Key mismatch: want=%v, got=%v", entry.key, got) + } + } + + // Test iterators + iter := db.NewIterator() + var index int + for iter.Next() { + key, value := iter.Key(), iter.Value() + if !bytes.Equal(key, entries[index].key) { + t.Fatalf("Key mismatch: want=%v, got=%v", entries[index].key, key) + } + if !bytes.Equal(value, entries[index].value) { + t.Fatalf("Value mismatch: want=%v, got=%v", entries[index].value, value) + } + index += 1 + } + iter.Release() + + // Test iterators with prefix + iter = db.NewIteratorWithPrefix([]byte{0xff, 0xff}) + index = 3 + for iter.Next() { + key, value := iter.Key(), iter.Value() + if !bytes.Equal(key, entries[index].key) { + t.Fatalf("Key mismatch: want=%v, got=%v", entries[index].key, key) + } + if !bytes.Equal(value, entries[index].value) { + t.Fatalf("Value mismatch: want=%v, got=%v", entries[index].value, value) + } + index += 1 + } + iter.Release() + + // Test iterators with start point + iter = db.NewIteratorWithStart([]byte{0xff, 0xff, 0x02}) + index = 4 + for iter.Next() { + key, value := iter.Key(), iter.Value() + if !bytes.Equal(key, entries[index].key) { + t.Fatalf("Key mismatch: want=%v, got=%v", entries[index].key, key) + } + if !bytes.Equal(value, entries[index].value) { + t.Fatalf("Value mismatch: want=%v, got=%v", entries[index].value, value) + } + index += 1 + } + iter.Release() +}