Add function to process several records at once
2 files changed, 180 insertions(+), 85 deletions(-)

M db.go
M db_test.go
M db.go +48 -22
@@ 17,6 17,7 @@ type Db interface {
 	UpdateInvites(string, []string) error
 	Invited(string, string) (bool, error)
 	Error() error
+	ForEachRecord(func(row []string)) error
 }
 
 type DbStats struct {

          
@@ 26,6 27,7 @@ type DbStats struct {
 type db struct {
 	path string
 	err  error
+	lockV *flock.Flock
 }
 
 func OpenDb(path string) Db {

          
@@ 60,25 62,35 @@ func (d db) Error() error {
 	return d.err
 }
 
+// ForEachRecord calls a supplied function for each row in the database. If the
+// function returns false, the row is not persisted.
+func (d db) ForEachRecord(cb func (row []string)) error {
+	err := d.lock()
+	if err != nil {
+		return err
+	}
+	defer d.unlock()
+
+	rs, err := d.readDb()
+	if err != nil {
+		return err
+	}
+	for _, r := range rs {
+		cb(r)
+	}
+	return d.save(rs)
+}
+
 // persistInvites saves the invite list of the ICS UUID to the database.
 // It replaces any existing record for that UUID.
 func (d db) UpdateInvites(id string, as []string) error {
 	sas := strings.Join(as, ",")
 	// Lock the database
-	ln := d.path + ".lock"
-	lock := flock.New(ln)
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-	defer cancel()
-	success, err := lock.TryLockContext(ctx, 100*time.Millisecond)
-	if !success {
-		msg := fmt.Sprintf("Database appears to be locked by another process. Try removing %s\n", ln)
-		return errors.New(msg)
-	}
+	err := d.lock()
 	if err != nil {
 		return err
 	}
-	defer lock.Unlock()
-	defer os.Remove(lock.Path())
+	defer d.unlock()
 
 	rs, err := d.readDb()
 	if err != nil {

          
@@ 107,20 119,11 @@ func (d db) UpdateInvites(id string, as 
 
 func (d db) Invited(id, email string) (bool, error) {
 	// Lock the database
-	ln := d.path + ".lock"
-	lock := flock.New(ln)
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-	defer cancel()
-	success, err := lock.TryLockContext(ctx, 100*time.Millisecond)
-	if !success {
-		msg := fmt.Sprintf("Database appears to be locked by another process. Try removing %s\n", ln)
-		return false, errors.New(msg)
-	}
+	err := d.lock()
 	if err != nil {
 		return false, err
 	}
-	defer lock.Unlock()
-	defer os.Remove(lock.Path())
+	defer d.unlock()
 
 	rs, err := d.readDb()
 	if err != nil {

          
@@ 142,6 145,27 @@ func (d db) Invited(id, email string) (b
 	return false, nil
 }
 
+func (d *db) lock() error {
+	ln := d.path + ".lock"
+	d.lockV = flock.New(ln)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	defer cancel()
+	success, err := d.lockV.TryLockContext(ctx, 100*time.Millisecond)
+	if !success {
+		msg := fmt.Sprintf("Database appears to be locked by another process. Try removing %s\n", d.lockV.Path())
+		return errors.New(msg)
+	}
+	return err
+}
+
+func (d *db) unlock() {
+	d.lockV.Unlock()
+	os.Remove(d.lockV.Path())
+	d.lockV = nil
+}
+
+// save persists rows of data, overwriting any existing data. It expects the
+// database to be already locked.
 func (d db) save(rs [][]string) error {
 	// Write the DB back out
 	foutn := d.path + ".tmp"

          
@@ 161,6 185,8 @@ func (d db) save(rs [][]string) error {
 	return nil
 }
 
+// readDb reads data out of a database. It expects the database to be already
+// locked.
 func (d db) readDb() ([][]string, error) {
 	// Read the DB
 	fin, err := os.Open(d.path)

          
M db_test.go +132 -63
@@ 4,6 4,9 @@ import (
 	"io/ioutil"
 	"os"
 	"reflect"
+	"sort"
+	"strconv"
+	"strings"
 	"testing"
 )
 

          
@@ 41,6 44,31 @@ func Test_db_Stats(t *testing.T) {
 	}
 }
 
+var lines []string = []string{
+		"0\ta@b.c",
+		"1\tw@v.u",
+		"2\tt@s.r,q@p.o",
+		"3\tq@r.s,a@b.c,d@e.f",
+		"4\tz@y.x",
+		"5\ta@b.c",
+		"6\ta@b.c",
+	}
+
+func setupTestDb(lns []string) (name string, err error) {
+	var rv string
+	for _, l := range lns {
+		rv += l + "\n"
+	}
+	dbName := "testdb"
+	fout, err := os.Create(dbName)
+	if err != nil {
+		return "", err
+	}
+	fout.WriteString(rv)
+	fout.Close()
+	return dbName, nil
+}
+
 func Test_db_UpdateInvites(t *testing.T) {
 	type args struct {
 		id string

          
@@ 48,80 76,74 @@ func Test_db_UpdateInvites(t *testing.T)
 	}
 	tests := []struct {
 		name    string
-		fields  []string
+		initDb    []string
 		args    []args
-		want    string
+		want    []string
 		wantErr bool
 	}{
 		{"add one new to new DB",
 			[]string{},
-			[]args{args{"1", []string{"a@b.c"}}},
-			"1\ta@b.c\n",
+			[]args{args{"0", []string{"a@b.c"}}},
+			lines[0:1],
 			false},
-		{"add several new to new DB",
+		{"add several new to new DB", 
 			[]string{},
 			[]args{
-				args{"1", []string{"a@b.c"}},
-				args{"2", []string{"d@e.f", "g@h.i"}},
-				args{"3", []string{"j@k.l"}},
+				args{"0", []string{"a@b.c"}},
+				args{"1", []string{"w@v.u"}},
+				args{"2", []string{"t@s.r,q@p.o"}},
 			},
-			"1\ta@b.c\n2\td@e.f,g@h.i\n3\tj@k.l\n",
+			lines[0:3],
 			false},
 		{"add one new to old DB",
-			[]string{
-				"A\tz@y.x",
-				"B\tw@v.u",
-				"C\tt@s.r,q@p.o",
-			},
-			[]args{args{"1", []string{"a@b.c"}}},
-			"A\tz@y.x\nB\tw@v.u\nC\tt@s.r,q@p.o\n1\ta@b.c\n",
+			lines[1:7],
+			[]args{args{"0", []string{"a@b.c"}}},
+			lines,
 			false},
 		{"add several new to old DB",
-			[]string{
-				"A\tz@y.x",
-				"B\tw@v.u",
-				"C\tt@s.r,q@p.o",
-			},
+			lines[3:7],
 			[]args{
-				args{"1", []string{"a@b.c"}},
-				args{"2", []string{"d@e.f", "g@h.i"}},
-				args{"3", []string{"j@k.l"}},
+				args{"0", []string{"a@b.c"}},
+				args{"1", []string{"w@v.u"}},
+				args{"2", []string{"t@s.r,q@p.o"}},
 			},
-			"A\tz@y.x\nB\tw@v.u\nC\tt@s.r,q@p.o\n1\ta@b.c\n2\td@e.f,g@h.i\n3\tj@k.l\n",
+			lines,
+			false},
+		{"add several new to old DB, no repeat",
+			lines[2:7],
+			[]args{
+				args{"0", []string{"a@b.c"}},
+				args{"1", []string{"w@v.u"}},
+				args{"2", []string{"t@s.r,q@p.o"}},
+			},
+			lines,
 			false},
 		{"change one",
+			lines,
+			[]args{args{"1", []string{"x@y.z"}}},
 			[]string{
-				"A\tz@y.x",
-				"1\tw@v.u",
-				"C\tt@s.r,q@p.o",
+				lines[0], "1\tx@y.z", lines[2],lines[3],lines[4],lines[5],lines[6],
 			},
-			[]args{args{"1", []string{"a@b.c"}}},
-			"A\tz@y.x\n1\ta@b.c\nC\tt@s.r,q@p.o\n",
 			false},
 		{"change several",
-			[]string{
-				"A\tz@y.x",
-				"3\tw@v.u",
-				"C\tt@s.r,q@p.o",
-				"2\tw@v.u",
-			},
+			lines[:5],
 			[]args{
 				args{"2", []string{"d@e.f", "g@h.i"}},
 				args{"3", []string{"j@k.l"}},
 			},
-			"A\tz@y.x\n3\tj@k.l\nC\tt@s.r,q@p.o\n2\td@e.f,g@h.i\n",
+			[]string{
+				lines[0], lines[1], "2\td@e.f,g@h.i", "3\tj@k.l",lines[4],
+			},
 			false},
 		{"change and add",
-			[]string{
-				"A\tz@y.x",
-				"3\tw@v.u",
-				"C\tt@s.r,q@p.o",
+			lines[:3],
+			[]args{
+				args{"1", []string{"x@y.z"}},
+				args{"3", []string{"j@k.l"}},
 			},
-			[]args{
-				args{"2", []string{"a@b.c"}},
-				args{"3", []string{"d@e.f"}},
+			[]string{
+				lines[0], "1\tx@y.z",lines[2],"3\tj@k.l",
 			},
-			"A\tz@y.x\n3\td@e.f\nC\tt@s.r,q@p.o\n2\ta@b.c\n",
 			false},
 	}
 	dbName := "testdb"

          
@@ 129,35 151,43 @@ func Test_db_UpdateInvites(t *testing.T)
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			// Set up the test DB
-			fout, err := os.Create(dbName)
+			dbName, err := setupTestDb(tt.initDb)
 			if err != nil {
 				t.Errorf("error setting up db %s", err)
 			}
-			// Clean up test data
 			defer os.Remove(dbName)
-			for _, row := range tt.fields {
-				fout.WriteString(row + "\n")
-			}
-			fout.Close()
 			d := OpenDb(dbName)
 			for _, inv := range tt.args {
 				if err := d.UpdateInvites(inv.id, inv.as); (err != nil) != tt.wantErr {
 					t.Errorf("db.UpdateInvites() %s error = %v, wantErr %v", tt.name, err, tt.wantErr)
 				}
 			}
-			if got, err := ioutil.ReadFile(dbName); (err != nil) || (string(got) != tt.want) {
-				t.Errorf("db.UpdateInvites() `%s`; error = %v\nwant:\n`%v`\ngot:\n`%v`", tt.name, err, tt.want, string(got))
+			got, err := ioutil.ReadFile(dbName)
+			if (err != nil) {
+				t.Errorf("`%s`: error = %v", tt.name, err)
+			}
+			gots := strings.Split(strings.TrimSpace(string(got)), "\n")
+			if len(gots) > len(tt.want) {
+				t.Errorf("`%s`; wants: %v, gots: %v", tt.name, tt.want, gots)
+			}
+			sort.Strings(gots)
+			sort.Strings(tt.want)
+			for i, g := range gots {
+				if g != tt.want[i] {
+					t.Errorf("db.UpdateInvites() `%s`; want: %v, got: %v\ntt: %+v", tt.name, tt.want, gots, tt)
+					break
+				}
 			}
 		})
 	}
 }
 
-// TODO test locking
+// TODO test locking (ensure locking works)
 // TODO test malformed DB
 
 func TestInvited(t *testing.T) {
 	type args struct {
-		id string
+		id    string
 		email string
 	}
 	tests := []struct {

          
@@ 166,23 196,19 @@ func TestInvited(t *testing.T) {
 		want    bool
 		wantErr bool
 	}{
-		{"yes invited one", args{"2", "w@v.u"}, true, false},
+		{"yes invited one", args{"4", "z@y.x"}, true, false},
 		{"not invited", args{"2", "a@b.c"}, false, false},
-		{"yes invited many", args{"3", "q@p.o"}, true, false},
+		{"yes invited many", args{"3", "d@e.f"}, true, false},
 	}
 
 	// Set up the test DB
-	dbName := "testdb"
-	fout, err := os.Create(dbName)
+	dbName, err := setupTestDb(lines)
 	if err != nil {
 		t.Errorf("error setting up db %s", err)
 	}
-	// Clean up test data
 	defer os.Remove(dbName)
-	fout.WriteString("1\tz@y.x\n2\tw@v.u\n3\tt@s.r,q@p.o\n4\tq@r.s,a@b.c,d@e.f\n")
-	fout.Close()
+	d := OpenDb(dbName)
 
-	d := OpenDb(dbName)
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			got, err := d.Invited(tt.args.id, tt.args.email)

          
@@ 196,3 222,46 @@ func TestInvited(t *testing.T) {
 		})
 	}
 }
+
+
+func Test_db_ForEachRecord(t *testing.T) {
+	var count int
+	var success bool
+	tests := []struct {
+		name    string
+		cb func([]string)
+		wantErr bool
+		wantTest func() bool
+	}{
+		{"count", func(row []string){ count++ }, false, func() bool { return count == len(lines) }},
+		{"index is id", func(row []string){
+			id, err := strconv.Atoi(row[0])
+			success = success && err == nil && id == count
+			count++
+		}, false, func() bool {
+			return success
+		}},
+
+	}
+
+	// Set up the test DB
+	dbName, err := setupTestDb(lines)
+	if err != nil {
+		t.Errorf("error setting up db %s", err)
+	}
+	defer os.Remove(dbName)
+	d := OpenDb(dbName)
+
+	for _, tt := range tests {
+		t.Run  (tt.name, func(t *testing.T) {
+			count = 0
+			success = true
+			if err := d.ForEachRecord(tt.cb); (err != nil) != tt.wantErr {
+				t.Errorf("db.ForEachRecord() %s error = %v, wantErr %v", tt.name, err, tt.wantErr)
+			}
+			if !tt.wantTest() {
+				t.Errorf("%s test failed", tt.name)
+			}
+		})
+	}
+}