The big Callback type adjustment of 2020

This change makes all callbacks that can fail return an `error`. This
makes things a lot more idiomatic.
This commit is contained in:
lhchavez 2020-12-01 19:11:41 -08:00
parent 70e5e419cf
commit 5def02a589
13 changed files with 77 additions and 89 deletions

View File

@ -7,7 +7,6 @@ extern void _go_git_populate_checkout_callbacks(git_checkout_options *opts);
*/
import "C"
import (
"errors"
"os"
"runtime"
"unsafe"
@ -49,8 +48,8 @@ const (
CheckoutUpdateSubmodulesIfChanged CheckoutStrategy = C.GIT_CHECKOUT_UPDATE_SUBMODULES_IF_CHANGED // Recursively checkout submodules if HEAD moved in super repo (NOT IMPLEMENTED)
)
type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) ErrorCode
type CheckoutProgressCallback func(path string, completed, total uint) ErrorCode
type CheckoutNotifyCallback func(why CheckoutNotifyType, path string, baseline, target, workdir DiffFile) error
type CheckoutProgressCallback func(path string, completed, total uint)
type CheckoutOptions struct {
Strategy CheckoutStrategy // Default will be a dry run
@ -116,9 +115,9 @@ func checkoutNotifyCallback(
if data.options.NotifyCallback == nil {
return C.int(ErrorCodeOK)
}
ret := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir)
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String())
err := data.options.NotifyCallback(CheckoutNotifyType(why), path, baseline, target, workdir)
if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}
return C.int(ErrorCodeOK)

View File

@ -7,12 +7,11 @@ extern void _go_git_populate_clone_callbacks(git_clone_options *opts);
*/
import "C"
import (
"errors"
"runtime"
"unsafe"
)
type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, ErrorCode)
type RemoteCreateCallback func(repo *Repository, name, url string) (*Remote, error)
type CloneOptions struct {
*CheckoutOpts
@ -71,9 +70,10 @@ func remoteCreateCallback(
panic("invalid remote create callback")
}
remote, ret := data.options.RemoteCreateCallback(repo, name, url)
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String())
remote, err := data.options.RemoteCreateCallback(repo, name, url)
if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}
if remote == nil {

View File

@ -49,15 +49,9 @@ func TestCloneWithCallback(t *testing.T) {
opts := CloneOptions{
Bare: true,
RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, ErrorCode) {
RemoteCreateCallback: func(r *Repository, name, url string) (*Remote, error) {
testPayload += 1
remote, err := r.Remotes.Create(REMOTENAME, url)
if err != nil {
return nil, ErrorCodeGeneric
}
return remote, ErrorCodeOK
return r.Remotes.Create(REMOTENAME, url)
},
}

View File

@ -230,7 +230,11 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
if !ok {
panic("invalid submodule visitor callback")
}
return (C.int)(callback(sub, C.GoString(name)))
err := callback(sub, C.GoString(name))
if err != nil {
return C.int(ErrorCodeUser)
}
return C.int(ErrorCodeOK)
}
// tree.go
@ -239,9 +243,13 @@ func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
func CallbackGitTreeWalk(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer) C.int {
root := C.GoString(_root)
if callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback); ok {
return C.int(callback(root, newTreeEntry(entry)))
} else {
callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback)
if !ok {
panic("invalid treewalk callback")
}
err := callback(root, newTreeEntry(entry))
if err != nil {
return C.int(ErrorCodeUser)
}
return C.int(ErrorCodeOK)
}

View File

@ -10,13 +10,12 @@ extern int _go_git_index_remove_all(git_index*, const git_strarray*, void*);
*/
import "C"
import (
"errors"
"fmt"
"runtime"
"unsafe"
)
type IndexMatchedPathCallback func(string, string) int
type IndexMatchedPathCallback func(string, string) error
type indexMatchedPathCallbackData struct {
callback IndexMatchedPathCallback
errorTarget *error
@ -343,9 +342,9 @@ func indexMatchedPathCallback(cPath, cMatchedPathspec *C.char, payload unsafe.Po
panic("invalid matched path callback")
}
ret := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec))
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String())
err := data.callback(C.GoString(cPath), C.GoString(cMatchedPathspec))
if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}

View File

@ -223,9 +223,9 @@ func TestIndexAddAllCallback(t *testing.T) {
checkFatal(t, err)
cbPath := ""
err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) int {
err = idx.AddAll([]string{}, IndexAddDefault, func(p, mP string) error {
cbPath = p
return 0
return nil
})
checkFatal(t, err)
if cbPath != "README" {

View File

@ -33,9 +33,9 @@ func TestIndexerOutOfOrder(t *testing.T) {
defer os.RemoveAll(tmpPath)
var finalStats TransferProgress
idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) ErrorCode {
idx, err := NewIndexer(tmpPath, nil, func(stats TransferProgress) error {
finalStats = stats
return ErrorCodeOK
return nil
})
checkFatal(t, err)
defer idx.Free()

View File

@ -167,9 +167,9 @@ func TestOdbWritepack(t *testing.T) {
checkFatal(t, err)
var finalStats TransferProgress
writepack, err := odb.NewWritePack(func(stats TransferProgress) ErrorCode {
writepack, err := odb.NewWritePack(func(stats TransferProgress) error {
finalStats = stats
return ErrorCodeOK
return nil
})
checkFatal(t, err)
defer writepack.Free()

View File

@ -69,15 +69,15 @@ const (
ConnectDirectionPush ConnectDirection = C.GIT_DIRECTION_PUSH
)
type TransportMessageCallback func(str string) ErrorCode
type CompletionCallback func(RemoteCompletion) ErrorCode
type TransportMessageCallback func(str string) error
type CompletionCallback func(RemoteCompletion) error
type CredentialsCallback func(url string, username_from_url string, allowed_types CredentialType) (*Credential, error)
type TransferProgressCallback func(stats TransferProgress) ErrorCode
type UpdateTipsCallback func(refname string, a *Oid, b *Oid) ErrorCode
type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) ErrorCode
type PackbuilderProgressCallback func(stage int32, current, total uint32) ErrorCode
type PushTransferProgressCallback func(current, total uint32, bytes uint) ErrorCode
type PushUpdateReferenceCallback func(refname, status string) ErrorCode
type TransferProgressCallback func(stats TransferProgress) error
type UpdateTipsCallback func(refname string, a *Oid, b *Oid) error
type CertificateCheckCallback func(cert *Certificate, valid bool, hostname string) error
type PackbuilderProgressCallback func(stage int32, current, total uint32) error
type PushTransferProgressCallback func(current, total uint32, bytes uint) error
type PushUpdateReferenceCallback func(refname, status string) error
type RemoteCallbacks struct {
SidebandProgressCallback TransportMessageCallback
@ -329,10 +329,8 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h
if data.callbacks.SidebandProgressCallback == nil {
return C.int(ErrorCodeOK)
}
str := C.GoStringN(_str, _len)
ret := data.callbacks.SidebandProgressCallback(str)
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.SidebandProgressCallback(C.GoStringN(_str, _len))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -342,14 +340,13 @@ func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, h
}
//export completionCallback
func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, handle unsafe.Pointer) C.int {
func completionCallback(errorMessage **C.char, completionType C.git_remote_completion_type, handle unsafe.Pointer) C.int {
data := pointerHandles.Get(handle).(*remoteCallbacksData)
if data.callbacks.CompletionCallback == nil {
return C.int(ErrorCodeOK)
}
ret := data.callbacks.CompletionCallback(RemoteCompletion(completion_type))
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.CompletionCallback(RemoteCompletion(completionType))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -396,9 +393,8 @@ func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progr
if data.callbacks.TransferProgressCallback == nil {
return C.int(ErrorCodeOK)
}
ret := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats))
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -422,9 +418,8 @@ func updateTipsCallback(
refname := C.GoString(_refname)
a := newOidFromC(_a)
b := newOidFromC(_b)
ret := data.callbacks.UpdateTipsCallback(refname, a, b)
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.UpdateTipsCallback(refname, a, b)
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -489,9 +484,8 @@ func certificateCheckCallback(
return setCallbackError(errorMessage, err)
}
ret := data.callbacks.CertificateCheckCallback(&cert, valid, host)
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.CertificateCheckCallback(&cert, valid, host)
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -507,9 +501,8 @@ func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.u
return C.int(ErrorCodeOK)
}
ret := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total))
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -525,9 +518,8 @@ func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint,
return C.int(ErrorCodeOK)
}
ret := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes))
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}
@ -543,9 +535,8 @@ func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char,
return C.int(ErrorCodeOK)
}
ret := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status))
if ret < 0 {
err := errors.New(ErrorCode(ret).String())
err := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status))
if err != nil {
if data.errorTarget != nil {
*data.errorTarget = err
}

View File

@ -38,13 +38,13 @@ func TestListRemotes(t *testing.T) {
compareStringList(t, expected, actual)
}
func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) ErrorCode {
func assertHostname(cert *Certificate, valid bool, hostname string, t *testing.T) error {
if hostname != "github.com" {
t.Fatal("Hostname does not match")
return ErrorCodeUser
t.Fatal("hostname does not match")
return errors.New("hostname does not match")
}
return ErrorCodeOK
return nil
}
func TestCertificateCheck(t *testing.T) {
@ -58,7 +58,7 @@ func TestCertificateCheck(t *testing.T) {
options := FetchOptions{
RemoteCallbacks: RemoteCallbacks{
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode {
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error {
return assertHostname(cert, valid, hostname, t)
},
},
@ -479,14 +479,13 @@ func TestRemoteSSH(t *testing.T) {
certificateCheckCallbackCalled := false
fetchOpts := FetchOptions{
RemoteCallbacks: RemoteCallbacks{
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode {
CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) error {
hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:])
if hostkeyFingerprint != publicKeyFingerprint {
t.Logf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint)
return ErrorCodeAuth
return fmt.Errorf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint)
}
certificateCheckCallbackCalled = true
return ErrorCodeOK
return nil
},
CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) {
if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 {

View File

@ -8,7 +8,6 @@ extern int _go_git_visit_submodule(git_repository *repo, void *fct);
import "C"
import (
"errors"
"runtime"
"unsafe"
)
@ -111,7 +110,7 @@ func (c *SubmoduleCollection) Lookup(name string) (*Submodule, error) {
}
// SubmoduleCallback is a function that is called for every submodule found in SubmoduleCollection.Foreach.
type SubmoduleCallback func(sub *Submodule, name string) int
type SubmoduleCallback func(sub *Submodule, name string) error
type submoduleCallbackData struct {
callback SubmoduleCallback
errorTarget *error
@ -126,9 +125,9 @@ func submoduleCallback(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer)
panic("invalid submodule visitor callback")
}
ret := data.callback(sub, C.GoString(name))
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String())
err := data.callback(sub, C.GoString(name))
if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}

View File

@ -15,9 +15,9 @@ func TestSubmoduleForeach(t *testing.T) {
checkFatal(t, err)
i := 0
err = repo.Submodules.Foreach(func(sub *Submodule, name string) int {
err = repo.Submodules.Foreach(func(sub *Submodule, name string) error {
i++
return 0
return nil
})
checkFatal(t, err)

View File

@ -8,7 +8,6 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C"
import (
"errors"
"runtime"
"unsafe"
)
@ -121,7 +120,7 @@ func (t *Tree) EntryCount() uint64 {
return uint64(num)
}
type TreeWalkCallback func(string, *TreeEntry) int
type TreeWalkCallback func(string, *TreeEntry) error
type treeWalkCallbackData struct {
callback TreeWalkCallback
errorTarget *error
@ -134,9 +133,9 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
panic("invalid treewalk callback")
}
ret := data.callback(C.GoString(_root), newTreeEntry(entry))
if ret < 0 {
*data.errorTarget = errors.New(ErrorCode(ret).String())
err := data.callback(C.GoString(_root), newTreeEntry(entry))
if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}