From c5295d3538f62961d57f4a1f8d42c8accb1a2c70 Mon Sep 17 00:00:00 2001
From: lhchavez <lhchavez@lhchavez.com>
Date: Sun, 13 Dec 2020 10:35:34 -0800
Subject: [PATCH] Support more MergeBase functions (#720)

This change adds support for MergeBaseMany, MergeBasesMany, and
MergeBaseOctopus.

(cherry picked from commit 698ddfb4ac4d8d7d66f68e36ceabcabc5426002b)
---
 merge.go      | 76 +++++++++++++++++++++++++++++++++++++++++++++++++--
 merge_test.go | 61 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 134 insertions(+), 3 deletions(-)

diff --git a/merge.go b/merge.go
index 19bfd87..4b87168 100644
--- a/merge.go
+++ b/merge.go
@@ -334,7 +334,7 @@ func (r *Repository) MergeBases(one, two *Oid) ([]*Oid, error) {
 	runtime.KeepAlive(one)
 	runtime.KeepAlive(two)
 	if ret < 0 {
-		return make([]*Oid, 0), MakeGitError(ret)
+		return nil, MakeGitError(ret)
 	}
 
 	oids := make([]*Oid, coids.count)
@@ -353,8 +353,78 @@ func (r *Repository) MergeBases(one, two *Oid) ([]*Oid, error) {
 	return oids, nil
 }
 
-//TODO: int git_merge_base_many(git_oid *out, git_repository *repo, size_t length, const git_oid input_array[]);
-//TODO: GIT_EXTERN(int) git_merge_base_octopus(git_oid *out,git_repository *repo,size_t length,const git_oid input_array[]);
+// MergeBaseMany finds a merge base given a list of commits.
+func (r *Repository) MergeBaseMany(oids []*Oid) (*Oid, error) {
+	coids := make([]C.git_oid, len(oids))
+	for i := 0; i < len(oids); i++ {
+		coids[i] = *oids[i].toC()
+	}
+
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	var oid C.git_oid
+	ret := C.git_merge_base_many(&oid, r.ptr, C.size_t(len(oids)), &coids[0])
+	runtime.KeepAlive(r)
+	runtime.KeepAlive(coids)
+	if ret < 0 {
+		return nil, MakeGitError(ret)
+	}
+	return newOidFromC(&oid), nil
+}
+
+// MergeBasesMany finds all merge bases given a list of commits.
+func (r *Repository) MergeBasesMany(oids []*Oid) ([]*Oid, error) {
+	inCoids := make([]C.git_oid, len(oids))
+	for i := 0; i < len(oids); i++ {
+		inCoids[i] = *oids[i].toC()
+	}
+
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	var outCoids C.git_oidarray
+	ret := C.git_merge_bases_many(&outCoids, r.ptr, C.size_t(len(oids)), &inCoids[0])
+	runtime.KeepAlive(r)
+	runtime.KeepAlive(inCoids)
+	if ret < 0 {
+		return nil, MakeGitError(ret)
+	}
+
+	outOids := make([]*Oid, outCoids.count)
+	hdr := reflect.SliceHeader{
+		Data: uintptr(unsafe.Pointer(outCoids.ids)),
+		Len:  int(outCoids.count),
+		Cap:  int(outCoids.count),
+	}
+	goSlice := *(*[]C.git_oid)(unsafe.Pointer(&hdr))
+
+	for i, cid := range goSlice {
+		outOids[i] = newOidFromC(&cid)
+	}
+
+	return outOids, nil
+}
+
+// MergeBaseOctopus finds a merge base in preparation for an octopus merge.
+func (r *Repository) MergeBaseOctopus(oids []*Oid) (*Oid, error) {
+	coids := make([]C.git_oid, len(oids))
+	for i := 0; i < len(oids); i++ {
+		coids[i] = *oids[i].toC()
+	}
+
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	var oid C.git_oid
+	ret := C.git_merge_base_octopus(&oid, r.ptr, C.size_t(len(oids)), &coids[0])
+	runtime.KeepAlive(r)
+	runtime.KeepAlive(coids)
+	if ret < 0 {
+		return nil, MakeGitError(ret)
+	}
+	return newOidFromC(&oid), nil
+}
 
 type MergeFileResult struct {
 	Automergeable bool
diff --git a/merge_test.go b/merge_test.go
index 319bef3..d49d07c 100644
--- a/merge_test.go
+++ b/merge_test.go
@@ -163,6 +163,15 @@ func TestMergeBase(t *testing.T) {
 	if mergeBase.Cmp(commitAId) != 0 {
 		t.Fatalf("unexpected merge base")
 	}
+}
+
+func TestMergeBases(t *testing.T) {
+	t.Parallel()
+	repo := createTestRepo(t)
+	defer cleanupTestRepo(t, repo)
+
+	commitAId, _ := seedTestRepo(t, repo)
+	commitBId, _ := appendCommit(t, repo)
 
 	mergeBases, err := repo.MergeBases(commitAId, commitBId)
 	checkFatal(t, err)
@@ -176,6 +185,58 @@ func TestMergeBase(t *testing.T) {
 	}
 }
 
+func TestMergeBaseMany(t *testing.T) {
+	t.Parallel()
+	repo := createTestRepo(t)
+	defer cleanupTestRepo(t, repo)
+
+	commitAId, _ := seedTestRepo(t, repo)
+	commitBId, _ := appendCommit(t, repo)
+
+	mergeBase, err := repo.MergeBaseMany([]*Oid{commitAId, commitBId})
+	checkFatal(t, err)
+
+	if mergeBase.Cmp(commitAId) != 0 {
+		t.Fatalf("unexpected merge base")
+	}
+}
+
+func TestMergeBasesMany(t *testing.T) {
+	t.Parallel()
+	repo := createTestRepo(t)
+	defer cleanupTestRepo(t, repo)
+
+	commitAId, _ := seedTestRepo(t, repo)
+	commitBId, _ := appendCommit(t, repo)
+
+	mergeBases, err := repo.MergeBasesMany([]*Oid{commitAId, commitBId})
+	checkFatal(t, err)
+
+	if len(mergeBases) != 1 {
+		t.Fatalf("expected merge bases len to be 1, got %v", len(mergeBases))
+	}
+
+	if mergeBases[0].Cmp(commitAId) != 0 {
+		t.Fatalf("unexpected merge base")
+	}
+}
+
+func TestMergeBaseOctopus(t *testing.T) {
+	t.Parallel()
+	repo := createTestRepo(t)
+	defer cleanupTestRepo(t, repo)
+
+	commitAId, _ := seedTestRepo(t, repo)
+	commitBId, _ := appendCommit(t, repo)
+
+	mergeBase, err := repo.MergeBaseOctopus([]*Oid{commitAId, commitBId})
+	checkFatal(t, err)
+
+	if mergeBase.Cmp(commitAId) != 0 {
+		t.Fatalf("unexpected merge base")
+	}
+}
+
 func compareBytes(t *testing.T, expected, actual []byte) {
 	for i, v := range expected {
 		if actual[i] != v {
-- 
2.45.2