diff --git a/reference.go b/reference.go index b5f5e47..7b5e3c2 100644 --- a/reference.go +++ b/reference.go @@ -488,3 +488,42 @@ func ReferenceIsValidName(name string) bool { } return false } + +const ( + // This should match GIT_REFNAME_MAX in src/refs.h + _refnameMaxLength = C.size_t(1024) +) + +type ReferenceFormat uint + +const ( + ReferenceFormatNormal ReferenceFormat = C.GIT_REFERENCE_FORMAT_NORMAL + ReferenceFormatAllowOnelevel ReferenceFormat = C.GIT_REFERENCE_FORMAT_ALLOW_ONELEVEL + ReferenceFormatRefspecPattern ReferenceFormat = C.GIT_REFERENCE_FORMAT_REFSPEC_PATTERN + ReferenceFormatRefspecShorthand ReferenceFormat = C.GIT_REFERENCE_FORMAT_REFSPEC_SHORTHAND +) + +// ReferenceNormalizeName normalizes the reference name and checks validity. +// +// This will normalize the reference name by removing any leading slash '/' +// characters and collapsing runs of adjacent slashes between name components +// into a single slash. +// +// See git_reference_symbolic_create() for rules about valid names. +func ReferenceNormalizeName(name string, flags ReferenceFormat) (string, error) { + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + + buf := (*C.char)(C.malloc(_refnameMaxLength)) + defer C.free(unsafe.Pointer(buf)) + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + ecode := C.git_reference_normalize_name(buf, _refnameMaxLength, cname, C.uint(flags)) + if ecode < 0 { + return "", MakeGitError(ecode) + } + + return C.GoString(buf), nil +} diff --git a/reference_test.go b/reference_test.go index b6721e1..e42db41 100644 --- a/reference_test.go +++ b/reference_test.go @@ -224,6 +224,29 @@ func TestReferenceIsValidName(t *testing.T) { } } +func TestReferenceNormalizeName(t *testing.T) { + t.Parallel() + + ref, err := ReferenceNormalizeName("refs/heads//master", ReferenceFormatNormal) + checkFatal(t, err) + + if ref != "refs/heads/master" { + t.Errorf("ReferenceNormalizeName(%q) = %q; want %q", "refs/heads//master", ref, "refs/heads/master") + } + + ref, err = ReferenceNormalizeName("master", ReferenceFormatAllowOnelevel|ReferenceFormatRefspecShorthand) + checkFatal(t, err) + + if ref != "master" { + t.Errorf("ReferenceNormalizeName(%q) = %q; want %q", "master", ref, "master") + } + + ref, err = ReferenceNormalizeName("foo^", ReferenceFormatNormal) + if !IsErrorCode(err, ErrInvalidSpec) { + t.Errorf("foo^ should be invalid") + } +} + func compareStringList(t *testing.T, expected, actual []string) { for i, v := range expected { if actual[i] != v {