@@ 4,12 4,12 @@
package kgp
import (
- "time"
"bytes"
"context"
"io"
"log"
"sync/atomic"
+ "time"
)
type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)
@@ 38,7 38,7 @@ type Demux struct {
server io.ReadWriter
connect ConnectFunc
- dispatcher *dispatcher
+ channels map[uint32]channel
// traffic stats
rxBytes, txBytes uint64
@@ 52,7 52,7 @@ func NewDemux(server io.ReadWriter, conn
TxSize: 4,
server: server,
connect: connect,
- dispatcher: newDispatcher(),
+ channels: make(map[uint32]channel),
}
}
@@ 63,23 63,27 @@ func (demux *Demux) printf(format string
}
func (demux *Demux) ChannelCount() int {
- return demux.dispatcher.Len()
+ return len(demux.channels)
}
func (demux *Demux) newChannel(
ctx context.Context,
fail func(error),
id uint32,
- input <-chan []byte,
+ inputchan *channel,
output chan<- *Packet,
- connect ConnectFunc,
) {
+
// Try to connect to the local host
var conn, err = demux.connect(ctx)
if err != nil {
- demux.printf("can't connect to downstream, closing channel %d", id)
+ demux.printf(
+ "can't connect to downstream, closing channel %d: %s", id, err,
+ )
// we can't connect to downstream, close the connection.
output <- NewPacketClose(id)
+ // close our newly created channel since we won't be able to forward it
+ inputchan.Close()
return
}
@@ 101,7 105,7 @@ func (demux *Demux) newChannel(
demux.printf("channel %d writer done", id)
connClose()
return
- case buf := <-input:
+ case buf := <-inputchan.input:
if buf == nil {
// channel's closed, the server side has closed, therefor
// we close the connection on upstream.
@@ 191,7 195,6 @@ func (demux *Demux) readServer(
// wait for packet to be sent
}
}
- demux.printf("readServer done")
}()
return output
@@ 232,7 235,7 @@ func (demux *Demux) writeServer(
if seq == 0 {
seq = 1
}
- case keepalivePacket:
+ case keepalivePacket, keepaliveAckPacket:
// Leave the sequence packet alone, set ack
p.header.Ack = seq
_, err = p.WriteTo(demux.server)
@@ 273,12 276,14 @@ func keepaliveReply(
}
}
+// Sends and acknownledge keepalives
func (demux *Demux) keepalive(
done func() <-chan struct{},
output chan<- *Packet,
) chan<- *Packet {
var input = make(chan *Packet)
- // if keepalive == 0 we should not get any keepalive ack back, discard them if needed
+ // if keepalive == 0 we should not get any keepalive ack back
+ // discard them if needed just-in-case
if demux.Keepalive == 0 {
go func() {
for {
@@ 301,15 306,16 @@ func (demux *Demux) keepalive(
case <-done():
return
case <-ticker.C:
- {
- // send keep-alive packet
- // and wait for response
- var k = NewPacket()
- k.header.Sequence = seq
- // send keepalive packet
- output <- k
- }
+ // send keep-alive packet
+ // and wait for response
+ var k = NewPacket()
+ k.header.Sequence = seq
+ k.Payload = []byte{'k'}
+ output <- k
+ demux.printf("sent keepalive %d", seq)
+ var t = time.Now()
+ demux.printf("wait for reply")
// now wait for the reply
var p = <-input
if p.header.Sequence != seq {
@@ 318,6 324,7 @@ func (demux *Demux) keepalive(
p.header.Sequence, p.header.Ack,
)
}
+ demux.printf("got keepalive ack %d: %s", seq, time.Since(t))
seq += 1
if seq == 0 {
seq = 1
@@ 359,6 366,11 @@ func (demux *Demux) Run(ctx context.Cont
var getAck = func() uint32 { return ack.Load().(uint32) }
// output writes packets to the server
var output = demux.writeServer(seq, done, fail, getAck)
+ // keepalive is were ack keepalives packets are written
+ var keepalive = demux.keepalive(done, output)
+
+ // connection id -> thread-safe channel
+ var channels = make(map[uint32]*channel)
// We each packets from the server and handle them depending on their type:
// open connection: we need to open new connection
@@ 384,6 396,8 @@ func (demux *Demux) Run(ctx context.Cont
switch p.Type() {
case keepalivePacket:
go keepaliveReply(done, p, output)
+ case keepaliveAckPacket:
+ go func() { keepalive <- p }()
case closePacket:
if bytes.Equal(p.Payload, []byte{'r'}) {
// half-close, we'll no longer receive data from the server
@@ 391,19 405,25 @@ func (demux *Demux) Run(ctx context.Cont
// sends us. Ignore it for now.
demux.printf("half-close!")
} else {
- demux.printf("channel %d closed remotely", p.ConnectionId)
- demux.dispatcher.Close(p.header.ConnectionId)
+ var c = channels[p.header.ConnectionId]
+ if c != nil {
+ close(c.input)
+ delete(channels, p.header.ConnectionId)
+ }
}
+ case openPacket:
+ var cid = p.header.ConnectionId
+ demux.printf("channel %d opened remotely", cid)
+ var c = newChannel()
+ channels[cid] = c
+ go demux.newChannel(cancelCtx, fail, cid, c, output)
case forwardPacket:
var cid = p.header.ConnectionId
- var c = demux.dispatcher.Get(cid)
+ var c = channels[cid]
if c == nil {
- demux.printf("channel %d opened remotely", cid)
- var c = demux.dispatcher.Open(cid)
-
- go demux.newChannel(cancelCtx, fail, cid, c, output, demux.connect)
+ demux.printf("channel %d already closed", cid)
} else {
- c <- p.Payload
+ c.forward(p.Payload)
}
}
}
@@ 222,7 222,7 @@ func TestRelayRequestResponse(t *testing
0, 0, 0, 1, // forward packet connection id 1
0, 0, 0, 1, // seq 1
0, 0, 0, 1, // ack 1, because the client started writing before we the
- // request with seq was still at 1
+ // request with seq was still at 1
0, 0, // control 0
0, 5, // len 5
119, 111, 114, 108, 100, // "hello"