diff --git a/dns.go b/dns.go index 4822012..ff3a109 100644 --- a/dns.go +++ b/dns.go @@ -3,6 +3,9 @@ package main import ( "github.com/miekg/dns" "fmt" + "time" + "log" + "net" ) const TSIG_FUDGE_SECONDS = 300 @@ -24,3 +27,51 @@ func (t *TSIGAlgorithm) UnmarshalFlag(value string) error { return nil } + +func query(query *dns.Msg) (*dns.Msg, error) { + clientConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + return nil, err + } + + timeout := time.Duration(clientConfig.Timeout * int(time.Second)) + + var client = new(dns.Client) + + client.DialTimeout = timeout + client.ReadTimeout = timeout + client.WriteTimeout = timeout + + for _, server := range clientConfig.Servers { + addr := net.JoinHostPort(server, "53") + + if answer, _, err := client.Exchange(query, addr); err != nil { + log.Printf("query %v: %v", server, err) + continue + } else { + return answer, nil + } + } + + return nil, fmt.Errorf("DNS query failed") +} + +// Discover likely master NS for zone +func discoverZoneServer(zone string) (string, error) { + var q = new(dns.Msg) + + q.SetQuestion(zone, dns.TypeSOA) + + r, err := query(q) + if err != nil { + return "", err + } + + for _, rr := range r.Answer { + if soa, ok := rr.(*dns.SOA); ok { + return soa.Ns, nil + } + } + + return "", fmt.Errorf("No SOA response") +} diff --git a/main.go b/main.go index 20ee05b..f502aec 100644 --- a/main.go +++ b/main.go @@ -62,5 +62,7 @@ func main() { // update if err := update.Update(addrs, options.Verbose); err != nil { log.Fatalf("update: %v", err) + } else { + log.Printf("update: ok") } } diff --git a/update.go b/update.go index edb84a9..778423f 100644 --- a/update.go +++ b/update.go @@ -20,13 +20,37 @@ type Update struct { } func (u *Update) Init(name string, zone string, server string) error { - u.name = dns.Fqdn(name) - u.zone = dns.Fqdn(zone) - - if _, _, err := net.SplitHostPort(server); err == nil { - u.server = server + if name == "" { + return fmt.Errorf("Missing name") } else { - u.server = net.JoinHostPort(server, "53") + u.name = dns.Fqdn(name) + } + + if zone == "" { + // guess + if labels := dns.Split(u.name); len(labels) > 1 { + u.zone = u.name[labels[1]:] + } else { + return fmt.Errorf("Missing zone") + } + } else { + u.zone = dns.Fqdn(zone) + } + + if server == "" { + if server, err := discoverZoneServer(u.zone); err != nil { + return fmt.Errorf("Failed to discver server") + } else { + log.Printf("discover server=%v", server) + + u.server = net.JoinHostPort(server, "53") + } + } else { + if _, _, err := net.SplitHostPort(server); err == nil { + u.server = server + } else { + u.server = net.JoinHostPort(server, "53") + } } return nil