diff --git a/merkle_tree.go b/merkle_tree.go new file mode 100644 index 0000000..7a4156b --- /dev/null +++ b/merkle_tree.go @@ -0,0 +1,65 @@ +package main + +import ( + "crypto/sha256" +) + +// MerkleTree represent a Merkle tree +type MerkleTree struct { + RootNode *MerkleNode +} + +// MerkleNode represent a Merkle tree node +type MerkleNode struct { + Left *MerkleNode + Right *MerkleNode + Data []byte +} + +// NewMerkleTree creates a new Merkle tree from a sequence of data +func NewMerkleTree(data [][]byte) *MerkleTree { + var nodes []MerkleNode + + if len(data)%2 != 0 { + data = append(data, data[len(data)-1]) + } + + for _, datum := range data { + node := NewMerkleNode(nil, nil, datum) + nodes = append(nodes, *node) + } + + for i := 0; i < len(data)/2; i++ { + var newLevel []MerkleNode + + for j := 0; j < len(nodes); j += 2 { + node := NewMerkleNode(&nodes[j], &nodes[j+1], nil) + newLevel = append(newLevel, *node) + } + + nodes = newLevel + } + + mTree := MerkleTree{&nodes[0]} + + return &mTree +} + +// NewMerkleNode creates a new Merkle tree node +func NewMerkleNode(left, right *MerkleNode, data []byte) *MerkleNode { + mNode := MerkleNode{} + + if left == nil && right == nil { + hash := sha256.Sum256(data) + mNode.Data = hash[:] + } else { + prevHashes := append(left.Data, right.Data...) + hash := sha256.Sum256(prevHashes) + mNode.Data = hash[:] + } + + mNode.Left = left + mNode.Right = right + + return &mNode +} diff --git a/merkle_tree_test.go b/merkle_tree_test.go new file mode 100644 index 0000000..acff5ff --- /dev/null +++ b/merkle_tree_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewMerkleNode(t *testing.T) { + data := [][]byte{ + []byte("node1"), + []byte("node2"), + []byte("node3"), + } + + // Level 1 + + n1 := NewMerkleNode(nil, nil, data[0]) + n2 := NewMerkleNode(nil, nil, data[1]) + n3 := NewMerkleNode(nil, nil, data[2]) + n4 := NewMerkleNode(nil, nil, data[2]) + + // Level 2 + n5 := NewMerkleNode(n1, n2, nil) + n6 := NewMerkleNode(n3, n4, nil) + + // Level 3 + n7 := NewMerkleNode(n5, n6, nil) + + assert.Equal( + t, + "64b04b718d8b7c5b6fd17f7ec221945c034cfce3be4118da33244966150c4bd4", + hex.EncodeToString(n5.Data), + "Level 1 hash 1 is correct", + ) + assert.Equal( + t, + "08bd0d1426f87a78bfc2f0b13eccdf6f5b58dac6b37a7b9441c1a2fab415d76c", + hex.EncodeToString(n6.Data), + "Level 1 hash 2 is correct", + ) + assert.Equal( + t, + "4e3e44e55926330ab6c31892f980f8bfd1a6e910ff1ebc3f778211377f35227e", + hex.EncodeToString(n7.Data), + "Root hash is correct", + ) +} + +func TestNewMerkleTree(t *testing.T) { + data := [][]byte{ + []byte("node1"), + []byte("node2"), + []byte("node3"), + } + // Level 1 + n1 := NewMerkleNode(nil, nil, data[0]) + n2 := NewMerkleNode(nil, nil, data[1]) + n3 := NewMerkleNode(nil, nil, data[2]) + n4 := NewMerkleNode(nil, nil, data[2]) + + // Level 2 + n5 := NewMerkleNode(n1, n2, nil) + n6 := NewMerkleNode(n3, n4, nil) + + // Level 3 + n7 := NewMerkleNode(n5, n6, nil) + + rootHash := fmt.Sprintf("%x", n7.Data) + mTree := NewMerkleTree(data) + + assert.Equal(t, rootHash, fmt.Sprintf("%x", mTree.RootNode.Data), "Merkle tree root hash is correct") +}