4 files changed, 155 insertions(+), 88 deletions(-)

M cmd/kgp_client/main.go
M dispatch.go
M kgp.go
M kgp_test.go
M cmd/kgp_client/main.go +10 -8
@@ 78,18 78,19 @@ func main() {
 
 	flag.Parse()
 
+	var logger = log.New(os.Stderr, "", log.LstdFlags)
 	var args = flag.Args()
 	if len(args) != 2 {
-		log.Fatalln(usage)
+		logger.Fatalln(usage)
 	}
 
 	var upstream, err = ParseAddress(args[0])
 	if err != nil {
-		log.Fatalln(err)
+		logger.Fatalln(err)
 	}
 	downstream, err := ParseAddress(args[1])
 	if err != nil {
-		log.Fatalln(err)
+		logger.Fatalln(err)
 	}
 
 	// Load metadata

          
@@ 101,19 102,19 @@ func main() {
 		var err error
 		meta, err = os.Open(*metafn)
 		if err != nil {
-			log.Fatalln(err)
+			logger.Fatalln(err)
 		}
 	}
-	log.Print("upstream", upstream, "downstream", downstream)
+	logger.Print("upstream", upstream, "downstream", downstream)
 
 	upconn, err := net.Dial(upstream.Network, upstream.Address)
 	if err != nil {
-		log.Fatalln("error connecting to ", upstream, ": ", err.Error())
+		logger.Fatalln("error connecting to ", upstream, ": ", err.Error())
 	}
-	log.Print("connected")
+	logger.Print("connected")
 
 	if _, err = handshake(upconn, meta); err != nil {
-		log.Fatalln(err)
+		logger.Fatalln(err)
 	}
 
 	// var buf = make([]byte, 1000)

          
@@ 129,5 130,6 @@ func main() {
 			)
 		},
 	)
+	relay.Logger = logger
 	fmt.Println(relay.Run(context.Background()))
 }

          
M dispatch.go +1 -0
@@ 7,6 7,7 @@ import (
 )
 
 type channels map[uint32]chan<- *bytes.Buffer
+
 //
 // A read-often & write-rarely optimized object to dispatch the packets to
 // their channel.

          
M kgp.go +129 -65
@@ 119,46 119,106 @@ const (
 )
 
 type Packet struct {
-	Type         int
-	ConnectionId uint32
+	header struct {
+		ConnectionId uint32
+
+		Sequence uint32
+		Ack      uint32
+		Control  int16
 
-	Payload bytes.Buffer
+		PayloadLen uint16
+	}
+
+	buffer  []byte
+	Payload []byte
+}
+
+func NewPacket() *Packet {
+	var p = new(Packet)
+	p.buffer = make([]byte, MaxPayloadLength)
+	return p
 }
 
-type packetHeader struct {
-	ConnectionId uint32
-
-	Sequence uint32
-	Ack      uint32
-	Control  int16
-
-	PayloadLen uint16
+func (packet *Packet) Type() int {
+	var h = packet.header
+	switch {
+	case h.ConnectionId == 0:
+		return keepalivePacket
+	case h.Control == 0 && h.PayloadLen == 0:
+		return openPacket
+	case h.Control == 1:
+		return closePacket
+	default:
+		return forwardPacket
+	}
 }
 
-func (packet *Packet) Write(w io.Writer, seq, ack uint32) (n int, err error) {
-	var control int16
-	if packet.Type == closePacket {
-		control = 1
+func (packet *Packet) Size() int {
+	var size = binary.Size(&packet.header)
+	if packet.Payload != nil {
+		size += len(packet.Payload)
 	}
-	var header = packetHeader{
-		ConnectionId: packet.ConnectionId,
-		Sequence:     seq,
-		Ack:          ack,
-		Control:      control,
-		PayloadLen:   uint16(packet.Payload.Len()),
-	}
-	err = binary.Write(w, binary.BigEndian, &header)
+	return size
+}
+
+func (packet *Packet) ConnectionId() uint32 {
+	return packet.header.ConnectionId
+}
+
+func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) {
+	err = binary.Read(reader, binary.BigEndian, &packet.header)
 	if err != nil {
 		return
 	}
-	n += binary.Size(&header)
-	var o int64
-	o, err = io.CopyN(w, &packet.Payload, int64(packet.Payload.Len()))
-	n += int(o)
+	n += int64(binary.Size(&packet.header))
+
+	var l = int64(packet.header.PayloadLen)
+	if l > 0 {
+		packet.Payload = packet.buffer[:l]
+		var x int
+		x, err = io.ReadFull(reader, packet.Payload)
+		if err != nil {
+			return
+		}
+
+		n += int64(x)
+	} else {
+		// nil == no payload
+		packet.Payload = nil
+	}
 
 	return
 }
 
+func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) {
+	err = binary.Write(w, binary.BigEndian, &packet.header)
+	if err != nil {
+		return
+	}
+	n += int64(binary.Size(&packet.header))
+	if packet.Payload != nil {
+		var o int
+		o, err = w.Write(packet.Payload)
+		if err != nil {
+			return
+		}
+
+		n += int64(o)
+	}
+
+	return
+}
+
+func (packet *Packet) NewPacketClose(id uint32) *Packet {
+	var p = NewPacket()
+	p.header.ConnectionId = id
+	p.header.Control = 1
+}
+
+func (packet *Packet) NewPacketForward(id uint32, reader io.Reader) *Packet {
+
+}
+
 func (relay *Relay) newChannel(
 	ctx context.Context,
 	fail func(error),

          
@@ 170,13 230,15 @@ func (relay *Relay) newChannel(
 	// Try to connect to the local host
 	var conn, err = relay.connect(ctx)
 	if err != nil {
-		// we can't connect to upstream, cancel everything
-		fail(err)
+		relay.printf("can't connect to downstream, closing channel %d", id)
+		// we can't connect to downstream, close the connection.
+		var p Packet
+		p.Close(id)
+		output <- &p
 		return
-	} else {
-		relay.printf("Channel %d connected", id)
 	}
 
+	relay.printf("Channel %d connected", id)
 	var connClose = func() {
 		if err := conn.Close(); err != nil {
 			relay.printf(

          
@@ 200,17 262,15 @@ func (relay *Relay) newChannel(
 					connClose()
 					return
 				}
-				var _, err = io.Copy(conn, buf)
+				var n, err = io.Copy(conn, buf)
+				relay.printf("wrote %d bytes to channel %d", n, id)
 				if err == io.EOF {
+					relay.printf("channel %d closed", id)
+					connClose()
 					// nothing else to read, we're done
 					return
 				} else if err != nil {
-					relay.printf(
-						"ERROR: while writing on channel %d: %s", id, err,
-					)
-					// FIXME seems a bit harsh to cancel everything just
-					// because 1 upstream connection terminated badly...
-					fail(err)
+					connClose()
 					return
 				}
 			}

          
@@ 220,35 280,34 @@ func (relay *Relay) newChannel(
 	// We read from conn and write it to output
 	go func() {
 		for {
+			relay.printf("reading from channel %d", id)
+			var buf = make([]byte, MaxPayloadLength)
+			var n, err = conn.Read(buf)
+			relay.printf("read %d bytes from channel %d", n, id)
+			if err == io.EOF {
+				relay.printf("channel %d closed", id)
+				connClose()
+				var p Packet
+				p.Close(id)
+				output <- &p
+				return
+			}
+			if err != nil {
+				connClose()
+				return
+			}
+			// fmt.Printf("BUF: %s\n", string(buf))
 			var packet = Packet{
 				Type:         forwardPacket,
 				ConnectionId: id,
-			}
-			var _, err = io.CopyN(&packet.Payload, conn, MaxPayloadLength)
-			if err == io.EOF {
-				// connection closed we send the left-over traffic if needed,
-				// then send a close KGP packet to the server and return
-				if packet.Payload.Len() > 0 {
-					output <- &packet
-				}
-				output <- &Packet{
-					Type:         closePacket,
-					ConnectionId: id,
-				}
-				return
-			}
-			if err != nil {
-				relay.printf(
-					"WARNING: error while reading from channel %d: %s",
-					id, err,
-				)
-				fail(err)
-				return
+				Payload:      *bytes.NewBuffer(buf[:n]),
 			}
 			select {
 			case <-ctx.Done():
-				break
+				connClose()
+				return
 			case output <- &packet:
+				fmt.Println("packet sent upstream")
 			}
 		}
 	}()

          
@@ 309,15 368,15 @@ func (relay *Relay) readServer(
 
 	go func() {
 		for {
-			var h packetHeader
-			var err = binary.Read(relay.server, binary.BigEndian, &h)
+			var p = NewPacket()
+			_, err = p.ReadFrom(relay.server)
 			if err != nil {
 				fail(err)
 				close(output)
 				return
 			}
-			ack.Store(h.Sequence)
-			relay.rxBytes += uint64(binary.Size(&h))
+			ack.Store(packet.header.Sequence)
+			relay.rxBytes += 1 // # uint64(binary.Size(&h))
 			var packet = Packet{
 				ConnectionId: h.ConnectionId,
 			}

          
@@ 385,7 444,9 @@ func (relay *Relay) writeServer(
 				switch p.Type {
 				case openPacket, forwardPacket, closePacket:
 					// Forward the packet to the server
-					_, err = p.Write(relay.server, seq, getAck())
+					var n int
+					n, err = p.Write(relay.server, seq, getAck())
+					relay.printf("wrote %d bytes upstream", n)
 					seq += 1
 					if seq == 0 {
 						seq = 1

          
@@ 441,6 502,7 @@ func (relay *Relay) Run(ctx context.Cont
 	// buffered channel to make sure the fail function doesn't block
 	var reportErr = make(chan error, 1)
 	var fail = func(err error) {
+		relay.printf("fail %v", err)
 		reportErr <- err
 		close(reportErr)
 		cancelFunc()

          
@@ 481,6 543,7 @@ func (relay *Relay) Run(ctx context.Cont
 				go keepAlive(done, p, output)
 			case openPacket:
 				var cid = p.ConnectionId
+				relay.printf("channel %d opened remotely", cid)
 				var c = relay.dispatcher.Open(cid)
 				// Verify if connection id exists for sanity check
 				if c == nil {

          
@@ 490,6 553,7 @@ func (relay *Relay) Run(ctx context.Cont
 
 				go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
 			case closePacket:
+				relay.printf("channel %d closed remotely", p.ConnectionId)
 				relay.dispatcher.Close(p.ConnectionId)
 			case forwardPacket:
 				var cid = p.ConnectionId

          
M kgp_test.go +15 -15
@@ 182,21 182,21 @@ func writeResponse(payload string, conne
 	 * We don't write the open packet since the server opened the connection.
 	 */
 
-	 // 1st Write forward packet with payload
-	 // FIXME we should split if len(payload) > MaxPayloadLength
-	 var p = Packet{
-		 Type:         forwardPacket,
-		 ConnectionId: connectionId,
-		 Payload:      *bytes.NewBufferString(payload),
-	 }
-	 p.Write(&r, seq, ack)
-	 // Close connection
-	 p = Packet{
-		 Type:         closePacket,
-		 ConnectionId: connectionId,
-	 }
-	 p.Write(&r, seq + 1, ack)
-	 return
+	// 1st Write forward packet with payload
+	// FIXME we should split if len(payload) > MaxPayloadLength
+	var p = Packet{
+		Type:         forwardPacket,
+		ConnectionId: connectionId,
+		Payload:      *bytes.NewBufferString(payload),
+	}
+	p.Write(&r, seq, ack)
+	// Close connection
+	p = Packet{
+		Type:         closePacket,
+		ConnectionId: connectionId,
+	}
+	p.Write(&r, seq+1, ack)
+	return
 }
 
 func equalResponse(t *testing.T, result io.Reader, expected []byte) bool {