From 5def02a589a2c1653f4bb515fdec290361a222be Mon Sep 17 00:00:00 2001 From: lhchavez Date: Tue, 1 Dec 2020 19:11:41 -0800 Subject: [PATCH] The big Callback type adjustment of 2020 This change makes all callbacks that can fail return an `error`. This makes things a lot more idiomatic. --- checkout.go | 11 ++++----- clone.go | 10 ++++---- clone_test.go | 10 ++------ deprecated.go | 16 +++++++++---- index.go | 9 ++++---- index_test.go | 4 ++-- indexer_test.go | 4 ++-- odb_test.go | 4 ++-- remote.go | 59 ++++++++++++++++++++--------------------------- remote_test.go | 17 +++++++------- submodule.go | 9 ++++---- submodule_test.go | 4 ++-- tree.go | 9 ++++---- 13 files changed, 77 insertions(+), 89 deletions(-) diff --git a/checkout.go b/checkout.go index ebf7c31..89841a8 100644 --- a/checkout.go +++ b/checkout.go @@ -7,7 +7,6 @@ extern void _go_git_populate_checkout_callbacks(git_checkout_options *opts); */ import "C" import ( - "errors" "os" "runtime" "unsafe" @@ -49,8 +48,8 @@ const ( CheckoutUpdateSubmodulesIfChanged CheckoutStrategy = C.GIT_CHECKOUT_UPDATE_SUBMODULES_IF_CHANGED // Recursively checkout submodules if HEAD moved in super repo (NOT IMPLEMENTED) ) -type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) ErrorCode -type CheckoutProgressCallback func(path string, completed, total uint) ErrorCode +type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) error +type CheckoutProgressCallback func(path string, completed, total uint) type CheckoutOptions struct { Strategy CheckoutStrategy // Default will be a dry run @@ -116,9 +115,9 @@ func checkoutNotifyCallback( if data.options.NotifyCallback == nil { return C.int(ErrorCodeOK) } - ret := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir) - if ret < 0 { - *data.errorTarget = errors.New(ErrorCode(ret).String()) + err := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir) + if err != nil { + *data.errorTarget = err return C.int(ErrorCodeUser) } return C.int(ErrorCodeOK) diff --git a/clone.go b/clone.go index b02a43e..276e753 100644 --- a/clone.go +++ b/clone.go @@ -7,12 +7,11 @@ extern void _go_git_populate_clone_callbacks(git_clone_options *opts); */ import "C" import ( - "errors" "runtime" "unsafe" ) -type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, ErrorCode) +type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, error) type CloneOptions struct { *CheckoutOpts @@ -71,9 +70,10 @@ func remoteCreateCallback( panic("invalid remote create callback") } - remote, ret := data.options.RemoteCreateCallback(repo, name, url) - if ret < 0 { - *data.errorTarget = errors.New(ErrorCode(ret).String()) + remote, err := data.options.RemoteCreateCallback(repo, name, url) + + if err != nil { + *data.errorTarget = err return C.int(ErrorCodeUser) } if remote == nil { diff --git a/clone_test.go b/clone_test.go index acfbbcb..8814dd0 100644 --- a/clone_test.go +++ b/clone_test.go @@ -49,15 +49,9 @@ func TestCloneWithCallback(t *testing.T) { opts := CloneOptions{ Bare: true, - RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, ErrorCode) { + RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, error) { testPayload += 1 - - remote, err := r.Remotes.Create(REMOTENAME, url) - if err != nil { - return nil, ErrorCodeGeneric - } - - return remote, ErrorCodeOK + return r.Remotes.Create(REMOTENAME, url) }, } diff --git a/deprecated.go b/deprecated.go index 587fd0e..5e69a51 100644 --- a/deprecated.go +++ b/deprecated.go @@ -230,7 +230,11 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer) if !ok { panic("invalid submodule visitor callback") } - return (C.int)(callback(sub, C.GoString(name))) + err := callback(sub, C.GoString(name)) + if err != nil { + return C.int(ErrorCodeUser) + } + return C.int(ErrorCodeOK) } // tree.go @@ -239,9 +243,13 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer) func CallbackGitTreeWalk(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer) C.int { root := C.GoString(_root) - if callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback); ok { - return C.int(callback(root, newTreeEntry(entry))) - } else { + callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback) + if !ok { panic("invalid treewalk callback") } + err := callback(root, newTreeEntry(entry)) + if err != nil { + return C.int(ErrorCodeUser) + } + return C.int(ErrorCodeOK) } diff --git a/index.go b/index.go index 48c922c..dcb2780 100644 --- a/index.go +++ b/index.go @@ -10,13 +10,12 @@ extern int _go_git_index_remove_all(git_index*, const git_strarray*, void*); */ import "C" import ( - "errors" "fmt" "runtime" "unsafe" ) -type IndexMatchedPathCallback func(string, string) int +type IndexMatchedPathCallback func(string, string) error type indexMatchedPathCallbackData struct { callback IndexMatchedPathCallback errorTarget *error @@ -343,9 +342,9 @@ func indexMatchedPathCallback(cPath, cMatchedPathspec *C.char, payload unsafe.Po panic("invalid matched path callback") } - ret := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec)) - if ret < 0 { - *data.errorTarget = errors.New(ErrorCode(ret).String()) + err := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec)) + if err != nil { + *data.errorTarget = err return C.int(ErrorCodeUser) } diff --git a/index_test.go b/index_test.go index 5fa3f9f..aea5c19 100644 --- a/index_test.go +++ b/index_test.go @@ -223,9 +223,9 @@ func TestIndexAddAllCallback(t *testing.T) { checkFatal(t, err) cbPath := "" - err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) int { + err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) error { cbPath = p - return 0 + return nil }) checkFatal(t, err) if cbPath != "README" { diff --git a/indexer_test.go b/indexer_test.go index 70b9f76..1566f97 100644 --- a/indexer_test.go +++ b/indexer_test.go @@ -33,9 +33,9 @@ func TestIndexerOutOfOrder(t *testing.T) { defer os.RemoveAll(tmpPath) var finalStats TransferProgress - idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) ErrorCode { + idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) error { finalStats = stats - return ErrorCodeOK + return nil }) checkFatal(t, err) defer idx.Free() diff --git a/odb_test.go b/odb_test.go index ed5c24c..2684851 100644 --- a/odb_test.go +++ b/odb_test.go @@ -167,9 +167,9 @@ func TestOdbWritepack(t *testing.T) { checkFatal(t, err) var finalStats TransferProgress - writepack, err := odb.NewWritePack(func(stats TransferProgress) ErrorCode { + writepack, err := odb.NewWritePack(func(stats TransferProgress) error { finalStats = stats - return ErrorCodeOK + return nil }) checkFatal(t, err) defer writepack.Free() diff --git a/remote.go b/remote.go index 275d4d9..e312a3a 100644 --- a/remote.go +++ b/remote.go @@ -69,15 +69,15 @@ const ( ConnectDirectionPush ConnectDirection = C.GIT_DIRECTION_PUSH ) -type TransportMessageCallback func(str string) ErrorCode -type CompletionCallback func(RemoteCompletion) ErrorCode +type TransportMessageCallback func(str string) error +type CompletionCallback func(RemoteCompletion) error type CredentialsCallback func(url string, username_from_url string, allowed_types CredentialType) (*Credential, error) -type TransferProgressCallback func(stats TransferProgress) ErrorCode -type UpdateTipsCallback func(refname string, a *Oid, b *Oid) ErrorCode -type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) ErrorCode -type PackbuilderProgressCallback func(stage int32, current, total uint32) ErrorCode -type PushTransferProgressCallback func(current, total uint32, bytes uint) ErrorCode -type PushUpdateReferenceCallback func(refname, status string) ErrorCode +type TransferProgressCallback func(stats TransferProgress) error +type UpdateTipsCallback func(refname string, a *Oid, b *Oid) error +type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) error +type PackbuilderProgressCallback func(stage int32, current, total uint32) error +type PushTransferProgressCallback func(current, total uint32, bytes uint) error +type PushUpdateReferenceCallback func(refname, status string) error type RemoteCallbacks struct { SidebandProgressCallback TransportMessageCallback @@ -329,10 +329,8 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h if data.callbacks.SidebandProgressCallback == nil { return C.int(ErrorCodeOK) } - str := C.GoStringN(_str, _len) - ret := data.callbacks.SidebandProgressCallback(str) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.SidebandProgressCallback(C.GoStringN(_str, _len)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -342,14 +340,13 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h } //export completionCallback -func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, handle unsafe.Pointer) C.int { +func completionCallback(errorMessage **C.char, completionType C.git_remote_completion_type, handle unsafe.Pointer) C.int { data := pointerHandles.Get(handle).(*remoteCallbacksData) if data.callbacks.CompletionCallback == nil { return C.int(ErrorCodeOK) } - ret := data.callbacks.CompletionCallback(RemoteCompletion(completion_type)) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.CompletionCallback(RemoteCompletion(completionType)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -396,9 +393,8 @@ func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progr if data.callbacks.TransferProgressCallback == nil { return C.int(ErrorCodeOK) } - ret := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats)) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -422,9 +418,8 @@ func updateTipsCallback( refname := C.GoString(_refname) a := newOidFromC(_a) b := newOidFromC(_b) - ret := data.callbacks.UpdateTipsCallback(refname, a, b) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.UpdateTipsCallback(refname, a, b) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -489,9 +484,8 @@ func certificateCheckCallback( return setCallbackError(errorMessage, err) } - ret := data.callbacks.CertificateCheckCallback(&cert, valid, host) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.CertificateCheckCallback(&cert, valid, host) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -507,9 +501,8 @@ func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.u return C.int(ErrorCodeOK) } - ret := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total)) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -525,9 +518,8 @@ func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint, return C.int(ErrorCodeOK) } - ret := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes)) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } @@ -543,9 +535,8 @@ func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char, return C.int(ErrorCodeOK) } - ret := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status)) - if ret < 0 { - err := errors.New(ErrorCode(ret).String()) + err := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status)) + if err != nil { if data.errorTarget != nil { *data.errorTarget = err } diff --git a/remote_test.go b/remote_test.go index 9660a3f..05395b3 100644 --- a/remote_test.go +++ b/remote_test.go @@ -38,13 +38,13 @@ func TestListRemotes(t *testing.T) { compareStringList(t, expected, actual) } -func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) ErrorCode { +func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) error { if hostname != "github.com" { - t.Fatal("Hostname does not match") - return ErrorCodeUser + t.Fatal("hostname does not match") + return errors.New("hostname does not match") } - return ErrorCodeOK + return nil } func TestCertificateCheck(t *testing.T) { @@ -58,7 +58,7 @@ func TestCertificateCheck(t *testing.T) { options := FetchOptions{ RemoteCallbacks: RemoteCallbacks{ - CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode { + CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error { return assertHostname(cert, valid, hostname, t) }, }, @@ -479,14 +479,13 @@ func TestRemoteSSH(t *testing.T) { certificateCheckCallbackCalled := false fetchOpts := FetchOptions{ RemoteCallbacks: RemoteCallbacks{ - CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode { + CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error { hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:]) if hostkeyFingerprint != publicKeyFingerprint { - t.Logf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint) - return ErrorCodeAuth + return fmt.Errorf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint) } certificateCheckCallbackCalled = true - return ErrorCodeOK + return nil }, CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) { if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 { diff --git a/submodule.go b/submodule.go index 673cf5f..0fdaa12 100644 --- a/submodule.go +++ b/submodule.go @@ -8,7 +8,6 @@ extern int _go_git_visit_submodule(git_repository *repo, void *fct); import "C" import ( - "errors" "runtime" "unsafe" ) @@ -111,7 +110,7 @@ func (c *SubmoduleCollection) Lookup(name string) (*Submodule, error) { } // SubmoduleCallback is a function that is called for every submodule found in SubmoduleCollection.Foreach. -type SubmoduleCallback func(sub *Submodule, name string) int +type SubmoduleCallback func(sub *Submodule, name string) error type submoduleCallbackData struct { callback SubmoduleCallback errorTarget *error @@ -126,9 +125,9 @@ func submoduleCallback(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer) panic("invalid submodule visitor callback") } - ret := data.callback(sub, C.GoString(name)) - if ret < 0 { - *data.errorTarget = errors.New(ErrorCode(ret).String()) + err := data.callback(sub, C.GoString(name)) + if err != nil { + *data.errorTarget = err return C.int(ErrorCodeUser) } diff --git a/submodule_test.go b/submodule_test.go index fa2e98c..09ddae5 100644 --- a/submodule_test.go +++ b/submodule_test.go @@ -15,9 +15,9 @@ func TestSubmoduleForeach(t *testing.T) { checkFatal(t, err) i := 0 - err = repo.Submodules.Foreach(func(sub *Submodule, name string) int { + err = repo.Submodules.Foreach(func(sub *Submodule, name string) error { i++ - return 0 + return nil }) checkFatal(t, err) diff --git a/tree.go b/tree.go index 14fe7e4..b1aeaa7 100644 --- a/tree.go +++ b/tree.go @@ -8,7 +8,6 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr); import "C" import ( - "errors" "runtime" "unsafe" ) @@ -121,7 +120,7 @@ func (t *Tree) EntryCount() uint64 { return uint64(num) } -type TreeWalkCallback func(string, *TreeEntry) int +type TreeWalkCallback func(string, *TreeEntry) error type treeWalkCallbackData struct { callback TreeWalkCallback errorTarget *error @@ -134,9 +133,9 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer panic("invalid treewalk callback") } - ret := data.callback(C.GoString(_root), newTreeEntry(entry)) - if ret < 0 { - *data.errorTarget = errors.New(ErrorCode(ret).String()) + err := data.callback(C.GoString(_root), newTreeEntry(entry)) + if err != nil { + *data.errorTarget = err return C.int(ErrorCodeUser) }