diff --git a/clone.go b/clone.go index 4acf170..958e65d 100644 --- a/clone.go +++ b/clone.go @@ -56,11 +56,7 @@ func populateCloneOptions(ptr *C.git_clone_options, opts *CloneOptions) { } populateCheckoutOpts(&ptr.checkout_opts, opts.CheckoutOpts) populateRemoteCallbacks(&ptr.remote_callbacks, opts.RemoteCallbacks) - if opts.Bare { - ptr.bare = 1 - } else { - ptr.bare = 0 - } + ptr.bare = cbool(opts.Bare) if opts.RemoteCreateCallback != nil { ptr.remote_cb = opts.RemoteCreateCallback diff --git a/remote.go b/remote.go index 74ebe27..1ff9092 100644 --- a/remote.go +++ b/remote.go @@ -2,14 +2,16 @@ package git /* #include -#include extern void _go_git_setup_callbacks(git_remote_callbacks *callbacks); */ import "C" -import "unsafe" -import "runtime" +import ( + "unsafe" + "runtime" + "crypto/x509" +) type TransferProgress struct { TotalObjects uint @@ -43,6 +45,7 @@ type CompletionCallback func(RemoteCompletion) int type CredentialsCallback func(url string, username_from_url string, allowed_types CredType) (int, *Cred) type TransferProgressCallback func(stats TransferProgress) int type UpdateTipsCallback func(refname string, a *Oid, b *Oid) int +type CertificateCheckCallback func(cert *x509.Certificate, valid bool, hostname string) int type RemoteCallbacks struct { SidebandProgressCallback TransportMessageCallback @@ -50,10 +53,12 @@ type RemoteCallbacks struct { CredentialsCallback TransferProgressCallback UpdateTipsCallback + CertificateCheckCallback } type Remote struct { ptr *C.git_remote + callbacks RemoteCallbacks } func populateRemoteCallbacks(ptr *C.git_remote_callbacks, callbacks *RemoteCallbacks) { @@ -118,6 +123,32 @@ func updateTipsCallback(_refname *C.char, _a *C.git_oid, _b *C.git_oid, data uns return callbacks.UpdateTipsCallback(refname, a, b) } +//export certificateCheckCallback +func certificateCheckCallback(_cert *C.git_cert, _valid C.int, _host *C.char, data unsafe.Pointer) int { + callbacks := (*RemoteCallbacks)(data) + if callbacks.CertificateCheckCallback == nil { + return 0 + } + host := C.GoString(_host) + valid := _valid != 0 + + if _cert.cert_type == C.GIT_CERT_X509 { + ccert := (*C.git_cert_x509)(unsafe.Pointer(_cert)) + x509_certs, err := x509.ParseCertificates(C.GoBytes(ccert.data, C.int(ccert.len))) + if err != nil { + return C.GIT_EUSER; + } + + // we assume there's only one, which should hold true for any web server we want to talk to + return callbacks.CertificateCheckCallback(x509_certs[0], valid, host) + } + + cstr := C.CString("Unsupported certificate type") + C.giterr_set_str(C.GITERR_NET, cstr) + C.free(unsafe.Pointer(cstr)) + return ErrUser // we don't support anything else atm +} + func RemoteIsValidName(name string) bool { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) @@ -127,14 +158,11 @@ func RemoteIsValidName(name string) bool { return false } -func (r *Remote) SetCheckCert(check bool) { - C.git_remote_check_cert(r.ptr, cbool(check)) -} - func (r *Remote) SetCallbacks(callbacks *RemoteCallbacks) error { - var ccallbacks C.git_remote_callbacks + r.callbacks = *callbacks - populateRemoteCallbacks(&ccallbacks, callbacks) + var ccallbacks C.git_remote_callbacks + populateRemoteCallbacks(&ccallbacks, &r.callbacks) runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -433,7 +461,11 @@ func (o *Remote) RefspecCount() uint { return uint(C.git_remote_refspec_count(o.ptr)) } -func (o *Remote) Fetch(sig *Signature, msg string) error { +// Fetch performs a fetch operation. refspecs specifies which refspecs +// to use for this fetch, use an empty list to use the refspecs from +// the configuration; sig and msg specify what to use for the reflog +// entries. Leave nil and "" to use defaults. +func (o *Remote) Fetch(refspecs []string, sig *Signature, msg string) error { var csig *C.git_signature = nil if sig != nil { @@ -441,14 +473,18 @@ func (o *Remote) Fetch(sig *Signature, msg string) error { defer C.free(unsafe.Pointer(csig)) } - var cmsg *C.char - if msg == "" { - cmsg = nil - } else { + var cmsg *C.char = nil + if msg != "" { cmsg = C.CString(msg) defer C.free(unsafe.Pointer(cmsg)) } - ret := C.git_remote_fetch(o.ptr, csig, cmsg) + + crefspecs := C.git_strarray{} + crefspecs.count = C.size_t(len(refspecs)) + crefspecs.strings = makeCStringsFromStrings(refspecs) + defer freeStrarray(&crefspecs) + + ret := C.git_remote_fetch(o.ptr, &crefspecs, csig, cmsg) if ret < 0 { return MakeGitError(ret) } diff --git a/remote_test.go b/remote_test.go index 7cef1ec..8021f71 100644 --- a/remote_test.go +++ b/remote_test.go @@ -1,6 +1,8 @@ package git import ( + "fmt" + "crypto/x509" "os" "testing" ) @@ -45,3 +47,35 @@ func TestListRemotes(t *testing.T) { compareStringList(t, expected, actual) } + + +func assertHostname(cert *x509.Certificate, valid bool, hostname string, t *testing.T) int { + fmt.Println("hostname", hostname) + if hostname != "github.com" { + t.Fatal("Hostname does not match") + return ErrUser + } + + return 0 +} + +func TestCertificateCheck(t *testing.T) { + repo := createTestRepo(t) + defer os.RemoveAll(repo.Workdir()) + defer repo.Free() + + remote, err := repo.CreateRemote("origin", "https://github.com/libgit2/TestGitRepository") + checkFatal(t, err) + + callbacks := RemoteCallbacks{ + CertificateCheckCallback: func (cert *x509.Certificate, valid bool, hostname string) int { + return assertHostname(cert, valid, hostname, t) + }, + } + + err = remote.SetCallbacks(&callbacks) + checkFatal(t, err) + err = remote.Fetch([]string{}, nil, "") + checkFatal(t, err) + fmt.Println("after Fetch()") +} diff --git a/vendor/libgit2 b/vendor/libgit2 index 89e05e2..e0383fa 160000 --- a/vendor/libgit2 +++ b/vendor/libgit2 @@ -1 +1 @@ -Subproject commit 89e05e2ab19ac452e84e0eaa2dfb8e07ac6839bf +Subproject commit e0383fa35f981c656043976a43c61bff059cb709 diff --git a/wrapper.c b/wrapper.c index 45c4358..15e11ce 100644 --- a/wrapper.c +++ b/wrapper.c @@ -70,14 +70,13 @@ void _go_git_setup_diff_notify_callbacks(git_diff_options *opts) { void _go_git_setup_callbacks(git_remote_callbacks *callbacks) { typedef int (*completion_cb)(git_remote_completion_type type, void *data); - typedef int (*credentials_cb)(git_cred **cred, const char *url, const char *username_from_url, unsigned int allowed_types, void *data); - typedef int (*transfer_progress_cb)(const git_transfer_progress *stats, void *data); typedef int (*update_tips_cb)(const char *refname, const git_oid *a, const git_oid *b, void *data); callbacks->sideband_progress = (git_transport_message_cb)sidebandProgressCallback; callbacks->completion = (completion_cb)completionCallback; - callbacks->credentials = (credentials_cb)credentialsCallback; - callbacks->transfer_progress = (transfer_progress_cb)transferProgressCallback; + callbacks->credentials = (git_cred_acquire_cb)credentialsCallback; + callbacks->transfer_progress = (git_transfer_progress_cb)transferProgressCallback; callbacks->update_tips = (update_tips_cb)updateTipsCallback; + callbacks->certificate_check = (git_transport_certificate_check_cb) certificateCheckCallback; } typedef int (*status_foreach_cb)(const char *ref, const char *msg, void *data);