The big Callback type adjustment of 2020

This change makes all callbacks that can fail return an `error`. This
makes things a lot more idiomatic.
This commit is contained in:
lhchavez 2020-12-01 19:11:41 -08:00
parent 70e5e419cf
commit 5def02a589
13 changed files with 77 additions and 89 deletions

View File

@ -7,7 +7,6 @@ extern void _go_git_populate_checkout_callbacks(git_checkout_options *opts);
*/ */
import "C" import "C"
import ( import (
"errors"
"os" "os"
"runtime" "runtime"
"unsafe" "unsafe"
@ -49,8 +48,8 @@ const (
CheckoutUpdateSubmodulesIfChanged CheckoutStrategy = C.GIT_CHECKOUT_UPDATE_SUBMODULES_IF_CHANGED // Recursively checkout submodules if HEAD moved in super repo (NOT IMPLEMENTED) CheckoutUpdateSubmodulesIfChanged CheckoutStrategy = C.GIT_CHECKOUT_UPDATE_SUBMODULES_IF_CHANGED // Recursively checkout submodules if HEAD moved in super repo (NOT IMPLEMENTED)
) )
type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) ErrorCode type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) error
type CheckoutProgressCallback func(path string, completed, total uint) ErrorCode type CheckoutProgressCallback func(path string, completed, total uint)
type CheckoutOptions struct { type CheckoutOptions struct {
Strategy CheckoutStrategy // Default will be a dry run Strategy CheckoutStrategy // Default will be a dry run
@ -116,9 +115,9 @@ func checkoutNotifyCallback(
if data.options.NotifyCallback == nil { if data.options.NotifyCallback == nil {
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir) err := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir)
if ret < 0 { if err != nil {
*data.errorTarget = errors.New(ErrorCode(ret).String()) *data.errorTarget = err
return C.int(ErrorCodeUser) return C.int(ErrorCodeUser)
} }
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)

View File

@ -7,12 +7,11 @@ extern void _go_git_populate_clone_callbacks(git_clone_options *opts);
*/ */
import "C" import "C"
import ( import (
"errors"
"runtime" "runtime"
"unsafe" "unsafe"
) )
type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, ErrorCode) type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, error)
type CloneOptions struct { type CloneOptions struct {
*CheckoutOpts *CheckoutOpts
@ -71,9 +70,10 @@ func remoteCreateCallback(
panic("invalid remote create callback") panic("invalid remote create callback")
} }
remote, ret := data.options.RemoteCreateCallback(repo, name, url) remote, err := data.options.RemoteCreateCallback(repo, name, url)
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String()) if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser) return C.int(ErrorCodeUser)
} }
if remote == nil { if remote == nil {

View File

@ -49,15 +49,9 @@ func TestCloneWithCallback(t *testing.T) {
opts := CloneOptions{ opts := CloneOptions{
Bare: true, Bare: true,
RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, ErrorCode) { RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, error) {
testPayload += 1 testPayload += 1
return r.Remotes.Create(REMOTENAME, url)
remote, err := r.Remotes.Create(REMOTENAME, url)
if err != nil {
return nil, ErrorCodeGeneric
}
return remote, ErrorCodeOK
}, },
} }

View File

@ -230,7 +230,11 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
if !ok { if !ok {
panic("invalid submodule visitor callback") panic("invalid submodule visitor callback")
} }
return (C.int)(callback(sub, C.GoString(name))) err := callback(sub, C.GoString(name))
if err != nil {
return C.int(ErrorCodeUser)
}
return C.int(ErrorCodeOK)
} }
// tree.go // tree.go
@ -239,9 +243,13 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
func CallbackGitTreeWalk(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer) C.int { func CallbackGitTreeWalk(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer) C.int {
root := C.GoString(_root) root := C.GoString(_root)
if callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback); ok { callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback)
return C.int(callback(root, newTreeEntry(entry))) if !ok {
} else {
panic("invalid treewalk callback") panic("invalid treewalk callback")
} }
err := callback(root, newTreeEntry(entry))
if err != nil {
return C.int(ErrorCodeUser)
}
return C.int(ErrorCodeOK)
} }

View File

@ -10,13 +10,12 @@ extern int _go_git_index_remove_all(git_index*, const git_strarray*, void*);
*/ */
import "C" import "C"
import ( import (
"errors"
"fmt" "fmt"
"runtime" "runtime"
"unsafe" "unsafe"
) )
type IndexMatchedPathCallback func(string, string) int type IndexMatchedPathCallback func(string, string) error
type indexMatchedPathCallbackData struct { type indexMatchedPathCallbackData struct {
callback IndexMatchedPathCallback callback IndexMatchedPathCallback
errorTarget *error errorTarget *error
@ -343,9 +342,9 @@ func indexMatchedPathCallback(cPath, cMatchedPathspec *C.char, payload unsafe.Po
panic("invalid matched path callback") panic("invalid matched path callback")
} }
ret := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec)) err := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec))
if ret < 0 { if err != nil {
*data.errorTarget = errors.New(ErrorCode(ret).String()) *data.errorTarget = err
return C.int(ErrorCodeUser) return C.int(ErrorCodeUser)
} }

View File

@ -223,9 +223,9 @@ func TestIndexAddAllCallback(t *testing.T) {
checkFatal(t, err) checkFatal(t, err)
cbPath := "" cbPath := ""
err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) int { err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) error {
cbPath = p cbPath = p
return 0 return nil
}) })
checkFatal(t, err) checkFatal(t, err)
if cbPath != "README" { if cbPath != "README" {

View File

@ -33,9 +33,9 @@ func TestIndexerOutOfOrder(t *testing.T) {
defer os.RemoveAll(tmpPath) defer os.RemoveAll(tmpPath)
var finalStats TransferProgress var finalStats TransferProgress
idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) ErrorCode { idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) error {
finalStats = stats finalStats = stats
return ErrorCodeOK return nil
}) })
checkFatal(t, err) checkFatal(t, err)
defer idx.Free() defer idx.Free()

View File

@ -167,9 +167,9 @@ func TestOdbWritepack(t *testing.T) {
checkFatal(t, err) checkFatal(t, err)
var finalStats TransferProgress var finalStats TransferProgress
writepack, err := odb.NewWritePack(func(stats TransferProgress) ErrorCode { writepack, err := odb.NewWritePack(func(stats TransferProgress) error {
finalStats = stats finalStats = stats
return ErrorCodeOK return nil
}) })
checkFatal(t, err) checkFatal(t, err)
defer writepack.Free() defer writepack.Free()

View File

@ -69,15 +69,15 @@ const (
ConnectDirectionPush ConnectDirection = C.GIT_DIRECTION_PUSH ConnectDirectionPush ConnectDirection = C.GIT_DIRECTION_PUSH
) )
type TransportMessageCallback func(str string) ErrorCode type TransportMessageCallback func(str string) error
type CompletionCallback func(RemoteCompletion) ErrorCode type CompletionCallback func(RemoteCompletion) error
type CredentialsCallback func(url string, username_from_url string, allowed_types CredentialType) (*Credential, error) type CredentialsCallback func(url string, username_from_url string, allowed_types CredentialType) (*Credential, error)
type TransferProgressCallback func(stats TransferProgress) ErrorCode type TransferProgressCallback func(stats TransferProgress) error
type UpdateTipsCallback func(refname string, a *Oid, b *Oid) ErrorCode type UpdateTipsCallback func(refname string, a *Oid, b *Oid) error
type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) ErrorCode type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) error
type PackbuilderProgressCallback func(stage int32, current, total uint32) ErrorCode type PackbuilderProgressCallback func(stage int32, current, total uint32) error
type PushTransferProgressCallback func(current, total uint32, bytes uint) ErrorCode type PushTransferProgressCallback func(current, total uint32, bytes uint) error
type PushUpdateReferenceCallback func(refname, status string) ErrorCode type PushUpdateReferenceCallback func(refname, status string) error
type RemoteCallbacks struct { type RemoteCallbacks struct {
SidebandProgressCallback TransportMessageCallback SidebandProgressCallback TransportMessageCallback
@ -329,10 +329,8 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h
if data.callbacks.SidebandProgressCallback == nil { if data.callbacks.SidebandProgressCallback == nil {
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
str := C.GoStringN(_str, _len) err := data.callbacks.SidebandProgressCallback(C.GoStringN(_str, _len))
ret := data.callbacks.SidebandProgressCallback(str) if err != nil {
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -342,14 +340,13 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h
} }
//export completionCallback //export completionCallback
func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, handle unsafe.Pointer) C.int { func completionCallback(errorMessage **C.char, completionType C.git_remote_completion_type, handle unsafe.Pointer) C.int {
data := pointerHandles.Get(handle).(*remoteCallbacksData) data := pointerHandles.Get(handle).(*remoteCallbacksData)
if data.callbacks.CompletionCallback == nil { if data.callbacks.CompletionCallback == nil {
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.callbacks.CompletionCallback(RemoteCompletion(completion_type)) err := data.callbacks.CompletionCallback(RemoteCompletion(completionType))
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -396,9 +393,8 @@ func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progr
if data.callbacks.TransferProgressCallback == nil { if data.callbacks.TransferProgressCallback == nil {
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats)) err := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats))
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -422,9 +418,8 @@ func updateTipsCallback(
refname := C.GoString(_refname) refname := C.GoString(_refname)
a := newOidFromC(_a) a := newOidFromC(_a)
b := newOidFromC(_b) b := newOidFromC(_b)
ret := data.callbacks.UpdateTipsCallback(refname, a, b) err := data.callbacks.UpdateTipsCallback(refname, a, b)
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -489,9 +484,8 @@ func certificateCheckCallback(
return setCallbackError(errorMessage, err) return setCallbackError(errorMessage, err)
} }
ret := data.callbacks.CertificateCheckCallback(&cert, valid, host) err := data.callbacks.CertificateCheckCallback(&cert, valid, host)
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -507,9 +501,8 @@ func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.u
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total)) err := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total))
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -525,9 +518,8 @@ func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint,
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes)) err := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes))
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }
@ -543,9 +535,8 @@ func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char,
return C.int(ErrorCodeOK) return C.int(ErrorCodeOK)
} }
ret := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status)) err := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status))
if ret < 0 { if err != nil {
err := errors.New(ErrorCode(ret).String())
if data.errorTarget != nil { if data.errorTarget != nil {
*data.errorTarget = err *data.errorTarget = err
} }

View File

@ -38,13 +38,13 @@ func TestListRemotes(t *testing.T) {
compareStringList(t, expected, actual) compareStringList(t, expected, actual)
} }
func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) ErrorCode { func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) error {
if hostname != "github.com" { if hostname != "github.com" {
t.Fatal("Hostname does not match") t.Fatal("hostname does not match")
return ErrorCodeUser return errors.New("hostname does not match")
} }
return ErrorCodeOK return nil
} }
func TestCertificateCheck(t *testing.T) { func TestCertificateCheck(t *testing.T) {
@ -58,7 +58,7 @@ func TestCertificateCheck(t *testing.T) {
options := FetchOptions{ options := FetchOptions{
RemoteCallbacks: RemoteCallbacks{ RemoteCallbacks: RemoteCallbacks{
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode { CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error {
return assertHostname(cert, valid, hostname, t) return assertHostname(cert, valid, hostname, t)
}, },
}, },
@ -479,14 +479,13 @@ func TestRemoteSSH(t *testing.T) {
certificateCheckCallbackCalled := false certificateCheckCallbackCalled := false
fetchOpts := FetchOptions{ fetchOpts := FetchOptions{
RemoteCallbacks: RemoteCallbacks{ RemoteCallbacks: RemoteCallbacks{
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode { CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error {
hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:]) hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:])
if hostkeyFingerprint != publicKeyFingerprint { if hostkeyFingerprint != publicKeyFingerprint {
t.Logf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint) return fmt.Errorf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint)
return ErrorCodeAuth
} }
certificateCheckCallbackCalled = true certificateCheckCallbackCalled = true
return ErrorCodeOK return nil
}, },
CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) { CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) {
if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 { if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 {

View File

@ -8,7 +8,6 @@ extern int _go_git_visit_submodule(git_repository *repo, void *fct);
import "C" import "C"
import ( import (
"errors"
"runtime" "runtime"
"unsafe" "unsafe"
) )
@ -111,7 +110,7 @@ func (c *SubmoduleCollection) Lookup(name string) (*Submodule, error) {
} }
// SubmoduleCallback is a function that is called for every submodule found in SubmoduleCollection.Foreach. // SubmoduleCallback is a function that is called for every submodule found in SubmoduleCollection.Foreach.
type SubmoduleCallback func(sub *Submodule, name string) int type SubmoduleCallback func(sub *Submodule, name string) error
type submoduleCallbackData struct { type submoduleCallbackData struct {
callback SubmoduleCallback callback SubmoduleCallback
errorTarget *error errorTarget *error
@ -126,9 +125,9 @@ func submoduleCallback(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
panic("invalid submodule visitor callback") panic("invalid submodule visitor callback")
} }
ret := data.callback(sub, C.GoString(name)) err := data.callback(sub, C.GoString(name))
if ret < 0 { if err != nil {
*data.errorTarget = errors.New(ErrorCode(ret).String()) *data.errorTarget = err
return C.int(ErrorCodeUser) return C.int(ErrorCodeUser)
} }

View File

@ -15,9 +15,9 @@ func TestSubmoduleForeach(t *testing.T) {
checkFatal(t, err) checkFatal(t, err)
i := 0 i := 0
err = repo.Submodules.Foreach(func(sub *Submodule, name string) int { err = repo.Submodules.Foreach(func(sub *Submodule, name string) error {
i++ i++
return 0 return nil
}) })
checkFatal(t, err) checkFatal(t, err)

View File

@ -8,7 +8,6 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C" import "C"
import ( import (
"errors"
"runtime" "runtime"
"unsafe" "unsafe"
) )
@ -121,7 +120,7 @@ func (t *Tree) EntryCount() uint64 {
return uint64(num) return uint64(num)
} }
type TreeWalkCallback func(string, *TreeEntry) int type TreeWalkCallback func(string, *TreeEntry) error
type treeWalkCallbackData struct { type treeWalkCallbackData struct {
callback TreeWalkCallback callback TreeWalkCallback
errorTarget *error errorTarget *error
@ -134,9 +133,9 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
panic("invalid treewalk callback") panic("invalid treewalk callback")
} }
ret := data.callback(C.GoString(_root), newTreeEntry(entry)) err := data.callback(C.GoString(_root), newTreeEntry(entry))
if ret < 0 { if err != nil {
*data.errorTarget = errors.New(ErrorCode(ret).String()) *data.errorTarget = err
return C.int(ErrorCodeUser) return C.int(ErrorCodeUser)
} }