374 lines
13 KiB
Python
374 lines
13 KiB
Python
"""
|
|
An ``asyncio.Protocol`` subclass for lower level IO handling.
|
|
"""
|
|
import asyncio
|
|
import collections
|
|
import re
|
|
import ssl
|
|
from typing import Deque, Optional, cast
|
|
|
|
from .errors import (
|
|
SMTPDataError,
|
|
SMTPReadTimeoutError,
|
|
SMTPResponseException,
|
|
SMTPServerDisconnected,
|
|
SMTPTimeoutError,
|
|
)
|
|
from .response import SMTPResponse
|
|
from .typing import SMTPStatus
|
|
|
|
|
|
__all__ = ("SMTPProtocol",)
|
|
|
|
|
|
MAX_LINE_LENGTH = 8192
|
|
LINE_ENDINGS_REGEX = re.compile(rb"(?:\r\n|\n|\r(?!\n))")
|
|
PERIOD_REGEX = re.compile(rb"(?m)^\.")
|
|
|
|
|
|
class FlowControlMixin(asyncio.Protocol):
|
|
"""
|
|
Reusable flow control logic for StreamWriter.drain().
|
|
This implements the protocol methods pause_writing(),
|
|
resume_writing() and connection_lost(). If the subclass overrides
|
|
these it must call the super methods.
|
|
StreamWriter.drain() must wait for _drain_helper() coroutine.
|
|
|
|
Copied from stdlib as per recommendation: https://bugs.python.org/msg343685.
|
|
Logging and asserts removed, type annotations added.
|
|
"""
|
|
|
|
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
|
|
if loop is None:
|
|
self._loop = asyncio.get_event_loop()
|
|
else:
|
|
self._loop = loop
|
|
|
|
self._paused = False
|
|
self._drain_waiters: Deque[asyncio.Future[None]] = collections.deque()
|
|
self._connection_lost = False
|
|
|
|
def pause_writing(self) -> None:
|
|
self._paused = True
|
|
|
|
def resume_writing(self) -> None:
|
|
self._paused = False
|
|
|
|
for waiter in self._drain_waiters:
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
self._connection_lost = True
|
|
# Wake up the writer(s) if currently paused.
|
|
if not self._paused:
|
|
return
|
|
|
|
for waiter in self._drain_waiters:
|
|
if not waiter.done():
|
|
if exc is None:
|
|
waiter.set_result(None)
|
|
else:
|
|
waiter.set_exception(exc)
|
|
|
|
async def _drain_helper(self) -> None:
|
|
if self._connection_lost:
|
|
raise ConnectionResetError("Connection lost")
|
|
if not self._paused:
|
|
return
|
|
waiter = self._loop.create_future()
|
|
self._drain_waiters.append(waiter)
|
|
try:
|
|
await waiter
|
|
finally:
|
|
self._drain_waiters.remove(waiter)
|
|
|
|
def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]":
|
|
raise NotImplementedError
|
|
|
|
|
|
class SMTPProtocol(FlowControlMixin, asyncio.BaseProtocol):
|
|
def __init__(
|
|
self,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
) -> None:
|
|
super().__init__(loop=loop)
|
|
self._over_ssl = False
|
|
self._buffer = bytearray()
|
|
self._response_waiter: Optional[asyncio.Future[SMTPResponse]] = None
|
|
|
|
self.transport: Optional[asyncio.BaseTransport] = None
|
|
self._command_lock: Optional[asyncio.Lock] = None
|
|
self._closed: "asyncio.Future[None]" = self._loop.create_future()
|
|
self._quit_sent = False
|
|
|
|
def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]":
|
|
return self._closed
|
|
|
|
def __del__(self) -> None:
|
|
# Avoid 'Future exception was never retrieved' warnings
|
|
# Some unknown race conditions can sometimes trigger these :(
|
|
self._retrieve_response_exception()
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""
|
|
Check if our transport is still connected.
|
|
"""
|
|
return bool(self.transport is not None and not self.transport.is_closing())
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
self.transport = cast(asyncio.Transport, transport)
|
|
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
self._response_waiter = self._loop.create_future()
|
|
self._command_lock = asyncio.Lock()
|
|
self._quit_sent = False
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
super().connection_lost(exc)
|
|
|
|
if not self._quit_sent:
|
|
smtp_exc = SMTPServerDisconnected("Connection lost")
|
|
if exc:
|
|
smtp_exc.__cause__ = exc
|
|
|
|
if self._response_waiter and not self._response_waiter.done():
|
|
self._response_waiter.set_exception(smtp_exc)
|
|
|
|
self.transport = None
|
|
self._command_lock = None
|
|
|
|
def data_received(self, data: bytes) -> None:
|
|
if self._response_waiter is None:
|
|
raise RuntimeError(
|
|
f"data_received called without a response waiter set: {data!r}"
|
|
)
|
|
elif self._response_waiter.done():
|
|
# We got a response without issuing a command; ignore it.
|
|
return
|
|
|
|
self._buffer.extend(data)
|
|
|
|
# If we got an obvious partial message, don't try to parse the buffer
|
|
last_linebreak = data.rfind(b"\n")
|
|
if (
|
|
last_linebreak == -1
|
|
or data[last_linebreak + 3 : last_linebreak + 4] == b"-"
|
|
):
|
|
return
|
|
|
|
try:
|
|
response = self._read_response_from_buffer()
|
|
except Exception as exc:
|
|
self._response_waiter.set_exception(exc)
|
|
else:
|
|
if response is not None:
|
|
self._response_waiter.set_result(response)
|
|
|
|
def eof_received(self) -> bool:
|
|
exc = SMTPServerDisconnected("Unexpected EOF received")
|
|
if self._response_waiter and not self._response_waiter.done():
|
|
self._response_waiter.set_exception(exc)
|
|
|
|
# Returning false closes the transport
|
|
return False
|
|
|
|
def _retrieve_response_exception(self) -> Optional[BaseException]:
|
|
"""
|
|
Return any exception that has been set on the response waiter.
|
|
|
|
Used to avoid 'Future exception was never retrieved' warnings
|
|
"""
|
|
if (
|
|
self._response_waiter
|
|
and self._response_waiter.done()
|
|
and not self._response_waiter.cancelled()
|
|
):
|
|
return self._response_waiter.exception()
|
|
|
|
return None
|
|
|
|
def _read_response_from_buffer(self) -> Optional[SMTPResponse]:
|
|
"""Parse the actual response (if any) from the data buffer"""
|
|
code = -1
|
|
message = bytearray()
|
|
offset = 0
|
|
message_complete = False
|
|
|
|
while True:
|
|
line_end_index = self._buffer.find(b"\n", offset)
|
|
if line_end_index == -1:
|
|
break
|
|
|
|
line = bytes(self._buffer[offset : line_end_index + 1])
|
|
|
|
if len(line) > MAX_LINE_LENGTH:
|
|
raise SMTPResponseException(
|
|
SMTPStatus.unrecognized_command, "Response too long"
|
|
)
|
|
|
|
try:
|
|
code = int(line[:3])
|
|
except ValueError:
|
|
raise SMTPResponseException(
|
|
SMTPStatus.invalid_response.value,
|
|
f"Malformed SMTP response line: {line!r}",
|
|
) from None
|
|
|
|
offset += len(line)
|
|
if len(message):
|
|
message.extend(b"\n")
|
|
message.extend(line[4:].strip(b" \t\r\n"))
|
|
if line[3:4] != b"-":
|
|
message_complete = True
|
|
break
|
|
|
|
if message_complete:
|
|
response = SMTPResponse(
|
|
code, bytes(message).decode("utf-8", "surrogateescape")
|
|
)
|
|
del self._buffer[:offset]
|
|
return response
|
|
else:
|
|
return None
|
|
|
|
async def read_response(self, timeout: Optional[float] = None) -> SMTPResponse:
|
|
"""
|
|
Get a status response from the server.
|
|
|
|
This method must be awaited once per command sent; if multiple commands
|
|
are written to the transport without awaiting, response data will be lost.
|
|
|
|
Returns an :class:`.response.SMTPResponse` namedtuple consisting of:
|
|
- server response code (e.g. 250, or such, if all goes well)
|
|
- server response string (multiline responses are converted to a
|
|
single, multiline string).
|
|
"""
|
|
if self._response_waiter is None:
|
|
raise SMTPServerDisconnected("Connection lost")
|
|
|
|
try:
|
|
result = await asyncio.wait_for(self._response_waiter, timeout)
|
|
except (TimeoutError, asyncio.TimeoutError) as exc:
|
|
raise SMTPReadTimeoutError("Timed out waiting for server response") from exc
|
|
finally:
|
|
# If we were disconnected, don't create a new waiter
|
|
if self.transport is None:
|
|
self._response_waiter = None
|
|
else:
|
|
self._response_waiter = self._loop.create_future()
|
|
|
|
return result
|
|
|
|
def write(self, data: bytes) -> None:
|
|
if self.transport is None or self.transport.is_closing():
|
|
raise SMTPServerDisconnected("Connection lost")
|
|
if not hasattr(self.transport, "write"):
|
|
raise RuntimeError(
|
|
f"Transport {self.transport!r} does not support writing."
|
|
)
|
|
|
|
self.transport.write(data) # type: ignore
|
|
|
|
async def execute_command(
|
|
self, *args: bytes, timeout: Optional[float] = None
|
|
) -> SMTPResponse:
|
|
"""
|
|
Sends an SMTP command along with any args to the server, and returns
|
|
a response.
|
|
"""
|
|
if self._command_lock is None:
|
|
raise SMTPServerDisconnected("Server not connected")
|
|
command = b" ".join(args) + b"\r\n"
|
|
|
|
async with self._command_lock:
|
|
self.write(command)
|
|
|
|
if command == b"QUIT\r\n":
|
|
self._quit_sent = True
|
|
|
|
response = await self.read_response(timeout=timeout)
|
|
|
|
return response
|
|
|
|
async def execute_data_command(
|
|
self, message: bytes, timeout: Optional[float] = None
|
|
) -> SMTPResponse:
|
|
"""
|
|
Sends an SMTP DATA command to the server, followed by encoded message content.
|
|
|
|
Automatically quotes lines beginning with a period per RFC821.
|
|
Lone \\\\r and \\\\n characters are converted to \\\\r\\\\n
|
|
characters.
|
|
"""
|
|
if self._command_lock is None:
|
|
raise SMTPServerDisconnected("Server not connected")
|
|
|
|
message = LINE_ENDINGS_REGEX.sub(b"\r\n", message)
|
|
message = PERIOD_REGEX.sub(b"..", message)
|
|
if not message.endswith(b"\r\n"):
|
|
message += b"\r\n"
|
|
message += b".\r\n"
|
|
|
|
async with self._command_lock:
|
|
self.write(b"DATA\r\n")
|
|
start_response = await self.read_response(timeout=timeout)
|
|
if start_response.code != SMTPStatus.start_input:
|
|
raise SMTPDataError(start_response.code, start_response.message)
|
|
|
|
self.write(message)
|
|
response = await self.read_response(timeout=timeout)
|
|
if response.code != SMTPStatus.completed:
|
|
raise SMTPDataError(response.code, response.message)
|
|
|
|
return response
|
|
|
|
async def start_tls(
|
|
self,
|
|
tls_context: ssl.SSLContext,
|
|
server_hostname: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
) -> SMTPResponse:
|
|
"""
|
|
Puts the connection to the SMTP server into TLS mode.
|
|
"""
|
|
if self._over_ssl:
|
|
raise RuntimeError("Already using TLS.")
|
|
if self._command_lock is None:
|
|
raise SMTPServerDisconnected("Server not connected")
|
|
|
|
async with self._command_lock:
|
|
self.write(b"STARTTLS\r\n")
|
|
response = await self.read_response(timeout=timeout)
|
|
if response.code != SMTPStatus.ready:
|
|
raise SMTPResponseException(response.code, response.message)
|
|
|
|
# Check for disconnect after response
|
|
if self.transport is None or self.transport.is_closing():
|
|
raise SMTPServerDisconnected("Connection lost")
|
|
|
|
try:
|
|
tls_transport = await self._loop.start_tls(
|
|
self.transport,
|
|
self,
|
|
tls_context,
|
|
server_side=False,
|
|
server_hostname=server_hostname,
|
|
ssl_handshake_timeout=timeout,
|
|
)
|
|
except (TimeoutError, asyncio.TimeoutError) as exc:
|
|
raise SMTPTimeoutError("Timed out while upgrading transport") from exc
|
|
# SSLProtocol only raises ConnectionAbortedError on timeout
|
|
except ConnectionAbortedError as exc:
|
|
raise SMTPTimeoutError(exc.args[0]) from exc
|
|
except ConnectionResetError as exc:
|
|
if exc.args:
|
|
message = exc.args[0]
|
|
else:
|
|
message = "Connection was reset while upgrading transport"
|
|
raise SMTPServerDisconnected(message) from exc
|
|
|
|
self.transport = tls_transport
|
|
|
|
return response
|