Refactor dispatcher
6 files changed, 86 insertions(+), 84 deletions(-)

M cmd/kgpdemux/main.go
M dispatch.go
M kgp.go
M kgp_test.go
M packet.go
M packet_test.go
M cmd/kgpdemux/main.go +1 -0
@@ 127,6 127,7 @@ func main() {
 		},
 	)
 	demux.Keepalive = *keepalive
+	fmt.Println(demux.Keepalive)
 	demux.Logger = logger
 	fmt.Println(demux.Run(ctx, lastSeenSeq+1))
 }

          
M dispatch.go +29 -52
@@ 1,69 1,46 @@ 
 package kgp
 
 import (
-	"sync"
 	"sync/atomic"
 )
 
-type channels map[uint32]chan<- []byte
 
-//
-// A read-often & write-rarely optimized object to dispatch the packets to
-// their channel.
-//
-type dispatcher struct {
-	channels atomic.Value
-	mutex    sync.Mutex
-}
-
-func newDispatcher() *dispatcher {
-	var d dispatcher
-	d.channels.Store(make(channels))
-	return &d
-}
-
-func (d *dispatcher) Get(id uint32) chan<- []byte {
-	var channels = d.channels.Load().(channels)
-	return channels[id]
+type channel struct {
+	input chan []byte
+	// Let the parent thread know the channel is closed
+	closed atomic.Value
 }
 
-func (d *dispatcher) Len() int {
-	var channels = d.channels.Load().(channels)
-	return len(channels)
-}
 
-func (d *dispatcher) Open(id uint32) chan []byte {
-	d.mutex.Lock() // synchronize with other potential writers
-	defer d.mutex.Unlock()
-	var oldchan = d.channels.Load().(channels)
-	// FIXME channel already exists
-	if _, ok := oldchan[id]; ok {
-		return nil
-	}
-	var newchan = make(channels)
-
-	for k, v := range oldchan {
-		newchan[k] = v // copy all data from the current object to the new one
-	}
-	var c = make(chan []byte)
-	newchan[id] = c
-	d.channels.Store(newchan) // atomically replace the current object with the new one
+func newChannel() *channel {
+	var c = new(channel)
+	c.input = make(chan []byte)
+	c.closed.Store(false)
 	return c
 }
 
-func (d *dispatcher) Close(id uint32) {
-	d.mutex.Lock() // synchronize with other potential writers
-	defer d.mutex.Unlock()
-	var oldchan = d.channels.Load().(channels)
-	var newchan = make(channels)
+// Returns true if the channel is still open
+func (c *channel) forward(payload []byte) bool {
+	var closed = c.closed.Load().(bool)
+	if !closed {
+		c.input <- payload
+	}
+	return !closed
+}
 
-	for k, v := range oldchan {
-		if k != id {
-			newchan[k] = v // copy all data from the current object to the new one
-		} else {
-			// close channel so we don't leak it
-			close(v)
+func (c *channel) Close() {
+	// message back the main thread that the channel is closed
+	c.closed.Store(true)
+	// remove anything left in the channel if needed
+	for {
+		select {
+		case _ = <-c.input:
+			// discard
+		default:
+			break
 		}
 	}
-	d.channels.Store(newchan) // atomically replace the current object with the new one
+
+	// close the channel
+	close(c.input)
 }

          
M kgp.go +47 -27
@@ 4,12 4,12 @@ 
 package kgp
 
 import (
-	"time"
 	"bytes"
 	"context"
 	"io"
 	"log"
 	"sync/atomic"
+	"time"
 )
 
 type ConnectFunc func(context.Context) (io.ReadWriteCloser, error)

          
@@ 38,7 38,7 @@ type Demux struct {
 	server  io.ReadWriter
 	connect ConnectFunc
 
-	dispatcher *dispatcher
+	channels	map[uint32]channel
 
 	// traffic stats
 	rxBytes, txBytes     uint64

          
@@ 52,7 52,7 @@ func NewDemux(server io.ReadWriter, conn
 		TxSize:     4,
 		server:     server,
 		connect:    connect,
-		dispatcher: newDispatcher(),
+		channels:	make(map[uint32]channel),
 	}
 }
 

          
@@ 63,23 63,27 @@ func (demux *Demux) printf(format string
 }
 
 func (demux *Demux) ChannelCount() int {
-	return demux.dispatcher.Len()
+	return len(demux.channels)
 }
 
 func (demux *Demux) newChannel(
 	ctx context.Context,
 	fail func(error),
 	id uint32,
-	input <-chan []byte,
+	inputchan *channel,
 	output chan<- *Packet,
-	connect ConnectFunc,
 ) {
+
 	// Try to connect to the local host
 	var conn, err = demux.connect(ctx)
 	if err != nil {
-		demux.printf("can't connect to downstream, closing channel %d", id)
+		demux.printf(
+			"can't connect to downstream, closing channel %d: %s", id, err,
+		)
 		// we can't connect to downstream, close the connection.
 		output <- NewPacketClose(id)
+		// close our newly created channel since we won't be able to forward it
+		inputchan.Close()
 		return
 	}
 

          
@@ 101,7 105,7 @@ func (demux *Demux) newChannel(
 				demux.printf("channel %d writer done", id)
 				connClose()
 				return
-			case buf := <-input:
+			case buf := <-inputchan.input:
 				if buf == nil {
 					// channel's closed, the server side has closed, therefor
 					// we close the connection on upstream.

          
@@ 191,7 195,6 @@ func (demux *Demux) readServer(
 				// wait for packet to be sent
 			}
 		}
-		demux.printf("readServer done")
 	}()
 
 	return output

          
@@ 232,7 235,7 @@ func (demux *Demux) writeServer(
 					if seq == 0 {
 						seq = 1
 					}
-				case keepalivePacket:
+				case keepalivePacket, keepaliveAckPacket:
 					// Leave the sequence packet alone, set ack
 					p.header.Ack = seq
 					_, err = p.WriteTo(demux.server)

          
@@ 273,12 276,14 @@ func keepaliveReply(
 	}
 }
 
+// Sends and acknownledge keepalives
 func (demux *Demux) keepalive(
 	done func() <-chan struct{},
 	output chan<- *Packet,
 ) chan<- *Packet {
 	var input = make(chan *Packet)
-	// if keepalive == 0 we should not get any keepalive ack back, discard them if needed
+	// if keepalive == 0 we should not get any keepalive ack back
+	// discard them if needed just-in-case
 	if demux.Keepalive == 0 {
 		go func() {
 			for {

          
@@ 301,15 306,16 @@ func (demux *Demux) keepalive(
 			case <-done():
 				return
 			case <-ticker.C:
-				{
-					// send keep-alive packet
-					// and wait for response
-					var k = NewPacket()
-					k.header.Sequence = seq
-					// send keepalive packet
-					output <- k
-				}
+				// send keep-alive packet
+				// and wait for response
+				var k = NewPacket()
+				k.header.Sequence = seq
+				k.Payload = []byte{'k'}
+				output <- k
+				demux.printf("sent keepalive %d", seq)
+				var t = time.Now()
 
+				demux.printf("wait for reply")
 				// now wait for the reply
 				var p = <-input
 				if p.header.Sequence != seq {

          
@@ 318,6 324,7 @@ func (demux *Demux) keepalive(
 						p.header.Sequence, p.header.Ack,
 					)
 				}
+				demux.printf("got keepalive ack %d: %s", seq, time.Since(t))
 				seq += 1
 				if seq == 0 {
 					seq = 1

          
@@ 359,6 366,11 @@ func (demux *Demux) Run(ctx context.Cont
 	var getAck = func() uint32 { return ack.Load().(uint32) }
 	// output writes packets to the server
 	var output = demux.writeServer(seq, done, fail, getAck)
+	// keepalive is were ack keepalives packets are written
+	var keepalive = demux.keepalive(done, output)
+
+	// connection id -> thread-safe channel
+	var channels = make(map[uint32]*channel)
 
 	// We each packets from the server and handle them depending on their type:
 	// 	   open connection: we need to open new connection

          
@@ 384,6 396,8 @@ func (demux *Demux) Run(ctx context.Cont
 			switch p.Type() {
 			case keepalivePacket:
 				go keepaliveReply(done, p, output)
+			case keepaliveAckPacket:
+				go func() { keepalive <- p }()
 			case closePacket:
 				if bytes.Equal(p.Payload, []byte{'r'}) {
 					// half-close, we'll no longer receive data from the server

          
@@ 391,19 405,25 @@ func (demux *Demux) Run(ctx context.Cont
 					// sends us. Ignore it for now.
 					demux.printf("half-close!")
 				} else {
-					demux.printf("channel %d closed remotely", p.ConnectionId)
-					demux.dispatcher.Close(p.header.ConnectionId)
+					var c = channels[p.header.ConnectionId]
+					if c != nil {
+						close(c.input)
+						delete(channels, p.header.ConnectionId)
+					}
 				}
+			case openPacket:
+				var cid = p.header.ConnectionId
+				demux.printf("channel %d opened remotely", cid)
+				var c = newChannel()
+				channels[cid] = c
+				go demux.newChannel(cancelCtx, fail, cid, c, output)
 			case forwardPacket:
 				var cid = p.header.ConnectionId
-				var c = demux.dispatcher.Get(cid)
+				var c = channels[cid]
 				if c == nil {
-					demux.printf("channel %d opened remotely", cid)
-					var c = demux.dispatcher.Open(cid)
-
-					go demux.newChannel(cancelCtx, fail, cid, c, output, demux.connect)
+					demux.printf("channel %d already closed", cid)
 				} else {
-					c <- p.Payload
+					c.forward(p.Payload)
 				}
 			}
 		}

          
M kgp_test.go +1 -1
@@ 222,7 222,7 @@ func TestRelayRequestResponse(t *testing
 		0, 0, 0, 1, // forward packet connection id 1
 		0, 0, 0, 1, // seq 1
 		0, 0, 0, 1, // ack 1, because the client started writing before we the
-		            // request with seq was still at 1
+		// request with seq was still at 1
 		0, 0, // control 0
 		0, 5, // len 5
 		119, 111, 114, 108, 100, // "hello"

          
M packet.go +6 -1
@@ 8,6 8,7 @@ import (
 
 const (
 	forwardPacket = iota
+	openPacket
 	closePacket
 	halfClosePacket
 	keepalivePacket

          
@@ 52,7 53,11 @@ func (packet *Packet) Type() int {
 			return closePacket
 		}
 	default:
-		return forwardPacket
+		if len(packet.Payload) == 0 {
+			return openPacket
+		} else {
+			return forwardPacket
+		}
 	}
 }
 

          
M packet_test.go +2 -3
@@ 44,12 44,11 @@ func TestPacket(t *testing.T) {
 	if int(n) != len(expected) {
 		t.Fatalf("%d != %d", n, len(expected))
 	}
-	if (
-		other.header.ConnectionId != 1234 ||
+	if other.header.ConnectionId != 1234 ||
 		other.header.Sequence != 1234 ||
 		other.header.Ack != 1234 ||
 		other.header.Control != 1234 ||
-		!bytes.Equal(other.Payload, []byte("1234"))) {
+		!bytes.Equal(other.Payload, []byte("1234")) {
 		t.Fatalf("bad packet %v", other)
 	}
 }