""" An ``asyncio.Protocol`` subclass for lower level IO handling. """ import asyncio import collections import re import ssl from typing import Deque, Optional, cast from .errors import ( SMTPDataError, SMTPReadTimeoutError, SMTPResponseException, SMTPServerDisconnected, SMTPTimeoutError, ) from .response import SMTPResponse from .typing import SMTPStatus __all__ = ("SMTPProtocol",) MAX_LINE_LENGTH = 8192 LINE_ENDINGS_REGEX = re.compile(rb"(?:\r\n|\n|\r(?!\n))") PERIOD_REGEX = re.compile(rb"(?m)^\.") class FlowControlMixin(asyncio.Protocol): """ Reusable flow control logic for StreamWriter.drain(). This implements the protocol methods pause_writing(), resume_writing() and connection_lost(). If the subclass overrides these it must call the super methods. StreamWriter.drain() must wait for _drain_helper() coroutine. Copied from stdlib as per recommendation: https://bugs.python.org/msg343685. Logging and asserts removed, type annotations added. """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): if loop is None: self._loop = asyncio.get_event_loop() else: self._loop = loop self._paused = False self._drain_waiters: Deque[asyncio.Future[None]] = collections.deque() self._connection_lost = False def pause_writing(self) -> None: self._paused = True def resume_writing(self) -> None: self._paused = False for waiter in self._drain_waiters: if not waiter.done(): waiter.set_result(None) def connection_lost(self, exc: Optional[Exception]) -> None: self._connection_lost = True # Wake up the writer(s) if currently paused. if not self._paused: return for waiter in self._drain_waiters: if not waiter.done(): if exc is None: waiter.set_result(None) else: waiter.set_exception(exc) async def _drain_helper(self) -> None: if self._connection_lost: raise ConnectionResetError("Connection lost") if not self._paused: return waiter = self._loop.create_future() self._drain_waiters.append(waiter) try: await waiter finally: self._drain_waiters.remove(waiter) def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]": raise NotImplementedError class SMTPProtocol(FlowControlMixin, asyncio.BaseProtocol): def __init__( self, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__(loop=loop) self._over_ssl = False self._buffer = bytearray() self._response_waiter: Optional[asyncio.Future[SMTPResponse]] = None self.transport: Optional[asyncio.BaseTransport] = None self._command_lock: Optional[asyncio.Lock] = None self._closed: "asyncio.Future[None]" = self._loop.create_future() self._quit_sent = False def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]": return self._closed def __del__(self) -> None: # Avoid 'Future exception was never retrieved' warnings # Some unknown race conditions can sometimes trigger these :( self._retrieve_response_exception() @property def is_connected(self) -> bool: """ Check if our transport is still connected. """ return bool(self.transport is not None and not self.transport.is_closing()) def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.Transport, transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None self._response_waiter = self._loop.create_future() self._command_lock = asyncio.Lock() self._quit_sent = False def connection_lost(self, exc: Optional[Exception]) -> None: super().connection_lost(exc) if not self._quit_sent: smtp_exc = SMTPServerDisconnected("Connection lost") if exc: smtp_exc.__cause__ = exc if self._response_waiter and not self._response_waiter.done(): self._response_waiter.set_exception(smtp_exc) self.transport = None self._command_lock = None def data_received(self, data: bytes) -> None: if self._response_waiter is None: raise RuntimeError( f"data_received called without a response waiter set: {data!r}" ) elif self._response_waiter.done(): # We got a response without issuing a command; ignore it. return self._buffer.extend(data) # If we got an obvious partial message, don't try to parse the buffer last_linebreak = data.rfind(b"\n") if ( last_linebreak == -1 or data[last_linebreak + 3 : last_linebreak + 4] == b"-" ): return try: response = self._read_response_from_buffer() except Exception as exc: self._response_waiter.set_exception(exc) else: if response is not None: self._response_waiter.set_result(response) def eof_received(self) -> bool: exc = SMTPServerDisconnected("Unexpected EOF received") if self._response_waiter and not self._response_waiter.done(): self._response_waiter.set_exception(exc) # Returning false closes the transport return False def _retrieve_response_exception(self) -> Optional[BaseException]: """ Return any exception that has been set on the response waiter. Used to avoid 'Future exception was never retrieved' warnings """ if ( self._response_waiter and self._response_waiter.done() and not self._response_waiter.cancelled() ): return self._response_waiter.exception() return None def _read_response_from_buffer(self) -> Optional[SMTPResponse]: """Parse the actual response (if any) from the data buffer""" code = -1 message = bytearray() offset = 0 message_complete = False while True: line_end_index = self._buffer.find(b"\n", offset) if line_end_index == -1: break line = bytes(self._buffer[offset : line_end_index + 1]) if len(line) > MAX_LINE_LENGTH: raise SMTPResponseException( SMTPStatus.unrecognized_command, "Response too long" ) try: code = int(line[:3]) except ValueError: raise SMTPResponseException( SMTPStatus.invalid_response.value, f"Malformed SMTP response line: {line!r}", ) from None offset += len(line) if len(message): message.extend(b"\n") message.extend(line[4:].strip(b" \t\r\n")) if line[3:4] != b"-": message_complete = True break if message_complete: response = SMTPResponse( code, bytes(message).decode("utf-8", "surrogateescape") ) del self._buffer[:offset] return response else: return None async def read_response(self, timeout: Optional[float] = None) -> SMTPResponse: """ Get a status response from the server. This method must be awaited once per command sent; if multiple commands are written to the transport without awaiting, response data will be lost. Returns an :class:`.response.SMTPResponse` namedtuple consisting of: - server response code (e.g. 250, or such, if all goes well) - server response string (multiline responses are converted to a single, multiline string). """ if self._response_waiter is None: raise SMTPServerDisconnected("Connection lost") try: result = await asyncio.wait_for(self._response_waiter, timeout) except (TimeoutError, asyncio.TimeoutError) as exc: raise SMTPReadTimeoutError("Timed out waiting for server response") from exc finally: # If we were disconnected, don't create a new waiter if self.transport is None: self._response_waiter = None else: self._response_waiter = self._loop.create_future() return result def write(self, data: bytes) -> None: if self.transport is None or self.transport.is_closing(): raise SMTPServerDisconnected("Connection lost") if not hasattr(self.transport, "write"): raise RuntimeError( f"Transport {self.transport!r} does not support writing." ) self.transport.write(data) # type: ignore async def execute_command( self, *args: bytes, timeout: Optional[float] = None ) -> SMTPResponse: """ Sends an SMTP command along with any args to the server, and returns a response. """ if self._command_lock is None: raise SMTPServerDisconnected("Server not connected") command = b" ".join(args) + b"\r\n" async with self._command_lock: self.write(command) if command == b"QUIT\r\n": self._quit_sent = True response = await self.read_response(timeout=timeout) return response async def execute_data_command( self, message: bytes, timeout: Optional[float] = None ) -> SMTPResponse: """ Sends an SMTP DATA command to the server, followed by encoded message content. Automatically quotes lines beginning with a period per RFC821. Lone \\\\r and \\\\n characters are converted to \\\\r\\\\n characters. """ if self._command_lock is None: raise SMTPServerDisconnected("Server not connected") message = LINE_ENDINGS_REGEX.sub(b"\r\n", message) message = PERIOD_REGEX.sub(b"..", message) if not message.endswith(b"\r\n"): message += b"\r\n" message += b".\r\n" async with self._command_lock: self.write(b"DATA\r\n") start_response = await self.read_response(timeout=timeout) if start_response.code != SMTPStatus.start_input: raise SMTPDataError(start_response.code, start_response.message) self.write(message) response = await self.read_response(timeout=timeout) if response.code != SMTPStatus.completed: raise SMTPDataError(response.code, response.message) return response async def start_tls( self, tls_context: ssl.SSLContext, server_hostname: Optional[str] = None, timeout: Optional[float] = None, ) -> SMTPResponse: """ Puts the connection to the SMTP server into TLS mode. """ if self._over_ssl: raise RuntimeError("Already using TLS.") if self._command_lock is None: raise SMTPServerDisconnected("Server not connected") async with self._command_lock: self.write(b"STARTTLS\r\n") response = await self.read_response(timeout=timeout) if response.code != SMTPStatus.ready: raise SMTPResponseException(response.code, response.message) # Check for disconnect after response if self.transport is None or self.transport.is_closing(): raise SMTPServerDisconnected("Connection lost") try: tls_transport = await self._loop.start_tls( self.transport, self, tls_context, server_side=False, server_hostname=server_hostname, ssl_handshake_timeout=timeout, ) except (TimeoutError, asyncio.TimeoutError) as exc: raise SMTPTimeoutError("Timed out while upgrading transport") from exc # SSLProtocol only raises ConnectionAbortedError on timeout except ConnectionAbortedError as exc: raise SMTPTimeoutError(exc.args[0]) from exc except ConnectionResetError as exc: if exc.args: message = exc.args[0] else: message = "Connection was reset while upgrading transport" raise SMTPServerDisconnected(message) from exc self.transport = tls_transport return response