This commit is contained in:
Alexis PIRES 2020-01-03 10:56:42 +01:00
parent 39f8fec129
commit 87f28cef6e
2 changed files with 27 additions and 17 deletions

29
conn.go
View File

@ -24,6 +24,10 @@ import (
"golang.org/x/sys/unix"
)
type Entity interface {
HandleResponse(netlink.Message)
}
// A Conn represents a netlink connection of the nftables family.
//
// All methods return their input, so that variables can be defined from string
@ -35,7 +39,7 @@ type Conn struct {
NetNS int // Network namespace netlink will interact with.
sync.Mutex
messages []netlink.Message
rules map[int]*Rule
entities map[int]Entity
err error
}
@ -44,7 +48,7 @@ func (cc *Conn) Flush() error {
cc.Lock()
defer func() {
cc.messages = nil
cc.rules = nil
cc.entities = nil
cc.Unlock()
}()
if len(cc.messages) == 0 {
@ -67,15 +71,15 @@ func (cc *Conn) Flush() error {
return fmt.Errorf("SendMessages: %w", err)
}
// Retrieving of seq number associated to rules
rulesBySeq := make(map[uint32]*Rule)
for i, rule := range cc.rules {
rulesBySeq[smsg[i].Header.Sequence] = rule
// Retrieving of seq number associated to entities
entitiesBySeq := make(map[uint32]Entity)
for i, e := range cc.entities {
entitiesBySeq[smsg[i].Header.Sequence] = e
}
// Search handle in netlink messages based on requests seq
echoedRules := 0
for len(cc.rules) > echoedRules {
echoedEntities := 0
for len(cc.entities) > echoedEntities {
rmsg, err := conn.Receive()
if err != nil {
@ -83,12 +87,9 @@ func (cc *Conn) Flush() error {
}
for _, msg := range rmsg {
if srule, ok := rulesBySeq[msg.Header.Sequence]; ok {
rrule, err := ruleFromMsg(msg)
if err == nil {
srule.Handle = rrule.Handle
echoedRules++
}
if e, ok := entitiesBySeq[msg.Header.Sequence]; ok {
e.HandleResponse(msg)
echoedEntities++
}
}

15
rule.go
View File

@ -130,11 +130,11 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
})
if cc.rules == nil {
cc.rules = make(map[int]*Rule)
if cc.entities == nil {
cc.entities = make(map[int]Entity)
}
cc.rules[len(cc.messages)] = r
cc.entities[len(cc.messages)] = r
return r
}
@ -166,6 +166,15 @@ func (cc *Conn) DelRule(r *Rule) error {
return nil
}
// HandleResponse retrieves Handle in netlink response
func (r *Rule) HandleResponse(msg netlink.Message) {
rule, err := ruleFromMsg(msg)
if err == nil {
r.Handle = rule.Handle
}
}
func exprsFromMsg(b []byte) ([]expr.Any, error) {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {