# HG changeset patch # User Henry Precheur # Date 1489963043 25200 # Sun Mar 19 15:37:23 2017 -0700 # Node ID f4faa75b678e656437064aee29f39851c1bd50c7 # Parent 5c956ad44ab67a51981e95a2b1d704dd98da8a4a Clean-ups diff --git a/cmd/kgp_client/main.go b/cmd/kgp_client/main.go --- a/cmd/kgp_client/main.go +++ b/cmd/kgp_client/main.go @@ -7,8 +7,8 @@ "log" "net" "os" + "os/signal" "strings" - "context" "bitbucket.org/henry/kgp" @@ -51,21 +51,6 @@ return } -func handshake(conn net.Conn, meta io.Reader) (a kgp.Announcement, err error) { - var my = kgp.Announcement{ - Metadata: meta, - } - var n int - if n, err = my.Write(conn); err != nil { - fmt.Println("wrote", n) - return - } - - a, err = kgp.ReadAnnouncement(conn) - log.Print(a) - return -} - const usage = `usage: kgp_client upstream downstream Example: upstream KGP server is tunnel.example.com:443 and the downstream @@ -105,22 +90,28 @@ logger.Fatalln(err) } } - logger.Print("upstream", upstream, "downstream", downstream) upconn, err := net.Dial(upstream.Network, upstream.Address) if err != nil { logger.Fatalln("error connecting to ", upstream, ": ", err.Error()) } - logger.Print("connected") + logger.Print("connected to ", args[0]) - if _, err = handshake(upconn, meta); err != nil { + var lastSeenSeq uint32 + if lastSeenSeq, err = kgp.Handshake(upconn, meta); err != nil { logger.Fatalln(err) } + if lastSeenSeq != 0 { + logger.Print("resuming traffic from packet #", lastSeenSeq) + } - // var buf = make([]byte, 1000) - // var n int - // n, err = upconn.Read(buf) - // fmt.Println(n, err, buf) + var ctx, cancel = context.WithCancel(context.Background()) + go func() { + var sigchan = make(chan os.Signal, 1) + signal.Notify(sigchan, os.Interrupt) + <-sigchan // wait for signal to come in + cancel() + }() var relay = kgp.NewRelay( upconn, @@ -131,5 +122,5 @@ }, ) relay.Logger = logger - fmt.Println(relay.Run(context.Background())) + fmt.Println(relay.Run(ctx, lastSeenSeq+1)) } diff --git a/kgp.go b/kgp.go --- a/kgp.go +++ b/kgp.go @@ -4,20 +4,16 @@ package kgp import ( - "bufio" "bytes" + "context" "encoding/binary" + "encoding/json" "fmt" "io" "log" "sync/atomic" - - "context" ) -// That's from the spec, no idea why it's the limit -const MaxPayloadLength = 30 * 1024 - type UUID [16]byte type Announcement struct { @@ -94,137 +90,77 @@ return } -func initialHandShake( - conn io.ReadWriter, - metadata io.Reader, -) (err error) { - var a = Announcement{Metadata: metadata} - _, err = a.Write(conn) - if err != nil { +// Return the last seen packet from the last client connected to this tunnel. +// Use this as the sequence number + 1 with Relay.Run() +func Handshake(conn io.ReadWriter, meta io.Reader) (lastSeenSeq uint32, err error) { + var my = Announcement{ + Metadata: meta, + } + if _, err = my.Write(conn); err != nil { return } - var reader = bufio.NewReader(conn) - _, err = ReadAnnouncement(reader) + + a, err := ReadAnnouncement(conn) if err != nil { return } - return -} -const ( - openPacket = iota - forwardPacket - closePacket - keepalivePacket -) - -type Packet struct { - header struct { - ConnectionId uint32 - - Sequence uint32 - Ack uint32 - Control int16 + var doc struct { + // We have to use float64 otherwise it won't parse + Last_seen_seq float64 } - - buffer []byte - Payload []byte -} - -func NewPacket() *Packet { - var p = new(Packet) - p.buffer = make([]byte, MaxPayloadLength) - return p -} - -func (packet *Packet) Type() int { - var h = packet.header - switch { - case h.ConnectionId == 0: - return keepalivePacket - case h.Control == 0 && len(packet.Payload) == 0: - return openPacket - case h.Control == 1: - return closePacket - default: - return forwardPacket - } -} - -func (packet *Packet) Size() uint64 { - var x uint16 - var size = binary.Size(&packet.header) + binary.Size(x) - if packet.Payload != nil { - size += len(packet.Payload) - } - return uint64(size) -} - -func (packet *Packet) ConnectionId() uint32 { - return packet.header.ConnectionId -} - -func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) { - err = binary.Read(reader, binary.BigEndian, &packet.header) - if err != nil { + if err = json.NewDecoder(a.Metadata).Decode(&doc); err != nil { return } - n += int64(binary.Size(&packet.header)) - - var payloadlen uint16 - if err = binary.Read(reader, binary.BigEndian, &payloadlen); err != nil { - return - } - - packet.Payload = packet.buffer[:payloadlen] - if payloadlen > 0 { - var x int - x, err = io.ReadFull(reader, packet.Payload) - n += int64(x) - if err != nil { - return - } - } + lastSeenSeq = uint32(doc.Last_seen_seq) return } -func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) { - err = binary.Write(w, binary.BigEndian, &packet.header) - if err != nil { - return - } - n += int64(binary.Size(&packet.header)) +type ConnectFunc func(context.Context) (io.ReadWriteCloser, error) - var payloadlen = uint16(len(packet.Payload)) - if err = binary.Write(w, binary.BigEndian, &payloadlen); err != nil { - return - } - n += int64(binary.Size(&payloadlen)) +// +// Relay between server and the remote with the connect function. +// +// Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you +// can log what's happening by setting Logger: +// +// var r = NewRelay(...) +// r.RxSize = 1 +// r.TxSize = 100 +// r.Logger = log.New(os.Stderr, "kgp_relay", 0) +// +type Relay struct { + RxSize int + TxSize int - if payloadlen > 0 { - var o int - o, err = w.Write(packet.Payload) - if err != nil { - return - } + Logger *log.Logger + + server io.ReadWriter + connect ConnectFunc - n += int64(o) - } + dispatcher *dispatcher - return + // traffic stats + rxBytes, txBytes uint64 + rxPackets, txPackets uint64 } -func (packet *Packet) SetSeqAck(seq, ack uint32) { - packet.header.Sequence = seq - packet.header.Ack = ack +// Creates a new relay with RxSize & TxSize set to 4. +func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay { + return &Relay{ + RxSize: 4, + TxSize: 4, + server: server, + connect: connect, + dispatcher: newDispatcher(), + } } -func NewPacketClose(id uint32) *Packet { - var p = NewPacket() - p.header.ConnectionId = id - p.header.Control = 1 - return p +func (relay *Relay) printf(format string, v ...interface{}) { + if relay.Logger != nil { + relay.Logger.Printf(format, v...) + } } func (relay *Relay) newChannel( @@ -258,7 +194,6 @@ for { select { case <-ctx.Done(): - // We're done, close the connection connClose() return case buf, ok := <-input: @@ -286,7 +221,6 @@ // We read from conn and write it to output go func() { for { - relay.printf("reading from channel %d", id) var p = NewPacket() var n, err = conn.Read(p.buffer) relay.printf("read %d bytes from channel %d", n, id) @@ -298,7 +232,7 @@ } if err != nil { relay.printf( - "error reading on channel %d: %s", + "error reading from channel %d: %s", id, err.Error(), ) connClose() @@ -316,49 +250,6 @@ }() } -type ConnectFunc func(context.Context) (io.ReadWriteCloser, error) - -// -// Relay between server and the remote with the connect function. -// -// Once created you can tweak the Rx/Tx size with: RxSize & TxSize. Also you -// can log what's happening by setting Logger: -// -// var r = NewRelay(...) -// r.RxSize = 1 -// r.TxSize = 100 -// r.Logger = log.New(os.Stderr, "kgp_relay", 0) -// -type Relay struct { - RxSize int - TxSize int - - Logger *log.Logger - - server io.ReadWriter - connect ConnectFunc - - dispatcher *dispatcher - - // traffic stats - rxBytes, txBytes uint64 - rxPackets, txPackets uint64 -} - -func (relay *Relay) printf(format string, v ...interface{}) { - if relay.Logger != nil { - relay.Logger.Printf(format, v...) - } -} - -func NewRelay(server io.ReadWriter, connect ConnectFunc) *Relay { - return &Relay{ - server: server, - connect: connect, - dispatcher: newDispatcher(), - } -} - // // Read packets from server and put them in the returned channel // @@ -371,9 +262,7 @@ go func() { for { - relay.printf("readServer begin") var p = NewPacket() - relay.printf("readServer ready to read") var n, err = p.ReadFrom(relay.server) relay.printf("read %d bytes from upstream", n) if err == io.EOF { @@ -394,7 +283,6 @@ return case output <- p: // wait for packet to be sent - relay.printf("send packet to output") } } relay.printf("readServer done") @@ -407,6 +295,7 @@ * Forward data from input to conn */ func (relay *Relay) writeServer( + seq uint32, done func() <-chan struct{}, fail func(error), getAck func() uint32, @@ -419,17 +308,14 @@ // server. // // conn will be closed by the caller once we're done. - var seq uint32 = 1 - for { select { case <-done(): - relay.printf("writeServer done") return case p := <-input: var err error switch p.Type() { - case openPacket, forwardPacket, closePacket: + case forwardPacket, closePacket: // Forward the packet to the server var n int64 p.header.Sequence = seq @@ -441,10 +327,11 @@ seq = 1 } case keepalivePacket: - // Forward the packet to the server without seq/ack - p.header.Sequence = 0 - p.header.Ack = 0 - _, err = p.WriteTo(relay.server) + // Forward the packet to the server as is + p.header.Ack = seq + var n int64 + n, err = p.WriteTo(relay.server) + relay.printf("wrote %d bytes upstream for keepalive", n) } if err != nil { fail(err) @@ -466,7 +353,8 @@ // do that in a goroutine to not block the main loop. go func() { var reply = NewPacket() - reply.header.Ack = p.header.Sequence + reply.header.Sequence = p.header.Sequence + reply.Payload = []byte{'a'} select { case <-done(): @@ -480,7 +368,7 @@ } } -func (relay *Relay) Run(ctx context.Context) error { +func (relay *Relay) Run(ctx context.Context, seq uint32) error { // Keeps track of the last sequence number from the server var ack atomic.Value ack.Store(uint32(0)) @@ -502,11 +390,9 @@ // input reads packets from the server var input = relay.readServer(done, fail, &ack) + var getAck = func() uint32 { return ack.Load().(uint32) } // output writes packets to the server - var output = relay.writeServer( - done, fail, - func() uint32 { return ack.Load().(uint32) }, - ) + var output = relay.writeServer(seq, done, fail, getAck) // We each packets from the server and handle them depending on their type: // open connection: we need to open new connection @@ -532,26 +418,23 @@ switch p.Type() { case keepalivePacket: go keepAlive(done, p, output) - case openPacket: - var cid = p.header.ConnectionId - relay.printf("channel %d opened remotely", cid) - var c = relay.dispatcher.Open(cid) - // Verify if connection id exists for sanity check - if c == nil { - relay.printf("ERROR: channel %d already exists", cid) - break + case closePacket: + if bytes.Equal(p.Payload, []byte{'r'}) { + // half-close, we'll no longer receive data from the server + // side, but still need to forward what the client side + // sends us. Ignore it for now. + } else { + relay.printf("channel %d closed remotely", p.ConnectionId) + relay.dispatcher.Close(p.header.ConnectionId) } - - go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect) - case closePacket: - relay.printf("channel %d closed remotely", p.ConnectionId) - relay.dispatcher.Close(p.header.ConnectionId) case forwardPacket: var cid = p.header.ConnectionId var c = relay.dispatcher.Get(cid) if c == nil { - relay.printf("ERROR: channel %d doesn't exists", cid) - break + relay.printf("channel %d opened remotely", cid) + var c = relay.dispatcher.Open(cid) + + go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect) } else { c <- p.Payload } diff --git a/kgp_test.go b/kgp_test.go --- a/kgp_test.go +++ b/kgp_test.go @@ -176,7 +176,7 @@ // Cancel relay to check result cancel() }() - var err = relay.Run(ctx) + var err = relay.Run(ctx, 1) if err != context.Canceled { t.Errorf("relay.Run errored: %s", err) } @@ -272,7 +272,7 @@ cancel() }() - var err = relay.Run(ctx) + var err = relay.Run(ctx, 1) if err != context.Canceled { t.Errorf("relay.Run errored: %s", err) } diff --git a/packet.go b/packet.go new file mode 100644 --- /dev/null +++ b/packet.go @@ -0,0 +1,128 @@ +package kgp + +import ( + "bytes" + "encoding/binary" + "io" +) + +const ( + forwardPacket = iota + closePacket + halfClosePacket + keepalivePacket +) + +type Packet struct { + header struct { + ConnectionId uint32 + + Sequence uint32 + Ack uint32 + Control int16 + } + + buffer []byte + Payload []byte +} + +// That's from the spec, no idea why it's the limit +const MaxPayloadLength = 30 * 1024 + +func NewPacket() *Packet { + var p = new(Packet) + p.buffer = make([]byte, MaxPayloadLength) + return p +} + +func (packet *Packet) Type() int { + var h = packet.header + switch { + case h.ConnectionId == 0: + return keepalivePacket + case h.Control == 1: + if bytes.Equal(packet.Payload, []byte{'r'}) { + return halfClosePacket + } else { + return closePacket + } + default: + return forwardPacket + } +} + +func (packet *Packet) Size() uint64 { + var x uint16 + var size = binary.Size(&packet.header) + binary.Size(x) + if packet.Payload != nil { + size += len(packet.Payload) + } + return uint64(size) +} + +func (packet *Packet) ConnectionId() uint32 { + return packet.header.ConnectionId +} + +func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) { + err = binary.Read(reader, binary.BigEndian, &packet.header) + if err != nil { + return + } + n += int64(binary.Size(&packet.header)) + + var payloadlen uint16 + if err = binary.Read(reader, binary.BigEndian, &payloadlen); err != nil { + return + } + + packet.Payload = packet.buffer[:payloadlen] + if payloadlen > 0 { + var x int + x, err = io.ReadFull(reader, packet.Payload) + n += int64(x) + if err != nil { + return + } + } + + return +} + +func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) { + err = binary.Write(w, binary.BigEndian, &packet.header) + if err != nil { + return + } + n += int64(binary.Size(&packet.header)) + + var payloadlen = uint16(len(packet.Payload)) + if err = binary.Write(w, binary.BigEndian, &payloadlen); err != nil { + return + } + n += int64(binary.Size(&payloadlen)) + + if payloadlen > 0 { + var o int + o, err = w.Write(packet.Payload) + if err != nil { + return + } + + n += int64(o) + } + + return +} + +func (packet *Packet) SetSeqAck(seq, ack uint32) { + packet.header.Sequence = seq + packet.header.Ack = ack +} + +func NewPacketClose(id uint32) *Packet { + var p = NewPacket() + p.header.ConnectionId = id + p.header.Control = 1 + return p +}