From 931f187301d0c262a4ecdded891e4fed9387b4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn=20Nieto?= Date: Tue, 21 May 2013 15:14:26 +0200 Subject: [PATCH] Implement a reference iterator Wrap the reference iterators, and provide a Iter() function to get them through a channel. --- git.go | 5 +++ reference.go | 73 +++++++++++++++++++++++++++++++++++++ reference_test.go | 92 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+) diff --git a/git.go b/git.go index fdc640a..19f4a32 100644 --- a/git.go +++ b/git.go @@ -8,6 +8,7 @@ package git import "C" import ( "bytes" + "errors" "unsafe" "strings" ) @@ -18,6 +19,10 @@ const ( ENOTFOUND = C.GIT_ENOTFOUND ) +var ( + ErrIterOver = errors.New("Iteration is over") +) + func init() { C.git_threads_init() } diff --git a/reference.go b/reference.go index 820d166..93ab7de 100644 --- a/reference.go +++ b/reference.go @@ -111,3 +111,76 @@ func (v *Reference) Free() { runtime.SetFinalizer(v, nil) C.git_reference_free(v.ptr) } + +type ReferenceIterator struct { + ptr *C.git_reference_iterator + repo *Repository +} + +// NewReferenceIterator creates a new iterator over reference names +func (repo *Repository) NewReferenceIterator() (*ReferenceIterator, error) { + var ptr *C.git_reference_iterator + ret := C.git_reference_iterator_new(&ptr, repo.ptr) + if ret < 0 { + return nil, LastError() + } + + iter := &ReferenceIterator{repo: repo, ptr: ptr} + runtime.SetFinalizer(iter, (*ReferenceIterator).Free) + return iter, nil +} + +// NewReferenceIteratorGlob creates an iterator over reference names +// that match the speicified glob. The glob is of the usual fnmatch +// type. +func (repo *Repository) NewReferenceIteratorGlob(glob string) (*ReferenceIterator, error) { + cstr := C.CString(glob) + defer C.free(unsafe.Pointer(cstr)) + var ptr *C.git_reference_iterator + ret := C.git_reference_iterator_glob_new(&ptr, repo.ptr, cstr) + if ret < 0 { + return nil, LastError() + } + + iter := &ReferenceIterator{repo: repo, ptr: ptr} + runtime.SetFinalizer(iter, (*ReferenceIterator).Free) + return iter, nil +} + +// Next retrieves the next reference name. If the iteration is over, +// the returned error is git.ErrIterOver +func (v *ReferenceIterator) Next() (string, error) { + var ptr *C.char + ret := C.git_reference_next(&ptr, v.ptr) + if ret == ITEROVER { + return "", ErrIterOver + } + if ret < 0 { + return "", LastError() + } + + return C.GoString(ptr), nil +} + +// Create a channel from the iterator. You can use range on the +// returned channel to iterate over all the references. The channel +// will be closed in case any error is found. +func (v *ReferenceIterator) Iter() <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + name, err := v.Next() + for err == nil { + ch <- name + name, err = v.Next() + } + }() + + return ch +} + +// Free the reference iterator +func (v *ReferenceIterator) Free() { + runtime.SetFinalizer(v, nil) + C.git_reference_iterator_free(v.ptr) +} diff --git a/reference_test.go b/reference_test.go index 8043833..a03f638 100644 --- a/reference_test.go +++ b/reference_test.go @@ -3,6 +3,7 @@ package git import ( "os" "runtime" + "sort" "testing" "time" ) @@ -71,6 +72,97 @@ func TestRefModification(t *testing.T) { } +func TestIterator(t *testing.T) { + repo := createTestRepo(t) + defer os.RemoveAll(repo.Workdir()) + + loc, err := time.LoadLocation("Europe/Berlin") + checkFatal(t, err) + sig := &Signature{ + Name: "Rand Om Hacker", + Email: "random@hacker.com", + When: time.Date(2013, 03, 06, 14, 30, 0, 0, loc), + } + + idx, err := repo.Index() + checkFatal(t, err) + err = idx.AddByPath("README") + checkFatal(t, err) + treeId, err := idx.WriteTree() + checkFatal(t, err) + + message := "This is a commit\n" + tree, err := repo.LookupTree(treeId) + checkFatal(t, err) + commitId, err := repo.CreateCommit("HEAD", sig, sig, message, tree) + checkFatal(t, err) + + _, err = repo.CreateReference("refs/heads/one", commitId, true) + checkFatal(t, err) + + _, err = repo.CreateReference("refs/heads/two", commitId, true) + checkFatal(t, err) + + _, err = repo.CreateReference("refs/heads/three", commitId, true) + checkFatal(t, err) + + iter, err := repo.NewReferenceIterator() + checkFatal(t, err) + + var list []string + expected := []string{ + "refs/heads/master", + "refs/heads/one", + "refs/heads/three", + "refs/heads/two", + } + + // test some manual iteration + name, err := iter.Next() + for err == nil { + list = append(list, name) + name, err = iter.Next() + } + if err != ErrIterOver { + t.Fatal("Iteration not over") + } + + + sort.Strings(list) + compareStringList(t, expected, list) + + // test the channel iteration + list = []string{} + iter, err = repo.NewReferenceIterator() + for name := range iter.Iter() { + list = append(list, name) + } + + sort.Strings(list) + compareStringList(t, expected, list) + + iter, err = repo.NewReferenceIteratorGlob("refs/heads/t*") + expected = []string{ + "refs/heads/three", + "refs/heads/two", + } + + list = []string{} + for name := range iter.Iter() { + list = append(list, name) + } + + compareStringList(t, expected, list) +} + +func compareStringList(t *testing.T, expected, actual []string) { + for i, v := range expected { + if actual[i] != v { + t.Fatalf("Bad list") + } + } +} + func checkRefType(t *testing.T, ref *Reference, kind int) { if ref.Type() == kind { return -- 2.45.2