Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/mistralai/extra/workflows/encoding/payload_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

if TYPE_CHECKING:
from cryptography.exceptions import InvalidTag
Expand Down Expand Up @@ -38,6 +38,7 @@
NetworkEncodedResult,
WorkflowContext,
)
from mistralai.client.models.jsonpatchpayloadresponse import JSONPatchPayloadResponse
from mistralai.extra.exceptions import (
WorkflowPayloadEncryptionException,
WorkflowPayloadOffloadingException,
Expand Down Expand Up @@ -151,7 +152,10 @@ def _decrypt(self, data: bytes) -> bytes:
async def _handle_offloading(
self, data: bytes, context: Optional[WorkflowContext]
) -> tuple[bytes, bool]:
if self.offloading_config is None or self.offloading_config.storage_config is None:
if (
self.offloading_config is None
or self.offloading_config.storage_config is None
):
raise WorkflowPayloadOffloadingException(
"You must configure payload offloading storage"
)
Expand Down Expand Up @@ -281,7 +285,10 @@ async def encode_event_payload_content(
if self.encryption_config is None:
return data, []

if force_full_encryption or self.encryption_config.mode == PayloadEncryptionMode.FULL:
if (
force_full_encryption
or self.encryption_config.mode == PayloadEncryptionMode.FULL
):
encrypted_data = self._encrypt(data)
return encrypted_data, [EncodedPayloadOptions.ENCRYPTED]

Expand Down Expand Up @@ -339,8 +346,28 @@ async def decode_event_payload(
return payload_data

encoding_options = [EncodedPayloadOptions(opt) for opt in encoding_options_strs]

# Handle selective encryption for json_patch payloads
if EncodedPayloadOptions.PARTIALLY_ENCRYPTED in encoding_options:
try:
payload = JSONPatchPayloadResponse.model_validate(payload_data)
if isinstance(payload.value, list):
decrypted_patches = self._decrypt_json_patch_selective(
[p.model_dump() for p in payload.value]
)
return {
"type": payload.type,
"value": decrypted_patches,
"encoding_options": [],
}
except ValidationError:
pass # Not a json_patch payload, fall through to full decryption

# Standard full encryption (base64 string value)
encrypted_bytes = base64.b64decode(payload_data["value"])
decrypted_bytes = await self.decode_payload_content(encrypted_bytes, encoding_options)
decrypted_bytes = await self.decode_payload_content(
encrypted_bytes, encoding_options
)
decrypted_value = json.loads(decrypted_bytes)

return {
Expand All @@ -349,6 +376,34 @@ async def decode_event_payload(
"encoding_options": [],
}

_ENCRYPTED_PATCH_TYPE = "__encrypted__"

def _decrypt_json_patch_selective(
self, patches: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Decrypt patches with EncryptedPatchValue wrapper: {type: "__encrypted__", value: "base64..."}."""
decrypted = []
for patch in patches:
patch_value = patch.get("value")

# EncryptedPatchValue format: {"type": "__encrypted__", "value": "base64-encrypted-data"}
if (
isinstance(patch_value, dict)
and patch_value.get("type") == self._ENCRYPTED_PATCH_TYPE
):
encrypted_b64 = patch_value.get("value", "")
encrypted_data = base64.b64decode(encrypted_b64)
decrypted_bytes = self._decrypt(encrypted_data)
decrypted.append(
{
**patch,
"value": json.loads(decrypted_bytes),
}
)
else:
decrypted.append(patch)
return decrypted

async def encode_network_input(
self, data: Optional[Dict[str, Any]], context: WorkflowContext
) -> NetworkEncodedInput:
Expand Down
Loading