-
Notifications
You must be signed in to change notification settings - Fork 50
FEAT: Add ActiveDirectoryMSI support for bulk copy #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9e0b418
3ca2165
66b8eec
6451a4e
e106d5f
ff1aa41
1b7b434
df24965
9ad28a1
cbb673d
303b66c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,15 +11,33 @@ | |
|
|
||
| from mssql_python.logging import logger | ||
| from mssql_python.constants import AuthType, ConstantsDDBC | ||
| from mssql_python.connection_string_parser import _ConnectionStringParser | ||
|
|
||
| # Module-level credential instance cache. | ||
| # Reusing credential objects allows the Azure Identity SDK's built-in | ||
| # in-memory token cache to work, avoiding redundant token acquisitions. | ||
| # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md | ||
| _credential_cache: Dict[str, object] = {} | ||
| # | ||
| # Cache is keyed on (auth_type, sorted credential_kwargs), which is | ||
| # bounded by the distinct credentials a single process ever uses. | ||
| _credential_cache: Dict[object, object] = {} | ||
| _credential_cache_lock = threading.Lock() | ||
|
|
||
|
|
||
| def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): | ||
| """Build a hashable cache key from auth_type and optional credential kwargs. | ||
|
|
||
| Returns the plain auth_type string when no kwargs are provided so that | ||
| callers caching by string (the original behavior) keep working. When | ||
| kwargs are present (e.g. user-assigned MSI client_id), the key is a | ||
| tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map | ||
| to different cached credentials. | ||
| """ | ||
| if not credential_kwargs: | ||
| return auth_type | ||
| return (auth_type, tuple(sorted(credential_kwargs.items()))) | ||
|
|
||
|
|
||
| class AADAuth: | ||
| """Handles Azure Active Directory authentication""" | ||
|
|
||
|
|
@@ -37,24 +55,26 @@ def get_token_struct(token: str) -> bytes: | |
| return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes) | ||
|
|
||
| @staticmethod | ||
| def get_token(auth_type: str) -> bytes: | ||
| def get_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> bytes: | ||
| """Get DDBC token struct for the specified authentication type.""" | ||
| token_struct, _ = AADAuth._acquire_token(auth_type) | ||
| token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) | ||
| return token_struct | ||
|
|
||
| @staticmethod | ||
| def get_raw_token(auth_type: str) -> str: | ||
| def get_raw_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> str: | ||
| """Acquire a raw JWT for the mssql-py-core connection (bulk copy). | ||
|
|
||
| Uses the cached credential instance so the Azure Identity SDK's | ||
| built-in token cache can serve a valid token without a round-trip | ||
| when the previous token has not yet expired. | ||
| """ | ||
| _, raw_token = AADAuth._acquire_token(auth_type) | ||
| _, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs) | ||
| return raw_token | ||
|
|
||
| @staticmethod | ||
| def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | ||
| def _acquire_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> Tuple[bytes, str]: | ||
| """Internal: acquire token and return (ddbc_struct, raw_jwt).""" | ||
| # Import Azure libraries inside method to support test mocking | ||
| # pylint: disable=import-outside-toplevel | ||
|
|
@@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| DefaultAzureCredential, | ||
| DeviceCodeCredential, | ||
| InteractiveBrowserCredential, | ||
| ManagedIdentityCredential, | ||
| ) | ||
| from azure.core.exceptions import ClientAuthenticationError | ||
| except ImportError as e: | ||
|
|
@@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| "default": DefaultAzureCredential, | ||
| "devicecode": DeviceCodeCredential, | ||
| "interactive": InteractiveBrowserCredential, | ||
| "msi": ManagedIdentityCredential, | ||
| } | ||
|
|
||
| credential_class = credential_map.get(auth_type) | ||
|
|
@@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| credential_class.__name__, | ||
| ) | ||
|
|
||
| kwargs = credential_kwargs or {} | ||
| cache_key = _credential_cache_key(auth_type, kwargs) | ||
| try: | ||
| with _credential_cache_lock: | ||
| if auth_type not in _credential_cache: | ||
| if cache_key not in _credential_cache: | ||
| logger.debug( | ||
| "get_token: Creating new credential instance for auth_type=%s", | ||
| auth_type, | ||
| ) | ||
| _credential_cache[auth_type] = credential_class() | ||
| _credential_cache[cache_key] = credential_class(**kwargs) | ||
| else: | ||
| logger.debug( | ||
| "get_token: Reusing cached credential instance for auth_type=%s", | ||
| auth_type, | ||
| ) | ||
| credential = _credential_cache[auth_type] | ||
| credential = _credential_cache[cache_key] | ||
| raw_token = credential.get_token("https://database.windows.net/.default").token | ||
| logger.info( | ||
| "get_token: Azure AD token acquired successfully - token_length=%d chars", | ||
|
|
@@ -130,6 +154,28 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: | |
| raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e | ||
|
|
||
|
|
||
| def _extract_msi_client_id(connection_string: str) -> Optional[str]: | ||
| """Pull UID out of a connection string for user-assigned MSI. | ||
|
|
||
| For ActiveDirectoryMSI, UID (when present) carries the user-assigned | ||
| identity's ``client_id``. Returns None for system-assigned MSI. | ||
|
|
||
| Uses the canonical ``_ConnectionStringParser`` so braced ODBC values | ||
| are handled correctly: a ``UID={hello=world}`` resolves to the value | ||
| ``hello=world`` (no surrounding braces, no false split on the inner | ||
| ``=``), and a semicolon inside a legitimate braced value (e.g. | ||
| ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. | ||
| """ | ||
| # Connection.__init__ already parsed the same string through | ||
| # _ConnectionStringParser via _construct_connection_string, so by the | ||
| # time we get here the input is guaranteed parseable. No defensive | ||
| # try/except: a parse failure now means a real bug upstream and should | ||
| # propagate, not silently degrade user-assigned MSI to system-assigned. | ||
| parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Functionally correct but a tiny perf hit due to double parsing of connections string. An adhoc fix is to maintain a hashmap of connection string and uid. But that's prone to other problems esp concurrency. |
||
| uid = (parsed.get("uid") or "").strip() | ||
|
bewithgaurav marked this conversation as resolved.
|
||
| return uid or None | ||
|
|
||
|
|
||
| def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: | ||
| """ | ||
| Process connection parameters and extract authentication type. | ||
|
|
@@ -180,6 +226,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ | |
| # Default authentication (uses DefaultAzureCredential) | ||
| logger.debug("process_auth_parameters: Default Azure authentication detected") | ||
| auth_type = "default" | ||
| elif value_lower == AuthType.MSI.value: | ||
| # Managed identity authentication (system- or user-assigned) | ||
| logger.debug("process_auth_parameters: Managed identity authentication detected") | ||
| auth_type = "msi" | ||
| modified_parameters.append(param) | ||
|
|
||
| logger.debug( | ||
|
|
@@ -212,7 +262,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: | |
| return result | ||
|
|
||
|
|
||
| def get_auth_token(auth_type: str) -> Optional[bytes]: | ||
| def get_auth_token( | ||
| auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None | ||
| ) -> Optional[bytes]: | ||
| """Get DDBC authentication token struct based on auth type.""" | ||
| logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) | ||
| if not auth_type: | ||
|
|
@@ -225,7 +277,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: | |
| return None # Let Windows handle AADInteractive natively | ||
|
|
||
| try: | ||
| token = AADAuth.get_token(auth_type) | ||
| token = AADAuth.get_token(auth_type, credential_kwargs) | ||
| logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type) | ||
| return token | ||
| except (ValueError, RuntimeError) as e: | ||
|
|
@@ -246,6 +298,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: | |
| AuthType.INTERACTIVE.value: "interactive", | ||
| AuthType.DEVICE_CODE.value: "devicecode", | ||
| AuthType.DEFAULT.value: "default", | ||
| AuthType.MSI.value: "msi", | ||
| } | ||
| for part in connection_string.split(";"): | ||
| key, _, value = part.strip().partition("=") | ||
|
|
@@ -256,16 +309,28 @@ def extract_auth_type(connection_string: str) -> Optional[str]: | |
|
|
||
| def process_connection_string( | ||
| connection_string: str, | ||
| ) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: | ||
| ) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: | ||
| """ | ||
| Process connection string and handle authentication. | ||
|
|
||
| NOTE: Returns a 4-tuple. Callers must unpack all four elements. | ||
| Destructuring with three names raises ``ValueError: too many values | ||
| to unpack``. The fourth element (``credential_kwargs``) is needed by | ||
| Connection.__init__ to persist credential constructor args (e.g. the | ||
| user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, | ||
| since UID is stripped from the sanitized connection string. | ||
|
|
||
| Args: | ||
| connection_string: The connection string to process | ||
|
|
||
| Returns: | ||
| Tuple[str, Optional[Dict], Optional[str]]: Processed connection string, | ||
| attrs_before dict if needed, and auth_type string for bulk copy token acquisition | ||
| Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: | ||
| Processed connection string, attrs_before dict if needed, auth_type | ||
| string for bulk copy token acquisition, and credential constructor | ||
| kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on | ||
| the Connection so bulkcopy can re-use them when acquiring a fresh | ||
| token after sanitization has stripped UID from the connection | ||
| string. | ||
|
|
||
| Raises: | ||
| ValueError: If the connection string is invalid or empty | ||
|
|
@@ -301,12 +366,33 @@ def process_connection_string( | |
|
|
||
| modified_parameters, auth_type = process_auth_parameters(parameters) | ||
|
|
||
| # Capture credential kwargs (e.g. user-assigned MSI client_id) before | ||
| # remove_sensitive_params strips UID from the parameter list. Pass the | ||
| # original connection_string (not modified_parameters) so the helper can | ||
| # use the canonical _ConnectionStringParser — handles braced values like | ||
| # UID={hello=world} correctly. | ||
| credential_kwargs: Dict[str, str] = {} | ||
| if auth_type == "msi": | ||
| client_id = _extract_msi_client_id(connection_string) | ||
| if client_id: | ||
| credential_kwargs["client_id"] = client_id | ||
| logger.debug( | ||
| "process_connection_string: ActiveDirectoryMSI with UID — " | ||
| "user-assigned managed identity selected (client_id length=%d)", | ||
| len(client_id), | ||
| ) | ||
| else: | ||
| logger.debug( | ||
| "process_connection_string: ActiveDirectoryMSI without UID — " | ||
| "system-assigned managed identity selected" | ||
| ) | ||
|
|
||
| if auth_type: | ||
| logger.info( | ||
| "process_connection_string: Authentication type detected - auth_type=%s", auth_type | ||
| ) | ||
| modified_parameters = remove_sensitive_params(modified_parameters) | ||
| token_struct = get_auth_token(auth_type) | ||
| token_struct = get_auth_token(auth_type, credential_kwargs or None) | ||
| if token_struct: | ||
| logger.info( | ||
| "process_connection_string: Token authentication configured successfully - auth_type=%s", | ||
|
|
@@ -316,6 +402,7 @@ def process_connection_string( | |
| ";".join(modified_parameters) + ";", | ||
| {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, | ||
| auth_type, | ||
| credential_kwargs or None, | ||
| ) | ||
| else: | ||
| logger.warning( | ||
|
|
@@ -326,4 +413,9 @@ def process_connection_string( | |
| "process_connection_string: Connection string processing complete - has_auth=%s", | ||
| bool(auth_type), | ||
| ) | ||
| return ";".join(modified_parameters) + ";", None, auth_type | ||
| return ( | ||
| ";".join(modified_parameters) + ";", | ||
| None, | ||
| auth_type, | ||
| credential_kwargs or None, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -321,6 +321,12 @@ def __init__( | |
| # We intentionally do NOT cache the token — a fresh one is acquired | ||
| # each time bulkcopy() is called to avoid expired-token errors. | ||
| self._auth_type = None | ||
| # Credential constructor kwargs (e.g. user-assigned MSI client_id) | ||
| # captured at __init__ time before remove_sensitive_params strips UID | ||
| # from self.connection_str. bulkcopy() re-uses these when acquiring a | ||
| # fresh token; re-parsing self.connection_str at that point would miss | ||
| # them because UID is already gone. | ||
| self._credential_kwargs: Optional[Dict[str, str]] = None | ||
|
|
||
| # Check if the connection string contains authentication parameters | ||
| # This is important for processing the connection string correctly. | ||
|
|
@@ -335,6 +341,7 @@ def __init__( | |
| # On Windows Interactive, process_connection_string returns None | ||
| # (DDBC handles auth natively), so fall back to the connection string. | ||
| self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) | ||
| self._credential_kwargs = connection_result[3] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call. the 4-tuple is interim. I have filed #580 to refactor process_connection_string to take the parsed connection-string map instead of the raw string (also avoids the 3x parse on the connect path). once that lands, the 4-tuple goes away and callers move to dict access. for now, added a short docstring note about the 4-tuple to cover anyone reaching in before #580 lands. |
||
|
|
||
| self._closed = False | ||
| self._timeout = timeout | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.