This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import sys
from .constants import FIELD_TYPE
@@ -48,8 +49,8 @@ from .times import (
# PyMySQL version.
# Used by setuptools and connection_attrs
VERSION = (1, 1, 0, "final", 1)
VERSION_STRING = "1.1.0"
VERSION = (1, 1, 2, "final")
VERSION_STRING = "1.1.2"
### for mysqlclient compatibility
### Django checks mysqlclient version.

View File

@@ -1,6 +1,7 @@
"""
Implements auth methods
"""
from .err import OperationalError
@@ -165,6 +166,8 @@ def sha256_password_auth(conn, pkt):
if pkt.is_auth_switch_request():
conn.salt = pkt.read_all()
if conn.salt.endswith(b"\0"):
conn.salt = conn.salt[:-1]
if not conn.server_public_key and conn.password:
# Request server public key
if DEBUG:
@@ -214,9 +217,11 @@ def caching_sha2_password_auth(conn, pkt):
if pkt.is_auth_switch_request():
# Try from fast auth
if DEBUG:
print("caching sha2: Trying fast path")
conn.salt = pkt.read_all()
if conn.salt.endswith(b"\0"): # str.removesuffix is available in 3.9
conn.salt = conn.salt[:-1]
if DEBUG:
print(f"caching sha2: Trying fast path. salt={conn.salt.hex()!r}")
scrambled = scramble_caching_sha2(conn.password, conn.salt)
pkt = _roundtrip(conn, scrambled)
# else: fast auth is tried in initial handshake

View File

@@ -45,9 +45,10 @@ class Charsets:
return self._by_id[id]
def by_name(self, name):
name = name.lower()
if name == "utf8":
name = "utf8mb4"
return self._by_name.get(name.lower())
return self._by_name.get(name)
_charsets = Charsets()

View File

@@ -40,11 +40,14 @@ try:
DEFAULT_USER = getpass.getuser()
del getpass
except (ImportError, KeyError):
# KeyError occurs when there's no entry in OS database for a current user.
except (ImportError, KeyError, OSError):
# When there's no entry in OS database for a current user:
# KeyError is raised in Python 3.12 and below.
# OSError is raised in Python 3.13+
DEFAULT_USER = None
DEBUG = False
_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
TEXT_TYPES = {
FIELD_TYPE.BIT,
@@ -84,8 +87,7 @@ def _lenenc_int(i):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError(
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
% (i, (1 << 64))
f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
)
@@ -135,6 +137,7 @@ class Connection:
:param ssl_disabled: A boolean value that disables usage of TLS.
:param ssl_key: Path to the file that contains a PEM-formatted private key for
the client certificate.
:param ssl_key_password: The password for the client certificate private key.
:param ssl_verify_cert: Set to true to check the server certificate's validity.
:param ssl_verify_identity: Set to true to check the server's identity.
:param read_default_group: Group to read from in the configuration file.
@@ -161,6 +164,7 @@ class Connection:
"""
_sock = None
_rfile = None
_auth_plugin_name = ""
_closed = False
_secure = False
@@ -201,6 +205,7 @@ class Connection:
ssl_cert=None,
ssl_disabled=None,
ssl_key=None,
ssl_key_password=None,
ssl_verify_cert=None,
ssl_verify_identity=None,
compress=None, # not supported
@@ -262,7 +267,7 @@ class Connection:
if not ssl:
ssl = {}
if isinstance(ssl, dict):
for key in ["ca", "capath", "cert", "key", "cipher"]:
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
value = _config("ssl-" + key, ssl.get(key))
if value:
ssl[key] = value
@@ -281,6 +286,8 @@ class Connection:
ssl["cert"] = ssl_cert
if ssl_key is not None:
ssl["key"] = ssl_key
if ssl_key_password is not None:
ssl["password"] = ssl_key_password
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
@@ -371,6 +378,12 @@ class Connection:
capath = sslp.get("capath")
hasnoca = ca is None and capath is None
ctx = ssl.create_default_context(cafile=ca, capath=capath)
# Python 3.13 enables VERIFY_X509_STRICT by default.
# But self signed certificates that are generated by MySQL automatically
# doesn't pass the verification.
ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
verify_mode_value = sslp.get("verify_mode")
if verify_mode_value is None:
@@ -389,7 +402,9 @@ class Connection:
else:
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
if "cert" in sslp:
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
ctx.load_cert_chain(
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
)
if "cipher" in sslp:
ctx.set_ciphers(sslp["cipher"])
ctx.options |= ssl.OP_NO_SSLv2
@@ -425,6 +440,8 @@ class Connection:
def _force_close(self):
"""Close connection without QUIT message."""
if self._rfile:
self._rfile.close()
if self._sock:
try:
self._sock.close()
@@ -566,9 +583,9 @@ class Connection:
return self._affected_rows
def kill(self, thread_id):
arg = struct.pack("<I", thread_id)
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
return self._read_ok_packet()
if not isinstance(thread_id, int):
raise TypeError("thread_id must be an integer")
self.query(f"KILL {thread_id:d}")
def ping(self, reconnect=True):
"""
@@ -691,12 +708,7 @@ class Connection:
if self.autocommit_mode is not None:
self.autocommit(self.autocommit_mode)
except BaseException as e:
self._rfile = None
if sock is not None:
try:
sock.close()
except: # noqa
pass
self._force_close()
if isinstance(e, (OSError, IOError)):
exc = err.OperationalError(
@@ -760,8 +772,6 @@ class Connection:
dump_packet(recv_data)
buff += recv_data
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
if bytes_to_read == 0xFFFFFF:
continue
if bytes_to_read < MAX_PACKET_LEN:
break
@@ -809,16 +819,10 @@ class Connection:
def _read_query_result(self, unbuffered=False):
self._result = None
result = MySQLResult(self)
if unbuffered:
try:
result = MySQLResult(self)
result.init_unbuffered_query()
except:
result.unbuffered_active = False
result.connection = None
raise
result.init_unbuffered_query()
else:
result = MySQLResult(self)
result.read()
self._result = result
if result.server_status is not None:
@@ -993,9 +997,8 @@ class Connection:
if plugin_name != b"dialog":
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r missing authenticate method"
% (plugin_name, type(handler)),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {type(handler)!r} missing authenticate method",
)
if plugin_name == b"caching_sha2_password":
return _auth.caching_sha2_password_auth(self, auth_packet)
@@ -1031,16 +1034,14 @@ class Connection:
except AttributeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r missing prompt method"
% (plugin_name, handler),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {handler!r} missing prompt method",
)
except TypeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_ERR,
"Authentication plugin '%s'"
" %r didn't respond with string. Returned '%r' to prompt %r"
% (plugin_name, handler, resp, prompt),
f"Authentication plugin '{plugin_name}'"
f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
)
else:
raise err.OperationalError(
@@ -1073,9 +1074,8 @@ class Connection:
except TypeError:
raise err.OperationalError(
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
"Authentication plugin '%s'"
" not loaded: - %r cannot be constructed with connection object"
% (plugin_name, plugin_class),
f"Authentication plugin '{plugin_name}'"
f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
)
else:
handler = None
@@ -1159,6 +1159,9 @@ class Connection:
else:
self._auth_plugin_name = data[i:server_end].decode("utf-8")
if _DEFAULT_AUTH_PLUGIN is not None: # for tests
self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
def get_server_info(self):
return self.server_version
@@ -1213,17 +1216,16 @@ class MySQLResult:
:raise OperationalError: If the connection to the MySQL server is lost.
:raise InternalError:
"""
self.unbuffered_active = True
first_packet = self.connection._read_packet()
if first_packet.is_ok_packet():
self.connection = None
self._read_ok_packet(first_packet)
self.unbuffered_active = False
self.connection = None
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
self.unbuffered_active = False
self.connection = None
try:
self._read_load_local_packet(first_packet)
finally:
self.connection = None
else:
self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions()
@@ -1232,6 +1234,7 @@ class MySQLResult:
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
# we set it to this instead of None, which would be preferred.
self.affected_rows = 18446744073709551615
self.unbuffered_active = True
def _read_ok_packet(self, first_packet):
ok_packet = OKPacketWrapper(first_packet)

View File

@@ -27,11 +27,7 @@ def escape_item(val, charset, mapping=None):
def escape_dict(val, charset, mapping=None):
n = {}
for k, v in val.items():
quoted = escape_item(v, charset, mapping)
n[k] = quoted
return n
raise TypeError("dict can not be used as parameter")
def escape_sequence(val, charset, mapping=None):

View File

@@ -136,7 +136,14 @@ del _map_error, ER
def raise_mysql_exception(data):
errno = struct.unpack("<h", data[1:3])[0]
errval = data[9:].decode("utf-8", "replace")
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
# Error packet has optional sqlstate that is 5 bytes and starts with '#'.
if data[3] == 0x23: # '#'
# sqlstate = data[4:9].decode()
# TODO: Append (sqlstate) in the error message. This will be come in next minor release.
errval = data[9:].decode("utf-8", "replace")
else:
errval = data[3:].decode("utf-8", "replace")
errorclass = error_map.get(errno)
if errorclass is None:
errorclass = InternalError if errno < 1000 else OperationalError

View File

@@ -65,8 +65,7 @@ class MysqlPacket:
if len(result) != size:
error = (
"Result length not requested length:\n"
"Expected=%s. Actual=%s. Position: %s. Data Length: %s"
% (size, len(result), self._position, len(self._data))
f"Expected={size}. Actual={len(result)}. Position: {self._position}. Data Length: {len(self._data)}"
)
if DEBUG:
print(error)
@@ -89,8 +88,7 @@ class MysqlPacket:
new_position = self._position + length
if new_position < 0 or new_position > len(self._data):
raise Exception(
"Invalid advance amount (%s) for cursor. "
"Position=%s" % (length, new_position)
f"Invalid advance amount ({length}) for cursor. Position={new_position}"
)
self._position = new_position