Implement a reference iterator #26

Merged
carlosmn merged 1 commits from ref-iter into master 2013-06-13 12:14:09 -05:00
3 changed files with 170 additions and 0 deletions

5
git.go
View File

@ -8,6 +8,7 @@ package git
import "C"
import (
"bytes"
"errors"
"unsafe"
"strings"
)
@ -18,6 +19,10 @@ const (
ENOTFOUND = C.GIT_ENOTFOUND
)
var (
ErrIterOver = errors.New("Iteration is over")
)
func init() {
C.git_threads_init()
}

View File

@ -111,3 +111,76 @@ func (v *Reference) Free() {
runtime.SetFinalizer(v, nil)
C.git_reference_free(v.ptr)
}
type ReferenceIterator struct {
ptr *C.git_reference_iterator
repo *Repository
}
// NewReferenceIterator creates a new iterator over reference names
func (repo *Repository) NewReferenceIterator() (*ReferenceIterator, error) {
var ptr *C.git_reference_iterator
ret := C.git_reference_iterator_new(&ptr, repo.ptr)
if ret < 0 {
return nil, LastError()
}
iter := &ReferenceIterator{repo: repo, ptr: ptr}
runtime.SetFinalizer(iter, (*ReferenceIterator).Free)
return iter, nil
}
// NewReferenceIteratorGlob creates an iterator over reference names
// that match the speicified glob. The glob is of the usual fnmatch
// type.
func (repo *Repository) NewReferenceIteratorGlob(glob string) (*ReferenceIterator, error) {
cstr := C.CString(glob)
defer C.free(unsafe.Pointer(cstr))
var ptr *C.git_reference_iterator
ret := C.git_reference_iterator_glob_new(&ptr, repo.ptr, cstr)
if ret < 0 {
return nil, LastError()
}
iter := &ReferenceIterator{repo: repo, ptr: ptr}
runtime.SetFinalizer(iter, (*ReferenceIterator).Free)
return iter, nil
}
// Next retrieves the next reference name. If the iteration is over,
// the returned error is git.ErrIterOver
func (v *ReferenceIterator) Next() (string, error) {
var ptr *C.char
ret := C.git_reference_next(&ptr, v.ptr)
if ret == ITEROVER {
return "", ErrIterOver
}
if ret < 0 {
return "", LastError()
}
return C.GoString(ptr), nil
}
// Create a channel from the iterator. You can use range on the
// returned channel to iterate over all the references. The channel
// will be closed in case any error is found.
func (v *ReferenceIterator) Iter() <-chan string {
ch := make(chan string)
go func() {
defer close(ch)
name, err := v.Next()
for err == nil {
ch <- name
name, err = v.Next()
}
}()
return ch
}
// Free the reference iterator
func (v *ReferenceIterator) Free() {
runtime.SetFinalizer(v, nil)
C.git_reference_iterator_free(v.ptr)
}

View File

@ -3,6 +3,7 @@ package git
import (
"os"
"runtime"
"sort"
"testing"
"time"
)
@ -71,6 +72,97 @@ func TestRefModification(t *testing.T) {
}
func TestIterator(t *testing.T) {
repo := createTestRepo(t)
defer os.RemoveAll(repo.Workdir())
loc, err := time.LoadLocation("Europe/Berlin")
checkFatal(t, err)
sig := &Signature{
Name: "Rand Om Hacker",
Email: "random@hacker.com",
When: time.Date(2013, 03, 06, 14, 30, 0, 0, loc),
}
idx, err := repo.Index()
checkFatal(t, err)
err = idx.AddByPath("README")
checkFatal(t, err)
treeId, err := idx.WriteTree()
checkFatal(t, err)
message := "This is a commit\n"
tree, err := repo.LookupTree(treeId)
checkFatal(t, err)
commitId, err := repo.CreateCommit("HEAD", sig, sig, message, tree)
checkFatal(t, err)
_, err = repo.CreateReference("refs/heads/one", commitId, true)
checkFatal(t, err)
_, err = repo.CreateReference("refs/heads/two", commitId, true)
checkFatal(t, err)
_, err = repo.CreateReference("refs/heads/three", commitId, true)
checkFatal(t, err)
iter, err := repo.NewReferenceIterator()
checkFatal(t, err)
var list []string
expected := []string{
"refs/heads/master",
"refs/heads/one",
"refs/heads/three",
"refs/heads/two",
}
// test some manual iteration
name, err := iter.Next()
for err == nil {
list = append(list, name)
name, err = iter.Next()
}
if err != ErrIterOver {
t.Fatal("Iteration not over")
}
sort.Strings(list)
compareStringList(t, expected, list)
// test the channel iteration
list = []string{}
iter, err = repo.NewReferenceIterator()
for name := range iter.Iter() {
list = append(list, name)
}
sort.Strings(list)
compareStringList(t, expected, list)
iter, err = repo.NewReferenceIteratorGlob("refs/heads/t*")
expected = []string{
"refs/heads/three",
"refs/heads/two",
}
list = []string{}
for name := range iter.Iter() {
list = append(list, name)
}
compareStringList(t, expected, list)
}
func compareStringList(t *testing.T, expected, actual []string) {
for i, v := range expected {
if actual[i] != v {
t.Fatalf("Bad list")
}
}
}
func checkRefType(t *testing.T, ref *Reference, kind int) {
if ref.Type() == kind {
return