updates
This commit is contained in:
@@ -40,11 +40,14 @@ try:
|
||||
|
||||
DEFAULT_USER = getpass.getuser()
|
||||
del getpass
|
||||
except (ImportError, KeyError):
|
||||
# KeyError occurs when there's no entry in OS database for a current user.
|
||||
except (ImportError, KeyError, OSError):
|
||||
# When there's no entry in OS database for a current user:
|
||||
# KeyError is raised in Python 3.12 and below.
|
||||
# OSError is raised in Python 3.13+
|
||||
DEFAULT_USER = None
|
||||
|
||||
DEBUG = False
|
||||
_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
|
||||
|
||||
TEXT_TYPES = {
|
||||
FIELD_TYPE.BIT,
|
||||
@@ -84,8 +87,7 @@ def _lenenc_int(i):
|
||||
return b"\xfe" + struct.pack("<Q", i)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
|
||||
% (i, (1 << 64))
|
||||
f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
|
||||
)
|
||||
|
||||
|
||||
@@ -135,6 +137,7 @@ class Connection:
|
||||
:param ssl_disabled: A boolean value that disables usage of TLS.
|
||||
:param ssl_key: Path to the file that contains a PEM-formatted private key for
|
||||
the client certificate.
|
||||
:param ssl_key_password: The password for the client certificate private key.
|
||||
:param ssl_verify_cert: Set to true to check the server certificate's validity.
|
||||
:param ssl_verify_identity: Set to true to check the server's identity.
|
||||
:param read_default_group: Group to read from in the configuration file.
|
||||
@@ -161,6 +164,7 @@ class Connection:
|
||||
"""
|
||||
|
||||
_sock = None
|
||||
_rfile = None
|
||||
_auth_plugin_name = ""
|
||||
_closed = False
|
||||
_secure = False
|
||||
@@ -201,6 +205,7 @@ class Connection:
|
||||
ssl_cert=None,
|
||||
ssl_disabled=None,
|
||||
ssl_key=None,
|
||||
ssl_key_password=None,
|
||||
ssl_verify_cert=None,
|
||||
ssl_verify_identity=None,
|
||||
compress=None, # not supported
|
||||
@@ -262,7 +267,7 @@ class Connection:
|
||||
if not ssl:
|
||||
ssl = {}
|
||||
if isinstance(ssl, dict):
|
||||
for key in ["ca", "capath", "cert", "key", "cipher"]:
|
||||
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
|
||||
value = _config("ssl-" + key, ssl.get(key))
|
||||
if value:
|
||||
ssl[key] = value
|
||||
@@ -281,6 +286,8 @@ class Connection:
|
||||
ssl["cert"] = ssl_cert
|
||||
if ssl_key is not None:
|
||||
ssl["key"] = ssl_key
|
||||
if ssl_key_password is not None:
|
||||
ssl["password"] = ssl_key_password
|
||||
if ssl:
|
||||
if not SSL_ENABLED:
|
||||
raise NotImplementedError("ssl module not found")
|
||||
@@ -371,6 +378,12 @@ class Connection:
|
||||
capath = sslp.get("capath")
|
||||
hasnoca = ca is None and capath is None
|
||||
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
||||
|
||||
# Python 3.13 enables VERIFY_X509_STRICT by default.
|
||||
# But self signed certificates that are generated by MySQL automatically
|
||||
# doesn't pass the verification.
|
||||
ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
|
||||
|
||||
ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
|
||||
verify_mode_value = sslp.get("verify_mode")
|
||||
if verify_mode_value is None:
|
||||
@@ -389,7 +402,9 @@ class Connection:
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
||||
if "cert" in sslp:
|
||||
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
|
||||
ctx.load_cert_chain(
|
||||
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
|
||||
)
|
||||
if "cipher" in sslp:
|
||||
ctx.set_ciphers(sslp["cipher"])
|
||||
ctx.options |= ssl.OP_NO_SSLv2
|
||||
@@ -425,6 +440,8 @@ class Connection:
|
||||
|
||||
def _force_close(self):
|
||||
"""Close connection without QUIT message."""
|
||||
if self._rfile:
|
||||
self._rfile.close()
|
||||
if self._sock:
|
||||
try:
|
||||
self._sock.close()
|
||||
@@ -566,9 +583,9 @@ class Connection:
|
||||
return self._affected_rows
|
||||
|
||||
def kill(self, thread_id):
|
||||
arg = struct.pack("<I", thread_id)
|
||||
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
|
||||
return self._read_ok_packet()
|
||||
if not isinstance(thread_id, int):
|
||||
raise TypeError("thread_id must be an integer")
|
||||
self.query(f"KILL {thread_id:d}")
|
||||
|
||||
def ping(self, reconnect=True):
|
||||
"""
|
||||
@@ -691,12 +708,7 @@ class Connection:
|
||||
if self.autocommit_mode is not None:
|
||||
self.autocommit(self.autocommit_mode)
|
||||
except BaseException as e:
|
||||
self._rfile = None
|
||||
if sock is not None:
|
||||
try:
|
||||
sock.close()
|
||||
except: # noqa
|
||||
pass
|
||||
self._force_close()
|
||||
|
||||
if isinstance(e, (OSError, IOError)):
|
||||
exc = err.OperationalError(
|
||||
@@ -760,8 +772,6 @@ class Connection:
|
||||
dump_packet(recv_data)
|
||||
buff += recv_data
|
||||
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
||||
if bytes_to_read == 0xFFFFFF:
|
||||
continue
|
||||
if bytes_to_read < MAX_PACKET_LEN:
|
||||
break
|
||||
|
||||
@@ -809,16 +819,10 @@ class Connection:
|
||||
|
||||
def _read_query_result(self, unbuffered=False):
|
||||
self._result = None
|
||||
result = MySQLResult(self)
|
||||
if unbuffered:
|
||||
try:
|
||||
result = MySQLResult(self)
|
||||
result.init_unbuffered_query()
|
||||
except:
|
||||
result.unbuffered_active = False
|
||||
result.connection = None
|
||||
raise
|
||||
result.init_unbuffered_query()
|
||||
else:
|
||||
result = MySQLResult(self)
|
||||
result.read()
|
||||
self._result = result
|
||||
if result.server_status is not None:
|
||||
@@ -993,9 +997,8 @@ class Connection:
|
||||
if plugin_name != b"dialog":
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r missing authenticate method"
|
||||
% (plugin_name, type(handler)),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {type(handler)!r} missing authenticate method",
|
||||
)
|
||||
if plugin_name == b"caching_sha2_password":
|
||||
return _auth.caching_sha2_password_auth(self, auth_packet)
|
||||
@@ -1031,16 +1034,14 @@ class Connection:
|
||||
except AttributeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r missing prompt method"
|
||||
% (plugin_name, handler),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {handler!r} missing prompt method",
|
||||
)
|
||||
except TypeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_ERR,
|
||||
"Authentication plugin '%s'"
|
||||
" %r didn't respond with string. Returned '%r' to prompt %r"
|
||||
% (plugin_name, handler, resp, prompt),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
|
||||
)
|
||||
else:
|
||||
raise err.OperationalError(
|
||||
@@ -1073,9 +1074,8 @@ class Connection:
|
||||
except TypeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r cannot be constructed with connection object"
|
||||
% (plugin_name, plugin_class),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
|
||||
)
|
||||
else:
|
||||
handler = None
|
||||
@@ -1159,6 +1159,9 @@ class Connection:
|
||||
else:
|
||||
self._auth_plugin_name = data[i:server_end].decode("utf-8")
|
||||
|
||||
if _DEFAULT_AUTH_PLUGIN is not None: # for tests
|
||||
self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_version
|
||||
|
||||
@@ -1213,17 +1216,16 @@ class MySQLResult:
|
||||
:raise OperationalError: If the connection to the MySQL server is lost.
|
||||
:raise InternalError:
|
||||
"""
|
||||
self.unbuffered_active = True
|
||||
first_packet = self.connection._read_packet()
|
||||
|
||||
if first_packet.is_ok_packet():
|
||||
self.connection = None
|
||||
self._read_ok_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
elif first_packet.is_load_local_packet():
|
||||
self._read_load_local_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
try:
|
||||
self._read_load_local_packet(first_packet)
|
||||
finally:
|
||||
self.connection = None
|
||||
else:
|
||||
self.field_count = first_packet.read_length_encoded_integer()
|
||||
self._get_descriptions()
|
||||
@@ -1232,6 +1234,7 @@ class MySQLResult:
|
||||
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
||||
# we set it to this instead of None, which would be preferred.
|
||||
self.affected_rows = 18446744073709551615
|
||||
self.unbuffered_active = True
|
||||
|
||||
def _read_ok_packet(self, first_packet):
|
||||
ok_packet = OKPacketWrapper(first_packet)
|
||||
|
||||
Reference in New Issue
Block a user