diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f432928a..319bfbc7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 0e055265..3134ed93 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa86..dfe300dd 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 8f95c855..7b3334c7 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ca2d239a..f722e04e 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524..0aa533b4 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..c304ee42 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -270,13 +270,15 @@ def __init__( @classmethod def load_text_encoder(cls, config: HyperParameters): - torch_dtype = getattr(torch, str(config.weights_dtype), torch.float32) + text_encoder_dtype = getattr(config, "text_encoder_dtype", "float32") + torch_dtype = getattr(torch, str(text_encoder_dtype), torch.float32) text_encoder = UMT5EncoderModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, ) - text_encoder = torch.compile(text_encoder) + if getattr(config, "compile_text_encoder", True): + text_encoder = torch.compile(text_encoder) return text_encoder @classmethod