add GetRule

This commit is contained in:
Michael Stapelberg 2018-06-23 21:12:14 +02:00
parent 0dd2e15e25
commit 1324f5d5a9
7 changed files with 304 additions and 1 deletions

View File

@ -26,6 +26,9 @@ import (
type ByteOrder interface {
PutUint16(v uint16) []byte
PutUint32(v uint32) []byte
PutUint64(v uint64) []byte
Uint32(b []byte) uint32
Uint64(b []byte) uint64
}
// NativeEndian is either little endian or big endian, depending on the native
@ -46,6 +49,20 @@ func (nativeEndian) PutUint32(v uint32) []byte {
return buf
}
func (nativeEndian) PutUint64(v uint64) []byte {
buf := make([]byte, 8)
natend.NativeEndian.PutUint64(buf, v)
return buf
}
func (nativeEndian) Uint32(b []byte) uint32 {
return natend.NativeEndian.Uint32(b)
}
func (nativeEndian) Uint64(b []byte) uint64 {
return natend.NativeEndian.Uint64(b)
}
// BigEndian is like binary.BigEndian, but allocates memory and returns byte
// slices, for convenience.
var BigEndian ByteOrder = &bigEndian{}
@ -63,3 +80,17 @@ func (bigEndian) PutUint32(v uint32) []byte {
binary.BigEndian.PutUint32(buf, v)
return buf
}
func (bigEndian) PutUint64(v uint64) []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, v)
return buf
}
func (bigEndian) Uint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b)
}
func (bigEndian) Uint64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}

View File

@ -16,6 +16,8 @@
package expr
import (
"fmt"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -26,9 +28,15 @@ func Marshal(e Any) ([]byte, error) {
return e.marshal()
}
// Unmarshal fills an expression from the specified byte slice.
func Unmarshal(data []byte, e Any) error {
return e.unmarshal(data)
}
// Any is an interface implemented by any expression type.
type Any interface {
marshal() ([]byte, error)
unmarshal([]byte) error
}
// MetaKey specifies which piece of meta information should be loaded. See also
@ -87,6 +95,23 @@ func (e *Meta) marshal() ([]byte, error) {
})
}
func (e *Meta) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_META_DREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
case unix.NFTA_META_KEY:
e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data))
}
}
return nil
}
// Masq (Masquerade) is a special case of SNAT, where the source address is
// automagically set to the address of the output interface. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)#Masquerading
@ -99,6 +124,10 @@ func (e *Masq) marshal() ([]byte, error) {
})
}
func (e *Masq) unmarshal(data []byte) error {
return fmt.Errorf("not yet implemented")
}
// CmpOp specifies which type of comparison should be performed.
type CmpOp uint32
@ -139,3 +168,22 @@ func (e *Cmp) marshal() ([]byte, error) {
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData},
})
}
func (e *Cmp) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_CMP_SREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
case unix.NFTA_CMP_OP:
e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data))
case unix.NFTA_CMP_DATA:
e.Data = attr.Data
}
}
return nil
}

View File

@ -15,6 +15,8 @@
package expr
import (
"fmt"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -45,3 +47,7 @@ func (e *Immediate) marshal() ([]byte, error) {
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
})
}
func (e *Immediate) unmarshal(data []byte) error {
return fmt.Errorf("not yet implemented")
}

View File

@ -15,6 +15,8 @@
package expr
import (
"fmt"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -68,3 +70,7 @@ func (e *NAT) marshal() ([]byte, error) {
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
})
}
func (e *NAT) unmarshal(data []byte) error {
return fmt.Errorf("not yet implemented")
}

View File

@ -15,6 +15,8 @@
package expr
import (
"fmt"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -51,3 +53,7 @@ func (e *Payload) marshal() ([]byte, error) {
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
})
}
func (e *Payload) unmarshal(data []byte) error {
return fmt.Errorf("not yet implemented")
}

View File

@ -18,6 +18,7 @@ package nftables
import (
"fmt"
"math"
"strings"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@ -175,6 +176,82 @@ type Rule struct {
Exprs []expr.Any
}
var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE)
func stringFrom0(b []byte) string {
return strings.TrimSuffix(string(b), "\x00")
}
func exprsFromMsg(b []byte) ([]expr.Any, error) {
elems, err := netlink.UnmarshalAttributes(b)
if err != nil {
return nil, err
}
var exprs []expr.Any
for _, elem := range elems {
attrs, err := netlink.UnmarshalAttributes(elem.Data)
if err != nil {
return nil, err
}
var (
name string
data []byte
)
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_EXPR_NAME:
name = stringFrom0(attr.Data)
case unix.NFTA_EXPR_DATA:
data = attr.Data
}
}
var e expr.Any
switch name {
case "meta":
e = &expr.Meta{}
case "cmp":
e = &expr.Cmp{}
case "counter":
e = &expr.Counter{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know
// something is here.
continue // unsupported expression type
}
if err := expr.Unmarshal(data, e); err != nil {
return nil, err
}
exprs = append(exprs, e)
}
return exprs, nil
}
func ruleFromMsg(msg netlink.Message) (*Rule, error) {
if got, want := msg.Header.Type, ruleHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
}
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:])
if err != nil {
return nil, err
}
var r Rule
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_RULE_TABLE:
r.Table = &Table{Name: stringFrom0(attr.Data)}
case unix.NFTA_RULE_CHAIN:
r.Chain = &Chain{Name: stringFrom0(attr.Data)}
case unix.NFTA_RULE_EXPRESSIONS:
r.Exprs, err = exprsFromMsg(attr.Data)
if err != nil {
return nil, err
}
}
}
return &r, nil
}
// AddRule adds the specified Rule. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management
func (cc *Conn) AddRule(r *Rule) *Rule {
@ -194,7 +271,7 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE),
Type: ruleHeaderType,
Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsCreate,
},
Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
@ -290,3 +367,54 @@ func (cc *Conn) Flush() error {
return nil
}
// GetRule returns the rules in the specified table and chain.
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
var conn *netlink.Conn
var err error
if cc.TestDial == nil {
conn, err = netlink.Dial(unix.NETLINK_NETFILTER, nil)
} else {
conn = nltest.Dial(cc.TestDial)
}
if err != nil {
return nil, err
}
defer conn.Close()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
})
if err != nil {
return nil, err
}
message := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsDump,
},
Data: append(extraHeader(uint8(t.Family), 0), data...),
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}
var rules []*Rule
for _, msg := range reply {
r, err := ruleFromMsg(msg)
if err != nil {
return nil, err
}
rules = append(rules, r)
}
return rules, nil
}

View File

@ -251,3 +251,81 @@ func TestConfigureNAT(t *testing.T) {
t.Fatal(err)
}
}
func TestGetRule(t *testing.T) {
// The want byte sequences come from stracing nft(8), e.g.:
// strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward
want := [][]byte{
[]byte{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x0, 0x0, 0x0},
}
// The reply messages come from adding log.Printf("msgs: %#v", msgs) to
// (*github.com/mdlayher/netlink/Conn).receive
reply := [][]netlink.Message{
nil,
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x68, Type: 0xa06, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x2, 0x0, 0x0, 0xc, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x0, 0x4, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6d, 0x92, 0x20, 0x20, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x48, 0xd9}}},
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: got %#v, want %#v", idx, got, want)
}
want = want[1:]
}
rep := reply[0]
reply = reply[1:]
return rep, nil
},
}
rules, err := c.GetRule(
&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
},
&nftables.Chain{
Name: "input",
},
)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
rule := rules[0]
if got, want := len(rule.Exprs), 1; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
ce, ok := rule.Exprs[0].(*expr.Counter)
if !ok {
t.Fatalf("unexpected expression type: got %T, want *expr.Counter", rule.Exprs[0])
}
if got, want := ce.Packets, uint64(674009); got != want {
t.Errorf("unexpected number of packets: got %d, want %d", got, want)
}
if got, want := ce.Bytes, uint64(1838293024); got != want {
t.Errorf("unexpected number of bytes: got %d, want %d", got, want)
}
}