# HG changeset patch # User Henry Precheur # Date 1490162424 25200 # Tue Mar 21 23:00:24 2017 -0700 # Node ID d7957e7a163ab22650dd5c56db846c5ae7b563fa # Parent 5e6c07f94dbcb7a610bf367ee7ebebe3ee7ffaf2 Refactor dispatcher diff --git a/cmd/kgpdemux/main.go b/cmd/kgpdemux/main.go --- a/cmd/kgpdemux/main.go +++ b/cmd/kgpdemux/main.go @@ -127,6 +127,7 @@ }, ) demux.Keepalive = *keepalive + fmt.Println(demux.Keepalive) demux.Logger = logger fmt.Println(demux.Run(ctx, lastSeenSeq+1)) } diff --git a/dispatch.go b/dispatch.go --- a/dispatch.go +++ b/dispatch.go @@ -1,69 +1,46 @@ package kgp import ( - "sync" "sync/atomic" ) -type channels map[uint32]chan<- []byte -// -// A read-often & write-rarely optimized object to dispatch the packets to -// their channel. -// -type dispatcher struct { - channels atomic.Value - mutex sync.Mutex -} - -func newDispatcher() *dispatcher { - var d dispatcher - d.channels.Store(make(channels)) - return &d -} - -func (d *dispatcher) Get(id uint32) chan<- []byte { - var channels = d.channels.Load().(channels) - return channels[id] +type channel struct { + input chan []byte + // Let the parent thread know the channel is closed + closed atomic.Value } -func (d *dispatcher) Len() int { - var channels = d.channels.Load().(channels) - return len(channels) -} -func (d *dispatcher) Open(id uint32) chan []byte { - d.mutex.Lock() // synchronize with other potential writers - defer d.mutex.Unlock() - var oldchan = d.channels.Load().(channels) - // FIXME channel already exists - if _, ok := oldchan[id]; ok { - return nil - } - var newchan = make(channels) - - for k, v := range oldchan { - newchan[k] = v // copy all data from the current object to the new one - } - var c = make(chan []byte) - newchan[id] = c - d.channels.Store(newchan) // atomically replace the current object with the new one +func newChannel() *channel { + var c = new(channel) + c.input = make(chan []byte) + c.closed.Store(false) return c } -func (d *dispatcher) Close(id uint32) { - d.mutex.Lock() // synchronize with other potential writers - defer d.mutex.Unlock() - var oldchan = d.channels.Load().(channels) - var newchan = make(channels) +// Returns true if the channel is still open +func (c *channel) forward(payload []byte) bool { + var closed = c.closed.Load().(bool) + if !closed { + c.input <- payload + } + return !closed +} - for k, v := range oldchan { - if k != id { - newchan[k] = v // copy all data from the current object to the new one - } else { - // close channel so we don't leak it - close(v) +func (c *channel) Close() { + // message back the main thread that the channel is closed + c.closed.Store(true) + // remove anything left in the channel if needed + for { + select { + case _ = <-c.input: + // discard + default: + break } } - d.channels.Store(newchan) // atomically replace the current object with the new one + + // close the channel + close(c.input) } diff --git a/kgp.go b/kgp.go --- a/kgp.go +++ b/kgp.go @@ -4,12 +4,12 @@ package kgp import ( - "time" "bytes" "context" "io" "log" "sync/atomic" + "time" ) type ConnectFunc func(context.Context) (io.ReadWriteCloser, error) @@ -38,7 +38,7 @@ server io.ReadWriter connect ConnectFunc - dispatcher *dispatcher + channels map[uint32]channel // traffic stats rxBytes, txBytes uint64 @@ -52,7 +52,7 @@ TxSize: 4, server: server, connect: connect, - dispatcher: newDispatcher(), + channels: make(map[uint32]channel), } } @@ -63,23 +63,27 @@ } func (demux *Demux) ChannelCount() int { - return demux.dispatcher.Len() + return len(demux.channels) } func (demux *Demux) newChannel( ctx context.Context, fail func(error), id uint32, - input <-chan []byte, + inputchan *channel, output chan<- *Packet, - connect ConnectFunc, ) { + // Try to connect to the local host var conn, err = demux.connect(ctx) if err != nil { - demux.printf("can't connect to downstream, closing channel %d", id) + demux.printf( + "can't connect to downstream, closing channel %d: %s", id, err, + ) // we can't connect to downstream, close the connection. output <- NewPacketClose(id) + // close our newly created channel since we won't be able to forward it + inputchan.Close() return } @@ -101,7 +105,7 @@ demux.printf("channel %d writer done", id) connClose() return - case buf := <-input: + case buf := <-inputchan.input: if buf == nil { // channel's closed, the server side has closed, therefor // we close the connection on upstream. @@ -191,7 +195,6 @@ // wait for packet to be sent } } - demux.printf("readServer done") }() return output @@ -232,7 +235,7 @@ if seq == 0 { seq = 1 } - case keepalivePacket: + case keepalivePacket, keepaliveAckPacket: // Leave the sequence packet alone, set ack p.header.Ack = seq _, err = p.WriteTo(demux.server) @@ -273,12 +276,14 @@ } } +// Sends and acknownledge keepalives func (demux *Demux) keepalive( done func() <-chan struct{}, output chan<- *Packet, ) chan<- *Packet { var input = make(chan *Packet) - // if keepalive == 0 we should not get any keepalive ack back, discard them if needed + // if keepalive == 0 we should not get any keepalive ack back + // discard them if needed just-in-case if demux.Keepalive == 0 { go func() { for { @@ -301,15 +306,16 @@ case <-done(): return case <-ticker.C: - { - // send keep-alive packet - // and wait for response - var k = NewPacket() - k.header.Sequence = seq - // send keepalive packet - output <- k - } + // send keep-alive packet + // and wait for response + var k = NewPacket() + k.header.Sequence = seq + k.Payload = []byte{'k'} + output <- k + demux.printf("sent keepalive %d", seq) + var t = time.Now() + demux.printf("wait for reply") // now wait for the reply var p = <-input if p.header.Sequence != seq { @@ -318,6 +324,7 @@ p.header.Sequence, p.header.Ack, ) } + demux.printf("got keepalive ack %d: %s", seq, time.Since(t)) seq += 1 if seq == 0 { seq = 1 @@ -359,6 +366,11 @@ var getAck = func() uint32 { return ack.Load().(uint32) } // output writes packets to the server var output = demux.writeServer(seq, done, fail, getAck) + // keepalive is were ack keepalives packets are written + var keepalive = demux.keepalive(done, output) + + // connection id -> thread-safe channel + var channels = make(map[uint32]*channel) // We each packets from the server and handle them depending on their type: // open connection: we need to open new connection @@ -384,6 +396,8 @@ switch p.Type() { case keepalivePacket: go keepaliveReply(done, p, output) + case keepaliveAckPacket: + go func() { keepalive <- p }() case closePacket: if bytes.Equal(p.Payload, []byte{'r'}) { // half-close, we'll no longer receive data from the server @@ -391,19 +405,25 @@ // sends us. Ignore it for now. demux.printf("half-close!") } else { - demux.printf("channel %d closed remotely", p.ConnectionId) - demux.dispatcher.Close(p.header.ConnectionId) + var c = channels[p.header.ConnectionId] + if c != nil { + close(c.input) + delete(channels, p.header.ConnectionId) + } } + case openPacket: + var cid = p.header.ConnectionId + demux.printf("channel %d opened remotely", cid) + var c = newChannel() + channels[cid] = c + go demux.newChannel(cancelCtx, fail, cid, c, output) case forwardPacket: var cid = p.header.ConnectionId - var c = demux.dispatcher.Get(cid) + var c = channels[cid] if c == nil { - demux.printf("channel %d opened remotely", cid) - var c = demux.dispatcher.Open(cid) - - go demux.newChannel(cancelCtx, fail, cid, c, output, demux.connect) + demux.printf("channel %d already closed", cid) } else { - c <- p.Payload + c.forward(p.Payload) } } } diff --git a/kgp_test.go b/kgp_test.go --- a/kgp_test.go +++ b/kgp_test.go @@ -222,7 +222,7 @@ 0, 0, 0, 1, // forward packet connection id 1 0, 0, 0, 1, // seq 1 0, 0, 0, 1, // ack 1, because the client started writing before we the - // request with seq was still at 1 + // request with seq was still at 1 0, 0, // control 0 0, 5, // len 5 119, 111, 114, 108, 100, // "hello" diff --git a/packet.go b/packet.go --- a/packet.go +++ b/packet.go @@ -8,6 +8,7 @@ const ( forwardPacket = iota + openPacket closePacket halfClosePacket keepalivePacket @@ -52,7 +53,11 @@ return closePacket } default: - return forwardPacket + if len(packet.Payload) == 0 { + return openPacket + } else { + return forwardPacket + } } } diff --git a/packet_test.go b/packet_test.go --- a/packet_test.go +++ b/packet_test.go @@ -44,12 +44,11 @@ if int(n) != len(expected) { t.Fatalf("%d != %d", n, len(expected)) } - if ( - other.header.ConnectionId != 1234 || + if other.header.ConnectionId != 1234 || other.header.Sequence != 1234 || other.header.Ack != 1234 || other.header.Control != 1234 || - !bytes.Equal(other.Payload, []byte("1234"))) { + !bytes.Equal(other.Payload, []byte("1234")) { t.Fatalf("bad packet %v", other) } }