add GetRule
This commit is contained in:
parent
0dd2e15e25
commit
1324f5d5a9
|
@ -26,6 +26,9 @@ import (
|
|||
type ByteOrder interface {
|
||||
PutUint16(v uint16) []byte
|
||||
PutUint32(v uint32) []byte
|
||||
PutUint64(v uint64) []byte
|
||||
Uint32(b []byte) uint32
|
||||
Uint64(b []byte) uint64
|
||||
}
|
||||
|
||||
// NativeEndian is either little endian or big endian, depending on the native
|
||||
|
@ -46,6 +49,20 @@ func (nativeEndian) PutUint32(v uint32) []byte {
|
|||
return buf
|
||||
}
|
||||
|
||||
func (nativeEndian) PutUint64(v uint64) []byte {
|
||||
buf := make([]byte, 8)
|
||||
natend.NativeEndian.PutUint64(buf, v)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (nativeEndian) Uint32(b []byte) uint32 {
|
||||
return natend.NativeEndian.Uint32(b)
|
||||
}
|
||||
|
||||
func (nativeEndian) Uint64(b []byte) uint64 {
|
||||
return natend.NativeEndian.Uint64(b)
|
||||
}
|
||||
|
||||
// BigEndian is like binary.BigEndian, but allocates memory and returns byte
|
||||
// slices, for convenience.
|
||||
var BigEndian ByteOrder = &bigEndian{}
|
||||
|
@ -63,3 +80,17 @@ func (bigEndian) PutUint32(v uint32) []byte {
|
|||
binary.BigEndian.PutUint32(buf, v)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (bigEndian) PutUint64(v uint64) []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, v)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint64(b []byte) uint64 {
|
||||
return binary.BigEndian.Uint64(b)
|
||||
}
|
||||
|
|
48
expr/expr.go
48
expr/expr.go
|
@ -16,6 +16,8 @@
|
|||
package expr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -26,9 +28,15 @@ func Marshal(e Any) ([]byte, error) {
|
|||
return e.marshal()
|
||||
}
|
||||
|
||||
// Unmarshal fills an expression from the specified byte slice.
|
||||
func Unmarshal(data []byte, e Any) error {
|
||||
return e.unmarshal(data)
|
||||
}
|
||||
|
||||
// Any is an interface implemented by any expression type.
|
||||
type Any interface {
|
||||
marshal() ([]byte, error)
|
||||
unmarshal([]byte) error
|
||||
}
|
||||
|
||||
// MetaKey specifies which piece of meta information should be loaded. See also
|
||||
|
@ -87,6 +95,23 @@ func (e *Meta) marshal() ([]byte, error) {
|
|||
})
|
||||
}
|
||||
|
||||
func (e *Meta) unmarshal(data []byte) error {
|
||||
attrs, err := netlink.UnmarshalAttributes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, attr := range attrs {
|
||||
switch attr.Type {
|
||||
case unix.NFTA_META_DREG:
|
||||
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
|
||||
case unix.NFTA_META_KEY:
|
||||
e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Masq (Masquerade) is a special case of SNAT, where the source address is
|
||||
// automagically set to the address of the output interface. See also
|
||||
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)#Masquerading
|
||||
|
@ -99,6 +124,10 @@ func (e *Masq) marshal() ([]byte, error) {
|
|||
})
|
||||
}
|
||||
|
||||
func (e *Masq) unmarshal(data []byte) error {
|
||||
return fmt.Errorf("not yet implemented")
|
||||
}
|
||||
|
||||
// CmpOp specifies which type of comparison should be performed.
|
||||
type CmpOp uint32
|
||||
|
||||
|
@ -139,3 +168,22 @@ func (e *Cmp) marshal() ([]byte, error) {
|
|||
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData},
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Cmp) unmarshal(data []byte) error {
|
||||
attrs, err := netlink.UnmarshalAttributes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, attr := range attrs {
|
||||
switch attr.Type {
|
||||
case unix.NFTA_CMP_SREG:
|
||||
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
|
||||
case unix.NFTA_CMP_OP:
|
||||
e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data))
|
||||
case unix.NFTA_CMP_DATA:
|
||||
e.Data = attr.Data
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
package expr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -45,3 +47,7 @@ func (e *Immediate) marshal() ([]byte, error) {
|
|||
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Immediate) unmarshal(data []byte) error {
|
||||
return fmt.Errorf("not yet implemented")
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
package expr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -68,3 +70,7 @@ func (e *NAT) marshal() ([]byte, error) {
|
|||
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
|
||||
})
|
||||
}
|
||||
|
||||
func (e *NAT) unmarshal(data []byte) error {
|
||||
return fmt.Errorf("not yet implemented")
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
package expr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -51,3 +53,7 @@ func (e *Payload) marshal() ([]byte, error) {
|
|||
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Payload) unmarshal(data []byte) error {
|
||||
return fmt.Errorf("not yet implemented")
|
||||
}
|
||||
|
|
130
nftables.go
130
nftables.go
|
@ -18,6 +18,7 @@ package nftables
|
|||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
|
@ -175,6 +176,82 @@ type Rule struct {
|
|||
Exprs []expr.Any
|
||||
}
|
||||
|
||||
var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE)
|
||||
|
||||
func stringFrom0(b []byte) string {
|
||||
return strings.TrimSuffix(string(b), "\x00")
|
||||
}
|
||||
|
||||
func exprsFromMsg(b []byte) ([]expr.Any, error) {
|
||||
elems, err := netlink.UnmarshalAttributes(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var exprs []expr.Any
|
||||
for _, elem := range elems {
|
||||
attrs, err := netlink.UnmarshalAttributes(elem.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var (
|
||||
name string
|
||||
data []byte
|
||||
)
|
||||
for _, attr := range attrs {
|
||||
switch attr.Type {
|
||||
case unix.NFTA_EXPR_NAME:
|
||||
name = stringFrom0(attr.Data)
|
||||
case unix.NFTA_EXPR_DATA:
|
||||
data = attr.Data
|
||||
}
|
||||
}
|
||||
var e expr.Any
|
||||
switch name {
|
||||
case "meta":
|
||||
e = &expr.Meta{}
|
||||
case "cmp":
|
||||
e = &expr.Cmp{}
|
||||
case "counter":
|
||||
e = &expr.Counter{}
|
||||
}
|
||||
if e == nil {
|
||||
// TODO: introduce an opaque expression type so that users know
|
||||
// something is here.
|
||||
continue // unsupported expression type
|
||||
}
|
||||
if err := expr.Unmarshal(data, e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exprs = append(exprs, e)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func ruleFromMsg(msg netlink.Message) (*Rule, error) {
|
||||
if got, want := msg.Header.Type, ruleHeaderType; got != want {
|
||||
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
|
||||
}
|
||||
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var r Rule
|
||||
for _, attr := range attrs {
|
||||
switch attr.Type {
|
||||
case unix.NFTA_RULE_TABLE:
|
||||
r.Table = &Table{Name: stringFrom0(attr.Data)}
|
||||
case unix.NFTA_RULE_CHAIN:
|
||||
r.Chain = &Chain{Name: stringFrom0(attr.Data)}
|
||||
case unix.NFTA_RULE_EXPRESSIONS:
|
||||
r.Exprs, err = exprsFromMsg(attr.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// AddRule adds the specified Rule. See also
|
||||
// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management
|
||||
func (cc *Conn) AddRule(r *Rule) *Rule {
|
||||
|
@ -194,7 +271,7 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
|
|||
|
||||
cc.messages = append(cc.messages, netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE),
|
||||
Type: ruleHeaderType,
|
||||
Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsCreate,
|
||||
},
|
||||
Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
|
||||
|
@ -290,3 +367,54 @@ func (cc *Conn) Flush() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRule returns the rules in the specified table and chain.
|
||||
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
|
||||
var conn *netlink.Conn
|
||||
var err error
|
||||
if cc.TestDial == nil {
|
||||
conn, err = netlink.Dial(unix.NETLINK_NETFILTER, nil)
|
||||
} else {
|
||||
conn = nltest.Dial(cc.TestDial)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
data, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
|
||||
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
|
||||
Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsDump,
|
||||
},
|
||||
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
||||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %v", err)
|
||||
}
|
||||
var rules []*Rule
|
||||
for _, msg := range reply {
|
||||
r, err := ruleFromMsg(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
|
|
@ -251,3 +251,81 @@ func TestConfigureNAT(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRule(t *testing.T) {
|
||||
// The want byte sequences come from stracing nft(8), e.g.:
|
||||
// strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward
|
||||
|
||||
want := [][]byte{
|
||||
[]byte{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x0, 0x0, 0x0},
|
||||
}
|
||||
|
||||
// The reply messages come from adding log.Printf("msgs: %#v", msgs) to
|
||||
// (*github.com/mdlayher/netlink/Conn).receive
|
||||
reply := [][]netlink.Message{
|
||||
nil,
|
||||
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x68, Type: 0xa06, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x2, 0x0, 0x0, 0xc, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x0, 0x4, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6d, 0x92, 0x20, 0x20, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x48, 0xd9}}},
|
||||
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
|
||||
}
|
||||
|
||||
c := &nftables.Conn{
|
||||
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
|
||||
for idx, msg := range req {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(b) < 16 {
|
||||
continue
|
||||
}
|
||||
b = b[16:]
|
||||
if len(want) == 0 {
|
||||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||||
continue
|
||||
}
|
||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||
t.Errorf("message %d: got %#v, want %#v", idx, got, want)
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
rep := reply[0]
|
||||
reply = reply[1:]
|
||||
return rep, nil
|
||||
},
|
||||
}
|
||||
|
||||
rules, err := c.GetRule(
|
||||
&nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "filter",
|
||||
},
|
||||
&nftables.Chain{
|
||||
Name: "input",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(rules), 1; got != want {
|
||||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
rule := rules[0]
|
||||
if got, want := len(rule.Exprs), 1; got != want {
|
||||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
ce, ok := rule.Exprs[0].(*expr.Counter)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected expression type: got %T, want *expr.Counter", rule.Exprs[0])
|
||||
}
|
||||
|
||||
if got, want := ce.Packets, uint64(674009); got != want {
|
||||
t.Errorf("unexpected number of packets: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if got, want := ce.Bytes, uint64(1838293024); got != want {
|
||||
t.Errorf("unexpected number of bytes: got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue