Luhn error checking

This commit is contained in:
Jakob Borg
2014-07-04 16:16:50 +02:00
parent c488179783
commit 3d7d4d845a
3 changed files with 70 additions and 11 deletions

View File

@@ -1,7 +1,10 @@
// Package luhn generates and validates Luhn mod N check digits. // Package luhn generates and validates Luhn mod N check digits.
package luhn package luhn
import "strings" import (
"fmt"
"strings"
)
// An alphabet is a string of N characters, representing the digits of a given // An alphabet is a string of N characters, representing the digits of a given
// base N. // base N.
@@ -13,13 +16,20 @@ var (
// Generate returns a check digit for the string s, which should be composed // Generate returns a check digit for the string s, which should be composed
// of characters from the Alphabet a. // of characters from the Alphabet a.
func (a Alphabet) Generate(s string) rune { func (a Alphabet) Generate(s string) (rune, error) {
if err:=a.check();err!=nil{
return 0,err
}
factor := 1 factor := 1
sum := 0 sum := 0
n := len(a) n := len(a)
for i := range s { for i := range s {
codepoint := strings.IndexByte(string(a), s[i]) codepoint := strings.IndexByte(string(a), s[i])
if codepoint == -1 {
return 0, fmt.Errorf("Digit %q not valid in alphabet %q", s[i], a)
}
addend := factor * codepoint addend := factor * codepoint
if factor == 2 { if factor == 2 {
factor = 1 factor = 1
@@ -31,13 +41,28 @@ func (a Alphabet) Generate(s string) rune {
} }
remainder := sum % n remainder := sum % n
checkCodepoint := (n - remainder) % n checkCodepoint := (n - remainder) % n
return rune(a[checkCodepoint]) return rune(a[checkCodepoint]), nil
} }
// Validate returns true if the last character of the string s is correct, for // Validate returns true if the last character of the string s is correct, for
// a string s composed of characters in the alphabet a. // a string s composed of characters in the alphabet a.
func (a Alphabet) Validate(s string) bool { func (a Alphabet) Validate(s string) bool {
t := s[:len(s)-1] t := s[:len(s)-1]
c := a.Generate(t) c, err := a.Generate(t)
if err != nil {
return false
}
return rune(s[len(s)-1]) == c return rune(s[len(s)-1]) == c
} }
// check returns an error if the given alphabet does not consist of unique characters
func (a Alphabet) check() error {
cm := make(map[byte]bool, len(a))
for i := range a {
if cm[a[i]] {
return fmt.Errorf("Digit %q non-unique in alphabet %q", a[i], a)
}
cm[a[i]] = true
}
return nil
}

View File

@@ -9,19 +9,43 @@ import (
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {
// Base 6 Luhn // Base 6 Luhn
a := luhn.Alphabet("abcdef") a := luhn.Alphabet("abcdef")
c := a.Generate("abcdef") c, err := a.Generate("abcdef")
if err != nil {
t.Fatal(err)
}
if c != 'e' { if c != 'e' {
t.Errorf("Incorrect check digit %c != e", c) t.Errorf("Incorrect check digit %c != e", c)
} }
// Base 10 Luhn // Base 10 Luhn
a = luhn.Alphabet("0123456789") a = luhn.Alphabet("0123456789")
c = a.Generate("7992739871") c, err = a.Generate("7992739871")
if err != nil {
t.Fatal(err)
}
if c != '3' { if c != '3' {
t.Errorf("Incorrect check digit %c != 3", c) t.Errorf("Incorrect check digit %c != 3", c)
} }
} }
func TestInvalidString(t *testing.T) {
a := luhn.Alphabet("ABC")
_, err := a.Generate("7992739871")
t.Log(err)
if err == nil {
t.Error("Unexpected nil error")
}
}
func TestBadAlphabet(t *testing.T) {
a := luhn.Alphabet("01234566789")
_, err := a.Generate("7992739871")
t.Log(err)
if err == nil {
t.Error("Unexpected nil error")
}
}
func TestValidate(t *testing.T) { func TestValidate(t *testing.T) {
a := luhn.Alphabet("abcdef") a := luhn.Alphabet("abcdef")
if !a.Validate("abcdefe") { if !a.Validate("abcdefe") {

View File

@@ -34,7 +34,11 @@ func NodeIDFromString(s string) (NodeID, error) {
func (n NodeID) String() string { func (n NodeID) String() string {
id := base32.StdEncoding.EncodeToString(n[:]) id := base32.StdEncoding.EncodeToString(n[:])
id = strings.Trim(id, "=") id = strings.Trim(id, "=")
id = luhnify(id) id, err := luhnify(id)
if err != nil {
// Should never happen
panic(err)
}
id = chunkify(id) id = chunkify(id)
return id return id
} }
@@ -84,7 +88,7 @@ func (n *NodeID) UnmarshalText(bs []byte) error {
} }
} }
func luhnify(s string) string { func luhnify(s string) (string, error) {
if len(s) != 52 { if len(s) != 52 {
panic("unsupported string length") panic("unsupported string length")
} }
@@ -92,10 +96,13 @@ func luhnify(s string) string {
res := make([]string, 0, 4) res := make([]string, 0, 4)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
p := s[i*13 : (i+1)*13] p := s[i*13 : (i+1)*13]
l := luhn.Base32.Generate(p) l, err := luhn.Base32.Generate(p)
if err != nil {
return "", err
}
res = append(res, fmt.Sprintf("%s%c", p, l)) res = append(res, fmt.Sprintf("%s%c", p, l))
} }
return res[0] + res[1] + res[2] + res[3] return res[0] + res[1] + res[2] + res[3], nil
} }
func unluhnify(s string) (string, error) { func unluhnify(s string) (string, error) {
@@ -106,7 +113,10 @@ func unluhnify(s string) (string, error) {
res := make([]string, 0, 4) res := make([]string, 0, 4)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
p := s[i*14 : (i+1)*14-1] p := s[i*14 : (i+1)*14-1]
l := luhn.Base32.Generate(p) l, err := luhn.Base32.Generate(p)
if err != nil {
return "", err
}
if g := fmt.Sprintf("%s%c", p, l); g != s[i*14:(i+1)*14] { if g := fmt.Sprintf("%s%c", p, l); g != s[i*14:(i+1)*14] {
log.Printf("%q; %q", g, s[i*14:(i+1)*14]) log.Printf("%q; %q", g, s[i*14:(i+1)*14])
return "", errors.New("check digit incorrect") return "", errors.New("check digit incorrect")