diff --git a/blockchain.go b/blockchain.go index 1874f45..b6efe0f 100644 --- a/blockchain.go +++ b/blockchain.go @@ -109,8 +109,24 @@ func (bc *Blockchain) AddBlock(block *Block) { blockInDb := b.Get(block.Hash) if blockInDb != nil { - blockData := block.Serialize() - b.Put(block.Hash, blockData) + return nil + } + + blockData := block.Serialize() + err := b.Put(block.Hash, blockData) + if err != nil { + log.Panic(err) + } + + lastHash := b.Get([]byte("l")) + lastBlockData := b.Get(lastHash) + lastBlock := DeserializeBlock(lastBlockData) + + if block.Height > lastBlock.Height { + err = b.Put([]byte("l"), block.Hash) + if err != nil { + log.Panic(err) + } } return nil @@ -192,6 +208,25 @@ func (bc *Blockchain) Iterator() *BlockchainIterator { return bci } +// GetBestHeight returns the height of the latest block +func (bc *Blockchain) GetBestHeight() int { + var lastBlock Block + + err := bc.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(blocksBucket)) + lastHash := b.Get([]byte("l")) + blockData := b.Get(lastHash) + lastBlock = *DeserializeBlock(blockData) + + return nil + }) + if err != nil { + log.Panic(err) + } + + return lastBlock.Height +} + // GetBlock finds a block by its hash and returns it func (bc *Blockchain) GetBlock(blockHash []byte) (Block, error) { var block Block diff --git a/server.go b/server.go index bfc9d2b..22fb2dd 100644 --- a/server.go +++ b/server.go @@ -11,12 +11,12 @@ import ( ) const protocol = "tcp" -const dnsNodeID = "3000" const nodeVersion = 1 const commandLength = 12 var nodeAddress string -var knownNodes []string +var knownNodes = []string{"localhost:3000"} +var blocksInTransit = [][]byte{} type addr struct { AddrList []string @@ -43,13 +43,10 @@ type inv struct { Items [][]byte } -type verack struct { -} - type verzion struct { - Version int - - AddrFrom string + Version int + BestHeight int + AddrFrom string } func commandToBytes(command string) []byte { @@ -136,22 +133,15 @@ func sendGetData(address, kind string, id []byte) { sendData(address, request) } -func sendVersion(addr string) { - payload := gobEncode(verzion{nodeVersion, nodeAddress}) +func sendVersion(addr string, bc *Blockchain) { + bestHeight := bc.GetBestHeight() + payload := gobEncode(verzion{nodeVersion, bestHeight, nodeAddress}) request := append(commandToBytes("version"), payload...) sendData(addr, request) } -func sendVrack(addr string) { - payload := gobEncode(verack{}) - - request := append(commandToBytes("verack"), payload...) - - sendData(addr, request) -} - func handleAddr(request []byte) { var buff bytes.Buffer var payload addr @@ -184,6 +174,16 @@ func handleBlock(request []byte, bc *Blockchain) { fmt.Println("Recevied a new block!") bc.AddBlock(block) + fmt.Printf("Added block %x\n", block.Hash) + fmt.Printf("Added block %d\n", block.Height) + + fmt.Println(blocksInTransit) + if len(blocksInTransit) > 0 { + blockHash := blocksInTransit[0] + sendGetData(payload.AddrFrom, "block", blockHash) + + blocksInTransit = blocksInTransit[1:] + } } func handleInv(request []byte, bc *Blockchain) { @@ -198,12 +198,20 @@ func handleInv(request []byte, bc *Blockchain) { } fmt.Printf("Recevied inventory with %d %s\n", len(payload.Items), payload.Type) - blocks := bc.GetBlockHashes() - if len(blocks) < len(payload.Items) { - for _, blockHash := range payload.Items { - sendGetData(payload.AddrFrom, "block", blockHash) + if payload.Type == "blocks" { + blocksInTransit = payload.Items + + blockHash := payload.Items[0] + sendGetData(payload.AddrFrom, "block", blockHash) + + newInTransit := [][]byte{} + for _, b := range blocksInTransit { + if bytes.Compare(b, blockHash) != 0 { + newInTransit = append(newInTransit, b) + } } + blocksInTransit = newInTransit } } @@ -243,7 +251,7 @@ func handleGetData(request []byte, bc *Blockchain) { } } -func handleVersion(request []byte) { +func handleVersion(request []byte, bc *Blockchain) { var buff bytes.Buffer var payload verzion @@ -254,9 +262,19 @@ func handleVersion(request []byte) { log.Panic(err) } - sendVrack(payload.AddrFrom) - sendAddr(payload.AddrFrom) - knownNodes = append(knownNodes, payload.AddrFrom) + myBestHeight := bc.GetBestHeight() + foreignerBestHeight := payload.BestHeight + + if myBestHeight < foreignerBestHeight { + sendGetBlocks(payload.AddrFrom) + } else { + sendVersion(payload.AddrFrom, bc) + } + + // sendAddr(payload.AddrFrom) + if !nodeIsKnown(payload.AddrFrom) { + knownNodes = append(knownNodes, payload.AddrFrom) + } } func handleConnection(conn net.Conn, bc *Blockchain) { @@ -279,9 +297,7 @@ func handleConnection(conn net.Conn, bc *Blockchain) { case "getdata": handleGetData(request, bc) case "version": - handleVersion(request) - case "verack": - // + handleVersion(request, bc) default: fmt.Println("Unknown command!") } @@ -298,12 +314,12 @@ func StartServer(nodeID string) { } defer ln.Close() - if nodeID != dnsNodeID { - sendVersion(fmt.Sprintf("localhost:%s", dnsNodeID)) - } - bc := NewBlockchain(nodeID) + if nodeAddress != knownNodes[0] { + sendVersion(knownNodes[0], bc) + } + for { conn, err := ln.Accept() if err != nil { @@ -324,3 +340,13 @@ func gobEncode(data interface{}) []byte { return buff.Bytes() } + +func nodeIsKnown(addr string) bool { + for _, node := range knownNodes { + if node == addr { + return true + } + } + + return false +}