@@ 27,7 27,7 @@ def get_auth_list(auth: str) -> list[str
return auth_list
-def deny(client_writer: asyncio.StreamWriter) -> None:
+async def deny(client_writer: asyncio.StreamWriter) -> None:
messages = (
b"HTTP/1.1 407 Proxy Authentication Required\r\n",
b'Proxy-Authenticate: Basic realm="Wormhole Proxy"\r\n',
@@ 35,6 35,7 @@ def deny(client_writer: asyncio.StreamWr
)
for message in messages:
client_writer.write(message)
+ await client_writer.drain()
async def verify(
@@ 51,5 52,5 @@ async def verify(
)
if user_password in get_auth_list(auth):
return get_ident(client_reader, client_writer, user_password.split(":")[0])
- deny(client_writer)
+ await deny(client_writer)
return None
@@ 10,23 10,29 @@ async def relay_stream(
ident: dict[str, str],
return_first_line: bool = False,
) -> bytes | None:
+ # Initialize logger for debugging and info logging
logger = Logger().get_logger()
+ first_line = None
- first_line = None
+ # Relay data between reader and writer with proper async handling
while True:
try:
+ # Read data asynchronously from the stream
line = await stream_reader.read(4096)
if not line:
break
+
+ # Write data and ensure buffer is flushed asynchronously
stream_writer.write(line)
+ await stream_writer.drain() # Ensure no blocking if buffer is full
except Exception as ex:
-
+ # Log any exceptions during relay
logger.debug(
f"[{ident['id']}][{ident['client']}]: {ex.__class__.__name__}: {' '.join(map(str, ex.args))}"
)
break
-
else:
+ # Capture first line if requested
if return_first_line and first_line is None:
first_line = line[: line.find(b"\r\n")]
@@ 40,18 46,28 @@ async def process_https(
uri: str,
ident: dict[str, str],
) -> None:
+ # Initialize logger for tracking request status
logger = Logger().get_logger()
host, port = get_host_and_port(uri)
+
try:
+ # Open connection to target server without SSL (tunneling)
req_reader, req_writer = await asyncio.open_connection(host, port, ssl=False)
+
+ # Send success response to client
client_writer.write(b"HTTP/1.1 200 Connection established\r\n\r\n")
+ await client_writer.drain() # Ensure response is sent
+
+ # Relay data bidirectionally between client and target server
await asyncio.gather(
relay_stream(client_reader, req_writer, ident),
relay_stream(req_reader, client_writer, ident),
)
+ # Log successful tunneling
logger.info(f"[{ident['id']}][{ident['client']}]: {request_method} 200 {uri}")
except Exception as ex:
+ # Log errors during HTTPS processing
logger.error(
f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})"
)
@@ 66,18 82,19 @@ async def process_http(
payload: bytes,
ident: dict[str, str],
) -> None:
+ # Initialize logger for request tracking
logger = Logger().get_logger()
hostname = "127.0.0.1"
request_headers = []
has_connection = False
+ # Process headers efficiently
for header in headers:
if ": " in header:
name, value = header.split(": ", 1)
match name.lower():
case "host":
hostname = value
-
case "connection":
has_connection = True
request_headers.append(
@@ 93,33 110,40 @@ async def process_http(
if not has_connection:
request_headers.append("Connection: close")
+ # Construct new request path and headers
path = uri.removeprefix(f"http://{hostname}")
new_head = f"{request_method} {path} {http_version}"
host, port = get_host_and_port(hostname, "80")
try:
-
+ # Open connection to target server with TCP_NODELAY for performance
req_reader, req_writer = await asyncio.open_connection(
host, port, flags=TCP_NODELAY
)
+
+ # Write request headers and payload asynchronously
req_writer.write(f"{new_head}\r\nHost: {hostname}\r\n".encode())
for header in request_headers:
req_writer.write(f"{header}\r\n".encode())
req_writer.write(b"\r\n")
if payload:
req_writer.write(payload)
- await req_writer.drain()
+ await req_writer.drain() # Ensure all data is sent
+ # Relay response and get status code
response_status = await relay_stream(req_reader, client_writer, ident, True)
response_code = (
int(response_status.decode("ascii").split(" ")[1])
if response_status
else 502
)
+
+ # Log request outcome
logger.info(
f"[{ident['id']}][{ident['client']}]: {request_method} {response_code} {uri}"
)
except Exception as ex:
+ # Log errors during HTTP processing
logger.error(
f"[{ident['id']}][{ident['client']}]: {request_method} 502 {uri} ({ex.__class__.__name__}: {' '.join(map(str, ex.args))})"
)
@@ 128,29 152,35 @@ async def process_http(
async def process_request(
client_reader: asyncio.StreamReader, max_retry: int, ident: dict[str, str]
) -> tuple[str, list[str], bytes]:
-
+ # Initialize logger for debugging
logger = Logger().get_logger()
- header = ""
payload = b""
retry = 0
- while True:
- line = await client_reader.readline()
- if not line:
- if not header and retry < max_retry:
- retry += 1
- await asyncio.sleep(0.1)
-
+ # Read headers until double CRLF efficiently
+ try:
+ header_bytes = await client_reader.readuntil(b"\r\n\r\n")
+ header = header_bytes.decode("ascii")
+ except asyncio.IncompleteReadError:
+ # Retry on incomplete read up to max_retry
+ while retry < max_retry:
+ retry += 1
+ await asyncio.sleep(0.1)
+ try:
+ header_bytes = await client_reader.readuntil(b"\r\n\r\n")
+ header = header_bytes.decode("ascii")
+ break
+ except asyncio.IncompleteReadError:
continue
- break
- if line == b"\r\n":
+ else:
+ # If retries exhausted, return empty result
+ return "", [], b""
- break
- header += line.decode()
-
+ # Extract content length and read payload
content_length = get_content_length(header)
- while len(payload) < content_length:
- payload += await client_reader.read(4096)
+ if content_length > 0:
+ payload = await client_reader.readexactly(content_length)
+ # Split headers into lines
header_lines = header.split("\r\n")
return header_lines[0], header_lines[1:-1] if len(header_lines) > 2 else [], payload