diff --git a/blob.go b/blob.go index 5a33bd8..b1fc78a 100644 --- a/blob.go +++ b/blob.go @@ -60,8 +60,13 @@ type BlobCallbackData struct { } //export blobChunkCb -func blobChunkCb(buffer *C.char, maxLen C.size_t, payload unsafe.Pointer) int { - data := (*BlobCallbackData)(payload) +func blobChunkCb(buffer *C.char, maxLen C.size_t, handle unsafe.Pointer) int { + payload := pointerHandles.Get(handle) + data, ok := payload.(*BlobCallbackData) + if !ok { + panic("could not retrieve blob callback data") + } + goBuf, err := data.Callback(int(maxLen)) if err == io.EOF { return 0 @@ -83,8 +88,12 @@ func (repo *Repository) CreateBlobFromChunks(hintPath string, callback BlobChunk defer C.free(unsafe.Pointer(chintPath)) } oid := C.git_oid{} + payload := &BlobCallbackData{Callback: callback} - ecode := C._go_git_blob_create_fromchunks(&oid, repo.ptr, chintPath, unsafe.Pointer(payload)) + handle := pointerHandles.Track(payload) + defer pointerHandles.Untrack(handle) + + ecode := C._go_git_blob_create_fromchunks(&oid, repo.ptr, chintPath, handle) if payload.Error != nil { return nil, payload.Error } diff --git a/diff.go b/diff.go index 23469f2..5e03175 100644 --- a/diff.go +++ b/diff.go @@ -265,7 +265,11 @@ func (diff *Diff) ForEach(cbFile DiffForEachFileCallback, detail DiffDetail) err data := &diffForEachData{ FileCallback: cbFile, } - ecode := C._go_git_diff_foreach(diff.ptr, 1, intHunks, intLines, unsafe.Pointer(data)) + + handle := pointerHandles.Track(data) + defer pointerHandles.Untrack(handle) + + ecode := C._go_git_diff_foreach(diff.ptr, 1, intHunks, intLines, handle) if ecode < 0 { return data.Error } @@ -273,8 +277,12 @@ func (diff *Diff) ForEach(cbFile DiffForEachFileCallback, detail DiffDetail) err } //export diffForEachFileCb -func diffForEachFileCb(delta *C.git_diff_delta, progress C.float, payload unsafe.Pointer) int { - data := (*diffForEachData)(payload) +func diffForEachFileCb(delta *C.git_diff_delta, progress C.float, handle unsafe.Pointer) int { + payload := pointerHandles.Get(handle) + data, ok := payload.(*diffForEachData) + if !ok { + panic("could not retrieve data for handle") + } data.HunkCallback = nil if data.FileCallback != nil { @@ -292,8 +300,12 @@ func diffForEachFileCb(delta *C.git_diff_delta, progress C.float, payload unsafe type DiffForEachHunkCallback func(DiffHunk) (DiffForEachLineCallback, error) //export diffForEachHunkCb -func diffForEachHunkCb(delta *C.git_diff_delta, hunk *C.git_diff_hunk, payload unsafe.Pointer) int { - data := (*diffForEachData)(payload) +func diffForEachHunkCb(delta *C.git_diff_delta, hunk *C.git_diff_hunk, handle unsafe.Pointer) int { + payload := pointerHandles.Get(handle) + data, ok := payload.(*diffForEachData) + if !ok { + panic("could not retrieve data for handle") + } data.LineCallback = nil if data.HunkCallback != nil { @@ -311,9 +323,12 @@ func diffForEachHunkCb(delta *C.git_diff_delta, hunk *C.git_diff_hunk, payload u type DiffForEachLineCallback func(DiffLine) error //export diffForEachLineCb -func diffForEachLineCb(delta *C.git_diff_delta, hunk *C.git_diff_hunk, line *C.git_diff_line, payload unsafe.Pointer) int { - - data := (*diffForEachData)(payload) +func diffForEachLineCb(delta *C.git_diff_delta, hunk *C.git_diff_hunk, line *C.git_diff_line, handle unsafe.Pointer) int { + payload := pointerHandles.Get(handle) + data, ok := payload.(*diffForEachData) + if !ok { + panic("could not retrieve data for handle") + } err := data.LineCallback(diffLineFromC(delta, hunk, line)) if err != nil { @@ -479,9 +494,15 @@ type diffNotifyData struct { } //export diffNotifyCb -func diffNotifyCb(_diff_so_far unsafe.Pointer, delta_to_add *C.git_diff_delta, matched_pathspec *C.char, payload unsafe.Pointer) int { +func diffNotifyCb(_diff_so_far unsafe.Pointer, delta_to_add *C.git_diff_delta, matched_pathspec *C.char, handle unsafe.Pointer) int { diff_so_far := (*C.git_diff)(_diff_so_far) - data := (*diffNotifyData)(payload) + + payload := pointerHandles.Get(handle) + data, ok := payload.(*diffNotifyData) + if !ok { + panic("could not retrieve data for handle") + } + if data != nil { if data.Diff == nil { data.Diff = newDiffFromC(diff_so_far) @@ -507,6 +528,7 @@ func diffOptionsToC(opts *DiffOptions) (copts *C.git_diff_options, notifyData *d notifyData = &diffNotifyData{ Callback: opts.NotifyCallback, } + if opts.Pathspec != nil { cpathspec.count = C.size_t(len(opts.Pathspec)) cpathspec.strings = makeCStringsFromStrings(opts.Pathspec) @@ -527,7 +549,7 @@ func diffOptionsToC(opts *DiffOptions) (copts *C.git_diff_options, notifyData *d if opts.NotifyCallback != nil { C._go_git_setup_diff_notify_callbacks(copts) - copts.notify_payload = unsafe.Pointer(notifyData) + copts.notify_payload = pointerHandles.Track(notifyData) } } return @@ -539,6 +561,9 @@ func freeDiffOptions(copts *C.git_diff_options) { freeStrarray(&cpathspec) C.free(unsafe.Pointer(copts.old_prefix)) C.free(unsafe.Pointer(copts.new_prefix)) + if copts.notify_payload != nil { + pointerHandles.Untrack(copts.notify_payload) + } } } diff --git a/git.go b/git.go index 9496d2d..4f1a65e 100644 --- a/git.go +++ b/git.go @@ -93,7 +93,11 @@ var ( ErrInvalid = errors.New("Invalid state for operation") ) +var pointerHandles *HandleList + func init() { + pointerHandles = NewHandleList() + C.git_libgit2_init() // This is not something we should be doing, as we may be diff --git a/handles.go b/handles.go new file mode 100644 index 0000000..ec62a48 --- /dev/null +++ b/handles.go @@ -0,0 +1,84 @@ +package git + +import ( + "fmt" + "sync" + "unsafe" +) + +type HandleList struct { + sync.RWMutex + // stores the Go pointers + handles []interface{} + // indicates which indices are in use + set map[int]bool +} + +func NewHandleList() *HandleList { + return &HandleList{ + handles: make([]interface{}, 5), + set: make(map[int]bool), + } +} + +// findUnusedSlot finds the smallest-index empty space in our +// list. You must only run this function while holding a write lock. +func (v *HandleList) findUnusedSlot() int { + for i := 1; i < len(v.handles); i++ { + isUsed := v.set[i] + if !isUsed { + return i + } + } + + // reaching here means we've run out of entries so append and + // return the new index, which is equal to the old length. + slot := len(v.handles) + v.handles = append(v.handles, nil) + + return slot +} + +// Track adds the given pointer to the list of pointers to track and +// returns a pointer value which can be passed to C as an opaque +// pointer. +func (v *HandleList) Track(pointer interface{}) unsafe.Pointer { + v.Lock() + + slot := v.findUnusedSlot() + v.handles[slot] = pointer + v.set[slot] = true + + v.Unlock() + + return unsafe.Pointer(&slot) +} + +// Untrack stops tracking the pointer given by the handle +func (v *HandleList) Untrack(handle unsafe.Pointer) { + slot := *(*int)(handle) + + v.Lock() + + v.handles[slot] = nil + delete(v.set, slot) + + v.Unlock() +} + +// Get retrieves the pointer from the given handle +func (v *HandleList) Get(handle unsafe.Pointer) interface{} { + slot := *(*int)(handle) + + v.RLock() + + if _, ok := v.set[slot]; !ok { + panic(fmt.Sprintf("invalid pointer handle: %p", handle)) + } + + ptr := v.handles[slot] + + v.RUnlock() + + return ptr +} diff --git a/index.go b/index.go index 1a899f1..c1bfb74 100644 --- a/index.go +++ b/index.go @@ -162,16 +162,17 @@ func (v *Index) AddAll(pathspecs []string, flags IndexAddOpts, callback IndexMat runtime.LockOSThread() defer runtime.UnlockOSThread() - var cb *IndexMatchedPathCallback + var handle unsafe.Pointer if callback != nil { - cb = &callback + handle = pointerHandles.Track(callback) + defer pointerHandles.Untrack(handle) } ret := C._go_git_index_add_all( v.ptr, &cpathspecs, C.uint(flags), - unsafe.Pointer(cb), + handle, ) if ret < 0 { return MakeGitError(ret) @@ -188,15 +189,16 @@ func (v *Index) UpdateAll(pathspecs []string, callback IndexMatchedPathCallback) runtime.LockOSThread() defer runtime.UnlockOSThread() - var cb *IndexMatchedPathCallback + var handle unsafe.Pointer if callback != nil { - cb = &callback + handle = pointerHandles.Track(callback) + defer pointerHandles.Untrack(handle) } ret := C._go_git_index_update_all( v.ptr, &cpathspecs, - unsafe.Pointer(cb), + handle, ) if ret < 0 { return MakeGitError(ret) @@ -213,15 +215,16 @@ func (v *Index) RemoveAll(pathspecs []string, callback IndexMatchedPathCallback) runtime.LockOSThread() defer runtime.UnlockOSThread() - var cb *IndexMatchedPathCallback + var handle unsafe.Pointer if callback != nil { - cb = &callback + handle = pointerHandles.Track(callback) + defer pointerHandles.Untrack(handle) } ret := C._go_git_index_remove_all( v.ptr, &cpathspecs, - unsafe.Pointer(cb), + handle, ) if ret < 0 { return MakeGitError(ret) @@ -231,8 +234,11 @@ func (v *Index) RemoveAll(pathspecs []string, callback IndexMatchedPathCallback) //export indexMatchedPathCallback func indexMatchedPathCallback(cPath, cMatchedPathspec *C.char, payload unsafe.Pointer) int { - callback := (*IndexMatchedPathCallback)(payload) - return (*callback)(C.GoString(cPath), C.GoString(cMatchedPathspec)) + if callback, ok := pointerHandles.Get(payload).(IndexMatchedPathCallback); ok { + return callback(C.GoString(cPath), C.GoString(cMatchedPathspec)) + } else { + panic("invalid matched path callback") + } } func (v *Index) RemoveByPath(path string) error { diff --git a/odb.go b/odb.go index ba03860..6b21329 100644 --- a/odb.go +++ b/odb.go @@ -98,8 +98,12 @@ type foreachData struct { } //export odbForEachCb -func odbForEachCb(id *C.git_oid, payload unsafe.Pointer) int { - data := (*foreachData)(payload) +func odbForEachCb(id *C.git_oid, handle unsafe.Pointer) int { + data, ok := pointerHandles.Get(handle).(*foreachData) + + if !ok { + panic("could not retrieve handle") + } err := data.callback(newOidFromC(id)) if err != nil { @@ -119,7 +123,10 @@ func (v *Odb) ForEach(callback OdbForEachCallback) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - ret := C._go_git_odb_foreach(v.ptr, unsafe.Pointer(&data)) + handle := pointerHandles.Track(&data) + defer pointerHandles.Untrack(handle) + + ret := C._go_git_odb_foreach(v.ptr, handle) if ret == C.GIT_EUSER { return data.err } else if ret < 0 { diff --git a/odb_test.go b/odb_test.go index 55ed297..2fb6840 100644 --- a/odb_test.go +++ b/odb_test.go @@ -81,7 +81,7 @@ func TestOdbForeach(t *testing.T) { checkFatal(t, err) if count != expect { - t.Fatalf("Expected %v objects, got %v") + t.Fatalf("Expected %v objects, got %v", expect, count) } expect = 1 diff --git a/packbuilder.go b/packbuilder.go index 54a8390..4dc352c 100644 --- a/packbuilder.go +++ b/packbuilder.go @@ -110,8 +110,13 @@ type packbuilderCbData struct { } //export packbuilderForEachCb -func packbuilderForEachCb(buf unsafe.Pointer, size C.size_t, payload unsafe.Pointer) int { - data := (*packbuilderCbData)(payload) +func packbuilderForEachCb(buf unsafe.Pointer, size C.size_t, handle unsafe.Pointer) int { + payload := pointerHandles.Get(handle) + data, ok := payload.(*packbuilderCbData) + if !ok { + panic("could not get packbuilder CB data") + } + slice := C.GoBytes(buf, C.int(size)) err := data.callback(slice) @@ -130,11 +135,13 @@ func (pb *Packbuilder) ForEach(callback PackbuilderForeachCallback) error { callback: callback, err: nil, } + handle := pointerHandles.Track(&data) + defer pointerHandles.Untrack(handle) runtime.LockOSThread() defer runtime.UnlockOSThread() - err := C._go_git_packbuilder_foreach(pb.ptr, unsafe.Pointer(&data)) + err := C._go_git_packbuilder_foreach(pb.ptr, handle) if err == C.GIT_EUSER { return data.err } diff --git a/submodule.go b/submodule.go index 7c6c922..3882462 100644 --- a/submodule.go +++ b/submodule.go @@ -97,17 +97,24 @@ func (repo *Repository) LookupSubmodule(name string) (*Submodule, error) { type SubmoduleCbk func(sub *Submodule, name string) int //export SubmoduleVisitor -func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, cfct unsafe.Pointer) C.int { +func SubmoduleVisitor(csub unsafe.Pointer, name *C.char, handle unsafe.Pointer) C.int { sub := &Submodule{(*C.git_submodule)(csub)} - fct := *(*SubmoduleCbk)(cfct) - return (C.int)(fct(sub, C.GoString(name))) + + if callback, ok := pointerHandles.Get(handle).(SubmoduleCbk); ok { + return (C.int)(callback(sub, C.GoString(name))) + } else { + panic("invalid submodule visitor callback") + } } func (repo *Repository) ForeachSubmodule(cbk SubmoduleCbk) error { runtime.LockOSThread() defer runtime.UnlockOSThread() - ret := C._go_git_visit_submodule(repo.ptr, unsafe.Pointer(&cbk)) + handle := pointerHandles.Track(cbk) + defer pointerHandles.Untrack(handle) + + ret := C._go_git_visit_submodule(repo.ptr, handle) if ret < 0 { return MakeGitError(ret) } diff --git a/tree.go b/tree.go index c18d02a..aad2c8d 100644 --- a/tree.go +++ b/tree.go @@ -90,22 +90,28 @@ func (t Tree) EntryCount() uint64 { type TreeWalkCallback func(string, *TreeEntry) int //export CallbackGitTreeWalk -func CallbackGitTreeWalk(_root unsafe.Pointer, _entry unsafe.Pointer, ptr unsafe.Pointer) C.int { - root := C.GoString((*C.char)(_root)) +func CallbackGitTreeWalk(_root *C.char, _entry unsafe.Pointer, ptr unsafe.Pointer) C.int { + root := C.GoString(_root) entry := (*C.git_tree_entry)(_entry) - callback := *(*TreeWalkCallback)(ptr) - return C.int(callback(root, newTreeEntry(entry))) + if callback, ok := pointerHandles.Get(ptr).(TreeWalkCallback); ok { + return C.int(callback(root, newTreeEntry(entry))) + } else { + panic("invalid treewalk callback") + } } func (t Tree) Walk(callback TreeWalkCallback) error { runtime.LockOSThread() defer runtime.UnlockOSThread() + ptr := pointerHandles.Track(callback) + defer pointerHandles.Untrack(ptr) + err := C._go_git_treewalk( t.cast_ptr, C.GIT_TREEWALK_PRE, - unsafe.Pointer(&callback), + ptr, ) if err < 0 {