From 9e0b4181f733333dba4d7262c04c5ca04a68b026 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Tue, 12 May 2026 17:38:11 +0530 Subject: [PATCH 1/7] Add ActiveDirectoryMSI support for bulk copy Adds Authentication=ActiveDirectoryMSI to the auth pipeline: - Zero-arg ManagedIdentityCredential() for system-assigned MSI. - ManagedIdentityCredential(client_id=UID) for user-assigned MSI, matching ODBC's convention where UID carries the identity's client_id under MSI. - Threads optional credential_kwargs through get_auth_token / get_raw_token / _acquire_token so future auth methods that need constructor args (e.g. ClientSecretCredential) can plug in via the same channel. - Cache key remains a plain string for zero-arg auth types and becomes a tuple when kwargs are present, so different client_ids get separate cached credentials. Partial fix for microsoft/mssql-python#534. ServicePrincipal and Password to follow as separate PRs. --- CHANGELOG.md | 1 + mssql_python/auth.py | 91 ++++++++++++++++++++++++++---- mssql_python/constants.py | 1 + mssql_python/cursor.py | 9 ++- tests/test_008_auth.py | 113 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bfc..14d42ee46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Bulk copy: `Authentication=ActiveDirectoryMSI` support (system- and user-assigned managed identity). UID is interpreted as the user-assigned identity's `client_id`. Partial fix for #534. ### Changed - Improved error handling in the connection module. diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 40c3e06e2..b1dc7ddd4 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -16,10 +16,24 @@ # Reusing credential objects allows the Azure Identity SDK's built-in # in-memory token cache to work, avoiding redundant token acquisitions. # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md -_credential_cache: Dict[str, object] = {} +_credential_cache: Dict[object, object] = {} _credential_cache_lock = threading.Lock() +def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): + """Build a hashable cache key from auth_type and optional credential kwargs. + + Returns the plain auth_type string when no kwargs are provided so that + callers caching by string (the original behavior) keep working. When + kwargs are present (e.g. user-assigned MSI client_id), the key is a + tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map + to different cached credentials. + """ + if not credential_kwargs: + return auth_type + return (auth_type, tuple(sorted(credential_kwargs.items()))) + + class AADAuth: """Handles Azure Active Directory authentication""" @@ -37,24 +51,30 @@ def get_token_struct(token: str) -> bytes: return struct.pack(f" bytes: + def get_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None + ) -> bytes: """Get DDBC token struct for the specified authentication type.""" - token_struct, _ = AADAuth._acquire_token(auth_type) + token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) return token_struct @staticmethod - def get_raw_token(auth_type: str) -> str: + def get_raw_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None + ) -> str: """Acquire a raw JWT for the mssql-py-core connection (bulk copy). Uses the cached credential instance so the Azure Identity SDK's built-in token cache can serve a valid token without a round-trip when the previous token has not yet expired. """ - _, raw_token = AADAuth._acquire_token(auth_type) + _, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs) return raw_token @staticmethod - def _acquire_token(auth_type: str) -> Tuple[bytes, str]: + def _acquire_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None + ) -> Tuple[bytes, str]: """Internal: acquire token and return (ddbc_struct, raw_jwt).""" # Import Azure libraries inside method to support test mocking # pylint: disable=import-outside-toplevel @@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: DefaultAzureCredential, DeviceCodeCredential, InteractiveBrowserCredential, + ManagedIdentityCredential, ) from azure.core.exceptions import ClientAuthenticationError except ImportError as e: @@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: "default": DefaultAzureCredential, "devicecode": DeviceCodeCredential, "interactive": InteractiveBrowserCredential, + "msi": ManagedIdentityCredential, } credential_class = credential_map.get(auth_type) @@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: credential_class.__name__, ) + kwargs = credential_kwargs or {} + cache_key = _credential_cache_key(auth_type, kwargs) try: with _credential_cache_lock: - if auth_type not in _credential_cache: + if cache_key not in _credential_cache: logger.debug( "get_token: Creating new credential instance for auth_type=%s", auth_type, ) - _credential_cache[auth_type] = credential_class() + _credential_cache[cache_key] = credential_class(**kwargs) else: logger.debug( "get_token: Reusing cached credential instance for auth_type=%s", auth_type, ) - credential = _credential_cache[auth_type] + credential = _credential_cache[cache_key] raw_token = credential.get_token("https://database.windows.net/.default").token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", @@ -130,6 +154,20 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +def _extract_msi_client_id(parameters: List[str]) -> Optional[str]: + """Pull UID out of connection parameters for user-assigned MSI. + + For ActiveDirectoryMSI, UID (when present) carries the user-assigned + identity's client_id. Returns None for system-assigned MSI. + """ + for param in parameters: + key, _, value = param.strip().partition("=") + if key.strip().lower() == "uid": + value = value.strip() + return value or None + return None + + def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: """ Process connection parameters and extract authentication type. @@ -180,6 +218,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ # Default authentication (uses DefaultAzureCredential) logger.debug("process_auth_parameters: Default Azure authentication detected") auth_type = "default" + elif value_lower == AuthType.MSI.value: + # Managed identity authentication (system- or user-assigned) + logger.debug("process_auth_parameters: Managed identity authentication detected") + auth_type = "msi" modified_parameters.append(param) logger.debug( @@ -212,7 +254,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: return result -def get_auth_token(auth_type: str) -> Optional[bytes]: +def get_auth_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None +) -> Optional[bytes]: """Get DDBC authentication token struct based on auth type.""" logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) if not auth_type: @@ -225,7 +269,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: return None # Let Windows handle AADInteractive natively try: - token = AADAuth.get_token(auth_type) + token = AADAuth.get_token(auth_type, credential_kwargs) logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type) return token except (ValueError, RuntimeError) as e: @@ -246,6 +290,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: AuthType.INTERACTIVE.value: "interactive", AuthType.DEVICE_CODE.value: "devicecode", AuthType.DEFAULT.value: "default", + AuthType.MSI.value: "msi", } for part in connection_string.split(";"): key, _, value = part.strip().partition("=") @@ -254,6 +299,20 @@ def extract_auth_type(connection_string: str) -> Optional[str]: return None +def extract_credential_kwargs( + connection_string: str, auth_type: Optional[str] +) -> Dict[str, str]: + """Extract credential constructor kwargs for the given auth type. + + For ActiveDirectoryMSI: returns ``{"client_id": uid}`` when UID is + set (user-assigned MSI) and ``{}`` for system-assigned MSI. + """ + if auth_type != "msi": + return {} + client_id = _extract_msi_client_id(connection_string.split(";")) + return {"client_id": client_id} if client_id else {} + + def process_connection_string( connection_string: str, ) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: @@ -301,12 +360,20 @@ def process_connection_string( modified_parameters, auth_type = process_auth_parameters(parameters) + # Capture credential kwargs (e.g. user-assigned MSI client_id) before + # remove_sensitive_params strips UID from the parameter list. + credential_kwargs: Dict[str, str] = {} + if auth_type == "msi": + client_id = _extract_msi_client_id(modified_parameters) + if client_id: + credential_kwargs["client_id"] = client_id + if auth_type: logger.info( "process_connection_string: Authentication type detected - auth_type=%s", auth_type ) modified_parameters = remove_sensitive_params(modified_parameters) - token_struct = get_auth_token(auth_type) + token_struct = get_auth_token(auth_type, credential_kwargs or None) if token_struct: logger.info( "process_connection_string: Token authentication configured successfully - auth_type=%s", diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 549737c60..5de02eceb 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -337,6 +337,7 @@ class AuthType(Enum): INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" + MSI = "activedirectorymsi" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 05324875e..d53e2f28e 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2912,10 +2912,15 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection - from mssql_python.auth import AADAuth + from mssql_python.auth import AADAuth, extract_credential_kwargs + credential_kwargs = extract_credential_kwargs( + self.connection.connection_str, self.connection._auth_type + ) try: - raw_token = AADAuth.get_raw_token(self.connection._auth_type) + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, credential_kwargs or None + ) except (RuntimeError, ValueError) as e: raise RuntimeError( f"Bulk copy failed: unable to acquire Azure AD token " diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f680518bc..8154e6b0c 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -16,6 +16,7 @@ get_auth_token, process_connection_string, extract_auth_type, + extract_credential_kwargs, _credential_cache, _credential_cache_lock, ) @@ -44,6 +45,17 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockManagedIdentityCredential: + # Captures construction kwargs so user-assigned MSI tests can assert + # client_id was forwarded correctly. + last_init_kwargs = None + + def __init__(self, **kwargs): + MockManagedIdentityCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + return MockToken() + # Mock ClientAuthenticationError class MockClientAuthenticationError(Exception): pass @@ -52,6 +64,7 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: class exceptions: @@ -87,6 +100,7 @@ def test_auth_type_constants(self): assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" + assert AuthType.MSI.value == "activedirectorymsi" class TestAADAuth: @@ -317,6 +331,16 @@ def test_default_auth(self): _, auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_msi_auth(self): + params = ["Authentication=ActiveDirectoryMSI", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + + def test_msi_auth_case_insensitive(self): + params = ["authentication=activedirectorymsi", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + class TestRemoveSensitiveParams: def test_remove_sensitive_parameters(self): @@ -407,6 +431,9 @@ def test_devicecode(self): == "devicecode" ) + def test_msi(self): + assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" + def test_no_auth(self): assert extract_auth_type("Server=test;Database=db;") is None @@ -414,6 +441,92 @@ def test_unsupported_auth(self): assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None +class TestManagedIdentity: + """Tests for ActiveDirectoryMSI support (system- and user-assigned).""" + + def test_get_token_system_assigned_msi(self): + """System-assigned MSI: ManagedIdentityCredential() constructed with no kwargs.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + token_struct = AADAuth.get_token("msi") + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {} + + def test_get_raw_token_system_assigned_msi(self): + raw_token = AADAuth.get_raw_token("msi") + assert raw_token == SAMPLE_TOKEN + + def test_get_token_user_assigned_msi(self): + """User-assigned MSI: client_id is forwarded to the credential constructor.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + client_id = "11111111-2222-3333-4444-555555555555" + token_struct = AADAuth.get_token("msi", {"client_id": client_id}) + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {"client_id": client_id} + + def test_msi_separate_cache_entries_per_client_id(self): + """System-assigned and user-assigned MSI must not share a cached credential.""" + AADAuth.get_token("msi") # system-assigned + AADAuth.get_token("msi", {"client_id": "abc"}) + AADAuth.get_token("msi", {"client_id": "def"}) + + # System-assigned uses the bare string key; user-assigned uses tuples. + assert "msi" in _credential_cache + assert ("msi", (("client_id", "abc"),)) in _credential_cache + assert ("msi", (("client_id", "def"),)) in _credential_cache + assert ( + _credential_cache["msi"] + is not _credential_cache[("msi", (("client_id", "abc"),))] + ) + + def test_extract_credential_kwargs_system_assigned(self): + """No UID in connection string → system-assigned MSI → empty kwargs.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;" + assert extract_credential_kwargs(conn_str, "msi") == {} + + def test_extract_credential_kwargs_user_assigned(self): + """UID present → user-assigned MSI → client_id kwarg.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID=11111111-2222-3333-4444-555555555555;" + assert extract_credential_kwargs(conn_str, "msi") == { + "client_id": "11111111-2222-3333-4444-555555555555" + } + + def test_extract_credential_kwargs_non_msi(self): + """For non-MSI auth types, kwargs are always empty (UID is ignored).""" + conn_str = "Server=test;Authentication=ActiveDirectoryDefault;UID=user;" + assert extract_credential_kwargs(conn_str, "default") == {} + + def test_extract_credential_kwargs_empty_uid(self): + """Empty UID value is treated as system-assigned MSI.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID=;" + assert extract_credential_kwargs(conn_str, "msi") == {} + + def test_process_connection_string_msi_strips_uid(self): + """MSI connection strings: UID is stripped from the ODBC connection + string but the client_id is still applied to the credential.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + conn_str = ( + "Server=test;Authentication=ActiveDirectoryMSI;" + "UID=11111111-2222-3333-4444-555555555555;Database=testdb" + ) + result_str, attrs, auth_type = process_connection_string(conn_str) + + assert auth_type == "msi" + assert "UID=" not in result_str + assert "Authentication=" not in result_str + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert az.ManagedIdentityCredential.last_init_kwargs == { + "client_id": "11111111-2222-3333-4444-555555555555" + } + + class TestCredentialInstanceCache: """Tests for the credential instance caching behavior.""" From 3ca21654b9cbef92e8a9591faef7f34c9bd2647b Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 14 May 2026 17:42:20 +0530 Subject: [PATCH 2/7] FIX: persist credential_kwargs on Connection for bulkcopy MSI Connection.__init__ overwrites self.connection_str with the sanitized (UID-stripped) string returned by process_connection_string. The original implementation re-parsed self.connection_str at bulkcopy time via extract_credential_kwargs, which silently dropped the user-assigned MSI client_id and degraded to system-assigned a wrong-identity bug.MSI Changes: - process_connection_string now returns a 4-tuple including the captured credential_kwargs so callers can persist them. - Connection.__init__ stores _credential_kwargs alongside _auth_type. - cursor.bulkcopy() reads self.connection._credential_kwargs instead of re-parsing self.connection_str. - The public extract_credential_kwargs helper is removed (it only existed to support the broken re-parse path; nothing else needs it). - black --line-length=100 reformats (CI was red). Tests: - test_bulkcopy_path_preserves_user_assigned_msi_client_id: invokes cursor.bulkcopy() with a mocked mssql_py_core, patches AADAuth.get_raw_token to capture the args it receives, and asserts the captured credential_kwargs match Connection._credential_kwargs. Fails if cursor reverts to re-parsing self.connection.connection_str. - test_credential_kwargs_persisted_for_user_assigned_msi: asserts Connection.__init__ stores _credential_kwargs from the 4-tuple. - test_credential_kwargs_none_for_system_assigned_msi. - test_credential_kwargs_none_for_non_msi_auth. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 1 - mssql_python/auth.py | 41 ++++------ mssql_python/connection.py | 7 ++ mssql_python/cursor.py | 13 +-- tests/test_008_auth.py | 163 +++++++++++++++++++++++++++++-------- 5 files changed, 159 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14d42ee46..517a60bfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. -- Bulk copy: `Authentication=ActiveDirectoryMSI` support (system- and user-assigned managed identity). UID is interpreted as the user-assigned identity's `client_id`. Partial fix for #534. ### Changed - Improved error handling in the connection module. diff --git a/mssql_python/auth.py b/mssql_python/auth.py index b1dc7ddd4..7aaf03add 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -51,17 +51,13 @@ def get_token_struct(token: str) -> bytes: return struct.pack(f" bytes: + def get_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> bytes: """Get DDBC token struct for the specified authentication type.""" token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) return token_struct @staticmethod - def get_raw_token( - auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None - ) -> str: + def get_raw_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> str: """Acquire a raw JWT for the mssql-py-core connection (bulk copy). Uses the cached credential instance so the Azure Identity SDK's @@ -299,23 +295,9 @@ def extract_auth_type(connection_string: str) -> Optional[str]: return None -def extract_credential_kwargs( - connection_string: str, auth_type: Optional[str] -) -> Dict[str, str]: - """Extract credential constructor kwargs for the given auth type. - - For ActiveDirectoryMSI: returns ``{"client_id": uid}`` when UID is - set (user-assigned MSI) and ``{}`` for system-assigned MSI. - """ - if auth_type != "msi": - return {} - client_id = _extract_msi_client_id(connection_string.split(";")) - return {"client_id": client_id} if client_id else {} - - def process_connection_string( connection_string: str, -) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: +) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: """ Process connection string and handle authentication. @@ -323,8 +305,13 @@ def process_connection_string( connection_string: The connection string to process Returns: - Tuple[str, Optional[Dict], Optional[str]]: Processed connection string, - attrs_before dict if needed, and auth_type string for bulk copy token acquisition + Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: + Processed connection string, attrs_before dict if needed, auth_type + string for bulk copy token acquisition, and credential constructor + kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on + the Connection so bulkcopy can re-use them when acquiring a fresh + token after sanitization has stripped UID from the connection + string. Raises: ValueError: If the connection string is invalid or empty @@ -383,6 +370,7 @@ def process_connection_string( ";".join(modified_parameters) + ";", {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, auth_type, + credential_kwargs or None, ) else: logger.warning( @@ -393,4 +381,9 @@ def process_connection_string( "process_connection_string: Connection string processing complete - has_auth=%s", bool(auth_type), ) - return ";".join(modified_parameters) + ";", None, auth_type + return ( + ";".join(modified_parameters) + ";", + None, + auth_type, + credential_kwargs or None, + ) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 35c2fb85d..a04ea5edf 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -280,6 +280,12 @@ def __init__( # We intentionally do NOT cache the token — a fresh one is acquired # each time bulkcopy() is called to avoid expired-token errors. self._auth_type = None + # Credential constructor kwargs (e.g. user-assigned MSI client_id) + # captured at __init__ time before remove_sensitive_params strips UID + # from self.connection_str. bulkcopy() re-uses these when acquiring a + # fresh token; re-parsing self.connection_str at that point would miss + # them because UID is already gone. + self._credential_kwargs: Optional[Dict[str, str]] = None # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. @@ -294,6 +300,7 @@ def __init__( # On Windows Interactive, process_connection_string returns None # (DDBC handles auth natively), so fall back to the connection string. self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) + self._credential_kwargs = connection_result[3] self._closed = False self._timeout = timeout diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index d53e2f28e..b81ac39b2 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2911,15 +2911,16 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: - # Fresh token acquisition for mssql-py-core connection - from mssql_python.auth import AADAuth, extract_credential_kwargs + # Fresh token acquisition for mssql-py-core connection. credential + # kwargs (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped UID + # from connection_str — re-parsing here would miss them. + from mssql_python.auth import AADAuth - credential_kwargs = extract_credential_kwargs( - self.connection.connection_str, self.connection._auth_type - ) try: raw_token = AADAuth.get_raw_token( - self.connection._auth_type, credential_kwargs or None + self.connection._auth_type, + self.connection._credential_kwargs, ) except (RuntimeError, ValueError) as e: raise RuntimeError( diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 8154e6b0c..eaed8c88e 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -16,7 +16,6 @@ get_auth_token, process_connection_string, extract_auth_type, - extract_credential_kwargs, _credential_cache, _credential_cache_lock, ) @@ -368,7 +367,7 @@ def test_remove_sensitive_parameters(self): class TestProcessConnectionString: def test_process_connection_string_with_default_auth(self): conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -376,10 +375,11 @@ def test_process_connection_string_with_default_auth(self): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "default" + assert credential_kwargs is None def test_process_connection_string_no_auth(self): conn_str = "Server=test;Database=testdb;UID=user;PWD=password" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -387,11 +387,12 @@ def test_process_connection_string_no_auth(self): assert "PWD=password" in result_str assert attrs is None assert auth_type is None + assert credential_kwargs is None def test_process_connection_string_interactive_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -399,6 +400,7 @@ def test_process_connection_string_interactive_non_windows(self, monkeypatch): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "interactive" + assert credential_kwargs is None def test_error_handling(): @@ -477,36 +479,12 @@ def test_msi_separate_cache_entries_per_client_id(self): assert "msi" in _credential_cache assert ("msi", (("client_id", "abc"),)) in _credential_cache assert ("msi", (("client_id", "def"),)) in _credential_cache - assert ( - _credential_cache["msi"] - is not _credential_cache[("msi", (("client_id", "abc"),))] - ) - - def test_extract_credential_kwargs_system_assigned(self): - """No UID in connection string → system-assigned MSI → empty kwargs.""" - conn_str = "Server=test;Authentication=ActiveDirectoryMSI;" - assert extract_credential_kwargs(conn_str, "msi") == {} - - def test_extract_credential_kwargs_user_assigned(self): - """UID present → user-assigned MSI → client_id kwarg.""" - conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID=11111111-2222-3333-4444-555555555555;" - assert extract_credential_kwargs(conn_str, "msi") == { - "client_id": "11111111-2222-3333-4444-555555555555" - } - - def test_extract_credential_kwargs_non_msi(self): - """For non-MSI auth types, kwargs are always empty (UID is ignored).""" - conn_str = "Server=test;Authentication=ActiveDirectoryDefault;UID=user;" - assert extract_credential_kwargs(conn_str, "default") == {} - - def test_extract_credential_kwargs_empty_uid(self): - """Empty UID value is treated as system-assigned MSI.""" - conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID=;" - assert extract_credential_kwargs(conn_str, "msi") == {} + assert _credential_cache["msi"] is not _credential_cache[("msi", (("client_id", "abc"),))] - def test_process_connection_string_msi_strips_uid(self): + def test_process_connection_string_msi_strips_uid_and_returns_kwargs(self): """MSI connection strings: UID is stripped from the ODBC connection - string but the client_id is still applied to the credential.""" + string but the client_id is captured as credential_kwargs (so it can + be persisted on the Connection for the bulkcopy fresh-token path).""" az = sys.modules["azure.identity"] az.ManagedIdentityCredential.last_init_kwargs = None @@ -514,7 +492,7 @@ def test_process_connection_string_msi_strips_uid(self): "Server=test;Authentication=ActiveDirectoryMSI;" "UID=11111111-2222-3333-4444-555555555555;Database=testdb" ) - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert auth_type == "msi" assert "UID=" not in result_str @@ -525,6 +503,75 @@ def test_process_connection_string_msi_strips_uid(self): assert az.ManagedIdentityCredential.last_init_kwargs == { "client_id": "11111111-2222-3333-4444-555555555555" } + # client_id must be returned so Connection can persist it for the + # bulkcopy fresh-token path (UID is gone from result_str by then). + assert credential_kwargs == {"client_id": "11111111-2222-3333-4444-555555555555"} + + def test_process_connection_string_msi_system_assigned_no_kwargs(self): + """System-assigned MSI: no UID → credential_kwargs is None.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs is None + + def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): + """Regression test (cursor.bulkcopy() end-to-end) for the silent + system-assigned fallback: the bulkcopy fresh-token code path must + forward Connection._credential_kwargs to AADAuth.get_raw_token, + not re-parse the (now UID-stripped) connection_str. + + Fails if cursor.py is reverted to call extract_credential_kwargs on + self.connection.connection_str, OR if Connection stops persisting + _credential_kwargs.""" + from mssql_python.cursor import Cursor + + client_id = "11111111-2222-3333-4444-555555555555" + + # Mock Connection holding what Connection.__init__ would store after + # process_connection_string strips UID from the user-supplied string. + mock_conn = MagicMock() + # Post-sanitization string: NO UID. If cursor re-parses this, the + # forwarded kwargs will be {} and the assert below will fail. + mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" + mock_conn._auth_type = "msi" + mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._is_connected = True + + cursor = Cursor.__new__(Cursor) + cursor._connection = mock_conn + cursor.closed = False + cursor.hstmt = None + + captured = {} + + def fake_get_raw_token(auth_type, credential_kwargs=None): + captured["auth_type"] = auth_type + captured["credential_kwargs"] = credential_kwargs + return SAMPLE_TOKEN + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with ( + patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}), + patch("mssql_python.auth.AADAuth.get_raw_token", side_effect=fake_get_raw_token), + ): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert captured["auth_type"] == "msi" + assert captured["credential_kwargs"] == {"client_id": client_id}, ( + f"bulkcopy must forward Connection._credential_kwargs verbatim; " + f"got {captured['credential_kwargs']!r}. If this is {{}} or None, " + f"the cursor likely re-parses the (UID-stripped) connection_str." + ) class TestCredentialInstanceCache: @@ -737,6 +784,50 @@ def test_auth_type_stored_on_connection(self, mock_ddbc_conn): assert conn._auth_type == "default" conn.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_persisted_for_user_assigned_msi(self, mock_ddbc_conn): + """Connection.__init__ must capture MSI client_id BEFORE + remove_sensitive_params strips UID, and persist it on + self._credential_kwargs so cursor.bulkcopy() can use it later.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + client_id = "11111111-2222-3333-4444-555555555555" + conn = connect( + f"Server=test;Database=testdb;Authentication=ActiveDirectoryMSI;UID={client_id}" + ) + assert conn._auth_type == "msi" + assert conn._credential_kwargs == {"client_id": client_id} + # And the connection_str on the Connection should NOT contain UID + # (this is what makes _credential_kwargs the source of truth). + assert "UID=" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_system_assigned_msi(self, mock_ddbc_conn): + """System-assigned MSI: no UID → _credential_kwargs stays None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryMSI") + assert conn._auth_type == "msi" + assert conn._credential_kwargs is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_non_msi_auth(self, mock_ddbc_conn): + """Non-MSI auth types must not pick up credential_kwargs even if + UID is present (e.g. SQL auth UID).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault;UID=user@x" + ) + assert conn._auth_type == "default" + assert conn._credential_kwargs is None + conn.close() + class TestCredentialCacheThreadSafety: """Verify thread-safe behavior of credential instance cache.""" @@ -873,7 +964,7 @@ class TestProcessConnectionStringTokenFailureFallthrough: def test_returns_none_attrs_when_token_acquisition_fails(self): """When auth type is detected but token acquisition fails, - process_connection_string should return (conn_str, None, auth_type).""" + process_connection_string should return (conn_str, None, auth_type, kwargs).""" import sys azure_identity = sys.modules["azure.identity"] @@ -886,7 +977,7 @@ def __init__(self): try: azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) # Auth type was detected assert auth_type == "default" @@ -895,5 +986,7 @@ def __init__(self): # Connection string is still returned (sensitive params removed) assert "Server=test" in result_str assert "Database=testdb" in result_str + # Default auth has no credential kwargs + assert credential_kwargs is None finally: azure_identity.DefaultAzureCredential = original From 6451a4e9dc633775310c33c694e1322c680eb96d Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 14 May 2026 18:40:18 +0530 Subject: [PATCH 3/7] Use _ConnectionStringParser for MSI client_id extraction Per @saurabh500's review: braced ODBC values like UID={hello=world} need the canonical parser, not naive partition('='). Without this, the helper returns '{hello=world}' verbatim and ManagedIdentityCredential rejects it. Worse, a UID containing a literal ';' would be truncated. _extract_msi_client_id now delegates to _ConnectionStringParser, which handles braces, escaped '}}' inside braces, and '=' inside braced values correctly. validate_keywords=False so the helper never raises on keys the auth flow doesn't care about. Tests: - test_msi_braced_uid_value_is_unwrapped: UID={hello=world} -> 'hello=world' - test_msi_braced_uid_with_semicolon_is_preserved: UID={abc;def;ghi} Note: process_connection_string and extract_auth_type still use naive split(';') for Authentication= detection across all Entra ID auth types. That's pre-existing and tracked separately for a parser-wide refactor. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 33 ++++++++++++++++++++++----------- tests/test_008_auth.py | 22 ++++++++++++++++++++++ 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 7aaf03add..7d64c7ab6 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -11,6 +11,7 @@ from mssql_python.logging import logger from mssql_python.constants import AuthType, ConstantsDDBC +from mssql_python.connection_string_parser import _ConnectionStringParser # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in @@ -150,18 +151,25 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e -def _extract_msi_client_id(parameters: List[str]) -> Optional[str]: - """Pull UID out of connection parameters for user-assigned MSI. +def _extract_msi_client_id(connection_string: str) -> Optional[str]: + """Pull UID out of a connection string for user-assigned MSI. For ActiveDirectoryMSI, UID (when present) carries the user-assigned - identity's client_id. Returns None for system-assigned MSI. + identity's ``client_id``. Returns None for system-assigned MSI. + + Uses the canonical ``_ConnectionStringParser`` so braced ODBC values + are handled correctly: a ``UID={hello=world}`` resolves to the value + ``hello=world`` (no surrounding braces, no false split on the inner + ``=``), and a semicolon inside a legitimate braced value (e.g. + ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. """ - for param in parameters: - key, _, value = param.strip().partition("=") - if key.strip().lower() == "uid": - value = value.strip() - return value or None - return None + try: + parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) + except Exception: # noqa: BLE001 — parser raises ConnectionStringParseError on malformed input; + # absence of UID is the safe answer for credential extraction. + return None + uid = (parsed.get("uid") or "").strip() + return uid or None def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: @@ -348,10 +356,13 @@ def process_connection_string( modified_parameters, auth_type = process_auth_parameters(parameters) # Capture credential kwargs (e.g. user-assigned MSI client_id) before - # remove_sensitive_params strips UID from the parameter list. + # remove_sensitive_params strips UID from the parameter list. Pass the + # original connection_string (not modified_parameters) so the helper can + # use the canonical _ConnectionStringParser — handles braced values like + # UID={hello=world} correctly. credential_kwargs: Dict[str, str] = {} if auth_type == "msi": - client_id = _extract_msi_client_id(modified_parameters) + client_id = _extract_msi_client_id(connection_string) if client_id: credential_kwargs["client_id"] = client_id diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index eaed8c88e..316f7a81d 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -514,6 +514,28 @@ def test_process_connection_string_msi_system_assigned_no_kwargs(self): assert auth_type == "msi" assert credential_kwargs is None + def test_msi_braced_uid_value_is_unwrapped(self): + """A braced UID value (UID={hello=world}) must be unwrapped by the + canonical _ConnectionStringParser; the inner '=' must NOT split the + value. Without parser-aware extraction the helper would return + '{hello=world}' verbatim and ManagedIdentityCredential would reject + it. Regression test for saurabh500's review comment on auth.py.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID={hello=world};Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": "hello=world"} + + def test_msi_braced_uid_with_semicolon_is_preserved(self): + """A braced UID value containing a semicolon (legal under ODBC) must + be returned intact, not truncated at the inner ';'.""" + weird_id = "abc;def;ghi" + conn_str = ( + f"Server=test;Authentication=ActiveDirectoryMSI;" f"UID={{{weird_id}}};Database=testdb" + ) + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": weird_id} + def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): """Regression test (cursor.bulkcopy() end-to-end) for the silent system-assigned fallback: the bulkcopy fresh-token code path must From ff1aa4175ada3e72e620dcca758a99436fa4ec8a Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 14 May 2026 18:49:49 +0530 Subject: [PATCH 4/7] Drop reviewer name-drop in test docstring Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_008_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 316f7a81d..f8df6f6f5 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -519,7 +519,7 @@ def test_msi_braced_uid_value_is_unwrapped(self): canonical _ConnectionStringParser; the inner '=' must NOT split the value. Without parser-aware extraction the helper would return '{hello=world}' verbatim and ManagedIdentityCredential would reject - it. Regression test for saurabh500's review comment on auth.py.""" + it.""" conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID={hello=world};Database=testdb" _, _, auth_type, credential_kwargs = process_connection_string(conn_str) assert auth_type == "msi" From 1b7b434c28245b3653403dd49b6892953c37ce61 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 14 May 2026 19:33:13 +0530 Subject: [PATCH 5/7] Address @saurabh500 review: log MSI branch + comment cache shape - Debug log distinguishing user-assigned vs system-assigned MSI when the user passes Authentication=ActiveDirectoryMSI. Helps diagnose which branch was taken when token acquisition fails. Logs client_id length, not value (still identity material). - Comment above _credential_cache explains the cache key shape so the unbounded growth is understood as a deliberate choice rather than an oversight. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 7d64c7ab6..c4911fec5 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -17,6 +17,9 @@ # Reusing credential objects allows the Azure Identity SDK's built-in # in-memory token cache to work, avoiding redundant token acquisitions. # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md +# +# Cache is keyed on (auth_type, sorted credential_kwargs), which is +# bounded by the distinct credentials a single process ever uses. _credential_cache: Dict[object, object] = {} _credential_cache_lock = threading.Lock() @@ -365,6 +368,16 @@ def process_connection_string( client_id = _extract_msi_client_id(connection_string) if client_id: credential_kwargs["client_id"] = client_id + logger.debug( + "process_connection_string: ActiveDirectoryMSI with UID — " + "user-assigned managed identity selected (client_id length=%d)", + len(client_id), + ) + else: + logger.debug( + "process_connection_string: ActiveDirectoryMSI without UID — " + "system-assigned managed identity selected" + ) if auth_type: logger.info( From df24965bda4012d6b6d77843bb0013dcd5cbe599 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 14 May 2026 19:52:50 +0530 Subject: [PATCH 6/7] Drop defensive try/except in _extract_msi_client_id Connection.__init__ already parses the same connection string through _ConnectionStringParser via _construct_connection_string (connection.py line 253) before process_connection_string is ever called. By the time _extract_msi_client_id runs, the input is guaranteed parseable. The try/except was dead code. A real parse failure here would indicate an upstream bug and should propagate, not silently degrade user-assigned MSI to system-assigned (which is the wrong-identity failure mode this PR exists to prevent). Brings diff coverage to 100%. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index c4911fec5..a98d81cf6 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -166,11 +166,12 @@ def _extract_msi_client_id(connection_string: str) -> Optional[str]: ``=``), and a semicolon inside a legitimate braced value (e.g. ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. """ - try: - parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) - except Exception: # noqa: BLE001 — parser raises ConnectionStringParseError on malformed input; - # absence of UID is the safe answer for credential extraction. - return None + # Connection.__init__ already parsed the same string through + # _ConnectionStringParser via _construct_connection_string, so by the + # time we get here the input is guaranteed parseable. No defensive + # try/except: a parse failure now means a real bug upstream and should + # propagate, not silently degrade user-assigned MSI to system-assigned. + parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) uid = (parsed.get("uid") or "").strip() return uid or None From cbb673d7b4d784ce06920848b35aa8a0b1ee4719 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Fri, 15 May 2026 14:06:47 +0530 Subject: [PATCH 7/7] Document 4-tuple return arity on process_connection_string Tracking refactor (parse-once, thread the parsed map through the auth path) is a separate follow-up; this docstring helps anyone reaching into the function before that lands. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mssql_python/auth.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index a98d81cf6..9b488c6d4 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -313,6 +313,13 @@ def process_connection_string( """ Process connection string and handle authentication. + NOTE: Returns a 4-tuple. Callers must unpack all four elements. + Destructuring with three names raises ``ValueError: too many values + to unpack``. The fourth element (``credential_kwargs``) is needed by + Connection.__init__ to persist credential constructor args (e.g. the + user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, + since UID is stripped from the sanitized connection string. + Args: connection_string: The connection string to process