trie: fix for range proof (#21107)
* trie: fix for range proof * trie: fix typo
This commit is contained in:
parent
81e9caed7d
commit
070a5e1252
|
@ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error {
|
||||||
if len(left) != len(right) {
|
if len(left) != len(right) {
|
||||||
return errors.New("inconsistent edge path")
|
return errors.New("inconsistent edge path")
|
||||||
}
|
}
|
||||||
// Step down to the fork point
|
// Step down to the fork point. There are two scenarios can happen:
|
||||||
prefix, pos := prefixLen(left, right), 0
|
// - the fork point is a shortnode: the left proof MUST point to a
|
||||||
var parent node
|
// non-existent key and the key doesn't match with the shortnode
|
||||||
|
// - the fork point is a fullnode: the left proof can point to an
|
||||||
|
// existent key or not.
|
||||||
|
var (
|
||||||
|
pos = 0
|
||||||
|
parent node
|
||||||
|
)
|
||||||
|
findFork:
|
||||||
for {
|
for {
|
||||||
if pos >= prefix {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch rn := (n).(type) {
|
switch rn := (n).(type) {
|
||||||
case *shortNode:
|
case *shortNode:
|
||||||
|
// The right proof must point to an existent key.
|
||||||
if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) {
|
if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) {
|
||||||
return errors.New("invalid edge path")
|
return errors.New("invalid edge path")
|
||||||
}
|
}
|
||||||
|
rn.flags = nodeFlag{dirty: true}
|
||||||
// Special case, the non-existent proof points to the same path
|
// Special case, the non-existent proof points to the same path
|
||||||
// as the existent proof, but the path of existent proof is longer.
|
// as the existent proof, but the path of existent proof is longer.
|
||||||
// In this case, truncate the extra path(it should be recovered
|
// In this case, the fork point is this shortnode.
|
||||||
// by node insertion).
|
|
||||||
if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) {
|
if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) {
|
||||||
fn := parent.(*fullNode)
|
break findFork
|
||||||
fn.Children[left[pos-1]] = nil
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
rn.flags = nodeFlag{dirty: true}
|
|
||||||
parent = n
|
parent = n
|
||||||
n, pos = rn.Val, pos+len(rn.Key)
|
n, pos = rn.Val, pos+len(rn.Key)
|
||||||
case *fullNode:
|
case *fullNode:
|
||||||
|
leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
|
||||||
|
// The right proof must point to an existent key.
|
||||||
|
if rightnode == nil {
|
||||||
|
return errors.New("invalid edge path")
|
||||||
|
}
|
||||||
rn.flags = nodeFlag{dirty: true}
|
rn.flags = nodeFlag{dirty: true}
|
||||||
|
if leftnode != rightnode {
|
||||||
|
break findFork
|
||||||
|
}
|
||||||
parent = n
|
parent = n
|
||||||
n, pos = rn.Children[right[pos]], pos+1
|
n, pos = rn.Children[left[pos]], pos+1
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
|
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn, ok := n.(*fullNode)
|
switch rn := n.(type) {
|
||||||
if !ok {
|
case *shortNode:
|
||||||
return errors.New("the fork point must be a fullnode")
|
if _, ok := rn.Val.(valueNode); ok {
|
||||||
|
parent.(*fullNode).Children[right[pos-1]] = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return unset(rn, rn.Val, right[pos:], len(rn.Key), true)
|
||||||
|
case *fullNode:
|
||||||
|
for i := left[pos] + 1; i < right[pos]; i++ {
|
||||||
|
rn.Children[i] = nil
|
||||||
|
}
|
||||||
|
if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
|
||||||
}
|
}
|
||||||
// Find the fork point! Unset all intermediate references
|
|
||||||
for i := left[prefix] + 1; i < right[prefix]; i++ {
|
|
||||||
fn.Children[i] = nil
|
|
||||||
}
|
|
||||||
fn.flags = nodeFlag{dirty: true}
|
|
||||||
if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// unset removes all internal node references either the left most or right most.
|
// unset removes all internal node references either the left most or right most.
|
||||||
|
@ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error
|
||||||
// The key of fork shortnode is less than the
|
// The key of fork shortnode is less than the
|
||||||
// path(it doesn't belong to the range), keep
|
// path(it doesn't belong to the range), keep
|
||||||
// it with the cached hash available.
|
// it with the cached hash available.
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if _, ok := cld.Val.(valueNode); ok {
|
if _, ok := cld.Val.(valueNode); ok {
|
||||||
fn := parent.(*fullNode)
|
fn := parent.(*fullNode)
|
||||||
|
|
|
@ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) {
|
||||||
|
|
||||||
// TestSingleSideRangeProof tests the range starts from zero.
|
// TestSingleSideRangeProof tests the range starts from zero.
|
||||||
func TestSingleSideRangeProof(t *testing.T) {
|
func TestSingleSideRangeProof(t *testing.T) {
|
||||||
trie := new(Trie)
|
for i := 0; i < 64; i++ {
|
||||||
var entries entrySlice
|
trie := new(Trie)
|
||||||
for i := 0; i < 4096; i++ {
|
var entries entrySlice
|
||||||
value := &kv{randBytes(32), randBytes(20), false}
|
for i := 0; i < 4096; i++ {
|
||||||
trie.Update(value.k, value.v)
|
value := &kv{randBytes(32), randBytes(20), false}
|
||||||
entries = append(entries, value)
|
trie.Update(value.k, value.v)
|
||||||
}
|
entries = append(entries, value)
|
||||||
sort.Sort(entries)
|
}
|
||||||
|
sort.Sort(entries)
|
||||||
|
|
||||||
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
|
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
|
||||||
for _, pos := range cases {
|
for _, pos := range cases {
|
||||||
firstProof, lastProof := memorydb.New(), memorydb.New()
|
firstProof, lastProof := memorydb.New(), memorydb.New()
|
||||||
if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
|
if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
|
||||||
t.Fatalf("Failed to prove the first node %v", err)
|
t.Fatalf("Failed to prove the first node %v", err)
|
||||||
}
|
}
|
||||||
if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
|
if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
|
||||||
t.Fatalf("Failed to prove the first node %v", err)
|
t.Fatalf("Failed to prove the first node %v", err)
|
||||||
}
|
}
|
||||||
k := make([][]byte, 0)
|
k := make([][]byte, 0)
|
||||||
v := make([][]byte, 0)
|
v := make([][]byte, 0)
|
||||||
for i := 0; i <= pos; i++ {
|
for i := 0; i <= pos; i++ {
|
||||||
k = append(k, entries[i].k)
|
k = append(k, entries[i].k)
|
||||||
v = append(v, entries[i].v)
|
v = append(v, entries[i].v)
|
||||||
}
|
}
|
||||||
err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
|
err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Expected no error, got %v", err)
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue