Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,21 +358,58 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")

# Select recipe based on training type
# Collect override_params from ALL matching recipes (standard + subscription)
recipe = None
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
recipe = next((r for r in recipes_with_template if r.get("Peft")), None)
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)

if not recipe:
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")

elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
# Start with the standard recipe's override_params
options_dict = {}
if recipe.get("SmtjOverrideParamsS3Uri"):
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
s3 = boto3.client("s3")
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
s3 = sagemaker_session.boto_session.client("s3")
uri_path = s3_uri.replace("s3://", "")
bucket, key = uri_path.split("/", 1)
obj = s3.get_object(Bucket=bucket, Key=key)
options_dict = json.loads(obj["Body"].read())

# Auto-detect and merge subscription recipe's override_params if available
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
else:
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)

if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
try:
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
sub_uri_path = sub_s3_uri.replace("s3://", "")
# Handle access point ARN URIs
if sub_uri_path.startswith("arn:"):
arn_parts = sub_uri_path.split("/", 2)
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
else:
sub_bucket, sub_key = sub_uri_path.split("/", 1)
s3_sub = sagemaker_session.boto_session.client("s3")
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
sub_options = json.loads(sub_obj["Body"].read())
# Merge: subscription params into _specs only (don't set defaults)
# This makes them settable but not serialized unless user explicitly sets them
for k, v in sub_options.items():
if k not in options_dict:
v_copy = v.copy() if isinstance(v, dict) else v
if isinstance(v_copy, dict):
v_copy['default'] = None # No default — won't appear in to_dict() unless set
options_dict[k] = v_copy
except Exception as e:
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")

if options_dict:
return FineTuningOptions(options_dict), model_arn, is_gated_model
else:
return FineTuningOptions({}), model_arn, is_gated_model
Expand Down
140 changes: 140 additions & 0 deletions sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pytest
from unittest.mock import Mock, patch, MagicMock
from sagemaker.train.common_utils.finetune_utils import (
Expand Down Expand Up @@ -304,6 +305,8 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get
mock_s3_client.get_object.return_value = {
"Body": Mock(read=Mock(return_value=b'{"learning_rate": 0.001}'))
}
mock_session.boto_session.client.return_value = mock_s3_client
mock_session.boto_session.client.return_value = mock_s3_client

result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)

Expand Down Expand Up @@ -551,3 +554,140 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client)
mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'')



@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content):
"""When and user is subscribed, datamix HPs are available."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
"Name": "standard_sft"
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
"IsSubscriptionModel": True
}
]
}
}

# Standard recipe returns base params
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
# Subscription recipe returns datamix params
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})

mock_s3.get_object.side_effect = [
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
]

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"test-model", "SFT", "FULL", mock_session,
)

assert "max_steps" in options._specs
assert "customer_data_percent" in options._specs
assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set

@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content):
"""When (default), datamix HPs are NOT available."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
"Name": "standard_sft"
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
"IsSubscriptionModel": True
}
]
}
}

standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))}

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"test-model", "SFT", "FULL", mock_session,
)

assert "max_steps" in options._specs
assert "customer_data_percent" not in options._specs

@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content):
"""When but user is NOT subscribed, falls back gracefully."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
"Name": "standard_sft"
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
"IsSubscriptionModel": True
}
]
}
}

standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
# First call succeeds (standard recipe), second call fails (access denied)
mock_s3.get_object.side_effect = [
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
Exception("Access Denied"),
]

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"test-model", "SFT", "FULL", mock_session,
)

# Should still have standard params, just not datamix ones
assert "max_steps" in options._specs
assert "customer_data_percent" not in options._specs
Loading