refactor packet handling
3 files changed, 137 insertions(+), 145 deletions(-)

M dispatch.go
M kgp.go
M kgp_test.go
M dispatch.go +4 -5
@@ 1,12 1,11 @@ 
 package kgp
 
 import (
-	"bytes"
 	"sync"
 	"sync/atomic"
 )
 
-type channels map[uint32]chan<- *bytes.Buffer
+type channels map[uint32]chan<- []byte
 
 //
 // A read-often & write-rarely optimized object to dispatch the packets to

          
@@ 23,7 22,7 @@ func newDispatcher() *dispatcher {
 	return &d
 }
 
-func (d *dispatcher) Get(id uint32) chan<- *bytes.Buffer {
+func (d *dispatcher) Get(id uint32) chan<- []byte {
 	var channels = d.channels.Load().(channels)
 	return channels[id]
 }

          
@@ 33,7 32,7 @@ func (d *dispatcher) Len() int {
 	return len(channels)
 }
 
-func (d *dispatcher) Open(id uint32) chan *bytes.Buffer {
+func (d *dispatcher) Open(id uint32) chan []byte {
 	d.mutex.Lock() // synchronize with other potential writers
 	defer d.mutex.Unlock()
 	var oldchan = d.channels.Load().(channels)

          
@@ 46,7 45,7 @@ func (d *dispatcher) Open(id uint32) cha
 	for k, v := range oldchan {
 		newchan[k] = v // copy all data from the current object to the new one
 	}
-	var c = make(chan *bytes.Buffer)
+	var c = make(chan []byte)
 	newchan[id] = c
 	d.channels.Store(newchan) // atomically replace the current object with the new one
 	return c

          
M kgp.go +70 -80
@@ 125,8 125,6 @@ type Packet struct {
 		Sequence uint32
 		Ack      uint32
 		Control  int16
-
-		PayloadLen uint16
 	}
 
 	buffer  []byte

          
@@ 144,7 142,7 @@ func (packet *Packet) Type() int {
 	switch {
 	case h.ConnectionId == 0:
 		return keepalivePacket
-	case h.Control == 0 && h.PayloadLen == 0:
+	case h.Control == 0 && len(packet.Payload) == 0:
 		return openPacket
 	case h.Control == 1:
 		return closePacket

          
@@ 153,12 151,12 @@ func (packet *Packet) Type() int {
 	}
 }
 
-func (packet *Packet) Size() int {
+func (packet *Packet) Size() uint64 {
 	var size = binary.Size(&packet.header)
 	if packet.Payload != nil {
 		size += len(packet.Payload)
 	}
-	return size
+	return uint64(size)
 }
 
 func (packet *Packet) ConnectionId() uint32 {

          
@@ 172,19 170,19 @@ func (packet *Packet) ReadFrom(reader io
 	}
 	n += int64(binary.Size(&packet.header))
 
-	var l = int64(packet.header.PayloadLen)
-	if l > 0 {
-		packet.Payload = packet.buffer[:l]
+	var payloadlen int16
+	if err = binary.Read(reader, binary.BigEndian, &payloadlen); err != nil {
+		return
+	}
+
+	packet.Payload = packet.buffer[:payloadlen]
+	if payloadlen > 0 {
 		var x int
 		x, err = io.ReadFull(reader, packet.Payload)
+		n += int64(x)
 		if err != nil {
 			return
 		}
-
-		n += int64(x)
-	} else {
-		// nil == no payload
-		packet.Payload = nil
 	}
 
 	return

          
@@ 196,7 194,14 @@ func (packet *Packet) WriteTo(w io.Write
 		return
 	}
 	n += int64(binary.Size(&packet.header))
-	if packet.Payload != nil {
+
+	var payloadlen = int16(len(packet.Payload))
+	if err = binary.Write(w, binary.BigEndian, &payloadlen); err != nil {
+		return
+	}
+	n += int64(binary.Size(&payloadlen))
+
+	if payloadlen > 0 {
 		var o int
 		o, err = w.Write(packet.Payload)
 		if err != nil {

          
@@ 209,21 214,23 @@ func (packet *Packet) WriteTo(w io.Write
 	return
 }
 
-func (packet *Packet) NewPacketClose(id uint32) *Packet {
+func (packet *Packet) SetSeqAck(seq, ack uint32) {
+	packet.header.Sequence = seq
+	packet.header.Ack = ack
+}
+
+func NewPacketClose(id uint32) *Packet {
 	var p = NewPacket()
 	p.header.ConnectionId = id
 	p.header.Control = 1
-}
-
-func (packet *Packet) NewPacketForward(id uint32, reader io.Reader) *Packet {
-
+	return p
 }
 
 func (relay *Relay) newChannel(
 	ctx context.Context,
 	fail func(error),
 	id uint32,
-	input <-chan *bytes.Buffer,
+	input <-chan []byte,
 	output chan<- *Packet,
 	connect ConnectFunc,
 ) {

          
@@ 232,9 239,7 @@ func (relay *Relay) newChannel(
 	if err != nil {
 		relay.printf("can't connect to downstream, closing channel %d", id)
 		// we can't connect to downstream, close the connection.
-		var p Packet
-		p.Close(id)
-		output <- &p
+		output <- NewPacketClose(id)
 		return
 	}
 

          
@@ 262,7 267,7 @@ func (relay *Relay) newChannel(
 					connClose()
 					return
 				}
-				var n, err = io.Copy(conn, buf)
+				var n, err = conn.Write(buf)
 				relay.printf("wrote %d bytes to channel %d", n, id)
 				if err == io.EOF {
 					relay.printf("channel %d closed", id)

          
@@ 281,33 286,30 @@ func (relay *Relay) newChannel(
 	go func() {
 		for {
 			relay.printf("reading from channel %d", id)
-			var buf = make([]byte, MaxPayloadLength)
-			var n, err = conn.Read(buf)
+			var p = NewPacket()
+			var n, err = conn.Read(p.buffer)
 			relay.printf("read %d bytes from channel %d", n, id)
 			if err == io.EOF {
 				relay.printf("channel %d closed", id)
 				connClose()
-				var p Packet
-				p.Close(id)
-				output <- &p
+				output <- NewPacketClose(id)
 				return
 			}
 			if err != nil {
+				relay.printf(
+					"error reading on channel %d: %s",
+					id, err.Error(),
+				)
 				connClose()
 				return
 			}
-			// fmt.Printf("BUF: %s\n", string(buf))
-			var packet = Packet{
-				Type:         forwardPacket,
-				ConnectionId: id,
-				Payload:      *bytes.NewBuffer(buf[:n]),
-			}
+			p.Payload = p.buffer[:n]
+			p.header.ConnectionId = id
 			select {
 			case <-ctx.Done():
 				connClose()
 				return
-			case output <- &packet:
-				fmt.Println("packet sent upstream")
+			case output <- p:
 			}
 		}
 	}()

          
@@ 368,49 370,33 @@ func (relay *Relay) readServer(
 
 	go func() {
 		for {
+			relay.printf("readServer begin")
 			var p = NewPacket()
-			_, err = p.ReadFrom(relay.server)
-			if err != nil {
+			relay.printf("readServer ready to read")
+			var n, err = p.ReadFrom(relay.server)
+			relay.printf("read %d bytes from upstream", n)
+			if err == io.EOF {
+				close(output)
+				return
+			} else if err != nil {
 				fail(err)
 				close(output)
 				return
 			}
-			ack.Store(packet.header.Sequence)
-			relay.rxBytes += 1 // # uint64(binary.Size(&h))
-			var packet = Packet{
-				ConnectionId: h.ConnectionId,
-			}
-			switch {
-			case h.ConnectionId == 0:
-				packet.Type = keepalivePacket
-			case h.Control == 0 && h.PayloadLen == 0:
-				packet.Type = openPacket
-			case h.Control == 1:
-				packet.Type = closePacket
-				// FIXME read half-close
-			default:
-				packet.Type = forwardPacket
-			}
-
-			if h.PayloadLen > 0 {
-				var n int64
-				n, err = io.CopyN(&packet.Payload, relay.server, int64(h.PayloadLen))
-				if err != nil {
-					fail(err)
-					close(output)
-					return
-				}
-				relay.rxBytes += uint64(n)
-			}
+			ack.Store(p.header.Sequence)
+			relay.rxBytes += p.Size()
 			relay.rxPackets += 1
 			// Send the new packet to output and check if we're done
 			select {
 			case <-done():
+				close(output)
 				return
-			case output <- &packet:
+			case output <- p:
 				// wait for packet to be sent
+				relay.printf("send packet to output")
 			}
 		}
+		relay.printf("readServer done")
 	}()
 
 	return output

          
@@ 441,11 427,13 @@ func (relay *Relay) writeServer(
 				return
 			case p := <-input:
 				var err error
-				switch p.Type {
+				switch p.Type() {
 				case openPacket, forwardPacket, closePacket:
 					// Forward the packet to the server
-					var n int
-					n, err = p.Write(relay.server, seq, getAck())
+					var n int64
+					p.header.Sequence = seq
+					p.header.Ack = getAck()
+					n, err = p.WriteTo(relay.server)
 					relay.printf("wrote %d bytes upstream", n)
 					seq += 1
 					if seq == 0 {

          
@@ 453,7 441,9 @@ func (relay *Relay) writeServer(
 					}
 				case keepalivePacket:
 					// Forward the packet to the server without seq/ack
-					_, err = p.Write(relay.server, 0, 0)
+					p.header.Sequence = 0
+					p.header.Ack = 0
+					_, err = p.WriteTo(relay.server)
 				}
 				if err != nil {
 					fail(err)

          
@@ 470,13 460,13 @@ func keepAlive(
 	p *Packet,
 	output chan<- *Packet,
 ) {
-	if bytes.Equal(p.Payload.Bytes(), []byte{'k'}) {
+	if bytes.Equal(p.Payload, []byte{'k'}) {
 		// We got a keep-alive request, let's send back a reply. We
 		// do that in a goroutine to not block the main loop.
 		go func() {
-			var reply = &Packet{
-				Payload: *bytes.NewBuffer([]byte{'a'}),
-			}
+			var reply = NewPacket()
+			reply.header.Ack = p.header.Sequence
+
 			select {
 			case <-done():
 			case output <- reply:

          
@@ 538,11 528,11 @@ func (relay *Relay) Run(ctx context.Cont
 			if p == nil {
 				break
 			}
-			switch p.Type {
+			switch p.Type() {
 			case keepalivePacket:
 				go keepAlive(done, p, output)
 			case openPacket:
-				var cid = p.ConnectionId
+				var cid = p.header.ConnectionId
 				relay.printf("channel %d opened remotely", cid)
 				var c = relay.dispatcher.Open(cid)
 				// Verify if connection id exists for sanity check

          
@@ 554,15 544,15 @@ func (relay *Relay) Run(ctx context.Cont
 				go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
 			case closePacket:
 				relay.printf("channel %d closed remotely", p.ConnectionId)
-				relay.dispatcher.Close(p.ConnectionId)
+				relay.dispatcher.Close(p.header.ConnectionId)
 			case forwardPacket:
-				var cid = p.ConnectionId
+				var cid = p.header.ConnectionId
 				var c = relay.dispatcher.Get(cid)
 				if c == nil {
 					relay.printf("ERROR: channel %d doesn't exists", cid)
 					break
 				} else {
-					c <- &p.Payload
+					c <- p.Payload
 				}
 			}
 		}

          
M kgp_test.go +63 -60
@@ 2,13 2,12 @@ package kgp
 
 import (
 	"bytes"
+	"context"
 	"io"
 	"io/ioutil"
 	"net"
 	"testing"
 	"time"
-
-	"golang.org/x/net/context"
 )
 
 const authMeta = `{"username": "u", "access_key": "k"}`

          
@@ 111,12 110,12 @@ waitloop:
 func equalBuffer(t *testing.T, result io.Reader, expected []byte) bool {
 	var r, err = ioutil.ReadAll(result)
 	if err != nil {
-		t.Errorf("Couldn't read clientOutput: %s", err)
+		t.Fatalf("Couldn't read clientOutput: %s", err)
 		return false
 	}
 
 	if !bytes.Equal(r, expected) {
-		t.Errorf("%v != %v", r, expected)
+		t.Fatalf("%v != %v", r, expected)
 		return false
 	}
 

          
@@ 125,14 124,16 @@ func equalBuffer(t *testing.T, result io
 
 func TestRelayRequest(t *testing.T) {
 	var inputReader, inputWriter = io.Pipe()
+	var nullReader, _ = io.Pipe()
 	var server = recorder{
 		Input:  inputReader,
 		Output: ioutil.Discard,
 	}
-	var clientOutput bytes.Buffer
+	var clientOutput = bytes.NewBuffer(nil)
+	// we use a pipe to avoid returning EOF right away
 	var client = recorder{
-		Input:  &bytes.Buffer{},
-		Output: &clientOutput,
+		Input:  nullReader,
+		Output: clientOutput,
 	}
 
 	var connect = func(context.Context) (io.ReadWriteCloser, error) {

          
@@ 143,38 144,44 @@ func TestRelayRequest(t *testing.T) {
 	var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
 	// We execute the client in the background while relay.Run() is running
 	go func() {
+		var p Packet
+
 		// open connection 1
-		var p = Packet{
-			Type:         openPacket,
-			ConnectionId: 1,
-		}
-		p.Write(inputWriter, 1, 0)
-		// Write "hello world"
-		p = Packet{
-			Type:         forwardPacket,
-			ConnectionId: 1,
-			Payload:      *bytes.NewBufferString("hello"),
+		p.header.ConnectionId = 1
+		p.SetSeqAck(1, 0)
+		_, err := p.WriteTo(inputWriter)
+		if err != nil {
+			t.Fatal(err)
 		}
-		p.Write(inputWriter, 2, 0)
-		// Close connection
-		p = Packet{
-			Type:         closePacket,
-			ConnectionId: 1,
+
+		// Write "hello world"
+		p.Payload = []byte("hello")
+		p.SetSeqAck(2, 0)
+		_, err = p.WriteTo(inputWriter)
+		if err != nil {
+			t.Fatal(err)
 		}
-		p.Write(inputWriter, 3, 0)
 
-		// Wait for the channel to close then check the result
+		// Close connection
+		p.header.Control = 1
+		p.Payload = nil
+		p.SetSeqAck(3, 0)
+		_, err = p.WriteTo(inputWriter)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		// Wait for all channels to close
 		waitChannelClose(ctx.Done, relay)
-
-		// Now let's read what the client has received
-		var expected = bytes.NewBufferString("hello").Bytes()
-		equalBuffer(t, &clientOutput, expected)
+		// Cancel relay to check result
 		cancel()
 	}()
 	var err = relay.Run(ctx)
 	if err != context.Canceled {
 		t.Errorf("relay.Run errored: %s", err)
 	}
+	var expected = []byte("hello")
+	equalBuffer(t, clientOutput, expected)
 }
 
 func writeResponse(payload string, connectionId, seq, ack uint32) (r bytes.Buffer) {

          
@@ 184,18 191,17 @@ func writeResponse(payload string, conne
 
 	// 1st Write forward packet with payload
 	// FIXME we should split if len(payload) > MaxPayloadLength
-	var p = Packet{
-		Type:         forwardPacket,
-		ConnectionId: connectionId,
-		Payload:      *bytes.NewBufferString(payload),
-	}
-	p.Write(&r, seq, ack)
+	var p Packet
+	p.header.ConnectionId = connectionId
+	p.Payload = []byte(payload)
+	p.SetSeqAck(seq, ack)
+	p.WriteTo(&r)
+
 	// Close connection
-	p = Packet{
-		Type:         closePacket,
-		ConnectionId: connectionId,
-	}
-	p.Write(&r, seq+1, ack)
+	p.Payload = []byte{}
+	p.header.Control = 1
+	p.SetSeqAck(seq+1, ack)
+	p.WriteTo(&r)
 	return
 }
 

          
@@ 245,37 251,34 @@ func TestRelayRequestResponse(t *testing
 	// We execute the client in the background while relay.Run() is running
 	go func() {
 		// open connection 1
-		var p = Packet{
-			Type:         openPacket,
-			ConnectionId: 1,
-		}
-		p.Write(inputWriter, 1, 0)
+		var p Packet
+		p.header.ConnectionId = 1
+		p.SetSeqAck(1, 0)
+		p.WriteTo(inputWriter)
+
 		// Write "hello"
-		p = Packet{
-			Type:         forwardPacket,
-			ConnectionId: 1,
-			Payload:      *bytes.NewBufferString("hello"),
-		}
-		p.Write(inputWriter, 2, 0)
+		p.Payload = []byte("hello")
+		p.SetSeqAck(2, 0)
+		p.WriteTo(inputWriter)
+
 		// Close connection
-		p = Packet{
-			Type:         closePacket,
-			ConnectionId: 1,
-		}
-		p.Write(inputWriter, 3, 0)
+		p.Payload = []byte{}
+		p.header.Control = 1
+		p.SetSeqAck(3, 0)
+		p.WriteTo(inputWriter)
 
 		// Wait for the channel to close then check the result
 		waitChannelClose(ctx.Done, relay)
 
-		// Now let's read what the client has received
-		var expected = bytes.NewBufferString("hello").Bytes()
-		equalBuffer(t, clientOutput, expected)
-		expected = bytes.NewBufferString("world").Bytes()
-		equalResponse(t, serverOutput, expected)
 		cancel()
 	}()
 	var err = relay.Run(ctx)
 	if err != context.Canceled {
 		t.Errorf("relay.Run errored: %s", err)
 	}
+
+	var expected = []byte("hello")
+	equalBuffer(t, clientOutput, expected)
+	expected = []byte("world")
+	equalResponse(t, serverOutput, []byte("world"))
 }