Fix tests

Rename Relay to Demux
3 files changed, 148 insertions(+), 103 deletions(-)

M cmd/kgp_client/main.go
M kgp.go
M kgp_test.go
M cmd/kgp_client/main.go +5 -5
@@ 1,6 1,7 @@ 
 package main
 
 import (
+	"context"
 	"flag"
 	"fmt"
 	"io"

          
@@ 9,7 10,6 @@ import (
 	"os"
 	"os/signal"
 	"strings"
-	"context"
 
 	"bitbucket.org/henry/kgp"
 )

          
@@ 109,13 109,13 @@ func main() {
 	go func() {
 		var sigchan = make(chan os.Signal, 1)
 		signal.Notify(sigchan, os.Interrupt)
-		<-sigchan // wait for signal to come in
+		<-sigchan                  // wait for signal to come in
 		signal.Reset(os.Interrupt) // Let go handle further signals
 		logger.Print("stopping traffic")
 		cancel()
 	}()
 
-	var relay = kgp.NewRelay(
+	var demux = kgp.NewDemux(
 		upconn,
 		func(c context.Context) (io.ReadWriteCloser, error) {
 			return (&net.Dialer{}).DialContext(

          
@@ 123,6 123,6 @@ func main() {
 			)
 		},
 	)
-	relay.Logger = logger
-	fmt.Println(relay.Run(ctx, lastSeenSeq+1))
+	demux.Logger = logger
+	fmt.Println(demux.Run(ctx, lastSeenSeq+1))
 }

          
M kgp.go +55 -49
@@ 91,7 91,7 @@ func ReadAnnouncement(reader io.Reader) 
 }
 
 // Return the last seen packet from the last client connected to this tunnel.
-// Use this as the sequence number + 1 with Relay.Run()
+// Use this as the sequence number + 1 with Demux.Run()
 func Handshake(conn io.ReadWriter, meta io.Reader) (lastSeenSeq uint32, err error) {
 	var my = Announcement{
 		Metadata: meta,

          
@@ 120,17 120,18 @@ func Handshake(conn io.ReadWriter, meta 
 type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)
 
 //
-// Relay between server and the remote with the connect function.
+// Demux KGP connection from server and local clients via the Demux.Connect()
+// function.
 //
 // Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you
 // can log what's happening by setting Logger:
 //
-// 		var r = NewRelay(...)
+// 		var r = NewDemux(...)
 //		r.RxSize = 1
 //		r.TxSize = 100
 //		r.Logger = log.New(os.Stderr, "kgp_relay", 0)
 //
-type Relay struct {
+type Demux struct {
 	RxSize int
 	TxSize int
 

          
@@ 146,9 147,9 @@ type Relay struct {
 	rxPackets, txPackets uint64
 }
 
-// Creates a new relay with RxSize & TxSize set to 4.
-func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay {
-	return &Relay{
+// Creates a new demuxer with RxSize & TxSize set to 4.
+func NewDemux(server io.ReadWriter, connect ConnectFunc) *Demux {
+	return &Demux{
 		RxSize:     4,
 		TxSize:     4,
 		server:     server,

          
@@ 157,13 158,17 @@ func NewRelay(server io.ReadWriter, conn
 	}
 }
 
-func (relay *Relay) printf(format string, v ...interface{}) {
-	if relay.Logger != nil {
-		relay.Logger.Printf(format, v...)
+func (demux *Demux) printf(format string, v ...interface{}) {
+	if demux.Logger != nil {
+		demux.Logger.Printf(format, v...)
 	}
 }
 
-func (relay *Relay) newChannel(
+func (demux *Demux) ChannelCount() int {
+	return demux.dispatcher.Len()
+}
+
+func (demux *Demux) newChannel(
 	ctx context.Context,
 	fail func(error),
 	id uint32,

          
@@ 172,19 177,19 @@ func (relay *Relay) newChannel(
 	connect ConnectFunc,
 ) {
 	// Try to connect to the local host
-	var conn, err = relay.connect(ctx)
+	var conn, err = demux.connect(ctx)
 	if err != nil {
-		relay.printf("can't connect to downstream, closing channel %d", id)
+		demux.printf("can't connect to downstream, closing channel %d", id)
 		// we can't connect to downstream, close the connection.
 		output <- NewPacketClose(id)
 		return
 	}
 
-	relay.printf("Channel %d connected", id)
+	demux.printf("Channel %d connected", id)
 
 	var connClose = func() {
 		if err := conn.Close(); err != nil {
-			relay.printf(
+			demux.printf(
 				"WARNING: error while closing channel %d: %s", id, err,
 			)
 		}

          
@@ 195,25 200,25 @@ func (relay *Relay) newChannel(
 		for {
 			select {
 			case <-ctx.Done():
-				relay.printf("channel %d writer done", id)
+				demux.printf("channel %d writer done", id)
 				connClose()
 				return
 			case buf := <-input:
 				if buf == nil {
 					// channel's closed, the server side has closed, therefor
 					// we close the connection on upstream.
-					relay.printf("channel %d closed remotely", id)
+					demux.printf("channel %d closed remotely", id)
 					connClose()
 					return
 				}
 				var n, err = conn.Write(buf)
-				relay.printf("wrote %d bytes to channel %d", n, id)
+				demux.printf("wrote %d bytes to channel %d", n, id)
 				if err == io.EOF {
-					relay.printf("channel %d closed", id)
+					demux.printf("channel %d closed", id)
 					// nothing else to read, we're done
 					return
 				} else if err != nil {
-					relay.printf("ERROR: writing to channel %d: %s", id, err)
+					demux.printf("ERROR: writing to channel %d: %s", id, err)
 					connClose()
 					return
 				}

          
@@ 226,14 231,14 @@ func (relay *Relay) newChannel(
 		for {
 			var p = NewPacket()
 			var n, err = conn.Read(p.buffer)
-			relay.printf("read %d bytes from channel %d", n, id)
+			demux.printf("read %d bytes from channel %d", n, id)
 			if err == io.EOF {
-				relay.printf("channel %d closed", id)
+				demux.printf("channel %d closed", id)
 				output <- NewPacketClose(id)
 				return
 			}
 			if err != nil {
-				relay.printf(
+				demux.printf(
 					"ERROR: reading from channel %d: %s",
 					id, err.Error(),
 				)

          
@@ 255,20 260,20 @@ func (relay *Relay) newChannel(
 //
 // Read packets from server and put them in the returned channel
 //
-func (relay *Relay) readServer(
+func (demux *Demux) readServer(
 	done func() <-chan struct{},
 	fail func(error),
 	ack *atomic.Value,
 ) <-chan *Packet {
-	var output = make(chan *Packet, relay.RxSize)
+	var output = make(chan *Packet, demux.RxSize)
 
 	go func() {
 		for {
 			var p = NewPacket()
-			var _, err = p.ReadFrom(relay.server)
-			// relay.printf("read %d bytes from upstream", n)
+			var _, err = p.ReadFrom(demux.server)
+			// demux.printf("read %d bytes from upstream", n)
 			if err == io.EOF {
-				relay.printf("upstream closed connection")
+				demux.printf("upstream closed connection")
 				close(output)
 				return
 			} else if err != nil {

          
@@ 277,8 282,9 @@ func (relay *Relay) readServer(
 				return
 			}
 			ack.Store(p.header.Sequence)
-			relay.rxBytes += p.Size()
-			relay.rxPackets += 1
+			fmt.Println(ack.Load())
+			demux.rxBytes += p.Size()
+			demux.rxPackets += 1
 			// Send the new packet to output and check if we're done
 			select {
 			case <-done():

          
@@ 288,7 294,7 @@ func (relay *Relay) readServer(
 				// wait for packet to be sent
 			}
 		}
-		relay.printf("readServer done")
+		demux.printf("readServer done")
 	}()
 
 	return output

          
@@ 297,14 303,14 @@ func (relay *Relay) readServer(
 /*
  * Forward data from input to conn
  */
-func (relay *Relay) writeServer(
+func (demux *Demux) writeServer(
 	seq uint32,
 	done func() <-chan struct{},
 	fail func(error),
 	getAck func() uint32,
 ) chan<- *Packet {
 	// output writes packets to the server
-	var input = make(chan *Packet, relay.TxSize)
+	var input = make(chan *Packet, demux.TxSize)
 
 	go func() {
 		// We continously read what the channels send us on output and write it to

          
@@ 323,8 329,8 @@ func (relay *Relay) writeServer(
 					var n int64
 					p.header.Sequence = seq
 					p.header.Ack = getAck()
-					n, err = p.WriteTo(relay.server)
-					relay.printf("wrote %d bytes upstream", n)
+					n, err = p.WriteTo(demux.server)
+					demux.printf("wrote %d bytes upstream", n)
 					seq += 1
 					if seq == 0 {
 						seq = 1

          
@@ 333,11 339,11 @@ func (relay *Relay) writeServer(
 					// Forward the packet to the server as is
 					p.header.Ack = seq
 					// var n int64
-					_, err = p.WriteTo(relay.server)
-					// relay.printf("wrote %d bytes upstream for keepalive", n)
+					_, err = p.WriteTo(demux.server)
+					// demux.printf("wrote %d bytes upstream for keepalive", n)
 				}
 				if err != nil {
-					relay.printf("error writing packet %s", err)
+					demux.printf("error writing packet %s", err)
 					fail(err)
 				}
 			}

          
@@ 372,7 378,7 @@ func keepAlive(
 	}
 }
 
-func (relay *Relay) Run(ctx context.Context, seq uint32) error {
+func (demux *Demux) Run(ctx context.Context, seq uint32) error {
 	// Keeps track of the last sequence number from the server
 	var ack atomic.Value
 	ack.Store(uint32(0))

          
@@ 385,7 391,7 @@ func (relay *Relay) Run(ctx context.Cont
 	// buffered channel to make sure the fail function doesn't block
 	var reportErr = make(chan error, 1)
 	var fail = func(err error) {
-		relay.printf("fail %v", err)
+		demux.printf("fail %v", err)
 		reportErr <- err
 		close(reportErr)
 		cancelFunc()

          
@@ 393,10 399,10 @@ func (relay *Relay) Run(ctx context.Cont
 	var done = cancelCtx.Done
 
 	// input reads packets from the server
-	var input = relay.readServer(done, fail, &ack)
+	var input = demux.readServer(done, fail, &ack)
 	var getAck = func() uint32 { return ack.Load().(uint32) }
 	// output writes packets to the server
-	var output = relay.writeServer(seq, done, fail, getAck)
+	var output = demux.writeServer(seq, done, fail, getAck)
 
 	// We each packets from the server and handle them depending on their type:
 	// 	   open connection: we need to open new connection

          
@@ 427,19 433,19 @@ func (relay *Relay) Run(ctx context.Cont
 					// half-close, we'll no longer receive data from the server
 					// side, but still need to forward what the client side
 					// sends us. Ignore it for now.
-					relay.printf("half-close!")
+					demux.printf("half-close!")
 				} else {
-					relay.printf("channel %d closed remotely", p.ConnectionId)
-					relay.dispatcher.Close(p.header.ConnectionId)
+					demux.printf("channel %d closed remotely", p.ConnectionId)
+					demux.dispatcher.Close(p.header.ConnectionId)
 				}
 			case forwardPacket:
 				var cid = p.header.ConnectionId
-				var c = relay.dispatcher.Get(cid)
+				var c = demux.dispatcher.Get(cid)
 				if c == nil {
-					relay.printf("channel %d opened remotely", cid)
-					var c = relay.dispatcher.Open(cid)
+					demux.printf("channel %d opened remotely", cid)
+					var c = demux.dispatcher.Open(cid)
 
-					go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
+					go demux.newChannel(cancelCtx, fail, cid, c, output, demux.connect)
 				} else {
 					c <- p.Payload
 				}

          
M kgp_test.go +88 -49
@@ 1,11 1,14 @@ 
 package kgp
 
 import (
+	"log"
+	"os"
 	"bytes"
 	"context"
 	"io"
 	"io/ioutil"
 	"net"
+	"strings"
 	"testing"
 	"time"
 )

          
@@ 89,19 92,19 @@ func (rec recorder) Close() error {
 	return nil
 }
 
-func waitChannelClose(
-	done func() <-chan struct{},
-	relay *Relay,
-) {
-waitloop:
+func waitChannelCountEqual(
+	n int,
+	demux *Demux,
+	done func() <-chan struct{}, // just-in-case demux stops
+) bool {
 	for {
 		select {
 		case <-done():
-			break waitloop
+			return false
 		default:
 			time.Sleep(time.Millisecond)
-			if relay.dispatcher.Len() == 0 {
-				break waitloop
+			if demux.ChannelCount() == n {
+				return true
 			}
 		}
 	}

          
@@ 110,12 113,12 @@ waitloop:
 func equalBuffer(t *testing.T, result io.Reader, expected []byte) bool {
 	var r, err = ioutil.ReadAll(result)
 	if err != nil {
-		t.Fatalf("Couldn't read clientOutput: %s", err)
+		t.Errorf("Couldn't read clientOutput: %s", err)
 		return false
 	}
 
 	if !bytes.Equal(r, expected) {
-		t.Fatalf("%v != %v", r, expected)
+		t.Errorf("\n%v\n!=\n%v", r, expected)
 		return false
 	}
 

          
@@ 140,7 143,7 @@ func TestRelayRequest(t *testing.T) {
 		return client, nil
 	}
 
-	var relay = NewRelay(server, connect)
+	var demux = NewDemux(server, connect)
 	var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
 	// We execute the client in the background while relay.Run() is running
 	go func() {

          
@@ 172,16 175,21 @@ func TestRelayRequest(t *testing.T) {
 		}
 
 		// Wait for all channels to close
-		waitChannelClose(ctx.Done, relay)
+		waitChannelCountEqual(0, demux, ctx.Done)
+
 		// Cancel relay to check result
 		cancel()
 	}()
-	var err = relay.Run(ctx, 1)
+	var err = demux.Run(ctx, 1)
+	// uncomment for more logging
+	// demux.Logger = log.New(os.Stderr, "", 0)
 	if err != context.Canceled {
 		t.Errorf("relay.Run errored: %s", err)
 	}
 	var expected = []byte("hello")
-	equalBuffer(t, clientOutput, expected)
+	if !equalBuffer(t, clientOutput, expected) {
+		t.Fatalf("failed")
+	}
 }
 
 func writeResponse(payload string, connectionId, seq, ack uint32) (r bytes.Buffer) {

          
@@ 225,18 233,17 @@ func equalResponse(t *testing.T, result 
 }
 
 func TestRelayRequestResponse(t *testing.T) {
-	var inputReader, inputWriter = io.Pipe()
-	var serverOutput = &bytes.Buffer{}
+	var requestReader, requestWriter = io.Pipe()
+	var responseReader, responseWriter = io.Pipe()
 	var server = recorder{
-		Input:  inputReader,
-		Output: serverOutput,
+		Input:  requestReader,
+		Output: responseWriter,
 	}
 
-	// Client's response is world
-	var clientInput = bytes.NewBufferString("world")
 	var clientOutput = &bytes.Buffer{}
 	var client = recorder{
-		Input:  clientInput,
+		// Client's response is world
+		Input:  strings.NewReader("world"),
 		Output: clientOutput,
 	}
 

          
@@ 244,41 251,73 @@ func TestRelayRequestResponse(t *testing
 		return client, nil
 	}
 
-	var relay = NewRelay(server, connect)
+	var demux = NewDemux(server, connect)
 	// Test has up to 1 sec to complete
 	var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
 
-	// We execute the client in the background while relay.Run() is running
+	// Execute relay in the background
 	go func() {
-		// open connection 1
-		var p Packet
-		p.header.ConnectionId = 1
-		p.SetSeqAck(1, 0)
-		p.WriteTo(inputWriter)
+		// uncomment for more logging
+		// demux.Logger = log.New(os.Stderr, "", 0)
+		var err = demux.Run(ctx, 1)
+
+		if err != context.Canceled {
+			t.Fatalf("relay.Run errored: %s", err)
+		}
+	}()
 
-		// Write "hello"
-		p.Payload = []byte("hello")
-		p.SetSeqAck(2, 0)
-		p.WriteTo(inputWriter)
+	// open connection 1
+	var p Packet
+	p.header.ConnectionId = 1
+	p.SetSeqAck(1, 0)
+	p.WriteTo(requestWriter)
+
+	// Wait for the channel to open
+	waitChannelCountEqual(1, demux, ctx.Done)
 
-		// Close connection
-		p.Payload = []byte{}
-		p.header.Control = 1
-		p.SetSeqAck(3, 0)
-		p.WriteTo(inputWriter)
+	// Write "hello"
+	p.Payload = []byte("hello")
+	p.SetSeqAck(2, 0)
+	p.WriteTo(requestWriter)
 
-		// Wait for the channel to close then check the result
-		waitChannelClose(ctx.Done, relay)
-
-		cancel()
-	}()
-	var err = relay.Run(ctx, 1)
-	if err != context.Canceled {
-		t.Errorf("relay.Run errored: %s", err)
+	// we know ack will 2 because don't send the close connection until we've
+	// read this.
+	var expected = []byte{
+		0, 0, 0, 1, // forward packet connection id 1
+		0, 0, 0, 1, // seq 1
+		0, 0, 0, 1, // ack 1, because the client started writing before we the
+		            // request with seq was still at 1
+		0, 0, // control 0
+		0, 5, // len 5
+		119, 111, 114, 108, 100, // "hello"
+		0, 0, 0, 1, // close packet connection id 1
+		0, 0, 0, 2, // seq 2
+		0, 0, 0, 2, // ack 2
+		0, 1, // control 1
+		0, 0, // len 0
 	}
 
-	var expected = []byte("hello")
-	equalBuffer(t, clientOutput, expected)
-	expected = []byte("world")
-	equalResponse(t, serverOutput, []byte("world"))
+	var buf = make([]byte, len(expected))
+	if _, err := io.ReadFull(responseReader, buf); err != nil {
+		t.Fatal("ReadFull failed %s", err)
+	}
+
+	if !bytes.Equal(buf, expected) {
+		t.Fatalf("%v != %v", buf, expected)
+	}
+
+	// Close connection
+	p.Payload = []byte{}
+	p.header.Control = 1
+	p.SetSeqAck(3, 0)
+	p.WriteTo(requestWriter)
+
+	// Wait for the channel to close then check the result
+	waitChannelCountEqual(0, demux, ctx.Done)
+
+	cancel() // shutdown relay
+
+	if !equalBuffer(t, clientOutput, []byte("hello")) {
+		t.Fatalf("hello failed")
+	}
 }