@@ 78,18 78,19 @@ func main() {
flag.Parse()
+ var logger = log.New(os.Stderr, "", log.LstdFlags)
var args = flag.Args()
if len(args) != 2 {
- log.Fatalln(usage)
+ logger.Fatalln(usage)
}
var upstream, err = ParseAddress(args[0])
if err != nil {
- log.Fatalln(err)
+ logger.Fatalln(err)
}
downstream, err := ParseAddress(args[1])
if err != nil {
- log.Fatalln(err)
+ logger.Fatalln(err)
}
// Load metadata
@@ 101,19 102,19 @@ func main() {
var err error
meta, err = os.Open(*metafn)
if err != nil {
- log.Fatalln(err)
+ logger.Fatalln(err)
}
}
- log.Print("upstream", upstream, "downstream", downstream)
+ logger.Print("upstream", upstream, "downstream", downstream)
upconn, err := net.Dial(upstream.Network, upstream.Address)
if err != nil {
- log.Fatalln("error connecting to ", upstream, ": ", err.Error())
+ logger.Fatalln("error connecting to ", upstream, ": ", err.Error())
}
- log.Print("connected")
+ logger.Print("connected")
if _, err = handshake(upconn, meta); err != nil {
- log.Fatalln(err)
+ logger.Fatalln(err)
}
// var buf = make([]byte, 1000)
@@ 129,5 130,6 @@ func main() {
)
},
)
+ relay.Logger = logger
fmt.Println(relay.Run(context.Background()))
}
@@ 119,46 119,106 @@ const (
)
type Packet struct {
- Type int
- ConnectionId uint32
+ header struct {
+ ConnectionId uint32
+
+ Sequence uint32
+ Ack uint32
+ Control int16
- Payload bytes.Buffer
+ PayloadLen uint16
+ }
+
+ buffer []byte
+ Payload []byte
+}
+
+func NewPacket() *Packet {
+ var p = new(Packet)
+ p.buffer = make([]byte, MaxPayloadLength)
+ return p
}
-type packetHeader struct {
- ConnectionId uint32
-
- Sequence uint32
- Ack uint32
- Control int16
-
- PayloadLen uint16
+func (packet *Packet) Type() int {
+ var h = packet.header
+ switch {
+ case h.ConnectionId == 0:
+ return keepalivePacket
+ case h.Control == 0 && h.PayloadLen == 0:
+ return openPacket
+ case h.Control == 1:
+ return closePacket
+ default:
+ return forwardPacket
+ }
}
-func (packet *Packet) Write(w io.Writer, seq, ack uint32) (n int, err error) {
- var control int16
- if packet.Type == closePacket {
- control = 1
+func (packet *Packet) Size() int {
+ var size = binary.Size(&packet.header)
+ if packet.Payload != nil {
+ size += len(packet.Payload)
}
- var header = packetHeader{
- ConnectionId: packet.ConnectionId,
- Sequence: seq,
- Ack: ack,
- Control: control,
- PayloadLen: uint16(packet.Payload.Len()),
- }
- err = binary.Write(w, binary.BigEndian, &header)
+ return size
+}
+
+func (packet *Packet) ConnectionId() uint32 {
+ return packet.header.ConnectionId
+}
+
+func (packet *Packet) ReadFrom(reader io.Reader) (n int64, err error) {
+ err = binary.Read(reader, binary.BigEndian, &packet.header)
if err != nil {
return
}
- n += binary.Size(&header)
- var o int64
- o, err = io.CopyN(w, &packet.Payload, int64(packet.Payload.Len()))
- n += int(o)
+ n += int64(binary.Size(&packet.header))
+
+ var l = int64(packet.header.PayloadLen)
+ if l > 0 {
+ packet.Payload = packet.buffer[:l]
+ var x int
+ x, err = io.ReadFull(reader, packet.Payload)
+ if err != nil {
+ return
+ }
+
+ n += int64(x)
+ } else {
+ // nil == no payload
+ packet.Payload = nil
+ }
return
}
+func (packet *Packet) WriteTo(w io.Writer) (n int64, err error) {
+ err = binary.Write(w, binary.BigEndian, &packet.header)
+ if err != nil {
+ return
+ }
+ n += int64(binary.Size(&packet.header))
+ if packet.Payload != nil {
+ var o int
+ o, err = w.Write(packet.Payload)
+ if err != nil {
+ return
+ }
+
+ n += int64(o)
+ }
+
+ return
+}
+
+func (packet *Packet) NewPacketClose(id uint32) *Packet {
+ var p = NewPacket()
+ p.header.ConnectionId = id
+ p.header.Control = 1
+}
+
+func (packet *Packet) NewPacketForward(id uint32, reader io.Reader) *Packet {
+
+}
+
func (relay *Relay) newChannel(
ctx context.Context,
fail func(error),
@@ 170,13 230,15 @@ func (relay *Relay) newChannel(
// Try to connect to the local host
var conn, err = relay.connect(ctx)
if err != nil {
- // we can't connect to upstream, cancel everything
- fail(err)
+ relay.printf("can't connect to downstream, closing channel %d", id)
+ // we can't connect to downstream, close the connection.
+ var p Packet
+ p.Close(id)
+ output <- &p
return
- } else {
- relay.printf("Channel %d connected", id)
}
+ relay.printf("Channel %d connected", id)
var connClose = func() {
if err := conn.Close(); err != nil {
relay.printf(
@@ 200,17 262,15 @@ func (relay *Relay) newChannel(
connClose()
return
}
- var _, err = io.Copy(conn, buf)
+ var n, err = io.Copy(conn, buf)
+ relay.printf("wrote %d bytes to channel %d", n, id)
if err == io.EOF {
+ relay.printf("channel %d closed", id)
+ connClose()
// nothing else to read, we're done
return
} else if err != nil {
- relay.printf(
- "ERROR: while writing on channel %d: %s", id, err,
- )
- // FIXME seems a bit harsh to cancel everything just
- // because 1 upstream connection terminated badly...
- fail(err)
+ connClose()
return
}
}
@@ 220,35 280,34 @@ func (relay *Relay) newChannel(
// We read from conn and write it to output
go func() {
for {
+ relay.printf("reading from channel %d", id)
+ var buf = make([]byte, MaxPayloadLength)
+ var n, err = conn.Read(buf)
+ relay.printf("read %d bytes from channel %d", n, id)
+ if err == io.EOF {
+ relay.printf("channel %d closed", id)
+ connClose()
+ var p Packet
+ p.Close(id)
+ output <- &p
+ return
+ }
+ if err != nil {
+ connClose()
+ return
+ }
+ // fmt.Printf("BUF: %s\n", string(buf))
var packet = Packet{
Type: forwardPacket,
ConnectionId: id,
- }
- var _, err = io.CopyN(&packet.Payload, conn, MaxPayloadLength)
- if err == io.EOF {
- // connection closed we send the left-over traffic if needed,
- // then send a close KGP packet to the server and return
- if packet.Payload.Len() > 0 {
- output <- &packet
- }
- output <- &Packet{
- Type: closePacket,
- ConnectionId: id,
- }
- return
- }
- if err != nil {
- relay.printf(
- "WARNING: error while reading from channel %d: %s",
- id, err,
- )
- fail(err)
- return
+ Payload: *bytes.NewBuffer(buf[:n]),
}
select {
case <-ctx.Done():
- break
+ connClose()
+ return
case output <- &packet:
+ fmt.Println("packet sent upstream")
}
}
}()
@@ 309,15 368,15 @@ func (relay *Relay) readServer(
go func() {
for {
- var h packetHeader
- var err = binary.Read(relay.server, binary.BigEndian, &h)
+ var p = NewPacket()
+ _, err = p.ReadFrom(relay.server)
if err != nil {
fail(err)
close(output)
return
}
- ack.Store(h.Sequence)
- relay.rxBytes += uint64(binary.Size(&h))
+ ack.Store(packet.header.Sequence)
+ relay.rxBytes += 1 // # uint64(binary.Size(&h))
var packet = Packet{
ConnectionId: h.ConnectionId,
}
@@ 385,7 444,9 @@ func (relay *Relay) writeServer(
switch p.Type {
case openPacket, forwardPacket, closePacket:
// Forward the packet to the server
- _, err = p.Write(relay.server, seq, getAck())
+ var n int
+ n, err = p.Write(relay.server, seq, getAck())
+ relay.printf("wrote %d bytes upstream", n)
seq += 1
if seq == 0 {
seq = 1
@@ 441,6 502,7 @@ func (relay *Relay) Run(ctx context.Cont
// buffered channel to make sure the fail function doesn't block
var reportErr = make(chan error, 1)
var fail = func(err error) {
+ relay.printf("fail %v", err)
reportErr <- err
close(reportErr)
cancelFunc()
@@ 481,6 543,7 @@ func (relay *Relay) Run(ctx context.Cont
go keepAlive(done, p, output)
case openPacket:
var cid = p.ConnectionId
+ relay.printf("channel %d opened remotely", cid)
var c = relay.dispatcher.Open(cid)
// Verify if connection id exists for sanity check
if c == nil {
@@ 490,6 553,7 @@ func (relay *Relay) Run(ctx context.Cont
go relay.newChannel(cancelCtx, fail, cid, c, output, relay.connect)
case closePacket:
+ relay.printf("channel %d closed remotely", p.ConnectionId)
relay.dispatcher.Close(p.ConnectionId)
case forwardPacket:
var cid = p.ConnectionId
@@ 182,21 182,21 @@ func writeResponse(payload string, conne
* We don't write the open packet since the server opened the connection.
*/
- // 1st Write forward packet with payload
- // FIXME we should split if len(payload) > MaxPayloadLength
- var p = Packet{
- Type: forwardPacket,
- ConnectionId: connectionId,
- Payload: *bytes.NewBufferString(payload),
- }
- p.Write(&r, seq, ack)
- // Close connection
- p = Packet{
- Type: closePacket,
- ConnectionId: connectionId,
- }
- p.Write(&r, seq + 1, ack)
- return
+ // 1st Write forward packet with payload
+ // FIXME we should split if len(payload) > MaxPayloadLength
+ var p = Packet{
+ Type: forwardPacket,
+ ConnectionId: connectionId,
+ Payload: *bytes.NewBufferString(payload),
+ }
+ p.Write(&r, seq, ack)
+ // Close connection
+ p = Packet{
+ Type: closePacket,
+ ConnectionId: connectionId,
+ }
+ p.Write(&r, seq+1, ack)
+ return
}
func equalResponse(t *testing.T, result io.Reader, expected []byte) bool {