M cmd/kgp_client/main.go +5 -5
@@ 1,6 1,7 @@
package main
import (
+ "context"
"flag"
"fmt"
"io"
@@ 9,7 10,6 @@ import (
"os"
"os/signal"
"strings"
- "context"
"bitbucket.org/henry/kgp"
)
@@ 109,13 109,13 @@ func main() {
go func() {
var sigchan = make(chan os.Signal, 1)
signal.Notify(sigchan, os.Interrupt)
- <-sigchan // wait for signal to come in
+ <-sigchan // wait for signal to come in
signal.Reset(os.Interrupt) // Let go handle further signals
logger.Print("stopping traffic")
cancel()
}()
- var relay = kgp.NewRelay(
+ var demux = kgp.NewDemux(
upconn,
func(c context.Context) (io.ReadWriteCloser, error) {
return (&net.Dialer{}).DialContext(
@@ 123,6 123,6 @@ func main() {
)
},
)
- relay.Logger = logger
- fmt.Println(relay.Run(ctx, lastSeenSeq+1))
+ demux.Logger = logger
+ fmt.Println(demux.Run(ctx, lastSeenSeq+1))
}
M kgp.go +55 -49
@@ 91,7 91,7 @@ func ReadAnnouncement(reader io.Reader)
}
// Return the last seen packet from the last client connected to this tunnel.
-// Use this as the sequence number + 1 with Relay.Run()
+// Use this as the sequence number + 1 with Demux.Run()
func Handshake(conn io.ReadWriter, meta io.Reader) (lastSeenSeq uint32, err error) {
var my = Announcement{
Metadata: meta,
@@ 120,17 120,18 @@ func Handshake(conn io.ReadWriter, meta
type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)
//
-// Relay between server and the remote with the connect function.
+// Demux KGP connection from server and local clients via the Demux.Connect()
+// function.
//
// Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you
// can log what's happening by setting Logger:
//
-// var r = NewRelay(...)
+// var r = NewDemux(...)
// r.RxSize = 1
// r.TxSize = 100
// r.Logger = log.New(os.Stderr, "kgp_relay", 0)
//
-type Relay struct {
+type Demux struct {
RxSize int
TxSize int
@@ 146,9 147,9 @@ type Relay struct {
rxPackets, txPackets uint64
}
-// Creates a new relay with RxSize & TxSize set to 4.
-func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay {
- return &Relay{
+// Creates a new demuxer with RxSize & TxSize set to 4.
+func NewDemux(server io.ReadWriter, connect ConnectFunc) *Demux {
+ return &Demux{
RxSize: 4,
TxSize: 4,
server: server,
@@ 157,13 158,17 @@ func NewRelay(server io.ReadWriter, conn
}
}
-func (relay *Relay) printf(format string, v ...interface{}) {
- if relay.Logger != nil {
- relay.Logger.Printf(format, v...)
+func (demux *Demux) printf(format string, v ...interface{}) {
+ if demux.Logger != nil {
+ demux.Logger.Printf(format, v...)
}
}
-func (relay *Relay) newChannel(
+func (demux *Demux) ChannelCount() int {
+ return demux.dispatcher.Len()
+}
+
+func (demux *Demux) newChannel(
ctx context.Context,
fail func(error),
id uint32,
@@ 172,19 177,19 @@ func (relay *Relay) newChannel(
connect ConnectFunc,
) {
// Try to connect to the local host
- var conn, err = relay.connect(ctx)
+ var conn, err = demux.connect(ctx)
if err != nil {
- relay.printf("can't connect to downstream, closing channel %d", id)
+ demux.printf("can't connect to downstream, closing channel %d", id)
// we can't connect to downstream, close the connection.
output <- NewPacketClose(id)
return
}
- relay.printf("Channel %d connected", id)
+ demux.printf("Channel %d connected", id)
var connClose = func() {
if err := conn.Close(); err != nil {
- relay.printf(
+ demux.printf(
"WARNING: error while closing channel %d: %s", id, err,
)
}
@@ 195,25 200,25 @@ func (relay *Relay) newChannel(
for {
select {
case <-ctx.Done():
- relay.printf("channel %d writer done", id)
+ demux.printf("channel %d writer done", id)
connClose()
return
case buf := <-input:
if buf == nil {
// channel's closed, the server side has closed, therefor
// we close the connection on upstream.
- relay.printf("channel %d closed remotely", id)
+ demux.printf("channel %d closed remotely", id)
connClose()
return
}
var n, err = conn.Write(buf)
- relay.printf("wrote %d bytes to channel %d", n, id)
+ demux.printf("wrote %d bytes to channel %d", n, id)
if err == io.EOF {
- relay.printf("channel %d closed", id)
+ demux.printf("channel %d closed", id)
// nothing else to read, we're done
return
} else if err != nil {
- relay.printf("ERROR: writing to channel %d: %s", id, err)
+ demux.printf("ERROR: writing to channel %d: %s", id, err)
connClose()
return
}
@@ 226,14 231,14 @@ func (relay *Relay) newChannel(
for {
var p = NewPacket()
var n, err = conn.Read(p.buffer)
- relay.printf("read %d bytes from channel %d", n, id)
+ demux.printf("read %d bytes from channel %d", n, id)
if err == io.EOF {
- relay.printf("channel %d closed", id)
+ demux.printf("channel %d closed", id)
output <- NewPacketClose(id)
return
}
if err != nil {
- relay.printf(
+ demux.printf(
"ERROR: reading from channel %d: %s",
id, err.Error(),
)
@@ 255,20 260,20 @@ func (relay *Relay) newChannel(
//
// Read packets from server and put them in the returned channel
//
-func (relay *Relay) readServer(
+func (demux *Demux) readServer(
done func() <-chan struct{},
fail func(error),
ack *atomic.Value,
) <-chan *Packet {
- var output = make(chan *Packet, relay.RxSize)
+ var output = make(chan *Packet, demux.RxSize)
go func() {
for {
var p = NewPacket()
- var _, err = p.ReadFrom(relay.server)
- // relay.printf("read %d bytes from upstream", n)
+ var _, err = p.ReadFrom(demux.server)
+ // demux.printf("read %d bytes from upstream", n)
if err == io.EOF {
- relay.printf("upstream closed connection")
+ demux.printf("upstream closed connection")
close(output)
return
} else if err != nil {
@@ 277,8 282,9 @@ func (relay *Relay) readServer(
return
}
ack.Store(p.header.Sequence)
- relay.rxBytes += p.Size()
- relay.rxPackets += 1
+ fmt.Println(ack.Load())
+ demux.rxBytes += p.Size()
+ demux.rxPackets += 1
// Send the new packet to output and check if we're done
select {
case <-done():
@@ 288,7 294,7 @@ func (relay *Relay) readServer(
// wait for packet to be sent
}
}
- relay.printf("readServer done")
+ demux.printf("readServer done")
}()
return output
@@ 297,14 303,14 @@ func (relay *Relay) readServer(
/*
* Forward data from input to conn
*/
-func (relay *Relay) writeServer(
+func (demux *Demux) writeServer(
seq uint32,
done func() <-chan struct{},
fail func(error),
getAck func() uint32,
) chan<- *Packet {
// output writes packets to the server
- var input = make(chan *Packet, relay.TxSize)
+ var input = make(chan *Packet, demux.TxSize)
go func() {
// We continously read what the channels send us on output and write it to
@@ 323,8 329,8 @@ func (relay *Relay) writeServer(
var n int64
p.header.Sequence = seq
p.header.Ack = getAck()
- n, err = p.WriteTo(relay.server)
- relay.printf("wrote %d bytes upstream", n)
+ n, err = p.WriteTo(demux.server)
+ demux.printf("wrote %d bytes upstream", n)
seq += 1
if seq == 0 {
seq = 1
@@ 333,11 339,11 @@ func (relay *Relay) writeServer(
// Forward the packet to the server as is
p.header.Ack = seq
// var n int64
- _, err = p.WriteTo(relay.server)
- // relay.printf("wrote %d bytes upstream for keepalive", n)
+ _, err = p.WriteTo(demux.server)
+ // demux.printf("wrote %d bytes upstream for keepalive", n)
}
if err != nil {
- relay.printf("error writing packet %s", err)
+ demux.printf("error writing packet %s", err)
fail(err)
}
}
@@ 372,7 378,7 @@ func keepAlive(
}
}
-func (relay *Relay) Run(ctx context.Context, seq uint32) error {
+func (demux *Demux) Run(ctx context.Context, seq uint32) error {
// Keeps track of the last sequence number from the server
var ack atomic.Value
ack.Store(uint32(0))
@@ 385,7 391,7 @@ func (relay *Relay) Run(ctx context.Cont
// buffered channel to make sure the fail function doesn't block
var reportErr = make(chan error, 1)
var fail = func(err error) {
- relay.printf("fail %v", err)
+ demux.printf("fail %v", err)
reportErr <- err
close(reportErr)
cancelFunc()
@@ 393,10 399,10 @@ func (relay *Relay) Run(ctx context.Cont
var done = cancelCtx.Done
// input reads packets from the server
- var input = relay.readServer(done, fail, &ack)
+ var input = demux.readServer(done, fail, &ack)
var getAck = func() uint32 { return ack.Load().(uint32) }
// output writes packets to the server
- var output = relay.writeServer(seq, done, fail, getAck)
+ var output = demux.writeServer(seq, done, fail, getAck)
// We each packets from the server and handle them depending on their type:
// open connection: we need to open new connection
@@ 427,19 433,19 @@ func (relay *Relay) Run(ctx context.Cont
// half-close, we'll no longer receive data from the server
// side, but still need to forward what the client side
// sends us. Ignore it for now.
- relay.printf("half-close!")
+ demux.printf("half-close!")
} else {
- relay.printf("channel %d closed remotely", p.ConnectionId)
- relay.dispatcher.Close(p.header.ConnectionId)
+ demux.printf("channel %d closed remotely", p.ConnectionId)
+ demux.dispatcher.Close(p.header.ConnectionId)
}
case forwardPacket:
var cid = p.header.ConnectionId
- var c = relay.dispatcher.Get(cid)
+ var c = demux.dispatcher.Get(cid)
if c == nil {
- relay.printf("channel %d opened remotely", cid)
- var c = relay.dispatcher.Open(cid)
+ demux.printf("channel %d opened remotely", cid)
+ var c = demux.dispatcher.Open(cid)
- go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
+ go demux.newChannel(cancelCtx, fail, cid, c, output, demux.connect)
} else {
c <- p.Payload
}
M kgp_test.go +88 -49
@@ 1,11 1,14 @@
package kgp
import (
+ "log"
+ "os"
"bytes"
"context"
"io"
"io/ioutil"
"net"
+ "strings"
"testing"
"time"
)
@@ 89,19 92,19 @@ func (rec recorder) Close() error {
return nil
}
-func waitChannelClose(
- done func() <-chan struct{},
- relay *Relay,
-) {
-waitloop:
+func waitChannelCountEqual(
+ n int,
+ demux *Demux,
+ done func() <-chan struct{}, // just-in-case demux stops
+) bool {
for {
select {
case <-done():
- break waitloop
+ return false
default:
time.Sleep(time.Millisecond)
- if relay.dispatcher.Len() == 0 {
- break waitloop
+ if demux.ChannelCount() == n {
+ return true
}
}
}
@@ 110,12 113,12 @@ waitloop:
func equalBuffer(t *testing.T, result io.Reader, expected []byte) bool {
var r, err = ioutil.ReadAll(result)
if err != nil {
- t.Fatalf("Couldn't read clientOutput: %s", err)
+ t.Errorf("Couldn't read clientOutput: %s", err)
return false
}
if !bytes.Equal(r, expected) {
- t.Fatalf("%v != %v", r, expected)
+ t.Errorf("\n%v\n!=\n%v", r, expected)
return false
}
@@ 140,7 143,7 @@ func TestRelayRequest(t *testing.T) {
return client, nil
}
- var relay = NewRelay(server, connect)
+ var demux = NewDemux(server, connect)
var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
// We execute the client in the background while relay.Run() is running
go func() {
@@ 172,16 175,21 @@ func TestRelayRequest(t *testing.T) {
}
// Wait for all channels to close
- waitChannelClose(ctx.Done, relay)
+ waitChannelCountEqual(0, demux, ctx.Done)
+
// Cancel relay to check result
cancel()
}()
- var err = relay.Run(ctx, 1)
+ var err = demux.Run(ctx, 1)
+ // uncomment for more logging
+ // demux.Logger = log.New(os.Stderr, "", 0)
if err != context.Canceled {
t.Errorf("relay.Run errored: %s", err)
}
var expected = []byte("hello")
- equalBuffer(t, clientOutput, expected)
+ if !equalBuffer(t, clientOutput, expected) {
+ t.Fatalf("failed")
+ }
}
func writeResponse(payload string, connectionId, seq, ack uint32) (r bytes.Buffer) {
@@ 225,18 233,17 @@ func equalResponse(t *testing.T, result
}
func TestRelayRequestResponse(t *testing.T) {
- var inputReader, inputWriter = io.Pipe()
- var serverOutput = &bytes.Buffer{}
+ var requestReader, requestWriter = io.Pipe()
+ var responseReader, responseWriter = io.Pipe()
var server = recorder{
- Input: inputReader,
- Output: serverOutput,
+ Input: requestReader,
+ Output: responseWriter,
}
- // Client's response is world
- var clientInput = bytes.NewBufferString("world")
var clientOutput = &bytes.Buffer{}
var client = recorder{
- Input: clientInput,
+ // Client's response is world
+ Input: strings.NewReader("world"),
Output: clientOutput,
}
@@ 244,41 251,73 @@ func TestRelayRequestResponse(t *testing
return client, nil
}
- var relay = NewRelay(server, connect)
+ var demux = NewDemux(server, connect)
// Test has up to 1 sec to complete
var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
- // We execute the client in the background while relay.Run() is running
+ // Execute relay in the background
go func() {
- // open connection 1
- var p Packet
- p.header.ConnectionId = 1
- p.SetSeqAck(1, 0)
- p.WriteTo(inputWriter)
+ // uncomment for more logging
+ // demux.Logger = log.New(os.Stderr, "", 0)
+ var err = demux.Run(ctx, 1)
+
+ if err != context.Canceled {
+ t.Fatalf("relay.Run errored: %s", err)
+ }
+ }()
- // Write "hello"
- p.Payload = []byte("hello")
- p.SetSeqAck(2, 0)
- p.WriteTo(inputWriter)
+ // open connection 1
+ var p Packet
+ p.header.ConnectionId = 1
+ p.SetSeqAck(1, 0)
+ p.WriteTo(requestWriter)
+
+ // Wait for the channel to open
+ waitChannelCountEqual(1, demux, ctx.Done)
- // Close connection
- p.Payload = []byte{}
- p.header.Control = 1
- p.SetSeqAck(3, 0)
- p.WriteTo(inputWriter)
+ // Write "hello"
+ p.Payload = []byte("hello")
+ p.SetSeqAck(2, 0)
+ p.WriteTo(requestWriter)
- // Wait for the channel to close then check the result
- waitChannelClose(ctx.Done, relay)
-
- cancel()
- }()
- var err = relay.Run(ctx, 1)
- if err != context.Canceled {
- t.Errorf("relay.Run errored: %s", err)
+ // we know ack will 2 because don't send the close connection until we've
+ // read this.
+ var expected = []byte{
+ 0, 0, 0, 1, // forward packet connection id 1
+ 0, 0, 0, 1, // seq 1
+ 0, 0, 0, 1, // ack 1, because the client started writing before we the
+ // request with seq was still at 1
+ 0, 0, // control 0
+ 0, 5, // len 5
+ 119, 111, 114, 108, 100, // "hello"
+ 0, 0, 0, 1, // close packet connection id 1
+ 0, 0, 0, 2, // seq 2
+ 0, 0, 0, 2, // ack 2
+ 0, 1, // control 1
+ 0, 0, // len 0
}
- var expected = []byte("hello")
- equalBuffer(t, clientOutput, expected)
- expected = []byte("world")
- equalResponse(t, serverOutput, []byte("world"))
+ var buf = make([]byte, len(expected))
+ if _, err := io.ReadFull(responseReader, buf); err != nil {
+ t.Fatal("ReadFull failed %s", err)
+ }
+
+ if !bytes.Equal(buf, expected) {
+ t.Fatalf("%v != %v", buf, expected)
+ }
+
+ // Close connection
+ p.Payload = []byte{}
+ p.header.Control = 1
+ p.SetSeqAck(3, 0)
+ p.WriteTo(requestWriter)
+
+ // Wait for the channel to close then check the result
+ waitChannelCountEqual(0, demux, ctx.Done)
+
+ cancel() // shutdown relay
+
+ if !equalBuffer(t, clientOutput, []byte("hello")) {
+ t.Fatalf("hello failed")
+ }
}