updates
This commit is contained in:
@@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from .constants import FIELD_TYPE
|
||||
@@ -48,8 +49,8 @@ from .times import (
|
||||
|
||||
# PyMySQL version.
|
||||
# Used by setuptools and connection_attrs
|
||||
VERSION = (1, 1, 0, "final", 1)
|
||||
VERSION_STRING = "1.1.0"
|
||||
VERSION = (1, 1, 2, "final")
|
||||
VERSION_STRING = "1.1.2"
|
||||
|
||||
### for mysqlclient compatibility
|
||||
### Django checks mysqlclient version.
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Implements auth methods
|
||||
"""
|
||||
|
||||
from .err import OperationalError
|
||||
|
||||
|
||||
@@ -165,6 +166,8 @@ def sha256_password_auth(conn, pkt):
|
||||
|
||||
if pkt.is_auth_switch_request():
|
||||
conn.salt = pkt.read_all()
|
||||
if conn.salt.endswith(b"\0"):
|
||||
conn.salt = conn.salt[:-1]
|
||||
if not conn.server_public_key and conn.password:
|
||||
# Request server public key
|
||||
if DEBUG:
|
||||
@@ -214,9 +217,11 @@ def caching_sha2_password_auth(conn, pkt):
|
||||
|
||||
if pkt.is_auth_switch_request():
|
||||
# Try from fast auth
|
||||
if DEBUG:
|
||||
print("caching sha2: Trying fast path")
|
||||
conn.salt = pkt.read_all()
|
||||
if conn.salt.endswith(b"\0"): # str.removesuffix is available in 3.9
|
||||
conn.salt = conn.salt[:-1]
|
||||
if DEBUG:
|
||||
print(f"caching sha2: Trying fast path. salt={conn.salt.hex()!r}")
|
||||
scrambled = scramble_caching_sha2(conn.password, conn.salt)
|
||||
pkt = _roundtrip(conn, scrambled)
|
||||
# else: fast auth is tried in initial handshake
|
||||
|
||||
@@ -45,9 +45,10 @@ class Charsets:
|
||||
return self._by_id[id]
|
||||
|
||||
def by_name(self, name):
|
||||
name = name.lower()
|
||||
if name == "utf8":
|
||||
name = "utf8mb4"
|
||||
return self._by_name.get(name.lower())
|
||||
return self._by_name.get(name)
|
||||
|
||||
|
||||
_charsets = Charsets()
|
||||
|
||||
@@ -40,11 +40,14 @@ try:
|
||||
|
||||
DEFAULT_USER = getpass.getuser()
|
||||
del getpass
|
||||
except (ImportError, KeyError):
|
||||
# KeyError occurs when there's no entry in OS database for a current user.
|
||||
except (ImportError, KeyError, OSError):
|
||||
# When there's no entry in OS database for a current user:
|
||||
# KeyError is raised in Python 3.12 and below.
|
||||
# OSError is raised in Python 3.13+
|
||||
DEFAULT_USER = None
|
||||
|
||||
DEBUG = False
|
||||
_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
|
||||
|
||||
TEXT_TYPES = {
|
||||
FIELD_TYPE.BIT,
|
||||
@@ -84,8 +87,7 @@ def _lenenc_int(i):
|
||||
return b"\xfe" + struct.pack("<Q", i)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
|
||||
% (i, (1 << 64))
|
||||
f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
|
||||
)
|
||||
|
||||
|
||||
@@ -135,6 +137,7 @@ class Connection:
|
||||
:param ssl_disabled: A boolean value that disables usage of TLS.
|
||||
:param ssl_key: Path to the file that contains a PEM-formatted private key for
|
||||
the client certificate.
|
||||
:param ssl_key_password: The password for the client certificate private key.
|
||||
:param ssl_verify_cert: Set to true to check the server certificate's validity.
|
||||
:param ssl_verify_identity: Set to true to check the server's identity.
|
||||
:param read_default_group: Group to read from in the configuration file.
|
||||
@@ -161,6 +164,7 @@ class Connection:
|
||||
"""
|
||||
|
||||
_sock = None
|
||||
_rfile = None
|
||||
_auth_plugin_name = ""
|
||||
_closed = False
|
||||
_secure = False
|
||||
@@ -201,6 +205,7 @@ class Connection:
|
||||
ssl_cert=None,
|
||||
ssl_disabled=None,
|
||||
ssl_key=None,
|
||||
ssl_key_password=None,
|
||||
ssl_verify_cert=None,
|
||||
ssl_verify_identity=None,
|
||||
compress=None, # not supported
|
||||
@@ -262,7 +267,7 @@ class Connection:
|
||||
if not ssl:
|
||||
ssl = {}
|
||||
if isinstance(ssl, dict):
|
||||
for key in ["ca", "capath", "cert", "key", "cipher"]:
|
||||
for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
|
||||
value = _config("ssl-" + key, ssl.get(key))
|
||||
if value:
|
||||
ssl[key] = value
|
||||
@@ -281,6 +286,8 @@ class Connection:
|
||||
ssl["cert"] = ssl_cert
|
||||
if ssl_key is not None:
|
||||
ssl["key"] = ssl_key
|
||||
if ssl_key_password is not None:
|
||||
ssl["password"] = ssl_key_password
|
||||
if ssl:
|
||||
if not SSL_ENABLED:
|
||||
raise NotImplementedError("ssl module not found")
|
||||
@@ -371,6 +378,12 @@ class Connection:
|
||||
capath = sslp.get("capath")
|
||||
hasnoca = ca is None and capath is None
|
||||
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
||||
|
||||
# Python 3.13 enables VERIFY_X509_STRICT by default.
|
||||
# But self signed certificates that are generated by MySQL automatically
|
||||
# doesn't pass the verification.
|
||||
ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
|
||||
|
||||
ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
|
||||
verify_mode_value = sslp.get("verify_mode")
|
||||
if verify_mode_value is None:
|
||||
@@ -389,7 +402,9 @@ class Connection:
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
||||
if "cert" in sslp:
|
||||
ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
|
||||
ctx.load_cert_chain(
|
||||
sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
|
||||
)
|
||||
if "cipher" in sslp:
|
||||
ctx.set_ciphers(sslp["cipher"])
|
||||
ctx.options |= ssl.OP_NO_SSLv2
|
||||
@@ -425,6 +440,8 @@ class Connection:
|
||||
|
||||
def _force_close(self):
|
||||
"""Close connection without QUIT message."""
|
||||
if self._rfile:
|
||||
self._rfile.close()
|
||||
if self._sock:
|
||||
try:
|
||||
self._sock.close()
|
||||
@@ -566,9 +583,9 @@ class Connection:
|
||||
return self._affected_rows
|
||||
|
||||
def kill(self, thread_id):
|
||||
arg = struct.pack("<I", thread_id)
|
||||
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
|
||||
return self._read_ok_packet()
|
||||
if not isinstance(thread_id, int):
|
||||
raise TypeError("thread_id must be an integer")
|
||||
self.query(f"KILL {thread_id:d}")
|
||||
|
||||
def ping(self, reconnect=True):
|
||||
"""
|
||||
@@ -691,12 +708,7 @@ class Connection:
|
||||
if self.autocommit_mode is not None:
|
||||
self.autocommit(self.autocommit_mode)
|
||||
except BaseException as e:
|
||||
self._rfile = None
|
||||
if sock is not None:
|
||||
try:
|
||||
sock.close()
|
||||
except: # noqa
|
||||
pass
|
||||
self._force_close()
|
||||
|
||||
if isinstance(e, (OSError, IOError)):
|
||||
exc = err.OperationalError(
|
||||
@@ -760,8 +772,6 @@ class Connection:
|
||||
dump_packet(recv_data)
|
||||
buff += recv_data
|
||||
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
||||
if bytes_to_read == 0xFFFFFF:
|
||||
continue
|
||||
if bytes_to_read < MAX_PACKET_LEN:
|
||||
break
|
||||
|
||||
@@ -809,16 +819,10 @@ class Connection:
|
||||
|
||||
def _read_query_result(self, unbuffered=False):
|
||||
self._result = None
|
||||
result = MySQLResult(self)
|
||||
if unbuffered:
|
||||
try:
|
||||
result = MySQLResult(self)
|
||||
result.init_unbuffered_query()
|
||||
except:
|
||||
result.unbuffered_active = False
|
||||
result.connection = None
|
||||
raise
|
||||
result.init_unbuffered_query()
|
||||
else:
|
||||
result = MySQLResult(self)
|
||||
result.read()
|
||||
self._result = result
|
||||
if result.server_status is not None:
|
||||
@@ -993,9 +997,8 @@ class Connection:
|
||||
if plugin_name != b"dialog":
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r missing authenticate method"
|
||||
% (plugin_name, type(handler)),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {type(handler)!r} missing authenticate method",
|
||||
)
|
||||
if plugin_name == b"caching_sha2_password":
|
||||
return _auth.caching_sha2_password_auth(self, auth_packet)
|
||||
@@ -1031,16 +1034,14 @@ class Connection:
|
||||
except AttributeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r missing prompt method"
|
||||
% (plugin_name, handler),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {handler!r} missing prompt method",
|
||||
)
|
||||
except TypeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_ERR,
|
||||
"Authentication plugin '%s'"
|
||||
" %r didn't respond with string. Returned '%r' to prompt %r"
|
||||
% (plugin_name, handler, resp, prompt),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
|
||||
)
|
||||
else:
|
||||
raise err.OperationalError(
|
||||
@@ -1073,9 +1074,8 @@ class Connection:
|
||||
except TypeError:
|
||||
raise err.OperationalError(
|
||||
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
||||
"Authentication plugin '%s'"
|
||||
" not loaded: - %r cannot be constructed with connection object"
|
||||
% (plugin_name, plugin_class),
|
||||
f"Authentication plugin '{plugin_name}'"
|
||||
f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
|
||||
)
|
||||
else:
|
||||
handler = None
|
||||
@@ -1159,6 +1159,9 @@ class Connection:
|
||||
else:
|
||||
self._auth_plugin_name = data[i:server_end].decode("utf-8")
|
||||
|
||||
if _DEFAULT_AUTH_PLUGIN is not None: # for tests
|
||||
self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_version
|
||||
|
||||
@@ -1213,17 +1216,16 @@ class MySQLResult:
|
||||
:raise OperationalError: If the connection to the MySQL server is lost.
|
||||
:raise InternalError:
|
||||
"""
|
||||
self.unbuffered_active = True
|
||||
first_packet = self.connection._read_packet()
|
||||
|
||||
if first_packet.is_ok_packet():
|
||||
self.connection = None
|
||||
self._read_ok_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
elif first_packet.is_load_local_packet():
|
||||
self._read_load_local_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
try:
|
||||
self._read_load_local_packet(first_packet)
|
||||
finally:
|
||||
self.connection = None
|
||||
else:
|
||||
self.field_count = first_packet.read_length_encoded_integer()
|
||||
self._get_descriptions()
|
||||
@@ -1232,6 +1234,7 @@ class MySQLResult:
|
||||
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
||||
# we set it to this instead of None, which would be preferred.
|
||||
self.affected_rows = 18446744073709551615
|
||||
self.unbuffered_active = True
|
||||
|
||||
def _read_ok_packet(self, first_packet):
|
||||
ok_packet = OKPacketWrapper(first_packet)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -27,11 +27,7 @@ def escape_item(val, charset, mapping=None):
|
||||
|
||||
|
||||
def escape_dict(val, charset, mapping=None):
|
||||
n = {}
|
||||
for k, v in val.items():
|
||||
quoted = escape_item(v, charset, mapping)
|
||||
n[k] = quoted
|
||||
return n
|
||||
raise TypeError("dict can not be used as parameter")
|
||||
|
||||
|
||||
def escape_sequence(val, charset, mapping=None):
|
||||
|
||||
@@ -136,7 +136,14 @@ del _map_error, ER
|
||||
|
||||
def raise_mysql_exception(data):
|
||||
errno = struct.unpack("<h", data[1:3])[0]
|
||||
errval = data[9:].decode("utf-8", "replace")
|
||||
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
|
||||
# Error packet has optional sqlstate that is 5 bytes and starts with '#'.
|
||||
if data[3] == 0x23: # '#'
|
||||
# sqlstate = data[4:9].decode()
|
||||
# TODO: Append (sqlstate) in the error message. This will be come in next minor release.
|
||||
errval = data[9:].decode("utf-8", "replace")
|
||||
else:
|
||||
errval = data[3:].decode("utf-8", "replace")
|
||||
errorclass = error_map.get(errno)
|
||||
if errorclass is None:
|
||||
errorclass = InternalError if errno < 1000 else OperationalError
|
||||
|
||||
@@ -65,8 +65,7 @@ class MysqlPacket:
|
||||
if len(result) != size:
|
||||
error = (
|
||||
"Result length not requested length:\n"
|
||||
"Expected=%s. Actual=%s. Position: %s. Data Length: %s"
|
||||
% (size, len(result), self._position, len(self._data))
|
||||
f"Expected={size}. Actual={len(result)}. Position: {self._position}. Data Length: {len(self._data)}"
|
||||
)
|
||||
if DEBUG:
|
||||
print(error)
|
||||
@@ -89,8 +88,7 @@ class MysqlPacket:
|
||||
new_position = self._position + length
|
||||
if new_position < 0 or new_position > len(self._data):
|
||||
raise Exception(
|
||||
"Invalid advance amount (%s) for cursor. "
|
||||
"Position=%s" % (length, new_position)
|
||||
f"Invalid advance amount ({length}) for cursor. Position={new_position}"
|
||||
)
|
||||
self._position = new_position
|
||||
|
||||
|
||||
Reference in New Issue
Block a user