diff --git a/diff.go b/diff.go index e022b47..d9bceac 100644 --- a/diff.go +++ b/diff.go @@ -3,6 +3,7 @@ package git /* #include +extern void _go_git_apply_init_options(git_apply_options *options); extern int _go_git_diff_foreach(git_diff *diff, int eachFile, int eachHunk, int eachLine, void *payload); extern void _go_git_setup_diff_notify_callbacks(git_diff_options* opts); extern int _go_git_diff_blobs(git_blob *old, const char *old_path, git_blob *new, const char *new_path, git_diff_options *opts, int eachFile, int eachHunk, int eachLine, void *payload); @@ -847,3 +848,76 @@ func DiffBlobs(oldBlob *Blob, oldAsPath string, newBlob *Blob, newAsPath string, return nil } + +type ApplyOptions struct { + Version uint + Flags uint + // TODO: there are some more flags, not currently used +} + +func DefaultApplyOptions() (*ApplyOptions, error) { + opts := C.git_apply_options{} + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + C._go_git_apply_init_options(&opts) + + return applyOptionsFromC(&opts), nil +} + +func (a *ApplyOptions) toC() *C.git_apply_options { + if a == nil { + return nil + } + + opts := &C.git_apply_options{ + version: C.uint(a.Version), + flags: C.uint(a.Flags), + } + + return opts +} + +func applyOptionsFromC(opts *C.git_apply_options) *ApplyOptions { + return &ApplyOptions{ + Version: uint(opts.version), + Flags: uint(opts.flags), + } +} + +type GitApplyLocation int + +const ( + GitApplyLocationWorkdir GitApplyLocation = C.GIT_APPLY_LOCATION_WORKDIR + GitApplyLocationIndex GitApplyLocation = C.GIT_APPLY_LOCATION_INDEX + GitApplyLocationBoth GitApplyLocation = C.GIT_APPLY_LOCATION_BOTH +) + +func (v *Repository) ApplyDiff(diff *Diff, location GitApplyLocation, opts *ApplyOptions) error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + ecode := C.git_apply(v.ptr, diff.ptr, C.git_apply_location_t(location), opts.toC()) + runtime.KeepAlive(v) + if ecode < 0 { + return MakeGitError(ecode) + } + + return nil +} + +func DiffFromBuffer(buffer []byte, repo *Repository) (*Diff, error) { + var diff *C.git_diff + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + ecode := C.git_diff_from_buffer(&diff, C.CString(string(buffer)), C.size_t(len(buffer))) + if ecode < 0 { + return nil, MakeGitError(ecode) + } + runtime.KeepAlive(diff) + + return newDiffFromC(diff, repo), nil +} diff --git a/diff_test.go b/diff_test.go index 6fbad51..1a6797f 100644 --- a/diff_test.go +++ b/diff_test.go @@ -236,3 +236,65 @@ func TestDiffBlobs(t *testing.T) { t.Fatalf("Bad number of lines iterated") } } + +func Test_ApplyDiff_Addfile(t *testing.T) { + repo := createTestRepo(t) + defer cleanupTestRepo(t, repo) + + seedTestRepo(t, repo) + + addFirstFileCommit, addFileTree := addAndGetTree(t, repo, "file1", `hello`) + addSecondFileCommit, addSecondFileTree := addAndGetTree(t, repo, "file2", `hello2`) + + diff, err := repo.DiffTreeToTree(addFileTree, addSecondFileTree, nil) + checkFatal(t, err) + + t.Run("check does not apply to current tree because file exists", func(t *testing.T) { + err = repo.ResetToCommit(addSecondFileCommit, ResetHard, &CheckoutOpts{}) + checkFatal(t, err) + + err = repo.ApplyDiff(diff, GitApplyLocationBoth, nil) + if err == nil { + t.Error("expecting applying patch to current repo to fail") + } + }) + + t.Run("check apply to correct commit", func(t *testing.T) { + err = repo.ResetToCommit(addFirstFileCommit, ResetHard, &CheckoutOpts{}) + checkFatal(t, err) + + err = repo.ApplyDiff(diff, GitApplyLocationBoth, nil) + checkFatal(t, err) + }) + + t.Run("check convert to raw buffer and apply", func(t *testing.T) { + err = repo.ResetToCommit(addFirstFileCommit, ResetHard, &CheckoutOpts{}) + checkFatal(t, err) + + raw, err := diff.ToBuf(DiffFormatPatch) + checkFatal(t, err) + + if len(raw) == 0 { + t.Error("empty diff created") + } + + diff2, err := DiffFromBuffer(raw, repo) + checkFatal(t, err) + + err = repo.ApplyDiff(diff2, GitApplyLocationBoth, nil) + checkFatal(t, err) + }) +} + +func addAndGetTree(t *testing.T, repo *Repository, filename string, content string) (*Commit, *Tree) { + commitId, err := commitSomething(repo, filename, content) + checkFatal(t, err) + + commit, err := repo.LookupCommit(commitId) + checkFatal(t, err) + + tree, err := commit.Tree() + checkFatal(t, err) + + return commit, tree +}