# HG changeset patch # User Chaiwat Suttipongsakul # Date 1741690293 -25200 # Tue Mar 11 17:51:33 2025 +0700 # Node ID 5b2a1fb601293b54010ea8909d37d6a80a29bc78 # Parent 6c5dea2457b9802c96dbd7b9d226eed30adb92d3 Modernize codebase to Python 3.11+ with optional performance enhancements - Bump minimum Python version to 3.11 - Make uvloop (non-Windows) and winloop (Windows) optional dependencies - Add signal handlers (SIGINT, SIGTERM) for graceful shutdown - Update setup.py to use README.md with Markdown support diff --git a/README.md b/README.md --- a/README.md +++ b/README.md @@ -10,8 +10,9 @@ ## Dependency -- Python \>= 3.6 +- Python \>= 3.11 - [uvloop](https://github.com/MagicStack/uvloop) (optional) +- [winloop](https://github.com/Vizonex/Winloop) (optional for Windows) ## Docker Image Usage @@ -99,7 +100,7 @@ # License MIT License (included in -[license.py](https://bit.ly/wormhole-proxy-license)) +[license.py](https://hg.sr.ht/~cwt/wormhole/raw/wormhole/license.py)) # Notice @@ -110,5 +111,5 @@ - Wormhole may not work in: - some ISPs - some firewalls - - some browers + - some browsers - some web sites diff --git a/README.rst b/README.rst deleted file mode 100644 --- a/README.rst +++ /dev/null @@ -1,124 +0,0 @@ -Wormhole -======== - -**Wormhole** is a forward proxy without caching. You may use it for: - -- Modifying requests to look like they are originated from the IP address - that *Wormhole* is running on. -- Adding an authentication layer to the internet users in your organization. -- Logging internet activities to your syslog server. - -Dependency ----------- - -- Python >= 3.6 -- `uvloop `_ (optional) - -Docker Image Usage ------------------- - -Run without authentication -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: shell - - $ docker pull bashell/wormhole - $ docker run -d -p 8800:8800 bashell/wormhole - -Run with authentication -~~~~~~~~~~~~~~~~~~~~~~~ - -- Create an empty directory on your docker host -- Create an authentication file that contains username and password in this - format ``username:password`` -- Link that directory to the container via option ``-v`` and also run wormhole - container with option ``-a /path/to/authentication_file`` - -Example: - -.. code:: shell - - $ docker pull bashell/wormhole - $ mkdir -p /path/to/dir - $ echo "user1:password1" > /path/to/dir/wormhole.passwd - $ docker run -d -v /path/to/dir:/opt/wormhole \ - -p 8800:8800 bashell/wormhole \ - -a /opt/wormhole/wormhole.passwd - -How to install --------------- - -Stable Version -~~~~~~~~~~~~~~ - -Please install the **stable version** using ``pip`` command: - -.. code:: shell - - $ pip install wormhole-proxy - -Development Snapshot -~~~~~~~~~~~~~~~~~~~~ - -You can install the **development snapshot** using ``pip`` with ``mercurial``: - -.. code:: shell - - $ pip install hg+https://hg.sr.ht/~cwt/wormhole - -Or install from your local clone: - -.. code:: shell - - $ hg clone https://hg.sr.ht/~cwt/wormhole - $ cd wormhole/ - $ pip install -e . - -You can also install the latest ``tip`` snapshot using the following -command: - -.. code:: shell - - $ pip install https://hg.sr.ht/~cwt/wormhole/archive/tip.tar.gz - -How to use ----------- - -#. Run **wormhole** command - - .. code:: shell - - $ wormhole - -#. Set browser's proxy setting to - - .. code:: shell - - host: 127.0.0.1 - port: 8800 - -Command help ------------- - -.. code:: shell - - $ wormhole --help - -License -------- - -MIT License (included in `license.py `_) - -Notice ------- - -- This project is forked and converted to Mercurial from - `WARP `_ on GitHub. -- Authentication file contains ``username`` and ``password`` in **plain - text**, keep it secret! *(I will try to encrypt/encode it soon.)* -- Wormhole may not work in: - - - some ISPs - - some firewalls - - some browers - - some web sites diff --git a/setup.py b/setup.py --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ def readme(): - with open("README.rst", encoding="utf-8") as readme_file: + with open("README.md", encoding="utf-8") as readme_file: return "\n" + readme_file.read() @@ -18,8 +18,9 @@ author_email="cwt@bashell.com", url="https://hg.sr.ht/~cwt/wormhole", license="MIT", - description="Asynchronous I/O HTTP and HTTPS Proxy on Python >= 3.6", + description="Asynchronous I/O HTTP and HTTPS Proxy on Python >= 3.11", long_description=readme(), + long_description_content_type="text/markdown", keywords="wormhole asynchronous web proxy", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -29,18 +30,22 @@ "License :: OSI Approved :: MIT License", "Operating System :: POSIX", "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: Proxy Servers", ], + setup_requires=["setuptools>=40.1.0"], install_requires=[ 'pywin32;platform_system=="Windows"', - 'uvloop;platform_system=="Linux"', ], + extras_require={ + "performance": [ + 'winloop;platform_system=="Windows"', + 'uvloop;platform_system!="Windows"', + ], + }, packages=["wormhole"], include_package_data=True, entry_points={"console_scripts": ["wormhole = wormhole.proxy:main"]}, diff --git a/wormhole/authentication.py b/wormhole/authentication.py --- a/wormhole/authentication.py +++ b/wormhole/authentication.py @@ -1,49 +1,55 @@ from base64 import decodebytes +import asyncio -def get_ident(client_reader, client_writer, user=None): +def get_ident( + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + user: str | None = None, +) -> dict[str, str]: client = client_writer.get_extra_info("peername")[0] if user: client = f"{user}@{client}" return {"id": hex(id(client_reader))[-6:], "client": client} -auth_list = list() +auth_list: list[str] = [] -def get_auth_list(auth): +def get_auth_list(auth: str) -> list[str]: global auth_list if not auth_list: - auth_list = [ - line.strip() - for line in open(auth, "r") - if line.strip() and not line.strip().startswith("#") - ] + with open(auth, "r") as f: + + auth_list = [ + line.strip() for line in f if line.strip() and not line.startswith("#") + ] return auth_list -def deny(client_writer): - [ +def deny(client_writer: asyncio.StreamWriter) -> None: + messages = ( + b"HTTP/1.1 407 Proxy Authentication Required\r\n", + b'Proxy-Authenticate: Basic realm="Wormhole Proxy"\r\n', + b"\r\n", + ) + for message in messages: client_writer.write(message) - for message in ( - b"HTTP/1.1 407 Proxy Authentication Required\r\n", - b'Proxy-Authenticate: Basic realm="Wormhole Proxy"\r\n', - b"\r\n", - ) - ] -async def verify(client_reader, client_writer, headers, auth): - proxy_auth = [ - header - for header in headers - if header.lower().startswith("proxy-authorization:") - ] +async def verify( + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + headers: list[str], + auth: str, +) -> dict[str, str] | None: + proxy_auth = [h for h in headers if h.lower().startswith("proxy-authorization:")] if proxy_auth: - user_password = decodebytes( - proxy_auth[0].split(" ")[2].encode("ascii") - ).decode("ascii") + + user_password = decodebytes(proxy_auth[0].split(" ")[2].encode("ascii")).decode( + "ascii" + ) if user_password in get_auth_list(auth): - user = user_password.split(":")[0] - return get_ident(client_reader, client_writer, user) - return deny(client_writer) + return get_ident(client_reader, client_writer, user_password.split(":")[0]) + deny(client_writer) + return None diff --git a/wormhole/handler.py b/wormhole/handler.py --- a/wormhole/handler.py +++ b/wormhole/handler.py @@ -1,203 +1,156 @@ import asyncio from socket import TCP_NODELAY from logger import Logger -from tools import get_content_length -from tools import get_host_and_port +from tools import get_content_length, get_host_and_port async def relay_stream( - stream_reader, stream_writer, ident, return_first_line=False -): + stream_reader: asyncio.StreamReader, + stream_writer: asyncio.StreamWriter, + ident: dict[str, str], + return_first_line: bool = False, +) -> bytes | None: logger = Logger().get_logger() + first_line = None while True: try: line = await stream_reader.read(4096) - if len(line) == 0: + if not line: break stream_writer.write(line) except Exception as ex: - error_message = "%s: %s" % ( - ex.__class__.__name__, - " ".join([str(arg) for arg in ex.args]), + + logger.debug( + f"[{ident['id']}][{ident['client']}]: {ex.__class__.__name__}: {' '.join(map(str, ex.args))}" ) - logger.debug(f"[{ident['id']}][{ident['client']}]: {error_message}") break + else: if return_first_line and first_line is None: first_line = line[: line.find(b"\r\n")] + return first_line async def process_https( - client_reader, client_writer, request_method, uri, ident -): - response_code = 200 - error_message = None + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + request_method: str, + uri: str, + ident: dict[str, str], +) -> None: + logger = Logger().get_logger() host, port = get_host_and_port(uri) - logger = Logger().get_logger() try: - req_reader, req_writer = await asyncio.open_connection( - host, port, ssl=False - ) - client_writer.write(b"HTTP/1.1 200 Connection established\r\n") - client_writer.write(b"\r\n") - # HTTPS need to log here, as the connection may keep alive for long. - logger.info( - f"[{ident['id']}][{ident['client']}]: " - f"{request_method} {response_code} {uri}" + req_reader, req_writer = await asyncio.open_connection(host, port, ssl=False) + client_writer.write(b"HTTP/1.1 200 Connection established\r\n\r\n") + await asyncio.gather( + relay_stream(client_reader, req_writer, ident), + relay_stream(req_reader, client_writer, ident), ) - tasks = [ - relay_stream(client_reader, req_writer, ident), - relay_stream(req_reader, client_writer, ident), - ] - await asyncio.gather(*tasks) + logger.info(f"[{ident['id']}][{ident['client']}]: {request_method} 200 {uri}") except Exception as ex: - response_code = 502 - error_message = "%s: %s" % ( - ex.__class__.__name__, - " ".join([str(arg) for arg in ex.args]), - ) - - if error_message: logger.error( - f"[{ident['id']}][{ident['client']}]: " - f"{request_method} {response_code} {uri} ({error_message})" + f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})" ) async def process_http( - client_writer, request_method, uri, http_version, headers, payload, ident -): - response_status = None - response_code = None - error_message = None - hostname = "127.0.0.1" # hostname (with optional port) e.g. example.com:80 + client_writer: asyncio.StreamWriter, + request_method: str, + uri: str, + http_version: str, + headers: list[str], + payload: bytes, + ident: dict[str, str], +) -> None: + logger = Logger().get_logger() + hostname = "127.0.0.1" request_headers = [] - request_headers_end_index = 0 - has_connection_header = False + has_connection = False for header in headers: - name_and_value = header.split(": ", 1) - - if len(name_and_value) == 2: - name, value = name_and_value - else: - name, value = name_and_value[0], None + if ": " in header: + name, value = header.split(": ", 1) + match name.lower(): + case "host": + hostname = value - if name.lower() == "host": - if value is not None: - hostname = value - elif name.lower() == "connection": - has_connection_header = True - if value.lower() in ("keep-alive", "persist"): - # current version of this program does not support - # the HTTP keep-alive feature - request_headers.append("Connection: close") - else: - request_headers.append(header) - elif name.lower() != "proxy-connection": - request_headers.append(header) - if len(header) == 0 and request_headers_end_index == 0: - request_headers_end_index = len(request_headers) - 1 + case "connection": + has_connection = True + request_headers.append( + "Connection: close" + if value.lower() in ("keep-alive", "persist") + else header + ) + case "proxy-connection": + continue + case _: + request_headers.append(header) - if request_headers_end_index == 0: - request_headers_end_index = len(request_headers) + if not has_connection: + request_headers.append("Connection: close") - if not has_connection_header: - request_headers.insert(request_headers_end_index, "Connection: close") + path = uri.removeprefix(f"http://{hostname}") + new_head = f"{request_method} {path} {http_version}" + host, port = get_host_and_port(hostname, "80") - path = uri[len(hostname) + 7 :] # 7 is len('http://') - new_head = " ".join([request_method, path, http_version]) - host, port = get_host_and_port(hostname, 80) try: + req_reader, req_writer = await asyncio.open_connection( host, port, flags=TCP_NODELAY ) - req_writer.write(f"{new_head}\r\n".encode()) - await req_writer.drain() - - req_writer.write(f"Host: {hostname}".encode()) - req_writer.write(b"\r\n") - - [ + req_writer.write(f"{new_head}\r\nHost: {hostname}\r\n".encode()) + for header in request_headers: req_writer.write(f"{header}\r\n".encode()) - for header in request_headers - ] req_writer.write(b"\r\n") - - if payload != b"": + if payload: req_writer.write(payload) - req_writer.write(b"\r\n") await req_writer.drain() - response_status = await relay_stream( - req_reader, client_writer, ident, True + response_status = await relay_stream(req_reader, client_writer, ident, True) + response_code = ( + int(response_status.decode("ascii").split(" ")[1]) + if response_status + else 502 + ) + logger.info( + f"[{ident['id']}][{ident['client']}]: {request_method} {response_code} {uri}" ) except Exception as ex: - response_code = 502 - error_message = "%s: %s" % ( - ex.__class__.__name__, - " ".join([str(arg) for arg in ex.args]), - ) - - if response_code is None: - response_code = int(response_status.decode("ascii").split(" ")[1]) - - logger = Logger().get_logger() - if error_message is None: - logger.info( - f"[{ident['id']}][{ident['client']}]: " - f"{request_method} {response_code} {uri}" - ) - else: logger.error( - f"[{ident['id']}][{ident['client']}]: " - f"{request_method} {response_code} {uri} ({error_message})" + f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})" ) -async def process_request(client_reader, max_retry, ident): +async def process_request( + client_reader: asyncio.StreamReader, max_retry: int, ident: dict[str, str] +) -> tuple[str, list[str], bytes]: + logger = Logger().get_logger() - request_line = "" - headers = [] header = "" payload = b"" - try: - retry = 0 - while True: - line = await client_reader.readline() - if not line: - if len(header) == 0 and retry < max_retry: - # handle the case when the client make connection - # but sending data is delayed for some reasons - retry += 1 - await asyncio.sleep(0.1) - continue - else: - break - if line == b"\r\n": - break - if line != b"": - header += line.decode() + retry = 0 + + while True: + line = await client_reader.readline() + if not line: + if not header and retry < max_retry: + retry += 1 + await asyncio.sleep(0.1) - content_length = get_content_length(header) - while len(payload) < content_length: - payload += await client_reader.read(4096) - except Exception as ex: - name = ex.__class__.__name__ - args = " ".join([str(arg) for arg in ex.args]) - logger.debug( - f"[{ident['id']}][{ident['client']}]: " - f"!!! Task reject ({name}: {args})" - ) + continue + break + if line == b"\r\n": + + break + header += line.decode() - if header: - header_lines = header.split("\r\n") - if len(header_lines) > 1: - request_line = header_lines[0] - if len(header_lines) > 2: - headers = header_lines[1:-1] + content_length = get_content_length(header) + while len(payload) < content_length: + payload += await client_reader.read(4096) - return request_line, headers, payload + header_lines = header.split("\r\n") + return header_lines[0], header_lines[1:-1] if len(header_lines) > 2 else [], payload diff --git a/wormhole/license.py b/wormhole/license.py --- a/wormhole/license.py +++ b/wormhole/license.py @@ -3,7 +3,7 @@ LICENSE = """ The MIT License (MIT) -Copyright (c) 2016 cwt(at)bashell(dot)com +Copyright (c) 2016-2025 cwt(at)bashell(dot)com Copyright (c) 2013 devunt(at)gmail(dot)com Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/wormhole/logger.py b/wormhole/logger.py --- a/wormhole/logger.py +++ b/wormhole/logger.py @@ -1,64 +1,76 @@ import logging import logging.handlers import os + import socket +from datetime import datetime + class ContextFilter(logging.Filter): - hostname = socket.gethostname() - def filter(self, record): + hostname: str = socket.gethostname() + + def filter(self, record: logging.LogRecord) -> bool: record.hostname = self.hostname return True class Singleton(type): - _instances = {} + _instances: dict[type, object] = {} - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: tuple, **kwargs: dict) -> object: if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] class Logger(metaclass=Singleton): - def __init__(self, syslog_host=None, syslog_port=514, verbose=0): - unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s" - net_format = ( - "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s" - ) - date_format = "%b %d %H:%M:%S" - + def __init__( + self, syslog_host: str | None = None, syslog_port: int = 514, verbose: int = 0 + ) -> None: self.logger = logging.getLogger("wormhole") self.logger.setLevel(logging.INFO) logging.getLogger("asyncio").setLevel(logging.CRITICAL) + if verbose >= 1: self.logger.setLevel(logging.DEBUG) if verbose >= 2: logging.getLogger("asyncio").setLevel(logging.DEBUG) - # Add console handler + # Console handler console_handler = logging.StreamHandler() - console_formatter = logging.Formatter(unix_format, datefmt=date_format) - console_handler.setFormatter(console_formatter) + console_handler.setFormatter( + logging.Formatter( + "%(asctime)s %(name)s[%(process)d]: %(message)s", + datefmt="%b %d %H:%M:%S", + ) + ) self.logger.addHandler(console_handler) + # Syslog handler if syslog_host and syslog_host != "DISABLED": if syslog_host.startswith("/") and os.path.exists(syslog_host): - syslog = logging.handlers.SysLogHandler( - address=syslog_host, + syslog = logging.handlers.SysLogHandler(address=syslog_host) + syslog.setFormatter( + logging.Formatter( + "%(asctime)s %(name)s[%(process)d]: %(message)s", + datefmt="%b %d %H:%M:%S", + ) ) - formatter = logging.Formatter(unix_format, datefmt=date_format) else: self.logger.addFilter(ContextFilter()) syslog = logging.handlers.SysLogHandler( - address=(syslog_host, syslog_port), + address=(syslog_host, syslog_port) ) - formatter = logging.Formatter(net_format, datefmt=date_format) - syslog.setFormatter(formatter) + + syslog.setFormatter( + logging.Formatter( + "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s", + datefmt="%b %d %H:%M:%S", + ) + ) self.logger.addHandler(syslog) - def get_logger(self): + def get_logger(self) -> logging.Logger: return self.logger diff --git a/wormhole/proxy.py b/wormhole/proxy.py --- a/wormhole/proxy.py +++ b/wormhole/proxy.py @@ -1,54 +1,48 @@ #!/usr/bin/env python3 import sys - -if sys.version_info < (3, 6): - print("Error: You need python 3.6 or newer.") - exit(1) +import signal -import os -from pathlib import Path - -sys.path.insert(0, Path(os.path.realpath(__file__)).parent.as_posix()) +if sys.version_info < (3, 11): + print("Error: You need Python 3.11 or newer.") + sys.exit(1) import asyncio from argparse import ArgumentParser +from pathlib import Path + from license import LICENSE from logger import Logger from version import VERSION - -def start_server(host, port, authentication): - from server import start_wormhole_server +sys.path.insert(0, str(Path(__file__).parent.resolve())) - loop = asyncio.get_event_loop() - loop.run_until_complete(start_wormhole_server(host, port, authentication)) - loop.run_forever() +try: + if sys.platform in ("win32", "cygwin"): + import winloop as uvloop + else: + import uvloop +except ImportError: + uvloop = None -def check_uvloop(): - try: - import uvloop - except ImportError: - return False - else: - return True +async def start_server( + host: str, port: int, auth: str | None, shutdown_event: asyncio.Event +) -> None: + from server import start_wormhole_server + + server = await start_wormhole_server(host, port, auth) + await shutdown_event.wait() # รอสัญญาณ shutdown + server.close() + await server.wait_closed() -def main(): - """CLI frontend function. It takes command line options e.g. host, - port and provides `--help` message. - """ +def main() -> int: parser = ArgumentParser( - description=( - f"Wormhole({VERSION}): Asynchronous IO HTTP and HTTPS Proxy" - ) + description=f"Wormhole({VERSION}): Asynchronous IO HTTP and HTTPS Proxy" ) parser.add_argument( - "-H", - "--host", - default="0.0.0.0", - help="Host to listen [default: %(default)s]", + "-H", "--host", default="0.0.0.0", help="Host to listen [default: %(default)s]" ) parser.add_argument( "-p", @@ -57,14 +51,12 @@ default=8800, help="Port to listen [default: %(default)d]", ) + parser.add_argument( "-a", "--authentication", default="", - help=( - "File contains username and password list " - "for proxy authentication [default: no authentication]" - ), + help="File contains username and password list [default: no auth]", ) parser.add_argument( "-S", @@ -77,46 +69,46 @@ "--syslog-port", type=int, default=514, - help="Syslog Port to listen [default: %(default)d]", + help="Syslog Port [default: %(default)d]", ) parser.add_argument( - "-l", - "--license", - action="store_true", - default=False, - help="Print LICENSE and exit", + "-l", "--license", action="store_true", help="Print LICENSE and exit" ) parser.add_argument( "-v", "--verbose", action="count", default=0, help="Print verbose" ) args = parser.parse_args() + if args.license: print(parser.description) print(LICENSE) - exit() - if not (1 <= args.port <= 65535): + return 0 + + if not 1 <= args.port <= 65535: parser.error("port must be 1-65535") - logger = Logger( - args.syslog_host, args.syslog_port, args.verbose - ).get_logger() + logger = Logger(args.syslog_host, args.syslog_port, args.verbose).get_logger() if args.verbose: logger.debug( - f"[000000][{args.host}]: Using {'uvloop' if check_uvloop() else 'default event loop'}." + f"[000000][{args.host}]: Using {uvloop.__name__ if uvloop else 'default event loop'}" ) - if check_uvloop(): - import uvloop + loop = asyncio.get_event_loop() + shutdown_event = asyncio.Event() + + # Signal handler + def handle_shutdown(signum, frame): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + print("\nbye") + shutdown_event.set() - start_server(args.host, args.port, args.authentication) + signal.signal(signal.SIGINT, handle_shutdown) + signal.signal(signal.SIGTERM, handle_shutdown) # จับ SIGTERM ด้วย (เช่นจาก Docker) + + runner = uvloop.run if uvloop else asyncio.run + runner(start_server(args.host, args.port, args.authentication, shutdown_event)) + return 0 if __name__ == "__main__": - try: - exit(main()) - except OSError: - pass - except KeyboardInterrupt: - print("\nbye") + sys.exit(main()) diff --git a/wormhole/server.py b/wormhole/server.py --- a/wormhole/server.py +++ b/wormhole/server.py @@ -3,31 +3,29 @@ import socket import sys from time import time -from authentication import get_ident -from authentication import verify -from handler import process_http -from handler import process_https -from handler import process_request +from authentication import get_ident, verify +from handler import process_http, process_https, process_request + from logger import Logger -MAX_RETRY = 3 + +MAX_RETRY: int = 3 + if sys.platform == "win32": import win32file - FREE_TASKS = asyncio.Semaphore( - int(0.9 * win32file._getmaxstdio())) + MAX_TASKS: int = int(0.9 * win32file._getmaxstdio()) else: import resource - FREE_TASKS = asyncio.Semaphore( - int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0]) - ) -MAX_TASKS = FREE_TASKS._value - -clients = dict() + MAX_TASKS: int = int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0]) -async def process_wormhole(client_reader, client_writer, auth): +async def process_wormhole( + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + auth: str | None, +) -> None: logger = Logger().get_logger() ident = get_ident(client_reader, client_writer) @@ -35,25 +33,28 @@ client_reader, MAX_RETRY, ident ) if not request_line: + logger.debug( f"[{ident['id']}][{ident['client']}]: !!! Task reject (empty request)" ) return request_fields = request_line.split(" ") - if len(request_fields) == 2: - request_method, uri = request_fields - http_version = "HTTP/1.0" - elif len(request_fields) == 3: - request_method, uri, http_version = request_fields - else: - logger.debug( - f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)" - ) - return + match len(request_fields): + case 2: + request_method, uri = request_fields + http_version = "HTTP/1.0" + case 3: + request_method, uri, http_version = request_fields + case _: + logger.debug( + f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)" + ) + return if auth: user_ident = await verify(client_reader, client_writer, headers, auth) + if user_ident is None: logger.info( f"[{ident['id']}][{ident['client']}]: {request_method} 407 {uri}" @@ -61,61 +62,63 @@ return ident = user_ident - if request_method == "CONNECT": - async with FREE_TASKS: - logger.debug( - f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available" - ) - return await process_https( - client_reader, client_writer, request_method, uri, ident + async with asyncio.TaskGroup() as tg: + if request_method == "CONNECT": + tg.create_task( + process_https(client_reader, client_writer, request_method, uri, ident) ) - else: - async with FREE_TASKS: - logger.debug( - f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available" - ) - return await process_http( - client_writer, request_method, uri, http_version, headers, payload, ident, + else: + tg.create_task( + process_http( + client_writer, + request_method, + uri, + http_version, + headers, + payload, + ident, + ) ) -async def accept_client(client_reader, client_writer, auth): +async def accept_client( + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + auth: str | None, +) -> None: logger = Logger().get_logger() + ident = get_ident(client_reader, client_writer) - task = asyncio.ensure_future( - process_wormhole(client_reader, client_writer, auth) - ) - global clients - clients[task] = (client_reader, client_writer) started_time = time() - def client_done(task): - del clients[task] + async def client_task() -> None: + await process_wormhole(client_reader, client_writer, auth) client_writer.close() logger.debug( f"[{ident['id']}][{ident['client']}]: Connection closed ({time() - started_time:.5f} seconds)" ) logger.debug(f"[{ident['id']}][{ident['client']}]: Connection started") - task.add_done_callback(client_done) + await client_task() -async def start_wormhole_server(host, port, auth): +async def start_wormhole_server( + host: str, port: int, auth: str | None +) -> asyncio.Server: logger = Logger().get_logger() try: accept = functools.partial(accept_client, auth=auth) - # Check if the host string contains an IPv6 address - is_ipv6 = ":" in host - if is_ipv6: - family = socket.AF_INET6 - else: - family = socket.AF_INET - server = await asyncio.start_server(accept, host, port, family=family) + family = socket.AF_INET6 if ":" in host else socket.AF_INET + server = await asyncio.start_server( + accept, host, port, family=family, limit=MAX_TASKS + ) except OSError as ex: logger.critical( f"[000000][{host}]: !!! Failed to bind server at [{host}:{port}]: {ex.args[1]}" ) + raise + else: logger.info(f"[000000][{host}]: wormhole bound at {host}:{port}") return server diff --git a/wormhole/tools.py b/wormhole/tools.py --- a/wormhole/tools.py +++ b/wormhole/tools.py @@ -1,26 +1,19 @@ import re - REGEX_HOST = re.compile(r"(.+?):([0-9]{1,5})") -REGEX_CONTENT_LENGTH = re.compile( - r"\r\nContent-Length: ([0-9]+)\r\n", re.IGNORECASE -) +REGEX_CONTENT_LENGTH = re.compile(r"\r\nContent-Length: ([0-9]+)\r\n", re.IGNORECASE) -def get_host_and_port(hostname, default_port=None): - match = REGEX_HOST.search(hostname) - if match: - host = match.group(1) - port = int(match.group(2)) - else: - host = hostname - port = int(default_port) - return host, port +def get_host_and_port( + hostname: str, default_port: str | None = None +) -> tuple[str, int]: + if match := REGEX_HOST.search(hostname): + return match.group(1), int(match.group(2)) + return hostname, int(default_port or "80") -def get_content_length(header): - match = REGEX_CONTENT_LENGTH.search(header) - if match: +def get_content_length(header: str) -> int: + if match := REGEX_CONTENT_LENGTH.search(header): return int(match.group(1)) return 0 diff --git a/wormhole/version.py b/wormhole/version.py --- a/wormhole/version.py +++ b/wormhole/version.py @@ -1,1 +1,1 @@ -VERSION = "v3.0.2" +VERSION = "v3.1.0"