M wormhole/handler.py +11 -17
@@ 1,6 1,6 @@
import asyncio
from socket import TCP_NODELAY
-from logger import get_logger
+from logger import Logger
from tools import get_content_length
from tools import get_host_and_port
@@ 8,11 8,11 @@ from tools import get_host_and_port
async def relay_stream(
stream_reader, stream_writer, ident, return_first_line=False
):
- logger = get_logger()
+ logger = Logger().get_logger()
first_line = None
while True:
try:
- line = await stream_reader.read(1024)
+ line = await stream_reader.read(4096)
if len(line) == 0:
break
stream_writer.write(line)
@@ 21,9 21,7 @@ async def relay_stream(
ex.__class__.__name__,
" ".join([str(arg) for arg in ex.args]),
)
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: {error_message}"
- )
+ logger.debug(f"[{ident['id']}][{ident['client']}]: {error_message}")
break
else:
if return_first_line and first_line is None:
@@ 37,7 35,7 @@ async def process_https(
response_code = 200
error_message = None
host, port = get_host_and_port(uri)
- logger = get_logger()
+ logger = Logger().get_logger()
try:
req_reader, req_writer = await asyncio.open_connection(
host, port, ssl=False
@@ 51,14 49,10 @@ async def process_https(
)
tasks = [
- asyncio.ensure_future(
- relay_stream(client_reader, req_writer, ident)
- ),
- asyncio.ensure_future(
- relay_stream(req_reader, client_writer, ident)
- ),
+ relay_stream(client_reader, req_writer, ident),
+ relay_stream(req_reader, client_writer, ident),
]
- await asyncio.wait(tasks)
+ await asyncio.gather(*tasks)
except Exception as ex:
response_code = 502
error_message = "%s: %s" % (
@@ 151,7 145,7 @@ async def process_http(
if response_code is None:
response_code = int(response_status.decode("ascii").split(" ")[1])
- logger = get_logger()
+ logger = Logger().get_logger()
if error_message is None:
logger.info(
f"[{ident['id']}][{ident['client']}]: "
@@ 165,7 159,7 @@ async def process_http(
async def process_request(client_reader, max_retry, ident):
- logger = get_logger()
+ logger = Logger().get_logger()
request_line = ""
headers = []
header = ""
@@ 190,7 184,7 @@ async def process_request(client_reader,
content_length = get_content_length(header)
while len(payload) < content_length:
- payload += await client_reader.read(1024)
+ payload += await client_reader.read(4096)
except Exception as ex:
name = ex.__class__.__name__
args = " ".join([str(arg) for arg in ex.args])
M wormhole/logger.py +31 -16
@@ 12,26 12,39 @@ class ContextFilter(logging.Filter):
return True
-logger = None
+class Singleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(
+ *args, **kwargs
+ )
+ return cls._instances[cls]
-def get_logger(syslog_host=None, syslog_port=514, verbose=0):
- unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s"
- net_format = "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s"
- date_format = "%b %d %H:%M:%S"
+class Logger(metaclass=Singleton):
+ def __init__(self, syslog_host=None, syslog_port=514, verbose=0):
+ unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s"
+ net_format = (
+ "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s"
+ )
+ date_format = "%b %d %H:%M:%S"
- global logger
- if logger is None:
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s %(name)s[%(process)d]: %(message)s",
- )
+ self.logger = logging.getLogger("wormhole")
+ self.logger.setLevel(logging.INFO)
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
- logger = logging.getLogger("wormhole")
if verbose >= 1:
- logger.setLevel(logging.DEBUG)
+ self.logger.setLevel(logging.DEBUG)
if verbose >= 2:
logging.getLogger("asyncio").setLevel(logging.DEBUG)
+
+ # Add console handler
+ console_handler = logging.StreamHandler()
+ console_formatter = logging.Formatter(unix_format, datefmt=date_format)
+ console_handler.setFormatter(console_formatter)
+ self.logger.addHandler(console_handler)
+
if syslog_host and syslog_host != "DISABLED":
if syslog_host.startswith("/") and os.path.exists(syslog_host):
syslog = logging.handlers.SysLogHandler(
@@ 39,11 52,13 @@ def get_logger(syslog_host=None, syslog_
)
formatter = logging.Formatter(unix_format, datefmt=date_format)
else:
- logger.addFilter(ContextFilter())
+ self.logger.addFilter(ContextFilter())
syslog = logging.handlers.SysLogHandler(
address=(syslog_host, syslog_port),
)
formatter = logging.Formatter(net_format, datefmt=date_format)
syslog.setFormatter(formatter)
- logger.addHandler(syslog)
- return logger
+ self.logger.addHandler(syslog)
+
+ def get_logger(self):
+ return self.logger
M wormhole/proxy.py +35 -24
@@ 14,11 14,27 @@ sys.path.insert(0, Path(os.path.realpath
import asyncio
from argparse import ArgumentParser
from license import LICENSE
-from logger import get_logger
-from server import start_wormhole_server
+from logger import Logger
from version import VERSION
+def start_server(host, port, authentication):
+ from server import start_wormhole_server
+
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(start_wormhole_server(host, port, authentication))
+ loop.run_forever()
+
+
+def check_uvloop():
+ try:
+ import uvloop
+ except ImportError:
+ return False
+ else:
+ return True
+
+
def main():
"""CLI frontend function. It takes command line options e.g. host,
port and provides `--help` message.
@@ 81,31 97,26 @@ def main():
if not (1 <= args.port <= 65535):
parser.error("port must be 1-65535")
- logger = get_logger(args.syslog_host, args.syslog_port, args.verbose)
- try:
+ logger = Logger(
+ args.syslog_host, args.syslog_port, args.verbose
+ ).get_logger()
+ if args.verbose:
+ logger.debug(
+ f"[000000][{args.host}]: Using {'uvloop' if check_uvloop() else 'default event loop'}."
+ )
+
+ if check_uvloop():
import uvloop
- except ImportError:
- pass
- else:
- logger.debug(f"[000000][{args.host}]: Using event loop from uvloop.")
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
- loop = asyncio.new_event_loop()
+
+ start_server(args.host, args.port, args.authentication)
+
+
+if __name__ == "__main__":
try:
- loop.run_until_complete(
- start_wormhole_server(
- args.host,
- args.port,
- args.authentication,
- )
- )
- loop.run_forever()
+ exit(main())
except OSError:
pass
except KeyboardInterrupt:
- print("bye")
- finally:
- loop.close()
-
-
-if __name__ == "__main__":
- exit(main())
+ print("\nbye")
M wormhole/server.py +35 -63
@@ 1,5 1,6 @@
import asyncio
import functools
+import socket
import sys
from time import time
from authentication import get_ident
@@ 7,45 8,27 @@ from authentication import verify
from handler import process_http
from handler import process_https
from handler import process_request
-from logger import get_logger
-
+from logger import Logger
MAX_RETRY = 3
if sys.platform == "win32":
import win32file
- MAX_TASKS = win32file._getmaxstdio()
+ FREE_TASKS = asyncio.Semaphore(
+ int(0.9 * win32file._getmaxstdio()))
else:
import resource
- MAX_TASKS = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
-
-
-wormhole_semaphore = None
-
+ FREE_TASKS = asyncio.Semaphore(
+ int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0])
+ )
+MAX_TASKS = FREE_TASKS._value
-def get_wormhole_semaphore():
- max_wormholes = int(0.9 * MAX_TASKS) # Use only 90% of open files limit.
- global wormhole_semaphore
- if wormhole_semaphore is None:
- wormhole_semaphore = asyncio.Semaphore(max_wormholes)
- return wormhole_semaphore
-
-
-def debug_wormhole_semaphore(client_reader, client_writer):
- global wormhole_semaphore
- ident = get_ident(client_reader, client_writer)
- available = wormhole_semaphore._value
- logger = get_logger()
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: "
- "Resource available: "
- f"{100 * available / MAX_TASKS:.2f}% ({available}/{MAX_TASKS})"
- )
+clients = dict()
async def process_wormhole(client_reader, client_writer, auth):
- logger = get_logger()
+ logger = Logger().get_logger()
ident = get_ident(client_reader, client_writer)
request_line, headers, payload = await process_request(
@@ 53,8 36,7 @@ async def process_wormhole(client_reader
)
if not request_line:
logger.debug(
- f"[{ident['id']}][{ident['client']}]: "
- "!!! Task reject (empty request)"
+ f"[{ident['id']}][{ident['client']}]: !!! Task reject (empty request)"
)
return
@@ 66,8 48,7 @@ async def process_wormhole(client_reader
request_method, uri, http_version = request_fields
else:
logger.debug(
- f"[{ident['id']}][{ident['client']}]: "
- "!!! Task reject (invalid request)"
+ f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)"
)
return
@@ 75,47 56,34 @@ async def process_wormhole(client_reader
user_ident = await verify(client_reader, client_writer, headers, auth)
if user_ident is None:
logger.info(
- f"[{ident['id']}][{ident['client']}]: "
- f"{request_method} 407 {uri}"
+ f"[{ident['id']}][{ident['client']}]: {request_method} 407 {uri}"
)
return
ident = user_ident
if request_method == "CONNECT":
- async with get_wormhole_semaphore():
- debug_wormhole_semaphore(client_reader, client_writer)
+ async with FREE_TASKS:
+ logger.debug(
+ f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
+ )
return await process_https(
client_reader, client_writer, request_method, uri, ident
)
else:
- async with get_wormhole_semaphore():
- debug_wormhole_semaphore(client_reader, client_writer)
+ async with FREE_TASKS:
+ logger.debug(
+ f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
+ )
return await process_http(
- client_writer,
- request_method,
- uri,
- http_version,
- headers,
- payload,
- ident,
+ client_writer, request_method, uri, http_version, headers, payload, ident,
)
-async def limit_wormhole(client_reader, client_writer, auth):
- async with get_wormhole_semaphore():
- debug_wormhole_semaphore(client_reader, client_writer)
- await process_wormhole(client_reader, client_writer, auth)
- debug_wormhole_semaphore(client_reader, client_writer)
-
-
-clients = dict()
-
-
-def accept_client(client_reader, client_writer, auth):
- logger = get_logger()
+async def accept_client(client_reader, client_writer, auth):
+ logger = Logger().get_logger()
ident = get_ident(client_reader, client_writer)
task = asyncio.ensure_future(
- limit_wormhole(client_reader, client_writer, auth)
+ process_wormhole(client_reader, client_writer, auth)
)
global clients
clients[task] = (client_reader, client_writer)
@@ 125,8 93,7 @@ def accept_client(client_reader, client_
del clients[task]
client_writer.close()
logger.debug(
- f"[{ident['id']}][{ident['client']}]: "
- f"Connection closed ({time() - started_time:.5f} seconds)"
+ f"[{ident['id']}][{ident['client']}]: Connection closed ({time() - started_time:.5f} seconds)"
)
logger.debug(f"[{ident['id']}][{ident['client']}]: Connection started")
@@ 134,14 101,19 @@ def accept_client(client_reader, client_
async def start_wormhole_server(host, port, auth):
- logger = get_logger()
+ logger = Logger().get_logger()
try:
accept = functools.partial(accept_client, auth=auth)
- server = await asyncio.start_server(accept, host, port)
+ # Check if the host string contains an IPv6 address
+ is_ipv6 = ":" in host
+ if is_ipv6:
+ family = socket.AF_INET6
+ else:
+ family = socket.AF_INET
+ server = await asyncio.start_server(accept, host, port, family=family)
except OSError as ex:
logger.critical(
- f"[000000][{host}]: "
- f"!!! Failed to bind server at [{host}:{port}]: {ex.args[1]}"
+ f"[000000][{host}]: !!! Failed to bind server at [{host}:{port}]: {ex.args[1]}"
)
raise
else: