M README.md +4 -3
@@ 10,8 10,9 @@
## Dependency
-- Python \>= 3.6
+- Python \>= 3.11
- [uvloop](https://github.com/MagicStack/uvloop) (optional)
+- [winloop](https://github.com/Vizonex/Winloop) (optional for Windows)
## Docker Image Usage
@@ 99,7 100,7 @@ 2. Set browser\'s proxy setting to
# License
MIT License (included in
-[license.py](https://bit.ly/wormhole-proxy-license))
+[license.py](https://hg.sr.ht/~cwt/wormhole/raw/wormhole/license.py))
# Notice
@@ 110,5 111,5 @@ MIT License (included in
- Wormhole may not work in:
- some ISPs
- some firewalls
- - some browers
+ - some browsers
- some web sites
R README.rst => +0 -124
@@ 1,124 0,0 @@
-Wormhole
-========
-
-**Wormhole** is a forward proxy without caching. You may use it for:
-
-- Modifying requests to look like they are originated from the IP address
- that *Wormhole* is running on.
-- Adding an authentication layer to the internet users in your organization.
-- Logging internet activities to your syslog server.
-
-Dependency
-----------
-
-- Python >= 3.6
-- `uvloop <https://github.com/MagicStack/uvloop>`_ (optional)
-
-Docker Image Usage
-------------------
-
-Run without authentication
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. code:: shell
-
- $ docker pull bashell/wormhole
- $ docker run -d -p 8800:8800 bashell/wormhole
-
-Run with authentication
-~~~~~~~~~~~~~~~~~~~~~~~
-
-- Create an empty directory on your docker host
-- Create an authentication file that contains username and password in this
- format ``username:password``
-- Link that directory to the container via option ``-v`` and also run wormhole
- container with option ``-a /path/to/authentication_file``
-
-Example:
-
-.. code:: shell
-
- $ docker pull bashell/wormhole
- $ mkdir -p /path/to/dir
- $ echo "user1:password1" > /path/to/dir/wormhole.passwd
- $ docker run -d -v /path/to/dir:/opt/wormhole \
- -p 8800:8800 bashell/wormhole \
- -a /opt/wormhole/wormhole.passwd
-
-How to install
---------------
-
-Stable Version
-~~~~~~~~~~~~~~
-
-Please install the **stable version** using ``pip`` command:
-
-.. code:: shell
-
- $ pip install wormhole-proxy
-
-Development Snapshot
-~~~~~~~~~~~~~~~~~~~~
-
-You can install the **development snapshot** using ``pip`` with ``mercurial``:
-
-.. code:: shell
-
- $ pip install hg+https://hg.sr.ht/~cwt/wormhole
-
-Or install from your local clone:
-
-.. code:: shell
-
- $ hg clone https://hg.sr.ht/~cwt/wormhole
- $ cd wormhole/
- $ pip install -e .
-
-You can also install the latest ``tip`` snapshot using the following
-command:
-
-.. code:: shell
-
- $ pip install https://hg.sr.ht/~cwt/wormhole/archive/tip.tar.gz
-
-How to use
-----------
-
-#. Run **wormhole** command
-
- .. code:: shell
-
- $ wormhole
-
-#. Set browser's proxy setting to
-
- .. code:: shell
-
- host: 127.0.0.1
- port: 8800
-
-Command help
-------------
-
-.. code:: shell
-
- $ wormhole --help
-
-License
--------
-
-MIT License (included in `license.py <https://bit.ly/wormhole-proxy-license>`_)
-
-Notice
-------
-
-- This project is forked and converted to Mercurial from
- `WARP <https://github.com/devunt/warp>`_ on GitHub.
-- Authentication file contains ``username`` and ``password`` in **plain
- text**, keep it secret! *(I will try to encrypt/encode it soon.)*
-- Wormhole may not work in:
-
- - some ISPs
- - some firewalls
- - some browers
- - some web sites
M setup.py +13 -8
@@ 7,7 7,7 @@ from wormhole.version import VERSION
def readme():
- with open("README.rst", encoding="utf-8") as readme_file:
+ with open("README.md", encoding="utf-8") as readme_file:
return "\n" + readme_file.read()
@@ 18,8 18,9 @@ setup(
author_email="cwt@bashell.com",
url="https://hg.sr.ht/~cwt/wormhole",
license="MIT",
- description="Asynchronous I/O HTTP and HTTPS Proxy on Python >= 3.6",
+ description="Asynchronous I/O HTTP and HTTPS Proxy on Python >= 3.11",
long_description=readme(),
+ long_description_content_type="text/markdown",
keywords="wormhole asynchronous web proxy",
classifiers=[
"Development Status :: 5 - Production/Stable",
@@ 29,18 30,22 @@ setup(
"License :: OSI Approved :: MIT License",
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
- "Programming Language :: Python :: 3.6",
- "Programming Language :: Python :: 3.7",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Internet :: Proxy Servers",
],
+ setup_requires=["setuptools>=40.1.0"],
install_requires=[
'pywin32;platform_system=="Windows"',
- 'uvloop;platform_system=="Linux"',
],
+ extras_require={
+ "performance": [
+ 'winloop;platform_system=="Windows"',
+ 'uvloop;platform_system!="Windows"',
+ ],
+ },
packages=["wormhole"],
include_package_data=True,
entry_points={"console_scripts": ["wormhole = wormhole.proxy:main"]},
M wormhole/authentication.py +34 -28
@@ 1,49 1,55 @@
from base64 import decodebytes
+import asyncio
-def get_ident(client_reader, client_writer, user=None):
+def get_ident(
+ client_reader: asyncio.StreamReader,
+ client_writer: asyncio.StreamWriter,
+ user: str | None = None,
+) -> dict[str, str]:
client = client_writer.get_extra_info("peername")[0]
if user:
client = f"{user}@{client}"
return {"id": hex(id(client_reader))[-6:], "client": client}
-auth_list = list()
+auth_list: list[str] = []
-def get_auth_list(auth):
+def get_auth_list(auth: str) -> list[str]:
global auth_list
if not auth_list:
- auth_list = [
- line.strip()
- for line in open(auth, "r")
- if line.strip() and not line.strip().startswith("#")
- ]
+ with open(auth, "r") as f:
+
+ auth_list = [
+ line.strip() for line in f if line.strip() and not line.startswith("#")
+ ]
return auth_list
-def deny(client_writer):
- [
+def deny(client_writer: asyncio.StreamWriter) -> None:
+ messages = (
+ b"HTTP/1.1 407 Proxy Authentication Required\r\n",
+ b'Proxy-Authenticate: Basic realm="Wormhole Proxy"\r\n',
+ b"\r\n",
+ )
+ for message in messages:
client_writer.write(message)
- for message in (
- b"HTTP/1.1 407 Proxy Authentication Required\r\n",
- b'Proxy-Authenticate: Basic realm="Wormhole Proxy"\r\n',
- b"\r\n",
- )
- ]
-async def verify(client_reader, client_writer, headers, auth):
- proxy_auth = [
- header
- for header in headers
- if header.lower().startswith("proxy-authorization:")
- ]
+async def verify(
+ client_reader: asyncio.StreamReader,
+ client_writer: asyncio.StreamWriter,
+ headers: list[str],
+ auth: str,
+) -> dict[str, str] | None:
+ proxy_auth = [h for h in headers if h.lower().startswith("proxy-authorization:")]
if proxy_auth:
- user_password = decodebytes(
- proxy_auth[0].split(" ")[2].encode("ascii")
- ).decode("ascii")
+
+ user_password = decodebytes(proxy_auth[0].split(" ")[2].encode("ascii")).decode(
+ "ascii"
+ )
if user_password in get_auth_list(auth):
- user = user_password.split(":")[0]
- return get_ident(client_reader, client_writer, user)
- return deny(client_writer)
+ return get_ident(client_reader, client_writer, user_password.split(":")[0])
+ deny(client_writer)
+ return None
M wormhole/handler.py +95 -142
@@ 1,203 1,156 @@
import asyncio
from socket import TCP_NODELAY
from logger import Logger
-from tools import get_content_length
-from tools import get_host_and_port
+from tools import get_content_length, get_host_and_port
async def relay_stream(
- stream_reader, stream_writer, ident, return_first_line=False
-):
+ stream_reader: asyncio.StreamReader,
+ stream_writer: asyncio.StreamWriter,
+ ident: dict[str, str],
+ return_first_line: bool = False,
+) -> bytes | None:
logger = Logger().get_logger()
+
first_line = None
while True:
try:
line = await stream_reader.read(4096)
- if len(line) == 0:
+ if not line:
break
stream_writer.write(line)
except Exception as ex:
- error_message = "%s: %s" % (
- ex.__class__.__name__,
- " ".join([str(arg) for arg in ex.args]),
+
+ logger.debug(
+ f"[{ident['id']}][{ident['client']}]: {ex.__class__.__name__}: {' '.join(map(str, ex.args))}"
)
- logger.debug(f"[{ident['id']}][{ident['client']}]: {error_message}")
break
+
else:
if return_first_line and first_line is None:
first_line = line[: line.find(b"\r\n")]
+
return first_line
async def process_https(
- client_reader, client_writer, request_method, uri, ident
-):
- response_code = 200
- error_message = None
+ client_reader: asyncio.StreamReader,
+ client_writer: asyncio.StreamWriter,
+ request_method: str,
+ uri: str,
+ ident: dict[str, str],
+) -> None:
+ logger = Logger().get_logger()
host, port = get_host_and_port(uri)
- logger = Logger().get_logger()
try:
- req_reader, req_writer = await asyncio.open_connection(
- host, port, ssl=False
- )
- client_writer.write(b"HTTP/1.1 200 Connection established\r\n")
- client_writer.write(b"\r\n")
- # HTTPS need to log here, as the connection may keep alive for long.
- logger.info(
- f"[{ident['id']}][{ident['client']}]: "
- f"{request_method} {response_code} {uri}"
+ req_reader, req_writer = await asyncio.open_connection(host, port, ssl=False)
+ client_writer.write(b"HTTP/1.1 200 Connection established\r\n\r\n")
+ await asyncio.gather(
+ relay_stream(client_reader, req_writer, ident),
+ relay_stream(req_reader, client_writer, ident),
)
- tasks = [
- relay_stream(client_reader, req_writer, ident),
- relay_stream(req_reader, client_writer, ident),
- ]
- await asyncio.gather(*tasks)
+ logger.info(f"[{ident['id']}][{ident['client']}]: {request_method} 200 {uri}")
except Exception as ex:
- response_code = 502
- error_message = "%s: %s" % (
- ex.__class__.__name__,
- " ".join([str(arg) for arg in ex.args]),
- )
-
- if error_message:
logger.error(
- f"[{ident['id']}][{ident['client']}]: "
- f"{request_method} {response_code} {uri} ({error_message})"
+ f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})"
)
async def process_http(
- client_writer, request_method, uri, http_version, headers, payload, ident
-):
- response_status = None
- response_code = None
- error_message = None
- hostname = "127.0.0.1" # hostname (with optional port) e.g. example.com:80
+ client_writer: asyncio.StreamWriter,
+ request_method: str,
+ uri: str,
+ http_version: str,
+ headers: list[str],
+ payload: bytes,
+ ident: dict[str, str],
+) -> None:
+ logger = Logger().get_logger()
+ hostname = "127.0.0.1"
request_headers = []
- request_headers_end_index = 0
- has_connection_header = False
+ has_connection = False
for header in headers:
- name_and_value = header.split(": ", 1)
-
- if len(name_and_value) == 2:
- name, value = name_and_value
- else:
- name, value = name_and_value[0], None
+ if ": " in header:
+ name, value = header.split(": ", 1)
+ match name.lower():
+ case "host":
+ hostname = value
- if name.lower() == "host":
- if value is not None:
- hostname = value
- elif name.lower() == "connection":
- has_connection_header = True
- if value.lower() in ("keep-alive", "persist"):
- # current version of this program does not support
- # the HTTP keep-alive feature
- request_headers.append("Connection: close")
- else:
- request_headers.append(header)
- elif name.lower() != "proxy-connection":
- request_headers.append(header)
- if len(header) == 0 and request_headers_end_index == 0:
- request_headers_end_index = len(request_headers) - 1
+ case "connection":
+ has_connection = True
+ request_headers.append(
+ "Connection: close"
+ if value.lower() in ("keep-alive", "persist")
+ else header
+ )
+ case "proxy-connection":
+ continue
+ case _:
+ request_headers.append(header)
- if request_headers_end_index == 0:
- request_headers_end_index = len(request_headers)
+ if not has_connection:
+ request_headers.append("Connection: close")
- if not has_connection_header:
- request_headers.insert(request_headers_end_index, "Connection: close")
+ path = uri.removeprefix(f"http://{hostname}")
+ new_head = f"{request_method} {path} {http_version}"
+ host, port = get_host_and_port(hostname, "80")
- path = uri[len(hostname) + 7 :] # 7 is len('http://')
- new_head = " ".join([request_method, path, http_version])
- host, port = get_host_and_port(hostname, 80)
try:
+
req_reader, req_writer = await asyncio.open_connection(
host, port, flags=TCP_NODELAY
)
- req_writer.write(f"{new_head}\r\n".encode())
- await req_writer.drain()
-
- req_writer.write(f"Host: {hostname}".encode())
- req_writer.write(b"\r\n")
-
- [
+ req_writer.write(f"{new_head}\r\nHost: {hostname}\r\n".encode())
+ for header in request_headers:
req_writer.write(f"{header}\r\n".encode())
- for header in request_headers
- ]
req_writer.write(b"\r\n")
-
- if payload != b"":
+ if payload:
req_writer.write(payload)
- req_writer.write(b"\r\n")
await req_writer.drain()
- response_status = await relay_stream(
- req_reader, client_writer, ident, True
+ response_status = await relay_stream(req_reader, client_writer, ident, True)
+ response_code = (
+ int(response_status.decode("ascii").split(" ")[1])
+ if response_status
+ else 502
+ )
+ logger.info(
+ f"[{ident['id']}][{ident['client']}]: {request_method} {response_code} {uri}"
)
except Exception as ex:
- response_code = 502
- error_message = "%s: %s" % (
- ex.__class__.__name__,
- " ".join([str(arg) for arg in ex.args]),
- )
-
- if response_code is None:
- response_code = int(response_status.decode("ascii").split(" ")[1])
-
- logger = Logger().get_logger()
- if error_message is None:
- logger.info(
- f"[{ident['id']}][{ident['client']}]: "
- f"{request_method} {response_code} {uri}"
- )
- else:
logger.error(
- f"[{ident['id']}][{ident['client']}]: "
- f"{request_method} {response_code} {uri} ({error_message})"
+ f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})"
)
-async def process_request(client_reader, max_retry, ident):
+async def process_request(
+ client_reader: asyncio.StreamReader, max_retry: int, ident: dict[str, str]
+) -> tuple[str, list[str], bytes]:
+
logger = Logger().get_logger()
- request_line = ""
- headers = []
header = ""
payload = b""
- try:
- retry = 0
- while True:
- line = await client_reader.readline()
- if not line:
- if len(header) == 0 and retry < max_retry:
- # handle the case when the client make connection
- # but sending data is delayed for some reasons
- retry += 1
- await asyncio.sleep(0.1)
- continue
- else:
- break
- if line == b"\r\n":
- break
- if line != b"":
- header += line.decode()
+ retry = 0
+
+ while True:
+ line = await client_reader.readline()
+ if not line:
+ if not header and retry < max_retry:
+ retry += 1
+ await asyncio.sleep(0.1)
- content_length = get_content_length(header)
- while len(payload) < content_length:
- payload += await client_reader.read(4096)
- except Exception as ex:
- name = ex.__class__.__name__
- args = " ".join([str(arg) for arg in ex.args])
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: "
- f"!!! Task reject ({name}: {args})"
- )
+ continue
+ break
+ if line == b"\r\n":
+
+ break
+ header += line.decode()
- if header:
- header_lines = header.split("\r\n")
- if len(header_lines) > 1:
- request_line = header_lines[0]
- if len(header_lines) > 2:
- headers = header_lines[1:-1]
+ content_length = get_content_length(header)
+ while len(payload) < content_length:
+ payload += await client_reader.read(4096)
- return request_line, headers, payload
+ header_lines = header.split("\r\n")
+ return header_lines[0], header_lines[1:-1] if len(header_lines) > 2 else [], payload
M wormhole/license.py +1 -1
@@ 3,7 3,7 @@
LICENSE = """
The MIT License (MIT)
-Copyright (c) 2016 cwt(at)bashell(dot)com
+Copyright (c) 2016-2025 cwt(at)bashell(dot)com
Copyright (c) 2013 devunt(at)gmail(dot)com
Permission is hereby granted, free of charge, to any person obtaining a copy
M wormhole/logger.py +36 -24
@@ 1,64 1,76 @@
import logging
import logging.handlers
import os
+
import socket
+from datetime import datetime
+
class ContextFilter(logging.Filter):
- hostname = socket.gethostname()
- def filter(self, record):
+ hostname: str = socket.gethostname()
+
+ def filter(self, record: logging.LogRecord) -> bool:
record.hostname = self.hostname
return True
class Singleton(type):
- _instances = {}
+ _instances: dict[type, object] = {}
- def __call__(cls, *args, **kwargs):
+ def __call__(cls, *args: tuple, **kwargs: dict) -> object:
if cls not in cls._instances:
- cls._instances[cls] = super(Singleton, cls).__call__(
- *args, **kwargs
- )
+ cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
- def __init__(self, syslog_host=None, syslog_port=514, verbose=0):
- unix_format = "%(asctime)s %(name)s[%(process)d]: %(message)s"
- net_format = (
- "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s"
- )
- date_format = "%b %d %H:%M:%S"
-
+ def __init__(
+ self, syslog_host: str | None = None, syslog_port: int = 514, verbose: int = 0
+ ) -> None:
self.logger = logging.getLogger("wormhole")
self.logger.setLevel(logging.INFO)
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
+
if verbose >= 1:
self.logger.setLevel(logging.DEBUG)
if verbose >= 2:
logging.getLogger("asyncio").setLevel(logging.DEBUG)
- # Add console handler
+ # Console handler
console_handler = logging.StreamHandler()
- console_formatter = logging.Formatter(unix_format, datefmt=date_format)
- console_handler.setFormatter(console_formatter)
+ console_handler.setFormatter(
+ logging.Formatter(
+ "%(asctime)s %(name)s[%(process)d]: %(message)s",
+ datefmt="%b %d %H:%M:%S",
+ )
+ )
self.logger.addHandler(console_handler)
+ # Syslog handler
if syslog_host and syslog_host != "DISABLED":
if syslog_host.startswith("/") and os.path.exists(syslog_host):
- syslog = logging.handlers.SysLogHandler(
- address=syslog_host,
+ syslog = logging.handlers.SysLogHandler(address=syslog_host)
+ syslog.setFormatter(
+ logging.Formatter(
+ "%(asctime)s %(name)s[%(process)d]: %(message)s",
+ datefmt="%b %d %H:%M:%S",
+ )
)
- formatter = logging.Formatter(unix_format, datefmt=date_format)
else:
self.logger.addFilter(ContextFilter())
syslog = logging.handlers.SysLogHandler(
- address=(syslog_host, syslog_port),
+ address=(syslog_host, syslog_port)
)
- formatter = logging.Formatter(net_format, datefmt=date_format)
- syslog.setFormatter(formatter)
+
+ syslog.setFormatter(
+ logging.Formatter(
+ "%(asctime)s %(hostname)s %(name)s[%(process)d]: %(message)s",
+ datefmt="%b %d %H:%M:%S",
+ )
+ )
self.logger.addHandler(syslog)
- def get_logger(self):
+ def get_logger(self) -> logging.Logger:
return self.logger
M wormhole/proxy.py +50 -58
@@ 1,54 1,48 @@
#!/usr/bin/env python3
import sys
-
-if sys.version_info < (3, 6):
- print("Error: You need python 3.6 or newer.")
- exit(1)
+import signal
-import os
-from pathlib import Path
-
-sys.path.insert(0, Path(os.path.realpath(__file__)).parent.as_posix())
+if sys.version_info < (3, 11):
+ print("Error: You need Python 3.11 or newer.")
+ sys.exit(1)
import asyncio
from argparse import ArgumentParser
+from pathlib import Path
+
from license import LICENSE
from logger import Logger
from version import VERSION
-
-def start_server(host, port, authentication):
- from server import start_wormhole_server
+sys.path.insert(0, str(Path(__file__).parent.resolve()))
- loop = asyncio.get_event_loop()
- loop.run_until_complete(start_wormhole_server(host, port, authentication))
- loop.run_forever()
+try:
+ if sys.platform in ("win32", "cygwin"):
+ import winloop as uvloop
+ else:
+ import uvloop
+except ImportError:
+ uvloop = None
-def check_uvloop():
- try:
- import uvloop
- except ImportError:
- return False
- else:
- return True
+async def start_server(
+ host: str, port: int, auth: str | None, shutdown_event: asyncio.Event
+) -> None:
+ from server import start_wormhole_server
+
+ server = await start_wormhole_server(host, port, auth)
+ await shutdown_event.wait() # รอสัญญาณ shutdown
+ server.close()
+ await server.wait_closed()
-def main():
- """CLI frontend function. It takes command line options e.g. host,
- port and provides `--help` message.
- """
+def main() -> int:
parser = ArgumentParser(
- description=(
- f"Wormhole({VERSION}): Asynchronous IO HTTP and HTTPS Proxy"
- )
+ description=f"Wormhole({VERSION}): Asynchronous IO HTTP and HTTPS Proxy"
)
parser.add_argument(
- "-H",
- "--host",
- default="0.0.0.0",
- help="Host to listen [default: %(default)s]",
+ "-H", "--host", default="0.0.0.0", help="Host to listen [default: %(default)s]"
)
parser.add_argument(
"-p",
@@ 57,14 51,12 @@ def main():
default=8800,
help="Port to listen [default: %(default)d]",
)
+
parser.add_argument(
"-a",
"--authentication",
default="",
- help=(
- "File contains username and password list "
- "for proxy authentication [default: no authentication]"
- ),
+ help="File contains username and password list [default: no auth]",
)
parser.add_argument(
"-S",
@@ 77,46 69,46 @@ def main():
"--syslog-port",
type=int,
default=514,
- help="Syslog Port to listen [default: %(default)d]",
+ help="Syslog Port [default: %(default)d]",
)
parser.add_argument(
- "-l",
- "--license",
- action="store_true",
- default=False,
- help="Print LICENSE and exit",
+ "-l", "--license", action="store_true", help="Print LICENSE and exit"
)
parser.add_argument(
"-v", "--verbose", action="count", default=0, help="Print verbose"
)
args = parser.parse_args()
+
if args.license:
print(parser.description)
print(LICENSE)
- exit()
- if not (1 <= args.port <= 65535):
+ return 0
+
+ if not 1 <= args.port <= 65535:
parser.error("port must be 1-65535")
- logger = Logger(
- args.syslog_host, args.syslog_port, args.verbose
- ).get_logger()
+ logger = Logger(args.syslog_host, args.syslog_port, args.verbose).get_logger()
if args.verbose:
logger.debug(
- f"[000000][{args.host}]: Using {'uvloop' if check_uvloop() else 'default event loop'}."
+ f"[000000][{args.host}]: Using {uvloop.__name__ if uvloop else 'default event loop'}"
)
- if check_uvloop():
- import uvloop
+ loop = asyncio.get_event_loop()
+ shutdown_event = asyncio.Event()
+
+ # Signal handler
+ def handle_shutdown(signum, frame):
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+ print("\nbye")
+ shutdown_event.set()
- start_server(args.host, args.port, args.authentication)
+ signal.signal(signal.SIGINT, handle_shutdown)
+ signal.signal(signal.SIGTERM, handle_shutdown) # จับ SIGTERM ด้วย (เช่นจาก Docker)
+
+ runner = uvloop.run if uvloop else asyncio.run
+ runner(start_server(args.host, args.port, args.authentication, shutdown_event))
+ return 0
if __name__ == "__main__":
- try:
- exit(main())
- except OSError:
- pass
- except KeyboardInterrupt:
- print("\nbye")
+ sys.exit(main())
M wormhole/server.py +59 -56
@@ 3,31 3,29 @@ import functools
import socket
import sys
from time import time
-from authentication import get_ident
-from authentication import verify
-from handler import process_http
-from handler import process_https
-from handler import process_request
+from authentication import get_ident, verify
+from handler import process_http, process_https, process_request
+
from logger import Logger
-MAX_RETRY = 3
+
+MAX_RETRY: int = 3
+
if sys.platform == "win32":
import win32file
- FREE_TASKS = asyncio.Semaphore(
- int(0.9 * win32file._getmaxstdio()))
+ MAX_TASKS: int = int(0.9 * win32file._getmaxstdio())
else:
import resource
- FREE_TASKS = asyncio.Semaphore(
- int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0])
- )
-MAX_TASKS = FREE_TASKS._value
-
-clients = dict()
+ MAX_TASKS: int = int(0.9 * resource.getrlimit(resource.RLIMIT_NOFILE)[0])
-async def process_wormhole(client_reader, client_writer, auth):
+async def process_wormhole(
+ client_reader: asyncio.StreamReader,
+ client_writer: asyncio.StreamWriter,
+ auth: str | None,
+) -> None:
logger = Logger().get_logger()
ident = get_ident(client_reader, client_writer)
@@ 35,25 33,28 @@ async def process_wormhole(client_reader
client_reader, MAX_RETRY, ident
)
if not request_line:
+
logger.debug(
f"[{ident['id']}][{ident['client']}]: !!! Task reject (empty request)"
)
return
request_fields = request_line.split(" ")
- if len(request_fields) == 2:
- request_method, uri = request_fields
- http_version = "HTTP/1.0"
- elif len(request_fields) == 3:
- request_method, uri, http_version = request_fields
- else:
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)"
- )
- return
+ match len(request_fields):
+ case 2:
+ request_method, uri = request_fields
+ http_version = "HTTP/1.0"
+ case 3:
+ request_method, uri, http_version = request_fields
+ case _:
+ logger.debug(
+ f"[{ident['id']}][{ident['client']}]: !!! Task reject (invalid request)"
+ )
+ return
if auth:
user_ident = await verify(client_reader, client_writer, headers, auth)
+
if user_ident is None:
logger.info(
f"[{ident['id']}][{ident['client']}]: {request_method} 407 {uri}"
@@ 61,61 62,63 @@ async def process_wormhole(client_reader
return
ident = user_ident
- if request_method == "CONNECT":
- async with FREE_TASKS:
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
- )
- return await process_https(
- client_reader, client_writer, request_method, uri, ident
+ async with asyncio.TaskGroup() as tg:
+ if request_method == "CONNECT":
+ tg.create_task(
+ process_https(client_reader, client_writer, request_method, uri, ident)
)
- else:
- async with FREE_TASKS:
- logger.debug(
- f"[{ident['id']}][{ident['client']}]: {FREE_TASKS._value}/{MAX_TASKS} Resource available"
- )
- return await process_http(
- client_writer, request_method, uri, http_version, headers, payload, ident,
+ else:
+ tg.create_task(
+ process_http(
+ client_writer,
+ request_method,
+ uri,
+ http_version,
+ headers,
+ payload,
+ ident,
+ )
)
-async def accept_client(client_reader, client_writer, auth):
+async def accept_client(
+ client_reader: asyncio.StreamReader,
+ client_writer: asyncio.StreamWriter,
+ auth: str | None,
+) -> None:
logger = Logger().get_logger()
+
ident = get_ident(client_reader, client_writer)
- task = asyncio.ensure_future(
- process_wormhole(client_reader, client_writer, auth)
- )
- global clients
- clients[task] = (client_reader, client_writer)
started_time = time()
- def client_done(task):
- del clients[task]
+ async def client_task() -> None:
+ await process_wormhole(client_reader, client_writer, auth)
client_writer.close()
logger.debug(
f"[{ident['id']}][{ident['client']}]: Connection closed ({time() - started_time:.5f} seconds)"
)
logger.debug(f"[{ident['id']}][{ident['client']}]: Connection started")
- task.add_done_callback(client_done)
+ await client_task()
-async def start_wormhole_server(host, port, auth):
+async def start_wormhole_server(
+ host: str, port: int, auth: str | None
+) -> asyncio.Server:
logger = Logger().get_logger()
try:
accept = functools.partial(accept_client, auth=auth)
- # Check if the host string contains an IPv6 address
- is_ipv6 = ":" in host
- if is_ipv6:
- family = socket.AF_INET6
- else:
- family = socket.AF_INET
- server = await asyncio.start_server(accept, host, port, family=family)
+ family = socket.AF_INET6 if ":" in host else socket.AF_INET
+ server = await asyncio.start_server(
+ accept, host, port, family=family, limit=MAX_TASKS
+ )
except OSError as ex:
logger.critical(
f"[000000][{host}]: !!! Failed to bind server at [{host}:{port}]: {ex.args[1]}"
)
+
raise
+
else:
logger.info(f"[000000][{host}]: wormhole bound at {host}:{port}")
return server
M wormhole/tools.py +9 -16
@@ 1,26 1,19 @@
import re
-
REGEX_HOST = re.compile(r"(.+?):([0-9]{1,5})")
-REGEX_CONTENT_LENGTH = re.compile(
- r"\r\nContent-Length: ([0-9]+)\r\n", re.IGNORECASE
-)
+REGEX_CONTENT_LENGTH = re.compile(r"\r\nContent-Length: ([0-9]+)\r\n", re.IGNORECASE)
-def get_host_and_port(hostname, default_port=None):
- match = REGEX_HOST.search(hostname)
- if match:
- host = match.group(1)
- port = int(match.group(2))
- else:
- host = hostname
- port = int(default_port)
- return host, port
+def get_host_and_port(
+ hostname: str, default_port: str | None = None
+) -> tuple[str, int]:
+ if match := REGEX_HOST.search(hostname):
+ return match.group(1), int(match.group(2))
+ return hostname, int(default_port or "80")
-def get_content_length(header):
- match = REGEX_CONTENT_LENGTH.search(header)
- if match:
+def get_content_length(header: str) -> int:
+ if match := REGEX_CONTENT_LENGTH.search(header):
return int(match.group(1))
return 0
M wormhole/version.py +1 -1
@@ 1,1 1,1 @@
-VERSION = "v3.0.2"
+VERSION = "v3.1.0"