4 files changed, 222 insertions(+), 220 deletions(-)

M cmd/kgp_client/main.go
M kgp.go
M kgp_test.go
A => packet.go
M cmd/kgp_client/main.go +15 -24
@@ 7,8 7,8 @@ import (
 	"log"
 	"net"
 	"os"
+	"os/signal"
 	"strings"
-
 	"context"
 
 	"bitbucket.org/henry/kgp"

          
@@ 51,21 51,6 @@ func ParseAddress(s string) (a Address, 
 	return
 }
 
-func handshake(conn net.Conn, meta io.Reader) (a kgp.Announcement, err error) {
-	var my = kgp.Announcement{
-		Metadata: meta,
-	}
-	var n int
-	if n, err = my.Write(conn); err != nil {
-		fmt.Println("wrote", n)
-		return
-	}
-
-	a, err = kgp.ReadAnnouncement(conn)
-	log.Print(a)
-	return
-}
-
 const usage = `usage: kgp_client upstream downstream
 
   Example: upstream KGP server is tunnel.example.com:443 and the downstream

          
@@ 105,22 90,28 @@ func main() {
 			logger.Fatalln(err)
 		}
 	}
-	logger.Print("upstream", upstream, "downstream", downstream)
 
 	upconn, err := net.Dial(upstream.Network, upstream.Address)
 	if err != nil {
 		logger.Fatalln("error connecting to ", upstream, ": ", err.Error())
 	}
-	logger.Print("connected")
+	logger.Print("connected to ", args[0])
 
-	if _, err = handshake(upconn, meta); err != nil {
+	var lastSeenSeq uint32
+	if lastSeenSeq, err = kgp.Handshake(upconn, meta); err != nil {
 		logger.Fatalln(err)
 	}
+	if lastSeenSeq != 0 {
+		logger.Print("resuming traffic from packet #", lastSeenSeq)
+	}
 
-	// var buf = make([]byte, 1000)
-	// var n int
-	// n, err = upconn.Read(buf)
-	// fmt.Println(n, err, buf)
+	var ctx, cancel = context.WithCancel(context.Background())
+	go func() {
+		var sigchan = make(chan os.Signal, 1)
+		signal.Notify(sigchan, os.Interrupt)
+		<-sigchan // wait for signal to come in
+		cancel()
+	}()
 
 	var relay = kgp.NewRelay(
 		upconn,

          
@@ 131,5 122,5 @@ func main() {
 		},
 	)
 	relay.Logger = logger
-	fmt.Println(relay.Run(context.Background()))
+	fmt.Println(relay.Run(ctx, lastSeenSeq+1))
 }

          
M kgp.go +77 -194
@@ 4,20 4,16 @@ 
 package kgp
 
 import (
-	"bufio"
 	"bytes"
+	"context"
 	"encoding/binary"
+	"encoding/json"
 	"fmt"
 	"io"
 	"log"
 	"sync/atomic"
-
-	"context"
 )
 
-// That's from the spec, no idea why it's the limit
-const MaxPayloadLength = 30 * 1024
-
 type UUID [16]byte
 
 type Announcement struct {

          
@@ 94,137 90,77 @@ func ReadAnnouncement(reader io.Reader) 
 	return
 }
 
-func initialHandShake(
-	conn io.ReadWriter,
-	metadata io.Reader,
-) (err error) {
-	var a = Announcement{Metadata: metadata}
-	_, err = a.Write(conn)
-	if err != nil {
+// Return the last seen packet from the last client connected to this tunnel.
+// Use this as the sequence number + 1 with Relay.Run()
+func Handshake(conn io.ReadWriter, meta io.Reader) (lastSeenSeq uint32, err error) {
+	var my = Announcement{
+		Metadata: meta,
+	}
+	if _, err = my.Write(conn); err != nil {
 		return
 	}
-	var reader = bufio.NewReader(conn)
-	_, err = ReadAnnouncement(reader)
+
+	a, err := ReadAnnouncement(conn)
 	if err != nil {
 		return
 	}
-	return
-}
 
-const (
-	openPacket = iota
-	forwardPacket
-	closePacket
-	keepalivePacket
-)
-
-type Packet struct {
-	header struct {
-		ConnectionId uint32
-
-		Sequence uint32
-		Ack      uint32
-		Control  int16
+	var doc struct {
+		// We have to use float64 otherwise it won't parse
+		Last_seen_seq float64
 	}
-
-	buffer  []byte
-	Payload []byte
-}
-
-func NewPacket() *Packet {
-	var p = new(Packet)
-	p.buffer = make([]byte, MaxPayloadLength)
-	return p
-}
-
-func (packet *Packet) Type() int {
-	var h = packet.header
-	switch {
-	case h.ConnectionId == 0:
-		return keepalivePacket
-	case h.Control == 0 && len(packet.Payload) == 0:
-		return openPacket
-	case h.Control == 1:
-		return closePacket
-	default:
-		return forwardPacket
-	}
-}
-
-func (packet *Packet) Size() uint64 {
-	var x uint16
-	var size = binary.Size(&packet.header) + binary.Size(x)
-	if packet.Payload != nil {
-		size += len(packet.Payload)
-	}
-	return uint64(size)
-}
-
-func (packet *Packet) ConnectionId() uint32 {
-	return packet.header.ConnectionId
-}
-
-func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) {
-	err = binary.Read(reader, binary.BigEndian, &packet.header)
-	if err != nil {
+	if err = json.NewDecoder(a.Metadata).Decode(&doc); err != nil {
 		return
 	}
-	n += int64(binary.Size(&packet.header))
-
-	var payloadlen uint16
-	if err = binary.Read(reader, binary.BigEndian, &payloadlen); err != nil {
-		return
-	}
-
-	packet.Payload = packet.buffer[:payloadlen]
-	if payloadlen > 0 {
-		var x int
-		x, err = io.ReadFull(reader, packet.Payload)
-		n += int64(x)
-		if err != nil {
-			return
-		}
-	}
+	lastSeenSeq = uint32(doc.Last_seen_seq)
 
 	return
 }
 
-func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) {
-	err = binary.Write(w, binary.BigEndian, &packet.header)
-	if err != nil {
-		return
-	}
-	n += int64(binary.Size(&packet.header))
+type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)
 
-	var payloadlen = uint16(len(packet.Payload))
-	if err = binary.Write(w, binary.BigEndian, &payloadlen); err != nil {
-		return
-	}
-	n += int64(binary.Size(&payloadlen))
+//
+// Relay between server and the remote with the connect function.
+//
+// Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you
+// can log what's happening by setting Logger:
+//
+// 		var r = NewRelay(...)
+//		r.RxSize = 1
+//		r.TxSize = 100
+//		r.Logger = log.New(os.Stderr, "kgp_relay", 0)
+//
+type Relay struct {
+	RxSize int
+	TxSize int
 
-	if payloadlen > 0 {
-		var o int
-		o, err = w.Write(packet.Payload)
-		if err != nil {
-			return
-		}
+	Logger *log.Logger
+
+	server  io.ReadWriter
+	connect ConnectFunc
 
-		n += int64(o)
-	}
+	dispatcher *dispatcher
 
-	return
+	// traffic stats
+	rxBytes, txBytes     uint64
+	rxPackets, txPackets uint64
 }
 
-func (packet *Packet) SetSeqAck(seq, ack uint32) {
-	packet.header.Sequence = seq
-	packet.header.Ack = ack
+// Creates a new relay with RxSize & TxSize set to 4.
+func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay {
+	return &Relay{
+		RxSize:     4,
+		TxSize:     4,
+		server:     server,
+		connect:    connect,
+		dispatcher: newDispatcher(),
+	}
 }
 
-func NewPacketClose(id uint32) *Packet {
-	var p = NewPacket()
-	p.header.ConnectionId = id
-	p.header.Control = 1
-	return p
+func (relay *Relay) printf(format string, v ...interface{}) {
+	if relay.Logger != nil {
+		relay.Logger.Printf(format, v...)
+	}
 }
 
 func (relay *Relay) newChannel(

          
@@ 258,7 194,6 @@ func (relay *Relay) newChannel(
 		for {
 			select {
 			case <-ctx.Done():
-				// We're done, close the connection
 				connClose()
 				return
 			case buf, ok := <-input:

          
@@ 286,7 221,6 @@ func (relay *Relay) newChannel(
 	// We read from conn and write it to output
 	go func() {
 		for {
-			relay.printf("reading from channel %d", id)
 			var p = NewPacket()
 			var n, err = conn.Read(p.buffer)
 			relay.printf("read %d bytes from channel %d", n, id)

          
@@ 298,7 232,7 @@ func (relay *Relay) newChannel(
 			}
 			if err != nil {
 				relay.printf(
-					"error reading on channel %d: %s",
+					"error reading from channel %d: %s",
 					id, err.Error(),
 				)
 				connClose()

          
@@ 316,49 250,6 @@ func (relay *Relay) newChannel(
 	}()
 }
 
-type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)
-
-//
-// Relay between server and the remote with the connect function.
-//
-// Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you
-// can log what's happening by setting Logger:
-//
-// 		var r = NewRelay(...)
-//		r.RxSize = 1
-//		r.TxSize = 100
-//		r.Logger = log.New(os.Stderr, "kgp_relay", 0)
-//
-type Relay struct {
-	RxSize int
-	TxSize int
-
-	Logger *log.Logger
-
-	server  io.ReadWriter
-	connect ConnectFunc
-
-	dispatcher *dispatcher
-
-	// traffic stats
-	rxBytes, txBytes     uint64
-	rxPackets, txPackets uint64
-}
-
-func (relay *Relay) printf(format string, v ...interface{}) {
-	if relay.Logger != nil {
-		relay.Logger.Printf(format, v...)
-	}
-}
-
-func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay {
-	return &Relay{
-		server:     server,
-		connect:    connect,
-		dispatcher: newDispatcher(),
-	}
-}
-
 //
 // Read packets from server and put them in the returned channel
 //

          
@@ 371,9 262,7 @@ func (relay *Relay) readServer(
 
 	go func() {
 		for {
-			relay.printf("readServer begin")
 			var p = NewPacket()
-			relay.printf("readServer ready to read")
 			var n, err = p.ReadFrom(relay.server)
 			relay.printf("read %d bytes from upstream", n)
 			if err == io.EOF {

          
@@ 394,7 283,6 @@ func (relay *Relay) readServer(
 				return
 			case output <- p:
 				// wait for packet to be sent
-				relay.printf("send packet to output")
 			}
 		}
 		relay.printf("readServer done")

          
@@ 407,6 295,7 @@ func (relay *Relay) readServer(
  * Forward data from input to conn
  */
 func (relay *Relay) writeServer(
+	seq uint32,
 	done func() <-chan struct{},
 	fail func(error),
 	getAck func() uint32,

          
@@ 419,17 308,14 @@ func (relay *Relay) writeServer(
 		// server.
 		//
 		// conn will be closed by the caller once we're done.
-		var seq uint32 = 1
-
 		for {
 			select {
 			case <-done():
-				relay.printf("writeServer done")
 				return
 			case p := <-input:
 				var err error
 				switch p.Type() {
-				case openPacket, forwardPacket, closePacket:
+				case forwardPacket, closePacket:
 					// Forward the packet to the server
 					var n int64
 					p.header.Sequence = seq

          
@@ 441,10 327,11 @@ func (relay *Relay) writeServer(
 						seq = 1
 					}
 				case keepalivePacket:
-					// Forward the packet to the server without seq/ack
-					p.header.Sequence = 0
-					p.header.Ack = 0
-					_, err = p.WriteTo(relay.server)
+					// Forward the packet to the server as is
+					p.header.Ack = seq
+					var n int64
+					n, err = p.WriteTo(relay.server)
+					relay.printf("wrote %d bytes upstream for keepalive", n)
 				}
 				if err != nil {
 					fail(err)

          
@@ 466,7 353,8 @@ func keepAlive(
 		// do that in a goroutine to not block the main loop.
 		go func() {
 			var reply = NewPacket()
-			reply.header.Ack = p.header.Sequence
+			reply.header.Sequence = p.header.Sequence
+			reply.Payload = []byte{'a'}
 
 			select {
 			case <-done():

          
@@ 480,7 368,7 @@ func keepAlive(
 	}
 }
 
-func (relay *Relay) Run(ctx context.Context) error {
+func (relay *Relay) Run(ctx context.Context, seq uint32) error {
 	// Keeps track of the last sequence number from the server
 	var ack atomic.Value
 	ack.Store(uint32(0))

          
@@ 502,11 390,9 @@ func (relay *Relay) Run(ctx context.Cont
 
 	// input reads packets from the server
 	var input = relay.readServer(done, fail, &ack)
+	var getAck = func() uint32 { return ack.Load().(uint32) }
 	// output writes packets to the server
-	var output = relay.writeServer(
-		done, fail,
-		func() uint32 { return ack.Load().(uint32) },
-	)
+	var output = relay.writeServer(seq, done, fail, getAck)
 
 	// We each packets from the server and handle them depending on their type:
 	// 	   open connection: we need to open new connection

          
@@ 532,26 418,23 @@ func (relay *Relay) Run(ctx context.Cont
 			switch p.Type() {
 			case keepalivePacket:
 				go keepAlive(done, p, output)
-			case openPacket:
-				var cid = p.header.ConnectionId
-				relay.printf("channel %d opened remotely", cid)
-				var c = relay.dispatcher.Open(cid)
-				// Verify if connection id exists for sanity check
-				if c == nil {
-					relay.printf("ERROR: channel %d already exists", cid)
-					break
+			case closePacket:
+				if bytes.Equal(p.Payload, []byte{'r'}) {
+					// half-close, we'll no longer receive data from the server
+					// side, but still need to forward what the client side
+					// sends us. Ignore it for now.
+				} else {
+					relay.printf("channel %d closed remotely", p.ConnectionId)
+					relay.dispatcher.Close(p.header.ConnectionId)
 				}
-
-				go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
-			case closePacket:
-				relay.printf("channel %d closed remotely", p.ConnectionId)
-				relay.dispatcher.Close(p.header.ConnectionId)
 			case forwardPacket:
 				var cid = p.header.ConnectionId
 				var c = relay.dispatcher.Get(cid)
 				if c == nil {
-					relay.printf("ERROR: channel %d doesn't exists", cid)
-					break
+					relay.printf("channel %d opened remotely", cid)
+					var c = relay.dispatcher.Open(cid)
+
+					go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
 				} else {
 					c <- p.Payload
 				}

          
M kgp_test.go +2 -2
@@ 176,7 176,7 @@ func TestRelayRequest(t *testing.T) {
 		// Cancel relay to check result
 		cancel()
 	}()
-	var err = relay.Run(ctx)
+	var err = relay.Run(ctx, 1)
 	if err != context.Canceled {
 		t.Errorf("relay.Run errored: %s", err)
 	}

          
@@ 272,7 272,7 @@ func TestRelayRequestResponse(t *testing
 
 		cancel()
 	}()
-	var err = relay.Run(ctx)
+	var err = relay.Run(ctx, 1)
 	if err != context.Canceled {
 		t.Errorf("relay.Run errored: %s", err)
 	}

          
A => packet.go +128 -0
@@ 0,0 1,128 @@ 
+package kgp
+
+import (
+	"bytes"
+	"encoding/binary"
+	"io"
+)
+
+const (
+	forwardPacket = iota
+	closePacket
+	halfClosePacket
+	keepalivePacket
+)
+
+type Packet struct {
+	header struct {
+		ConnectionId uint32
+
+		Sequence uint32
+		Ack      uint32
+		Control  int16
+	}
+
+	buffer  []byte
+	Payload []byte
+}
+
+// That's from the spec, no idea why it's the limit
+const MaxPayloadLength = 30 * 1024
+
+func NewPacket() *Packet {
+	var p = new(Packet)
+	p.buffer = make([]byte, MaxPayloadLength)
+	return p
+}
+
+func (packet *Packet) Type() int {
+	var h = packet.header
+	switch {
+	case h.ConnectionId == 0:
+		return keepalivePacket
+	case h.Control == 1:
+		if bytes.Equal(packet.Payload, []byte{'r'}) {
+			return halfClosePacket
+		} else {
+			return closePacket
+		}
+	default:
+		return forwardPacket
+	}
+}
+
+func (packet *Packet) Size() uint64 {
+	var x uint16
+	var size = binary.Size(&packet.header) + binary.Size(x)
+	if packet.Payload != nil {
+		size += len(packet.Payload)
+	}
+	return uint64(size)
+}
+
+func (packet *Packet) ConnectionId() uint32 {
+	return packet.header.ConnectionId
+}
+
+func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) {
+	err = binary.Read(reader, binary.BigEndian, &packet.header)
+	if err != nil {
+		return
+	}
+	n += int64(binary.Size(&packet.header))
+
+	var payloadlen uint16
+	if err = binary.Read(reader, binary.BigEndian, &payloadlen); err != nil {
+		return
+	}
+
+	packet.Payload = packet.buffer[:payloadlen]
+	if payloadlen > 0 {
+		var x int
+		x, err = io.ReadFull(reader, packet.Payload)
+		n += int64(x)
+		if err != nil {
+			return
+		}
+	}
+
+	return
+}
+
+func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) {
+	err = binary.Write(w, binary.BigEndian, &packet.header)
+	if err != nil {
+		return
+	}
+	n += int64(binary.Size(&packet.header))
+
+	var payloadlen = uint16(len(packet.Payload))
+	if err = binary.Write(w, binary.BigEndian, &payloadlen); err != nil {
+		return
+	}
+	n += int64(binary.Size(&payloadlen))
+
+	if payloadlen > 0 {
+		var o int
+		o, err = w.Write(packet.Payload)
+		if err != nil {
+			return
+		}
+
+		n += int64(o)
+	}
+
+	return
+}
+
+func (packet *Packet) SetSeqAck(seq, ack uint32) {
+	packet.header.Sequence = seq
+	packet.header.Ack = ack
+}
+
+func NewPacketClose(id uint32) *Packet {
+	var p = NewPacket()
+	p.header.ConnectionId = id
+	p.header.Control = 1
+	return p
+}