From e476fb19bc0d55fb120772c175e438e21a0cefb1 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Sat, 21 Dec 2013 21:20:30 -0500 Subject: [PATCH] Added independent RLP and trie files --- processblock.py | 4 +- rlp.py | 42 +++++++------ trie.py | 155 ++++++++++++++++++++++++++++++++++++++++++++++++ trietest.py | 25 ++++++++ 4 files changed, 206 insertions(+), 20 deletions(-) create mode 100644 trie.py create mode 100644 trietest.py diff --git a/processblock.py b/processblock.py index 1be89ef022..4fb7700890 100644 --- a/processblock.py +++ b/processblock.py @@ -11,13 +11,13 @@ scriptcode_map = { 0x15: 'MOD', 0x16: 'SMOD', 0x17: 'EXP', + 0x18: 'NEG', 0x20: 'LT', 0x21: 'LE', 0x22: 'GT', 0x23: 'GE', 0x24: 'EQ', - 0x25: 'NEG', - 0x26: 'NOT', + 0x25: 'NOT', 0x30: 'SHA256', 0x31: 'RIPEMD-160', 0x32: 'ECMUL', diff --git a/rlp.py b/rlp.py index 3b2f16c15a..708c770d3e 100644 --- a/rlp.py +++ b/rlp.py @@ -14,37 +14,43 @@ def to_binary(n,L=None): return ''.join([chr(x) for x in to_binary_array(n,L)]) def from_binary(b): if len(b) == 0: return 0 - else: return ord(from_binary(b[:-1])) * 256 + b[-1] + else: return from_binary(b[:-1]) * 256 + ord(b[-1]) def num_to_var_int(n): - if n < 253: s = chr(n) - else if n < 2**16: s = [253] + list(reversed(to_binary_array(n,2))) - else if n < 2**32: s = [254] + list(reversed(to_binary_array(n,4))) - else if n < 2**64: s = [255] + list(reversed(to_binary_array(n,8))) - else raise Exception("number too big") + if n < 253: s = [n] + elif n < 2**16: s = [253] + list(to_binary_array(n,2)) + elif n < 2**32: s = [254] + list(to_binary_array(n,4)) + elif n < 2**64: s = [255] + list(to_binary_array(n,8)) + else: raise Exception("number too big") return ''.join([chr(x) for x in s]) -def decode(s): +def __decode(s): o = [] - index = 0 + index = [0] def read_var_int(): - si = ord(s[index]) - index += 1 - if si < 253: return s[index - 1] + si = ord(s[index[0]]) + index[0] += 1 + if si < 253: return si elif si == 253: read = 2 elif si == 254: read = 4 elif si == 255: read = 8 - index += read - return from_binary(s[index-read:index]) - while index < len(s): + index[0] += read + return from_binary(s[index[0]-read:index[0]]) + while index[0] < len(s): + tp = s[index[0]] + index[0] += 1 L = read_var_int() - o.append(s[index:index+L]) + item = s[index[0]:index[0]+L] + if tp == '\x00': o.append(item) + else: o.append(__decode(item)) + index[0] += L return o +def decode(s): return __decode(s)[0] + def encode(s): if isinstance(s,(int,long)): return encode(to_binary(s)) - if isinstance(s,str): return num_to_var_int(len(s))+s + if isinstance(s,str): return '\x00'+num_to_var_int(len(s))+s else: x = ''.join([encode(x) for x in s]) - return num_to_var_int(len(s))+s - + return '\x01'+num_to_var_int(len(x))+x diff --git a/trie.py b/trie.py new file mode 100644 index 0000000000..2d435744ee --- /dev/null +++ b/trie.py @@ -0,0 +1,155 @@ +import leveldb +import rlp +import hashlib + +def sha256(x): return hashlib.sha256(x).digest() + +class DB(): + def __init__(self,dbfile): self.db = leveldb.LevelDB(dbfile) + def get(self,key): + try: return self.db.Get(key) + except: return '' + def put(self,key,value): return self.db.Put(key,value) + def delete(self,key): return self.db.Delete(key) + +class Trie(): + def __init__(self,db,root='',debug=False): + self.root = root + self.db = DB(db) + self.debug = debug + + def __encode_key(self,key): + term = 1 if key[-1] == 16 else 0 + oddlen = (len(key) - term) % 2 + prefix = ('0' if oddlen else '') + main = ''.join(['0123456789abcdef'[x] for x in key[:len(key)-term]]) + return chr(2 * term + oddlen) + (prefix+main).decode('hex') + + def __decode_key(self,key): + o = ['0123456789abcdef'.find(x) for x in key[1:].encode('hex')] + if key[0] == '\x01' or key[0] == '\x03': o = o[1:] + if key[0] == '\x02' or key[0] == '\x03': o.append(16) + return o + + def __get_state(self,node,key): + if debug: print 'nk',node.encode('hex'),key + if len(key) == 0 or not node: + return node + curnode = rlp.decode(self.db.get(node)) + if debug: print 'cn', curnode + if not curnode: + raise Exception("node not found in database") + elif len(curnode) == 2: + (k2,v2) = curnode + k2 = self.__decode_key(k2) + if len(key) >= len(k2) and k2 == key[:len(k2)]: + return self.__get_state(v2,key[len(k2):]) + else: + return '' + elif len(curnode) == 17: + return self.__get_state(curnode[key[0]],key[1:]) + + def __put(self,node): + rlpnode = rlp.encode(node) + h = sha256(rlpnode) + self.db.put(h,rlpnode) + return h + + def __update_state(self,node,key,value): + if value != '': return self.__insert_state(node,key,value) + else: return self.__delete_state(node,key) + + def __insert_state(self,node,key,value): + if debug: print 'ink', node.encode('hex'), key + if len(key) == 0: + return value + else: + if not node: + newnode = [ self.__encode_key(key), value ] + return self.__put(newnode) + curnode = rlp.decode(self.db.get(node)) + if debug: print 'icn', curnode + if not curnode: + raise Exception("node not found in database") + if len(curnode) == 2: + (k2, v2) = curnode + k2 = self.__decode_key(k2) + if key == k2: + newnode = [ self.__encode_key(key), value ] + return self.__put(newnode) + else: + i = 0 + while key[:i+1] == k2[:i+1] and i < len(k2): i += 1 + if i == len(k2): + newhash3 = self.__insert_state(v2,key[i:],value) + else: + newnode1 = self.__insert_state('',key[i+1:],value) + newnode2 = self.__insert_state('',k2[i+1:],v2) + newnode3 = [ '' ] * 17 + newnode3[key[i]] = newnode1 + newnode3[k2[i]] = newnode2 + newhash3 = self.__put(newnode3) + if i == 0: + return newhash3 + else: + newnode4 = [ self.__encode_key(key[:i]), newhash3 ] + return self.__put(newnode4) + else: + newnode = [ curnode[i] for i in range(17) ] + newnode[key[0]] = self.__insert_state(curnode[key[0]],key[1:],value) + return self.__put(newnode) + + def __delete_state(self,node,key): + if debug: print 'dnk', node.encode('hex'), key + if len(key) == 0 or not node: + return '' + else: + curnode = rlp.decode(self.db.get(node)) + if not curnode: + raise Exception("node not found in database") + if debug: print 'dcn', curnode + if len(curnode) == 2: + (k2, v2) = curnode + k2 = self.__decode_key(k2) + if key == k2: + return '' + elif key[:len(k2)] == k2: + newhash = self.__delete_state(v2,key[len(k2):]) + childnode = rlp.decode(self.db.get(newhash)) + if len(childnode) == 2: + newkey = k2 + self.__decode_key(childnode[0]) + newnode = [ self.__encode_key(newkey), childnode[1] ] + else: + newnode = [ curnode[0], newhash ] + return self.__put(newnode) + else: return node + else: + newnode = [ curnode[i] for i in range(17) ] + newnode[key[0]] = self.__delete_state(newnode[key[0]],key[1:]) + onlynode = -1 + for i in range(17): + if newnode[i]: + if onlynode == -1: onlynode = i + else: onlynode = -2 + if onlynode >= 0: + childnode = rlp.decode(self.db.get(newnode[onlynode])) + if not childnode: + raise Exception("?????") + if len(childnode) == 17: + newnode2 = [ key[0], newnode[onlynode] ] + elif len(childnode) == 2: + newkey = [onlynode] + self.__decode_key(childnode[0]) + newnode2 = [ self.__encode_key(newkey), childnode[1] ] + else: + newnode2 = newnode + return self.__put(newnode2) + + def get(self,key): + key2 = ['0123456789abcdef'.find(x) for x in key.encode('hex')] + [16] + return self.__get_state(self.root,key2) + + def update(self,key,value): + if not isinstance(key,str) or not isinstance(value,str): + raise Exception("Key and value must be strings") + key2 = ['0123456789abcdef'.find(x) for x in key.encode('hex')] + [16] + self.root = self.__update_state(self.root,key2,value) diff --git a/trietest.py b/trietest.py new file mode 100644 index 0000000000..538c99a0f4 --- /dev/null +++ b/trietest.py @@ -0,0 +1,25 @@ +from trie import Trie +import random + +def genkey(): + L = random.randrange(30) + if random.randrange(5) == 0: return '' + return ''.join([random.choice('1234579qetyiasdfghjklzxcvbnm') for x in range(L)]) + +t = Trie('/tmp/'+genkey()) + +def trie_test(): + o = {} + for i in range(60): + key, value = genkey(), genkey() + if value: print "setting key: '"+key+"', value: '"+value+"'" + else: print "deleting key: '"+key+"'" + o[key] = value + t.update(key,value) + for k in o.keys(): + v1 = o[k] + v2 = t.get(k) + print v1,v2 + if v1 != v2: raise Exception("incorrect!") + +