Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions temporalio/contrib/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Pydantic v1 is not supported.
"""

import dataclasses
from dataclasses import dataclass
from typing import Any

Expand All @@ -33,11 +34,63 @@
# implements __get_pydantic_core_schema__ so that pydantic unwraps proxied types.


def _sanitize_for_json(obj: Any) -> Any:
"""Sanitize a value tree so pydantic_core's Rust serializer can encode it.

Handles two cases that crash ``pydantic_core.to_json``:
* **str** with Unicode surrogates (U+D800-U+DFFF)
* **bytes** with non-UTF-8 content
"""
if isinstance(obj, str):
return obj.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace").encode("utf-8")
if isinstance(obj, dict):
new_dict: dict[Any, Any] = {}
changed = False
for k, v in obj.items():
new_k = _sanitize_for_json(k)
new_v = _sanitize_for_json(v)
new_dict[new_k] = new_v
if new_k is not k or new_v is not v:
changed = True
return new_dict if changed else obj
if isinstance(obj, (list, tuple)):
new_items = [_sanitize_for_json(item) for item in obj]
changed = any(new is not old for new, old in zip(new_items, obj))
if not changed:
return obj
return type(obj)(new_items)
if hasattr(obj, "model_fields"):
updates: dict[str, Any] = {}
for field_name in obj.model_fields:
val = getattr(obj, field_name)
sanitized = _sanitize_for_json(val)
if sanitized is not val:
updates[field_name] = sanitized
return obj.model_copy(update=updates) if updates else obj
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
updates = {}
for f in dataclasses.fields(obj):
val = getattr(obj, f.name)
sanitized = _sanitize_for_json(val)
if sanitized is not val:
updates[f.name] = sanitized
return dataclasses.replace(obj, **updates) if updates else obj
return obj


@dataclass
class ToJsonOptions:
"""Options for converting to JSON with pydantic."""

exclude_unset: bool = False
lossy_utf8: bool = False
"""If ``True``, sanitize values that would crash pydantic_core's Rust
serializer (strings with Unicode surrogates, bytes with non-UTF-8 content)
instead of raising. Surrogates are replaced with U+FFFD and non-UTF-8 bytes
are decoded with ``errors='replace'``. This is lossy but prevents
serialization failures when payloads contain arbitrary binary data."""


class PydanticJSONPlainPayloadConverter(EncodingPayloadConverter):
Expand Down Expand Up @@ -71,13 +124,29 @@ def to_payload(self, value: Any) -> temporalio.api.common.v1.Payload | None:
See
https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.to_json.
"""
data = (
self._schema_serializer.to_json(
value, exclude_unset=self._to_json_options.exclude_unset
try:
data = (
self._schema_serializer.to_json(
value, exclude_unset=self._to_json_options.exclude_unset
)
if self._to_json_options
else to_json(value)
)
except Exception:
if not (self._to_json_options and self._to_json_options.lossy_utf8):
raise
# pydantic_core's Rust serializer cannot encode strings with
# Unicode surrogates or bytes with non-UTF-8 content.
# Sanitize the value tree, then retry the same serializer path.
sanitized = _sanitize_for_json(value)
data = (
self._schema_serializer.to_json(
sanitized,
exclude_unset=self._to_json_options.exclude_unset,
)
if self._to_json_options
else to_json(sanitized)
)
if self._to_json_options
else to_json(value)
)
return temporalio.api.common.v1.Payload(
metadata={"encoding": self.encoding.encode()}, data=data
)
Expand Down
114 changes: 113 additions & 1 deletion tests/contrib/pydantic/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import datetime
import json
import os
import pathlib
import uuid
Expand All @@ -9,7 +10,12 @@
from pydantic import BaseModel

from temporalio.client import Client
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.contrib.pydantic import (
PydanticJSONPlainPayloadConverter,
ToJsonOptions,
_sanitize_for_json,
pydantic_data_converter,
)
from temporalio.worker import Worker
from temporalio.worker.workflow_sandbox._restrictions import (
RestrictionContext,
Expand Down Expand Up @@ -380,3 +386,109 @@ def test_model_instantiation_from_restricted_proxy_values():
assert p.path_field == restricted_path
assert p.uuid_field == restricted_uuid
assert p.datetime_field == restricted_datetime


# --- Surrogate / non-UTF-8 sanitization tests ---


def test_sanitize_for_json_surrogate_pair():
# A surrogate pair encodes U+1F600 (grinning face); sanitization
# should decode it to the proper codepoint losslessly.
result = _sanitize_for_json("\ud83d\ude00")
assert result == "\U0001f600"


def test_sanitize_for_json_lone_surrogate():
result = _sanitize_for_json("\ud800")
assert result == "\ufffd"


def test_sanitize_for_json_invalid_bytes():
result = _sanitize_for_json(b"\x89PNG")
# \x89 is not valid UTF-8; it becomes U+FFFD (\xef\xbf\xbd in UTF-8)
assert result == b"\xef\xbf\xbdPNG"


def test_to_payload_raises_without_lossy_utf8():
converter = PydanticJSONPlainPayloadConverter()
with pytest.raises(Exception):
converter.to_payload({"text": "hello \ud800 world"})


def test_to_payload_with_surrogate_string():
converter = PydanticJSONPlainPayloadConverter(ToJsonOptions(lossy_utf8=True))
payload = converter.to_payload({"text": "hello \ud800 world"})
assert payload is not None
# The result must be valid JSON (no surrogates).
parsed = json.loads(payload.data)
assert parsed["text"] == "hello \ufffd world"


def test_to_payload_with_invalid_bytes():
class BytesModel(BaseModel):
data: bytes

converter = PydanticJSONPlainPayloadConverter(ToJsonOptions(lossy_utf8=True))
payload = converter.to_payload(BytesModel(data=b"\x89PNG\r\n"))
assert payload is not None


def test_to_payload_with_exclude_unset():
class UnsetModel(BaseModel):
text: str
count: int = 0

converter = PydanticJSONPlainPayloadConverter(
ToJsonOptions(exclude_unset=True, lossy_utf8=True)
)
# Only set the text field (with a surrogate), leave count at default.
model = UnsetModel(text="hello \ud800 world")
payload = converter.to_payload(model)
assert payload is not None
parsed = json.loads(payload.data)
# Surrogate is sanitized
assert parsed["text"] == "hello \ufffd world"
# Unset field is excluded
assert "count" not in parsed


def test_sanitize_for_json_pydantic_model():
class MixedModel(BaseModel):
text: str
data: bytes
count: int

model = MixedModel(text="hello \ud800", data=b"\x89PNG", count=42)
result = _sanitize_for_json(model)

assert isinstance(result, MixedModel)
assert result.text == "hello \ufffd"
assert result.data == b"\xef\xbf\xbdPNG"
assert result.count == 42


def test_sanitize_for_json_dataclass():
@dataclasses.dataclass
class MixedDC:
text: str
data: bytes
count: int

dc = MixedDC(text="hello \ud800", data=b"\x89PNG", count=42)
result = _sanitize_for_json(dc)

assert isinstance(result, MixedDC)
assert result.text == "hello \ufffd"
assert result.data == b"\xef\xbf\xbdPNG"
assert result.count == 42


def test_sanitize_for_json_nested_structures():
class InnerModel(BaseModel):
value: str

data = {"items": [InnerModel(value="abc \ud800 def")]}
result = _sanitize_for_json(data)

assert isinstance(result["items"][0], InnerModel)
assert result["items"][0].value == "abc \ufffd def"
Loading