protect cc.messages from racing (#75)

Signed-off-by: Serguei Bezverkhi <sbezverk@cisco.com>
This commit is contained in:
Serguei Bezverkhi 2019-11-14 10:22:42 -05:00 committed by Michael Stapelberg
parent 26aec69f06
commit 14f3137cde
7 changed files with 40 additions and 3 deletions

View File

@ -96,7 +96,8 @@ type Chain struct {
// AddChain adds the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains
func (cc *Conn) AddChain(c *Chain) *Chain {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
@ -122,7 +123,6 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
{Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")},
})...)
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN),
@ -137,6 +137,8 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
// DelChain deletes the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains
func (cc *Conn) DelChain(c *Chain) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
@ -154,6 +156,8 @@ func (cc *Conn) DelChain(c *Chain) {
// FlushChain removes all rules within the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain
func (cc *Conn) FlushChain(c *Chain) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},

10
conn.go
View File

@ -16,6 +16,7 @@ package nftables
import (
"fmt"
"sync"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
@ -32,12 +33,19 @@ import (
type Conn struct {
TestDial nltest.Func // for testing only; passed to nltest.Dial
NetNS int // Network namespace netlink will interact with.
sync.Mutex
messages []netlink.Message
err error
}
// Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error {
cc.Lock()
defer cc.Unlock()
if len(cc.messages) == 0 {
// Messages were already programmed, returning nil
return nil
}
if cc.err != nil {
return cc.err // serialization error
}
@ -64,6 +72,8 @@ func (cc *Conn) Flush() error {
// FlushRuleset flushes the entire ruleset. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
func (cc *Conn) FlushRuleset() {
cc.Lock()
defer cc.Unlock()
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),

2
go.sum
View File

@ -23,4 +23,6 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456 h1:ng0gs1AKnRRuEMZoTLLlbOd+C
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6aVsI6iztaz1bQd9BJwE=
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191112214154-59a1497f0cea h1:Mz1TMnfJDRJLk8S8OPCoJYgrsp/Se/2TBre2+vwX128=
golang.org/x/sys v0.0.0-20191113150313-8ad342257130 h1:+sdNBpwFF05NvMnEyGynbOs/Gr2LQwORWEPKXuEXxzU=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

2
obj.go
View File

@ -35,6 +35,8 @@ type Obj interface {
// AddObj adds the specified Obj. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
func (cc *Conn) AddObj(o Obj) Obj {
cc.Lock()
defer cc.Unlock()
data, err := o.marshal(true)
if err != nil {
cc.setErr(err)

View File

@ -83,6 +83,8 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
// AddRule adds the specified Rule
func (cc *Conn) AddRule(r *Rule) *Rule {
cc.Lock()
defer cc.Unlock()
exprAttrs := make([]netlink.Attribute, len(r.Exprs))
for idx, expr := range r.Exprs {
exprAttrs[idx] = netlink.Attribute{
@ -133,6 +135,8 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
// DelRule deletes the specified Rule, rule's handle cannot be 0
func (cc *Conn) DelRule(r *Rule) error {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")},

11
set.go
View File

@ -141,6 +141,8 @@ func decodeElement(d []byte) ([]byte, error) {
// SetAddElements applies data points to an nftables set.
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
@ -240,6 +242,8 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) {
// AddSet adds the specified Set.
func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
// Based on nft implementation & linux source.
// Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395
// Another reference: https://git.netfilter.org/nftables/tree/src
@ -342,6 +346,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// DelSet deletes a specific set, along with all elements it contains.
func (cc *Conn) DelSet(s *Set) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
@ -357,6 +363,8 @@ func (cc *Conn) DelSet(s *Set) {
// SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
@ -378,6 +386,8 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
// FlushSet deletes all data points from an nftables set.
func (cc *Conn) FlushSet(s *Set) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
@ -408,7 +418,6 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
switch ad.Type() {
case unix.NFTA_SET_NAME:
set.Name = ad.String()
fmt.Printf("Discover set %s\n", set.Name)
case unix.NFTA_SET_ID:
set.ID = binary.BigEndian.Uint32(ad.Bytes())
case unix.NFTA_SET_FLAGS:

View File

@ -44,6 +44,8 @@ type Table struct {
// DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
@ -60,6 +62,8 @@ func (cc *Conn) DelTable(t *Table) {
// AddTable adds the specified Table. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables
func (cc *Conn) AddTable(t *Table) *Table {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
@ -77,6 +81,8 @@ func (cc *Conn) AddTable(t *Table) *Table {
// FlushTable removes all rules in all chains within the specified Table. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables#Flushing_tables
func (cc *Conn) FlushTable(t *Table) {
cc.Lock()
defer cc.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
})