This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -40,11 +40,14 @@ try:
DEFAULT_USER = getpass.getuser()
del getpass
except (ImportError, KeyError):
# KeyError occurs when there's no entry in OS database for a current user.
except (ImportError, KeyError, OSError):
# When there's no entry in OS database for a current user:
# KeyError is raised in Python 3.12 and below.
# OSError is raised in Python 3.13+
DEFAULT_USER = None
DEBUG = False
_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
TEXT_TYPES = {
FIELD_TYPE.BIT,
@@ -84,8 +87,7 @@ def _lenenc_int(i):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError(
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
% (i, (1 << 64))
f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
)
@@ -135,6 +137,7 @@ class Connection:
:param ssl_disabled: A boolean value that disables usage of TLS.
:param ssl_key: Path to the file that contains a PEM-formatted private key for
the client certificate.
:param ssl_key_password: The password for the client certificate private key.
:param ssl_verify_cert: Set to true to check the server certificate's validity.
:param ssl_verify_identity: Set to true to check the server's identity.
:param read_default_group: Group to read from in the configuration file.
@@ -161,6 +164,7 @@ class Connection:
"""
_sock = None
_rfile = None
_auth_plugin_name = ""
_closed = False
_secure = False
@@ -201,6 +205,7 @@ class Connection:
ssl_cert=None,
ssl_disabled=None,
ssl_key=None,
ssl_key_password=None,
ssl_verify_cert=None,
ssl_verify_identity=None,
compress=None, # not supported
@@ -262,7 +267,7 @@ class Connection:
if not ssl:
ssl = {}
if isinstance(ssl, dict):
for key in ["ca", "capath", "cert", "key", "cipher"]:
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
value = _config("ssl-" + key, ssl.get(key))
if value:
ssl[key] = value
@@ -281,6 +286,8 @@ class Connection:
ssl["cert"] = ssl_cert
if ssl_key is not None:
ssl["key"] = ssl_key
if ssl_key_password is not None:
ssl["password"] = ssl_key_password
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
@@ -371,6 +378,12 @@ class Connection:
capath = sslp.get("capath")
hasnoca = ca is None and capath is None
ctx = ssl.create_default_context(cafile=ca, capath=capath)
# Python 3.13 enables VERIFY_X509_STRICT by default.
# But self signed certificates that are generated by MySQL automatically
# doesn't pass the verification.
ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
verify_mode_value = sslp.get("verify_mode")
if verify_mode_value is None:
@@ -389,7 +402,9 @@ class Connection:
else:
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
if "cert" in sslp:
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
ctx.load_cert_chain(
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
)
if "cipher" in sslp:
ctx.set_ciphers(sslp["cipher"])
ctx.options |= ssl.OP_NO_SSLv2
@@ -425,6 +440,8 @@ class Connection:
def _force_close(self):
"""Close connection without QUIT message."""
if self._rfile:
self._rfile.close()
if self._sock:
try:
self._sock.close()
@@ -566,9 +583,9 @@ class Connection:
return self._affected_rows
def kill(self, thread_id):
arg = struct.pack("<I", thread_id)
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
return self._read_ok_packet()
if not isinstance(thread_id, int):
raise TypeError("thread_id must be an integer")
self.query(f"KILL {thread_id:d}")
def ping(self, reconnect=True):
"""
@@ -691,12 +708,7 @@ class Connection:
if self.autocommit_mode is not None:
self.autocommit(self.autocommit_mode)
except BaseException as e:
self._rfile = None
if sock is not None:
try:
sock.close()
except: # noqa
pass
self._force_close()
if isinstance(e, (OSError, IOError)):
exc = err.OperationalError(
@@ -760,8 +772,6 @@ class Connection:
dump_packet(recv_data)
buff += recv_data
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
if bytes_to_read == 0xFFFFFF:
continue
if bytes_to_read < MAX_PACKET_LEN:
break
@@ -809,16 +819,10 @@ class Connection:
def _read_query_result(self, unbuffered=False):
self._result = None
result = MySQLResult(self)
if unbuffered:
try:
result = MySQLResult(self)
result.init_unbuffered_query()
except:
result.unbuffered_active = False
result.connection = None
raise
result.init_unbuffered_query()
else:
result = MySQLResult(self)
result.read()
self._result = result
if result.server_status is not None:
@@ -993,9 +997,8 @@ class Connection:
if plugin_name != b"dialog":
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r missing authenticate method"
% (plugin_name, type(handler)),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {type(handler)!r} missing authenticate method",
)
if plugin_name == b"caching_sha2_password":
return _auth.caching_sha2_password_auth(self, auth_packet)
@@ -1031,16 +1034,14 @@ class Connection:
except AttributeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r missing prompt method"
% (plugin_name, handler),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {handler!r} missing prompt method",
)
except TypeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_ERR,
"Authentication plugin '%s'"
" %r didn't respond with string. Returned '%r' to prompt %r"
% (plugin_name, handler, resp, prompt),
f"Authentication plugin '{plugin_name}'"
f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
)
else:
raise err.OperationalError(
@@ -1073,9 +1074,8 @@ class Connection:
except TypeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r cannot be constructed with connection object"
% (plugin_name, plugin_class),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
)
else:
handler = None
@@ -1159,6 +1159,9 @@ class Connection:
else:
self._auth_plugin_name = data[i:server_end].decode("utf-8")
if _DEFAULT_AUTH_PLUGIN is not None: # for tests
self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
def get_server_info(self):
return self.server_version
@@ -1213,17 +1216,16 @@ class MySQLResult:
:raise OperationalError: If the connection to the MySQL server is lost.
:raise InternalError:
"""
self.unbuffered_active = True
first_packet = self.connection._read_packet()
if first_packet.is_ok_packet():
self.connection = None
self._read_ok_packet(first_packet)
self.unbuffered_active = False
self.connection = None
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
self.unbuffered_active = False
self.connection = None
try:
self._read_load_local_packet(first_packet)
finally:
self.connection = None
else:
self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions()
@@ -1232,6 +1234,7 @@ class MySQLResult:
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
# we set it to this instead of None, which would be preferred.
self.affected_rows = 18446744073709551615
self.unbuffered_active = True
def _read_ok_packet(self, first_packet):
ok_packet = OKPacketWrapper(first_packet)