diff --git a/src/mistralai/extra/workflows/encoding/payload_encoder.py b/src/mistralai/extra/workflows/encoding/payload_encoder.py index 611f33fa..1a7fe7ae 100644 --- a/src/mistralai/extra/workflows/encoding/payload_encoder.py +++ b/src/mistralai/extra/workflows/encoding/payload_encoder.py @@ -9,7 +9,7 @@ import urllib.parse from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError if TYPE_CHECKING: from cryptography.exceptions import InvalidTag @@ -38,6 +38,7 @@ NetworkEncodedResult, WorkflowContext, ) +from mistralai.client.models.jsonpatchpayloadresponse import JSONPatchPayloadResponse from mistralai.extra.exceptions import ( WorkflowPayloadEncryptionException, WorkflowPayloadOffloadingException, @@ -151,7 +152,10 @@ def _decrypt(self, data: bytes) -> bytes: async def _handle_offloading( self, data: bytes, context: Optional[WorkflowContext] ) -> tuple[bytes, bool]: - if self.offloading_config is None or self.offloading_config.storage_config is None: + if ( + self.offloading_config is None + or self.offloading_config.storage_config is None + ): raise WorkflowPayloadOffloadingException( "You must configure payload offloading storage" ) @@ -281,7 +285,10 @@ async def encode_event_payload_content( if self.encryption_config is None: return data, [] - if force_full_encryption or self.encryption_config.mode == PayloadEncryptionMode.FULL: + if ( + force_full_encryption + or self.encryption_config.mode == PayloadEncryptionMode.FULL + ): encrypted_data = self._encrypt(data) return encrypted_data, [EncodedPayloadOptions.ENCRYPTED] @@ -339,8 +346,28 @@ async def decode_event_payload( return payload_data encoding_options = [EncodedPayloadOptions(opt) for opt in encoding_options_strs] + + # Handle selective encryption for json_patch payloads + if EncodedPayloadOptions.PARTIALLY_ENCRYPTED in encoding_options: + try: + payload = JSONPatchPayloadResponse.model_validate(payload_data) + if isinstance(payload.value, list): + decrypted_patches = self._decrypt_json_patch_selective( + [p.model_dump() for p in payload.value] + ) + return { + "type": payload.type, + "value": decrypted_patches, + "encoding_options": [], + } + except ValidationError: + pass # Not a json_patch payload, fall through to full decryption + + # Standard full encryption (base64 string value) encrypted_bytes = base64.b64decode(payload_data["value"]) - decrypted_bytes = await self.decode_payload_content(encrypted_bytes, encoding_options) + decrypted_bytes = await self.decode_payload_content( + encrypted_bytes, encoding_options + ) decrypted_value = json.loads(decrypted_bytes) return { @@ -349,6 +376,34 @@ async def decode_event_payload( "encoding_options": [], } + _ENCRYPTED_PATCH_TYPE = "__encrypted__" + + def _decrypt_json_patch_selective( + self, patches: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Decrypt patches with EncryptedPatchValue wrapper: {type: "__encrypted__", value: "base64..."}.""" + decrypted = [] + for patch in patches: + patch_value = patch.get("value") + + # EncryptedPatchValue format: {"type": "__encrypted__", "value": "base64-encrypted-data"} + if ( + isinstance(patch_value, dict) + and patch_value.get("type") == self._ENCRYPTED_PATCH_TYPE + ): + encrypted_b64 = patch_value.get("value", "") + encrypted_data = base64.b64decode(encrypted_b64) + decrypted_bytes = self._decrypt(encrypted_data) + decrypted.append( + { + **patch, + "value": json.loads(decrypted_bytes), + } + ) + else: + decrypted.append(patch) + return decrypted + async def encode_network_input( self, data: Optional[Dict[str, Any]], context: WorkflowContext ) -> NetworkEncodedInput: