4 files changed, 112 insertions(+), 120 deletions(-)

M wormhole/handler.py
M wormhole/logger.py
M wormhole/proxy.py
M wormhole/server.py
M wormhole/handler.py +11 -17
@@ 1,6 1,6 @@ 
 import asyncio
 from socket import TCP_NODELAY
-from logger import get_logger
+from logger import Logger
 from tools import get_content_length
 from tools import get_host_and_port
 

          
@@ 8,11 8,11 @@ from tools import get_host_and_port
 async def relay_stream(
     stream_reader, stream_writer, ident, return_first_line=False
 ):
-    logger = get_logger()
+    logger = Logger().get_logger()
     first_line = None
     while True:
         try:
-            line = await stream_reader.read(1024)
+            line = await stream_reader.read(4096)
             if len(line) == 0:
                 break
             stream_writer.write(line)

          
@@ 21,9 21,7 @@ async def relay_stream(
                 ex.__class__.__name__,
                 " ".join([str(arg) for arg in ex.args]),
             )
-            logger.debug(
-                f"[{ident['id']}][{ident['client']}]: {error_message}"
-            )
+            logger.debug(f"[{ident['id']}][{ident['client']}]: {error_message}")
             break
         else:
             if return_first_line and first_line is None:

          
@@ 37,7 35,7 @@ async def process_https(
     response_code = 200
     error_message = None
     host, port = get_host_and_port(uri)
-    logger = get_logger()
+    logger = Logger().get_logger()
     try:
         req_reader, req_writer = await asyncio.open_connection(
             host, port, ssl=False

          
@@ 51,14 49,10 @@ async def process_https(
         )
 
         tasks = [
-            asyncio.ensure_future(
-                relay_stream(client_reader, req_writer, ident)
-            ),
-            asyncio.ensure_future(
-                relay_stream(req_reader, client_writer, ident)
-            ),
+            relay_stream(client_reader, req_writer, ident),
+            relay_stream(req_reader, client_writer, ident),
         ]
-        await asyncio.wait(tasks)
+        await asyncio.gather(*tasks)
     except Exception as ex:
         response_code = 502
         error_message = "%s: %s" % (

          
@@ 151,7 145,7 @@ async def process_http(
     if response_code is None:
         response_code = int(response_status.decode("ascii").split(" ")[1])
 
-    logger = get_logger()
+    logger = Logger().get_logger()
     if error_message is None:
         logger.info(
             f"[{ident['id']}][{ident['client']}]: "

          
@@ 165,7 159,7 @@ async def process_http(
 
 
 async def process_request(client_reader, max_retry, ident):
-    logger = get_logger()
+    logger = Logger().get_logger()
     request_line = ""
     headers = []
     header = ""

          
@@ 190,7 184,7 @@ async def process_request(client_reader,
 
         content_length = get_content_length(header)
         while len(payload) < content_length:
-            payload += await client_reader.read(1024)
+            payload += await client_reader.read(4096)
     except Exception as ex:
         name = ex.__class__.__name__
         args = " ".join([str(arg) for arg in ex.args])

          
M wormhole/logger.py +31 -16
@@ 12,26 12,39 @@ class ContextFilter(logging.Filter):
         return True
 
 
-logger = None
+class Singleton(type):
+    _instances = {}
+
+    def __call__(cls, *args, **kwargs):
+        if cls not in cls._instances:
+            cls._instances[cls] = super(Singleton, cls).__call__(
+                *args, **kwargs
+            )
+        return cls._instances[cls]
 
 
-def get_logger(syslog_host=None, syslog_port=514, verbose=0):
-    unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s"
-    net_format = "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s"
-    date_format = "%b %d %H:%M:%S"
+class Logger(metaclass=Singleton):
+    def __init__(self, syslog_host=None, syslog_port=514, verbose=0):
+        unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s"
+        net_format = (
+            "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s"
+        )
+        date_format = "%b %d %H:%M:%S"
 
-    global logger
-    if logger is None:
-        logging.basicConfig(
-            level=logging.INFO,
-            format="%(asctime)s %(name)s[%(process)d]: %(message)s",
-        )
+        self.logger = logging.getLogger("wormhole")
+        self.logger.setLevel(logging.INFO)
         logging.getLogger("asyncio").setLevel(logging.CRITICAL)
-        logger = logging.getLogger("wormhole")
         if verbose >= 1:
-            logger.setLevel(logging.DEBUG)
+            self.logger.setLevel(logging.DEBUG)
         if verbose >= 2:
             logging.getLogger("asyncio").setLevel(logging.DEBUG)
+
+        # Add console handler
+        console_handler = logging.StreamHandler()
+        console_formatter = logging.Formatter(unix_format, datefmt=date_format)
+        console_handler.setFormatter(console_formatter)
+        self.logger.addHandler(console_handler)
+
         if syslog_host and syslog_host != "DISABLED":
             if syslog_host.startswith("/") and os.path.exists(syslog_host):
                 syslog = logging.handlers.SysLogHandler(

          
@@ 39,11 52,13 @@ def get_logger(syslog_host=None, syslog_
                 )
                 formatter = logging.Formatter(unix_format, datefmt=date_format)
             else:
-                logger.addFilter(ContextFilter())
+                self.logger.addFilter(ContextFilter())
                 syslog = logging.handlers.SysLogHandler(
                     address=(syslog_host, syslog_port),
                 )
                 formatter = logging.Formatter(net_format, datefmt=date_format)
             syslog.setFormatter(formatter)
-            logger.addHandler(syslog)
-    return logger
+            self.logger.addHandler(syslog)
+
+    def get_logger(self):
+        return self.logger

          
M wormhole/proxy.py +35 -24
@@ 14,11 14,27 @@ sys.path.insert(0, Path(os.path.realpath
 import asyncio
 from argparse import ArgumentParser
 from license import LICENSE
-from logger import get_logger
-from server import start_wormhole_server
+from logger import Logger
 from version import VERSION
 
 
+def start_server(host, port, authentication):
+    from server import start_wormhole_server
+
+    loop = asyncio.get_event_loop()
+    loop.run_until_complete(start_wormhole_server(host, port, authentication))
+    loop.run_forever()
+
+
+def check_uvloop():
+    try:
+        import uvloop
+    except ImportError:
+        return False
+    else:
+        return True
+
+
 def main():
     """CLI frontend function.  It takes command line options e.g. host,
     port and provides `--help` message.

          
@@ 81,31 97,26 @@ def main():
     if not (1 <= args.port <= 65535):
         parser.error("port must be 1-65535")
 
-    logger = get_logger(args.syslog_host, args.syslog_port, args.verbose)
-    try:
+    logger = Logger(
+        args.syslog_host, args.syslog_port, args.verbose
+    ).get_logger()
+    if args.verbose:
+        logger.debug(
+            f"[000000][{args.host}]: Using {'uvloop' if check_uvloop() else 'default event loop'}."
+        )
+
+    if check_uvloop():
         import uvloop
-    except ImportError:
-        pass
-    else:
-        logger.debug(f"[000000][{args.host}]: Using event loop from uvloop.")
+
         asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
-    loop = asyncio.new_event_loop()
+
+    start_server(args.host, args.port, args.authentication)
+
+
+if __name__ == "__main__":
     try:
-        loop.run_until_complete(
-            start_wormhole_server(
-                args.host,
-                args.port,
-                args.authentication,
-            )
-        )
-        loop.run_forever()
+        exit(main())
     except OSError:
         pass
     except KeyboardInterrupt:
-        print("bye")
-    finally:
-        loop.close()
-
-
-if __name__ == "__main__":
-    exit(main())
+        print("\nbye")

          
M wormhole/server.py +35 -63
@@ 1,5 1,6 @@ 
 import asyncio
 import functools
+import socket
 import sys
 from time import time
 from authentication import get_ident

          
@@ 7,45 8,27 @@ from authentication import verify
 from handler import process_http
 from handler import process_https
 from handler import process_request
-from logger import get_logger
-
+from logger import Logger
 
 MAX_RETRY = 3
 if sys.platform == "win32":
     import win32file
 
-    MAX_TASKS = win32file._getmaxstdio()
+    FREE_TASKS = asyncio.Semaphore(
+        int(0.9 * win32file._getmaxstdio()))
 else:
     import resource
 
-    MAX_TASKS = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
-
-
-wormhole_semaphore = None
-
+    FREE_TASKS = asyncio.Semaphore(
+        int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0])
+    )
+MAX_TASKS = FREE_TASKS._value
 
-def get_wormhole_semaphore():
-    max_wormholes = int(0.9 * MAX_TASKS)  # Use only 90% of open files limit.
-    global wormhole_semaphore
-    if wormhole_semaphore is None:
-        wormhole_semaphore = asyncio.Semaphore(max_wormholes)
-    return wormhole_semaphore
-
-
-def debug_wormhole_semaphore(client_reader, client_writer):
-    global wormhole_semaphore
-    ident = get_ident(client_reader, client_writer)
-    available = wormhole_semaphore._value
-    logger = get_logger()
-    logger.debug(
-        f"[{ident['id']}][{ident['client']}]: "
-        "Resource available: "
-        f"{100 * available / MAX_TASKS:.2f}% ({available}/{MAX_TASKS})"
-    )
+clients = dict()
 
 
 async def process_wormhole(client_reader, client_writer, auth):
-    logger = get_logger()
+    logger = Logger().get_logger()
     ident = get_ident(client_reader, client_writer)
 
     request_line, headers, payload = await process_request(

          
@@ 53,8 36,7 @@ async def process_wormhole(client_reader
     )
     if not request_line:
         logger.debug(
-            f"[{ident['id']}][{ident['client']}]: "
-            "!!! Task reject (empty request)"
+            f"[{ident['id']}][{ident['client']}]: !!! Task reject (empty request)"
         )
         return
 

          
@@ 66,8 48,7 @@ async def process_wormhole(client_reader
         request_method, uri, http_version = request_fields
     else:
         logger.debug(
-            f"[{ident['id']}][{ident['client']}]: "
-            "!!! Task reject (invalid request)"
+            f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)"
         )
         return
 

          
@@ 75,47 56,34 @@ async def process_wormhole(client_reader
         user_ident = await verify(client_reader, client_writer, headers, auth)
         if user_ident is None:
             logger.info(
-                f"[{ident['id']}][{ident['client']}]: "
-                f"{request_method} 407 {uri}"
+                f"[{ident['id']}][{ident['client']}]: {request_method} 407 {uri}"
             )
             return
         ident = user_ident
 
     if request_method == "CONNECT":
-        async with get_wormhole_semaphore():
-            debug_wormhole_semaphore(client_reader, client_writer)
+        async with FREE_TASKS:
+            logger.debug(
+                f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
+            )
             return await process_https(
                 client_reader, client_writer, request_method, uri, ident
             )
     else:
-        async with get_wormhole_semaphore():
-            debug_wormhole_semaphore(client_reader, client_writer)
+        async with FREE_TASKS:
+            logger.debug(
+                f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
+            )
             return await process_http(
-                client_writer,
-                request_method,
-                uri,
-                http_version,
-                headers,
-                payload,
-                ident,
+                client_writer, request_method, uri, http_version, headers, payload, ident,
             )
 
 
-async def limit_wormhole(client_reader, client_writer, auth):
-    async with get_wormhole_semaphore():
-        debug_wormhole_semaphore(client_reader, client_writer)
-        await process_wormhole(client_reader, client_writer, auth)
-        debug_wormhole_semaphore(client_reader, client_writer)
-
-
-clients = dict()
-
-
-def accept_client(client_reader, client_writer, auth):
-    logger = get_logger()
+async def accept_client(client_reader, client_writer, auth):
+    logger = Logger().get_logger()
     ident = get_ident(client_reader, client_writer)
     task = asyncio.ensure_future(
-        limit_wormhole(client_reader, client_writer, auth)
+        process_wormhole(client_reader, client_writer, auth)
     )
     global clients
     clients[task] = (client_reader, client_writer)

          
@@ 125,8 93,7 @@ def accept_client(client_reader, client_
         del clients[task]
         client_writer.close()
         logger.debug(
-            f"[{ident['id']}][{ident['client']}]: "
-            f"Connection closed ({time() - started_time:.5f} seconds)"
+            f"[{ident['id']}][{ident['client']}]: Connection closed ({time() - started_time:.5f} seconds)"
         )
 
     logger.debug(f"[{ident['id']}][{ident['client']}]: Connection started")

          
@@ 134,14 101,19 @@ def accept_client(client_reader, client_
 
 
 async def start_wormhole_server(host, port, auth):
-    logger = get_logger()
+    logger = Logger().get_logger()
     try:
         accept = functools.partial(accept_client, auth=auth)
-        server = await asyncio.start_server(accept, host, port)
+        # Check if the host string contains an IPv6 address
+        is_ipv6 = ":" in host
+        if is_ipv6:
+            family = socket.AF_INET6
+        else:
+            family = socket.AF_INET
+        server = await asyncio.start_server(accept, host, port, family=family)
     except OSError as ex:
         logger.critical(
-            f"[000000][{host}]: "
-            f"!!! Failed to bind server at [{host}:{port}]: {ex.args[1]}"
+            f"[000000][{host}]: !!! Failed to bind server at [{host}:{port}]: {ex.args[1]}"
         )
         raise
     else: