diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 40c3e06e..ceded2df 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -11,15 +11,33 @@ from mssql_python.logging import logger from mssql_python.constants import AuthType, ConstantsDDBC +from mssql_python.connection_string_parser import _ConnectionStringParser # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in # in-memory token cache to work, avoiding redundant token acquisitions. # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md -_credential_cache: Dict[str, object] = {} +# +# Cache is keyed on (auth_type, sorted credential_kwargs), which is +# bounded by the distinct credentials a single process ever uses. +_credential_cache: Dict[object, object] = {} _credential_cache_lock = threading.Lock() +def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): + """Build a hashable cache key from auth_type and optional credential kwargs. + + Returns the plain auth_type string when no kwargs are provided so that + callers caching by string (the original behavior) keep working. When + kwargs are present (e.g. user-assigned MSI client_id), the key is a + tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map + to different cached credentials. + """ + if not credential_kwargs: + return auth_type + return (auth_type, tuple(sorted(credential_kwargs.items()))) + + class AADAuth: """Handles Azure Active Directory authentication""" @@ -37,24 +55,26 @@ def get_token_struct(token: str) -> bytes: return struct.pack(f" bytes: + def get_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> bytes: """Get DDBC token struct for the specified authentication type.""" - token_struct, _ = AADAuth._acquire_token(auth_type) + token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) return token_struct @staticmethod - def get_raw_token(auth_type: str) -> str: + def get_raw_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> str: """Acquire a raw JWT for the mssql-py-core connection (bulk copy). Uses the cached credential instance so the Azure Identity SDK's built-in token cache can serve a valid token without a round-trip when the previous token has not yet expired. """ - _, raw_token = AADAuth._acquire_token(auth_type) + _, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs) return raw_token @staticmethod - def _acquire_token(auth_type: str) -> Tuple[bytes, str]: + def _acquire_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None + ) -> Tuple[bytes, str]: """Internal: acquire token and return (ddbc_struct, raw_jwt).""" # Import Azure libraries inside method to support test mocking # pylint: disable=import-outside-toplevel @@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: DefaultAzureCredential, DeviceCodeCredential, InteractiveBrowserCredential, + ManagedIdentityCredential, ) from azure.core.exceptions import ClientAuthenticationError except ImportError as e: @@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: "default": DefaultAzureCredential, "devicecode": DeviceCodeCredential, "interactive": InteractiveBrowserCredential, + "msi": ManagedIdentityCredential, } credential_class = credential_map.get(auth_type) @@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: credential_class.__name__, ) + kwargs = credential_kwargs or {} + cache_key = _credential_cache_key(auth_type, kwargs) try: with _credential_cache_lock: - if auth_type not in _credential_cache: + if cache_key not in _credential_cache: logger.debug( "get_token: Creating new credential instance for auth_type=%s", auth_type, ) - _credential_cache[auth_type] = credential_class() + _credential_cache[cache_key] = credential_class(**kwargs) else: logger.debug( "get_token: Reusing cached credential instance for auth_type=%s", auth_type, ) - credential = _credential_cache[auth_type] + credential = _credential_cache[cache_key] raw_token = credential.get_token("https://database.windows.net/.default").token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", @@ -130,6 +154,157 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +def _parse_tenant_id(sts_url: str) -> Optional[str]: + """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. + + Expected formats: + https://login.microsoftonline.com// + https://login.microsoftonline.com//?... + https://login.microsoftonline.com/ + where is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) + or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are + accepted by ``azure.identity.ClientSecretCredential``. + """ + # pylint: disable=import-outside-toplevel + from urllib.parse import urlparse + + try: + parsed = urlparse(sts_url) + except (ValueError, AttributeError): + return None + # Reject anything that isn't an https URL with a netloc. ``urlparse`` will + # happily put a bare string like ``"tenant-guid"`` into ``path``, which + # would then look like a valid tenant. Azure AD STS URLs are always https. + if parsed.scheme != "https" or not parsed.netloc: + return None + path = (parsed.path or "").strip("/") + if not path: + return None + first_segment = path.split("/", 1)[0] + return first_segment or None + + +class ServicePrincipalAuth: + """Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal. + + The bulkcopy path through mssql-py-core uses callback-based token + acquisition (FedAuth workflow ``0x02``) because tenant_id is only known + from the STS URL that the server returns during the TDS handshake. + """ + + @staticmethod + def make_token_factory(client_id: str, client_secret: str): + """Return a callable suitable for ``entra_id_token_factory``. + + Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``. + Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format). + + ``ClientSecretCredential`` instances are reused across calls via the + module-level ``_credential_cache``, keyed by + ``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's + in-memory token cache (which is per-credential-instance) actually + works across handshake retries, reconnects, and separate bulkcopy + invocations using the same identity. + """ + if not client_id: + raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)") + if not client_secret: + raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)") + + def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: + # pylint: disable=import-outside-toplevel,unused-argument + try: + from azure.identity import ClientSecretCredential + from azure.core.exceptions import ClientAuthenticationError + except ImportError as e: + raise RuntimeError( + "Azure authentication libraries are not installed. " + "Please install with: pip install azure-identity azure-core" + ) from e + + if not spn: + raise RuntimeError( + "ServicePrincipal token factory: empty SPN from server " + "(cannot construct token scope)" + ) + tenant_id = _parse_tenant_id(sts_url) + if not tenant_id: + raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}") + + logger.info( + "ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s", + tenant_id, + spn, + ) + try: + # Reuse the shared credential cache (introduced for MSI in PR #573) + # so SP credentials get the same per-instance token reuse semantics + # as the other AD methods. Key includes tenant_id so a server that + # somehow returns different tenants on different handshakes still + # gets distinct credentials. client_secret is intentionally NOT in + # the key — credentials are looked up by identity, not by secret; + # if the secret rotates, the closure will still hold the old one + # and AAD will reject the token, surfacing as ClientAuthenticationError. + cache_key = _credential_cache_key( + "serviceprincipal", + {"tenant_id": tenant_id, "client_id": client_id}, + ) + with _credential_cache_lock: + credential = _credential_cache.get(cache_key) + if credential is None: + credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + _credential_cache[cache_key] = credential + # mssql-tds passes the resource SPN; azure-identity wants a scope. + scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default" + token = credential.get_token(scope).token + logger.info( + "ServicePrincipal token factory: token acquired, length=%d chars", + len(token), + ) + return token.encode("utf-16-le") + except ClientAuthenticationError as e: + # Keep the detailed provider error in debug logs only. The + # surfaced message is intentionally generic so that any + # secret-bearing provider text never reaches the user-facing + # exception chain. + logger.error( + "ServicePrincipal authentication failed: tenant=%s, error=%s", + tenant_id, + str(e), + ) + raise RuntimeError( + "ServicePrincipal authentication failed; " "see debug logs for provider details" + ) from None + + return _factory + + +def _extract_msi_client_id(connection_string: str) -> Optional[str]: + """Pull UID out of a connection string for user-assigned MSI. + + For ActiveDirectoryMSI, UID (when present) carries the user-assigned + identity's ``client_id``. Returns None for system-assigned MSI. + + Uses the canonical ``_ConnectionStringParser`` so braced ODBC values + are handled correctly: a ``UID={hello=world}`` resolves to the value + ``hello=world`` (no surrounding braces, no false split on the inner + ``=``), and a semicolon inside a legitimate braced value (e.g. + ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. + """ + # Connection.__init__ already parsed the same string through + # _ConnectionStringParser via _construct_connection_string, so by the + # time we get here the input is guaranteed parseable. No defensive + # try/except: a parse failure now means a real bug upstream and should + # propagate, not silently degrade user-assigned MSI to system-assigned. + parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) + uid = (parsed.get("uid") or "").strip() + return uid or None + + def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: """ Process connection parameters and extract authentication type. @@ -180,6 +355,21 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ # Default authentication (uses DefaultAzureCredential) logger.debug("process_auth_parameters: Default Azure authentication detected") auth_type = "default" + elif value_lower == AuthType.MSI.value: + # Managed identity authentication (system- or user-assigned) + logger.debug("process_auth_parameters: Managed identity authentication detected") + auth_type = "msi" + elif value_lower == AuthType.SERVICE_PRINCIPAL.value: + # ServicePrincipal authentication. ODBC (msodbcsql 17.3+) + # handles this natively for regular queries, so leave + # auth_type=None to let ODBC own the query path. + # Bulkcopy still needs the auth type — extract_auth_type() + # propagates it as "serviceprincipal" so the bulkcopy path + # can register an entra_id_token_factory callback (Model B, + # required because tenant_id is only known from the STS URL + # that the server returns during the FedAuth handshake). + logger.debug("process_auth_parameters: Service principal authentication detected") + auth_type = None modified_parameters.append(param) logger.debug( @@ -212,7 +402,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: return result -def get_auth_token(auth_type: str) -> Optional[bytes]: +def get_auth_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None +) -> Optional[bytes]: """Get DDBC authentication token struct based on auth type.""" logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) if not auth_type: @@ -225,7 +417,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: return None # Let Windows handle AADInteractive natively try: - token = AADAuth.get_token(auth_type) + token = AADAuth.get_token(auth_type, credential_kwargs) logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type) return token except (ValueError, RuntimeError) as e: @@ -246,6 +438,8 @@ def extract_auth_type(connection_string: str) -> Optional[str]: AuthType.INTERACTIVE.value: "interactive", AuthType.DEVICE_CODE.value: "devicecode", AuthType.DEFAULT.value: "default", + AuthType.MSI.value: "msi", + AuthType.SERVICE_PRINCIPAL.value: "serviceprincipal", } for part in connection_string.split(";"): key, _, value = part.strip().partition("=") @@ -256,7 +450,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: def process_connection_string( connection_string: str, -) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: +) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: """ Process connection string and handle authentication. @@ -264,8 +458,13 @@ def process_connection_string( connection_string: The connection string to process Returns: - Tuple[str, Optional[Dict], Optional[str]]: Processed connection string, - attrs_before dict if needed, and auth_type string for bulk copy token acquisition + Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: + Processed connection string, attrs_before dict if needed, auth_type + string for bulk copy token acquisition, and credential constructor + kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on + the Connection so bulkcopy can re-use them when acquiring a fresh + token after sanitization has stripped UID from the connection + string. Raises: ValueError: If the connection string is invalid or empty @@ -301,12 +500,33 @@ def process_connection_string( modified_parameters, auth_type = process_auth_parameters(parameters) + # Capture credential kwargs (e.g. user-assigned MSI client_id) before + # remove_sensitive_params strips UID from the parameter list. Pass the + # original connection_string (not modified_parameters) so the helper can + # use the canonical _ConnectionStringParser — handles braced values like + # UID={hello=world} correctly. + credential_kwargs: Dict[str, str] = {} + if auth_type == "msi": + client_id = _extract_msi_client_id(connection_string) + if client_id: + credential_kwargs["client_id"] = client_id + logger.debug( + "process_connection_string: ActiveDirectoryMSI with UID — " + "user-assigned managed identity selected (client_id length=%d)", + len(client_id), + ) + else: + logger.debug( + "process_connection_string: ActiveDirectoryMSI without UID — " + "system-assigned managed identity selected" + ) + if auth_type: logger.info( "process_connection_string: Authentication type detected - auth_type=%s", auth_type ) modified_parameters = remove_sensitive_params(modified_parameters) - token_struct = get_auth_token(auth_type) + token_struct = get_auth_token(auth_type, credential_kwargs or None) if token_struct: logger.info( "process_connection_string: Token authentication configured successfully - auth_type=%s", @@ -316,6 +536,7 @@ def process_connection_string( ";".join(modified_parameters) + ";", {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, auth_type, + credential_kwargs or None, ) else: logger.warning( @@ -326,4 +547,9 @@ def process_connection_string( "process_connection_string: Connection string processing complete - has_auth=%s", bool(auth_type), ) - return ";".join(modified_parameters) + ";", None, auth_type + return ( + ";".join(modified_parameters) + ";", + None, + auth_type, + credential_kwargs or None, + ) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0064917a..e5876380 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -284,6 +284,12 @@ def __init__( # We intentionally do NOT cache the token — a fresh one is acquired # each time bulkcopy() is called to avoid expired-token errors. self._auth_type = None + # Credential constructor kwargs (e.g. user-assigned MSI client_id) + # captured at __init__ time before remove_sensitive_params strips UID + # from self.connection_str. bulkcopy() re-uses these when acquiring a + # fresh token; re-parsing self.connection_str at that point would miss + # them because UID is already gone. + self._credential_kwargs: Optional[Dict[str, str]] = None # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. @@ -298,6 +304,7 @@ def __init__( # On Windows Interactive, process_connection_string returns None # (DDBC handles auth natively), so fall back to the connection string. self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) + self._credential_kwargs = connection_result[3] self._closed = False self._timeout = timeout diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 549737c6..f9f9331d 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -337,6 +337,8 @@ class AuthType(Enum): INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" + MSI = "activedirectorymsi" + SERVICE_PRINCIPAL = "activedirectoryserviceprincipal" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index f0b1d6a6..9915eea2 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2934,24 +2934,59 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection - from mssql_python.auth import AADAuth - - try: - raw_token = AADAuth.get_raw_token(self.connection._auth_type) - except (RuntimeError, ValueError) as e: - raise RuntimeError( - f"Bulk copy failed: unable to acquire Azure AD token " - f"for auth_type '{self.connection._auth_type}': {e}" - ) from e - pycore_context["access_token"] = raw_token - # Token replaces credential fields — py-core's validator rejects - # access_token combined with authentication/user_name/password. - for key in ("authentication", "user_name", "password"): - pycore_context.pop(key, None) - logger.debug( - "Bulk copy: acquired fresh Azure AD token for auth_type=%s", - self.connection._auth_type, - ) + from mssql_python.auth import AADAuth, ServicePrincipalAuth + + if self.connection._auth_type == "serviceprincipal": + # Model B: callback-based. tenant_id is only known from the + # STS URL the server returns mid-handshake, so we register a + # factory that py-core invokes during FedAuth (workflow 0x02). + client_id = params.get("uid", "") + client_secret = params.get("pwd", "") + if not client_id or not client_secret: + raise RuntimeError( + "Bulk copy with Authentication=ActiveDirectoryServicePrincipal " + "requires UID (client_id) and PWD (client_secret) in the " + "connection string." + ) + try: + factory = ServicePrincipalAuth.make_token_factory(client_id, client_secret) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to build ServicePrincipal token factory: {e}" + ) from e + pycore_context["entra_id_token_factory"] = factory + # Keep authentication/user_name/password in pycore_context — + # py-core's auth validator + transformer need them to resolve + # the auth method to ActiveDirectoryServicePrincipal before + # the factory is dispatched at handshake time. + logger.debug("Bulk copy: registered ServicePrincipal token factory") + else: + # Model A: pre-acquired token. Used for Default, DeviceCode, + # Interactive (non-Windows), MSI (system- or user-assigned), + # and any other AD method whose tenant_id is discoverable + # client-side via Azure Identity SDK. credential kwargs + # (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped + # UID from connection_str — re-parsing here would miss them. + try: + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, + self.connection._credential_kwargs, + ) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire Azure AD token " + f"for auth_type '{self.connection._auth_type}': {e}" + ) from e + pycore_context["access_token"] = raw_token + # Token replaces credential fields — py-core's validator rejects + # access_token combined with authentication/user_name/password. + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh Azure AD token for auth_type=%s", + self.connection._auth_type, + ) pycore_connection = None pycore_cursor = None @@ -3001,9 +3036,17 @@ def bulkcopy( raise type(e)(str(e)) from None finally: - # Clear sensitive data to minimize memory exposure + # Clear sensitive data to minimize memory exposure. The + # entra_id_token_factory closure captures client_secret, so drop + # our dict reference to it (Rust still holds an Arc until the + # connection is dropped, but at least we don't keep an extra ref). if pycore_context: - for key in ("password", "user_name", "access_token"): + for key in ( + "password", + "user_name", + "access_token", + "entra_id_token_factory", + ): pycore_context.pop(key, None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f680518b..54f6236d 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -11,6 +11,8 @@ from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, + ServicePrincipalAuth, + _parse_tenant_id, process_auth_parameters, remove_sensitive_params, get_auth_token, @@ -44,6 +46,31 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockClientSecretCredential: + # Captures construction kwargs and get_token args so ServicePrincipal + # tests can assert the right tenant/client_id/secret/scope flowed + # through from the connection string + STS URL. + last_init_kwargs = None + last_scope = None + + def __init__(self, **kwargs): + MockClientSecretCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + MockClientSecretCredential.last_scope = scope + return MockToken() + + class MockManagedIdentityCredential: + # Captures construction kwargs so user-assigned MSI tests can assert + # client_id was forwarded correctly. + last_init_kwargs = None + + def __init__(self, **kwargs): + MockManagedIdentityCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + return MockToken() + # Mock ClientAuthenticationError class MockClientAuthenticationError(Exception): pass @@ -52,6 +79,8 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ClientSecretCredential = MockClientSecretCredential + ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: class exceptions: @@ -87,6 +116,8 @@ def test_auth_type_constants(self): assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" + assert AuthType.MSI.value == "activedirectorymsi" + assert AuthType.SERVICE_PRINCIPAL.value == "activedirectoryserviceprincipal" class TestAADAuth: @@ -317,6 +348,30 @@ def test_default_auth(self): _, auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_service_principal_auth_leaves_odbc_path_alone(self): + """ServicePrincipal is handled natively by ODBC. process_auth_parameters + must return auth_type=None so the ODBC path doesn't pre-acquire a token + (which would require tenant_id we don't have client-side).""" + params = ["Authentication=ActiveDirectoryServicePrincipal", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryServicePrincipal" in modified_params + assert auth_type is None + + def test_service_principal_auth_case_insensitive(self): + params = ["authentication=activedirectoryserviceprincipal", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type is None + + def test_msi_auth(self): + params = ["Authentication=ActiveDirectoryMSI", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + + def test_msi_auth_case_insensitive(self): + params = ["authentication=activedirectorymsi", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + class TestRemoveSensitiveParams: def test_remove_sensitive_parameters(self): @@ -344,7 +399,7 @@ def test_remove_sensitive_parameters(self): class TestProcessConnectionString: def test_process_connection_string_with_default_auth(self): conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -352,10 +407,11 @@ def test_process_connection_string_with_default_auth(self): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "default" + assert credential_kwargs is None def test_process_connection_string_no_auth(self): conn_str = "Server=test;Database=testdb;UID=user;PWD=password" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -363,11 +419,12 @@ def test_process_connection_string_no_auth(self): assert "PWD=password" in result_str assert attrs is None assert auth_type is None + assert credential_kwargs is None def test_process_connection_string_interactive_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -375,6 +432,7 @@ def test_process_connection_string_interactive_non_windows(self, monkeypatch): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "interactive" + assert credential_kwargs is None def test_error_handling(): @@ -407,6 +465,15 @@ def test_devicecode(self): == "devicecode" ) + def test_serviceprincipal(self): + assert ( + extract_auth_type("Server=test;Authentication=ActiveDirectoryServicePrincipal;") + == "serviceprincipal" + ) + + def test_msi(self): + assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" + def test_no_auth(self): assert extract_auth_type("Server=test;Database=db;") is None @@ -414,6 +481,159 @@ def test_unsupported_auth(self): assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None +class TestManagedIdentity: + """Tests for ActiveDirectoryMSI support (system- and user-assigned).""" + + def test_get_token_system_assigned_msi(self): + """System-assigned MSI: ManagedIdentityCredential() constructed with no kwargs.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + token_struct = AADAuth.get_token("msi") + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {} + + def test_get_raw_token_system_assigned_msi(self): + raw_token = AADAuth.get_raw_token("msi") + assert raw_token == SAMPLE_TOKEN + + def test_get_token_user_assigned_msi(self): + """User-assigned MSI: client_id is forwarded to the credential constructor.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + client_id = "11111111-2222-3333-4444-555555555555" + token_struct = AADAuth.get_token("msi", {"client_id": client_id}) + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {"client_id": client_id} + + def test_msi_separate_cache_entries_per_client_id(self): + """System-assigned and user-assigned MSI must not share a cached credential.""" + AADAuth.get_token("msi") # system-assigned + AADAuth.get_token("msi", {"client_id": "abc"}) + AADAuth.get_token("msi", {"client_id": "def"}) + + # System-assigned uses the bare string key; user-assigned uses tuples. + assert "msi" in _credential_cache + assert ("msi", (("client_id", "abc"),)) in _credential_cache + assert ("msi", (("client_id", "def"),)) in _credential_cache + assert _credential_cache["msi"] is not _credential_cache[("msi", (("client_id", "abc"),))] + + def test_process_connection_string_msi_strips_uid_and_returns_kwargs(self): + """MSI connection strings: UID is stripped from the ODBC connection + string but the client_id is captured as credential_kwargs (so it can + be persisted on the Connection for the bulkcopy fresh-token path).""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + conn_str = ( + "Server=test;Authentication=ActiveDirectoryMSI;" + "UID=11111111-2222-3333-4444-555555555555;Database=testdb" + ) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) + + assert auth_type == "msi" + assert "UID=" not in result_str + assert "Authentication=" not in result_str + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert az.ManagedIdentityCredential.last_init_kwargs == { + "client_id": "11111111-2222-3333-4444-555555555555" + } + # client_id must be returned so Connection can persist it for the + # bulkcopy fresh-token path (UID is gone from result_str by then). + assert credential_kwargs == {"client_id": "11111111-2222-3333-4444-555555555555"} + + def test_process_connection_string_msi_system_assigned_no_kwargs(self): + """System-assigned MSI: no UID → credential_kwargs is None.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs is None + + def test_msi_braced_uid_value_is_unwrapped(self): + """A braced UID value (UID={hello=world}) must be unwrapped by the + canonical _ConnectionStringParser; the inner '=' must NOT split the + value. Without parser-aware extraction the helper would return + '{hello=world}' verbatim and ManagedIdentityCredential would reject + it.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID={hello=world};Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": "hello=world"} + + def test_msi_braced_uid_with_semicolon_is_preserved(self): + """A braced UID value containing a semicolon (legal under ODBC) must + be returned intact, not truncated at the inner ';'.""" + weird_id = "abc;def;ghi" + conn_str = ( + f"Server=test;Authentication=ActiveDirectoryMSI;" f"UID={{{weird_id}}};Database=testdb" + ) + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": weird_id} + + def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): + """Regression test (cursor.bulkcopy() end-to-end) for the silent + system-assigned fallback: the bulkcopy fresh-token code path must + forward Connection._credential_kwargs to AADAuth.get_raw_token, + not re-parse the (now UID-stripped) connection_str. + + Fails if cursor.py is reverted to call extract_credential_kwargs on + self.connection.connection_str, OR if Connection stops persisting + _credential_kwargs.""" + from mssql_python.cursor import Cursor + + client_id = "11111111-2222-3333-4444-555555555555" + + # Mock Connection holding what Connection.__init__ would store after + # process_connection_string strips UID from the user-supplied string. + mock_conn = MagicMock() + # Post-sanitization string: NO UID. If cursor re-parses this, the + # forwarded kwargs will be {} and the assert below will fail. + mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" + mock_conn._auth_type = "msi" + mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._is_connected = True + + cursor = Cursor.__new__(Cursor) + cursor._connection = mock_conn + cursor.closed = False + cursor.hstmt = None + + captured = {} + + def fake_get_raw_token(auth_type, credential_kwargs=None): + captured["auth_type"] = auth_type + captured["credential_kwargs"] = credential_kwargs + return SAMPLE_TOKEN + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with ( + patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}), + patch("mssql_python.auth.AADAuth.get_raw_token", side_effect=fake_get_raw_token), + ): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert captured["auth_type"] == "msi" + assert captured["credential_kwargs"] == {"client_id": client_id}, ( + f"bulkcopy must forward Connection._credential_kwargs verbatim; " + f"got {captured['credential_kwargs']!r}. If this is {{}} or None, " + f"the cursor likely re-parses the (UID-stripped) connection_str." + ) + + class TestCredentialInstanceCache: """Tests for the credential instance caching behavior.""" @@ -624,6 +844,50 @@ def test_auth_type_stored_on_connection(self, mock_ddbc_conn): assert conn._auth_type == "default" conn.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_persisted_for_user_assigned_msi(self, mock_ddbc_conn): + """Connection.__init__ must capture MSI client_id BEFORE + remove_sensitive_params strips UID, and persist it on + self._credential_kwargs so cursor.bulkcopy() can use it later.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + client_id = "11111111-2222-3333-4444-555555555555" + conn = connect( + f"Server=test;Database=testdb;Authentication=ActiveDirectoryMSI;UID={client_id}" + ) + assert conn._auth_type == "msi" + assert conn._credential_kwargs == {"client_id": client_id} + # And the connection_str on the Connection should NOT contain UID + # (this is what makes _credential_kwargs the source of truth). + assert "UID=" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_system_assigned_msi(self, mock_ddbc_conn): + """System-assigned MSI: no UID → _credential_kwargs stays None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryMSI") + assert conn._auth_type == "msi" + assert conn._credential_kwargs is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_non_msi_auth(self, mock_ddbc_conn): + """Non-MSI auth types must not pick up credential_kwargs even if + UID is present (e.g. SQL auth UID).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault;UID=user@x" + ) + assert conn._auth_type == "default" + assert conn._credential_kwargs is None + conn.close() + class TestCredentialCacheThreadSafety: """Verify thread-safe behavior of credential instance cache.""" @@ -760,7 +1024,7 @@ class TestProcessConnectionStringTokenFailureFallthrough: def test_returns_none_attrs_when_token_acquisition_fails(self): """When auth type is detected but token acquisition fails, - process_connection_string should return (conn_str, None, auth_type).""" + process_connection_string should return (conn_str, None, auth_type, kwargs).""" import sys azure_identity = sys.modules["azure.identity"] @@ -773,7 +1037,7 @@ def __init__(self): try: azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) # Auth type was detected assert auth_type == "default" @@ -782,5 +1046,237 @@ def __init__(self): # Connection string is still returned (sensitive params removed) assert "Server=test" in result_str assert "Database=testdb" in result_str + # Default auth has no credential kwargs + assert credential_kwargs is None finally: azure_identity.DefaultAzureCredential = original + + +class TestParseTenantId: + def test_guid_tenant(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_guid_tenant_no_trailing_slash(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_domain_tenant(self): + url = "https://login.microsoftonline.com/contoso.onmicrosoft.com/" + assert _parse_tenant_id(url) == "contoso.onmicrosoft.com" + + def test_tenant_with_query_string(self): + url = "https://login.microsoftonline.com/tenant-guid/?foo=bar" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_extra_path_segments_after_tenant(self): + url = "https://login.microsoftonline.com/tenant-guid/oauth2/authorize" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_empty_string(self): + assert _parse_tenant_id("") is None + + def test_no_path(self): + assert _parse_tenant_id("https://login.microsoftonline.com/") is None + + def test_rejects_bare_string_without_scheme(self): + # urlparse puts a bare string into path; without a scheme/netloc check + # this would be silently treated as a tenant id. + assert _parse_tenant_id("tenant-guid") is None + + def test_rejects_path_only_url(self): + assert _parse_tenant_id("/tenant-guid/oauth2") is None + + def test_rejects_http_scheme(self): + # Azure AD STS URLs are always https. Reject http to avoid trusting + # a downgraded URL. + assert _parse_tenant_id("http://login.microsoftonline.com/tenant/") is None + + +class TestServicePrincipalAuth: + """Tests for the ActiveDirectoryServicePrincipal token factory.""" + + def test_make_token_factory_returns_callable(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + assert callable(factory) + + def test_factory_requires_client_id(self): + with pytest.raises(ValueError, match="client_id"): + ServicePrincipalAuth.make_token_factory("", "client-secret") + + def test_factory_requires_client_secret(self): + with pytest.raises(ValueError, match="client_secret"): + ServicePrincipalAuth.make_token_factory("client-id", "") + + def test_factory_returns_utf16le_bytes(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + result = factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + assert isinstance(result, bytes) + # SAMPLE_TOKEN is hex chars (ASCII). UTF-16LE encoding doubles each byte + # and inserts a 0x00 high byte after each ASCII char. + assert result == SAMPLE_TOKEN.encode("utf-16-le") + assert len(result) == len(SAMPLE_TOKEN) * 2 + + def test_factory_forwards_credentials_to_ClientSecretCredential(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory( + "11111111-2222-3333-4444-555555555555", "my-secret" + ) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/", + "activedirectoryserviceprincipal", + ) + + assert az.ClientSecretCredential.last_init_kwargs == { + "tenant_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "client_id": "11111111-2222-3333-4444-555555555555", + "client_secret": "my-secret", + } + + def test_factory_builds_scope_from_spn(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_keeps_existing_default_suffix(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/.default", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_errors_on_unparseable_sts_url(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="Could not extract tenant_id"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/", # no tenant segment + "activedirectoryserviceprincipal", + ) + + def test_factory_propagates_authentication_error(self): + from azure.core.exceptions import ClientAuthenticationError + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError("AADSTS7000215: Invalid client secret") + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="ServicePrincipal authentication failed"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_does_not_leak_provider_message_in_runtime_error(self): + """The user-facing RuntimeError must not echo the provider message + (which can carry tenant ids, claims, or other sensitive context). + Provider detail is preserved in debug logs only.""" + from azure.core.exceptions import ClientAuthenticationError + + secret_marker = "AADSTS7000215_SECRET_MARKER_in_provider_message" + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError(secret_marker) + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + try: + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + except RuntimeError as e: + full_chain = str(e) + cause = e.__cause__ + while cause is not None: + full_chain += " || " + str(cause) + cause = getattr(cause, "__cause__", None) + assert ( + secret_marker not in full_chain + ), f"Provider message leaked into surfaced exception chain: {full_chain}" + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_rejects_empty_spn(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="empty SPN"): + factory( + "", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + def test_factory_caches_credential_per_tenant(self): + """ClientSecretCredential must be reused across calls for the same + tenant so azure-identity's per-instance token cache actually works.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + sts = "https://login.microsoftonline.com/tenant-guid/" + for _ in range(3): + factory("https://database.windows.net/", sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1, ( + f"Expected 1 ClientSecretCredential construction across 3 calls, " + f"got {construction_count['n']}" + ) + # A different tenant should produce a second instance. + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/other-tenant/", + "activedirectoryserviceprincipal", + ) + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original