From 665fa444d1e22c8aa5754aa97c46c6fa19481aa7 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 21 May 2026 22:23:05 +0800 Subject: [PATCH] perf: run LTX audio VAE decode in one ggml graph --- src/ltx_audio_vae.h | 454 ++++++++++++++++++--------------------- src/stable-diffusion.cpp | 15 +- 2 files changed, 216 insertions(+), 253 deletions(-) diff --git a/src/ltx_audio_vae.h b/src/ltx_audio_vae.h index aad8e0f87..a160338f4 100644 --- a/src/ltx_audio_vae.h +++ b/src/ltx_audio_vae.h @@ -2,6 +2,7 @@ #define __SD_LTX_AUDIO_VAE_H__ #include +#include #include #include #include @@ -171,90 +172,59 @@ namespace LTXV { } }; - static sd::Tensor squeeze_trailing_singleton_dims(sd::Tensor tensor) { - while (tensor.dim() > 0 && tensor.shape().back() == 1) { - tensor = tensor.squeeze(static_cast(tensor.dim() - 1)); - } - return tensor; - } - - static sd::Tensor normalize_waveform_for_host(sd::Tensor waveform) { - waveform = squeeze_trailing_singleton_dims(std::move(waveform)); - if (waveform.empty()) { - return waveform; - } - if (waveform.dim() == 1) { - return waveform.reshape({waveform.shape()[0], 1, 1}); - } - if (waveform.dim() == 2) { - return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1}); - } - if (waveform.dim() == 3) { - return waveform; + static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx, + ggml_tensor* waveform, + ggml_tensor* forward_basis, + ggml_tensor* mel_basis, + int hop_length) { + auto ctx = runner_ctx->ggml_ctx; + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + GGML_ASSERT(forward_basis != nullptr); + GGML_ASSERT(mel_basis != nullptr); + GGML_ASSERT(waveform->type == GGML_TYPE_F32); + GGML_ASSERT(forward_basis->type == GGML_TYPE_F32); + GGML_ASSERT(mel_basis->type == GGML_TYPE_F32); + GGML_ASSERT(forward_basis->ne[1] == 1); + + const int64_t time = waveform->ne[0]; + const int64_t channels = waveform->ne[1]; + const int64_t batch = waveform->ne[2]; + const int64_t filter_len = forward_basis->ne[0]; + const int64_t stft_channels = forward_basis->ne[2]; + const int64_t n_freqs = stft_channels / 2; + const int64_t n_mels = mel_basis->ne[1]; + const int64_t left_pad = std::max(0, filter_len - hop_length); + const int64_t padded_time = time + left_pad; + const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; + + GGML_ASSERT(stft_channels % 2 == 0); + GGML_ASSERT(mel_basis->ne[0] == n_freqs); + GGML_ASSERT(waveform->ne[3] == 1); + GGML_ASSERT(frame_count > 0); + + auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch); + if (left_pad > 0) { + x = ggml_pad_ext(ctx, x, static_cast(left_pad), 0, 0, 0, 0, 0, 0, 0); } - throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim())); - } - static sd::Tensor load_param_tensor_f32(ggml_tensor* tensor) { - GGML_ASSERT(tensor != nullptr); - return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml(tensor)); - } + auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1); + GGML_ASSERT(frames->ne[0] == frame_count); + GGML_ASSERT(frames->ne[1] == stft_channels); + GGML_ASSERT(frames->ne[2] == channels * batch); - static sd::Tensor compute_log_mel_spectrogram(const sd::Tensor& waveform_in, - const sd::Tensor& forward_basis, - const sd::Tensor& mel_basis, - int hop_length) { - auto waveform = normalize_waveform_for_host(waveform_in); - GGML_ASSERT(forward_basis.dim() >= 3); - GGML_ASSERT(mel_basis.dim() >= 2); - - const int64_t time = waveform.shape()[0]; - const int64_t channels = waveform.shape()[1]; - const int64_t batch = waveform.shape()[2]; - const int64_t filter_len = forward_basis.shape()[0]; - const int64_t basis_freq2 = forward_basis.shape().back(); - const int64_t n_freqs = basis_freq2 / 2; - const int64_t n_mels = mel_basis.shape()[1]; - const int64_t left_pad = std::max(0, filter_len - hop_length); - const int64_t padded_time = time + left_pad; - const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; - - sd::Tensor log_mel({n_mels, frame_count, channels, batch}); - std::vector padded(static_cast(padded_time), 0.0f); - std::vector magnitude(static_cast(n_freqs), 0.0f); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t c = 0; c < channels; ++c) { - std::fill(padded.begin(), padded.end(), 0.0f); - for (int64_t t = 0; t < time; ++t) { - padded[static_cast(t + left_pad)] = waveform.index(t, c, b); - } + auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs); + auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels); + auto magnitude = ggml_sqrt(ctx, + ggml_add(ctx, + ggml_sqr(ctx, real), + ggml_sqr(ctx, imag))); - for (int64_t frame = 0; frame < frame_count; ++frame) { - const int64_t frame_offset = frame * hop_length; - for (int64_t f = 0; f < n_freqs; ++f) { - double real = 0.0; - double imag = 0.0; - for (int64_t k = 0; k < filter_len; ++k) { - const float sample = padded[static_cast(frame_offset + k)]; - real += static_cast(sample) * static_cast(forward_basis.index(k, 0, f)); - imag += static_cast(sample) * static_cast(forward_basis.index(k, 0, f + n_freqs)); - } - magnitude[static_cast(f)] = static_cast(std::sqrt(real * real + imag * imag)); - } - - for (int64_t m = 0; m < n_mels; ++m) { - double mel_value = 0.0; - for (int64_t f = 0; f < n_freqs; ++f) { - mel_value += static_cast(mel_basis.index(f, m)) * static_cast(magnitude[static_cast(f)]); - } - log_mel.index(m, frame, c, b) = static_cast(std::log(std::max(mel_value, 1e-5))); - } - } - } - } + magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3)); + auto mel = ggml_mul_mat(ctx, mel_basis, magnitude); + mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits::max())); - return log_mel; + return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch); } static std::vector build_hann_resample_filter(int ratio) { @@ -276,75 +246,6 @@ namespace LTXV { return filter; } - static sd::Tensor upsample_waveform_hann(const sd::Tensor& waveform_in, int ratio) { - auto waveform = normalize_waveform_for_host(waveform_in); - if (ratio <= 1) { - return waveform; - } - - const int lowpass_filter_width = 6; - const double rolloff = 0.99; - const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); - const int kernel_size = 2 * width * ratio + 1; - const int pad = width; - const int pad_left = 2 * width * ratio; - const int pad_right = kernel_size - ratio; - const int64_t time = waveform.shape()[0]; - const int64_t channels = waveform.shape()[1]; - const int64_t batch = waveform.shape()[2]; - const int64_t padded_time = time + 2 * pad; - const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size; - const int64_t cropped_time = conv_out_time - pad_left - pad_right; - auto filter = build_hann_resample_filter(ratio); - - sd::Tensor output({cropped_time, channels, batch}); - std::vector padded(static_cast(padded_time), 0.0f); - std::vector conv_out(static_cast(conv_out_time), 0.0f); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t c = 0; c < channels; ++c) { - std::fill(padded.begin(), padded.end(), 0.0f); - const float first = waveform.index(0, c, b); - const float last = waveform.index(time - 1, c, b); - for (int i = 0; i < pad; ++i) { - padded[static_cast(i)] = first; - padded[static_cast(pad + time + i)] = last; - } - for (int64_t t = 0; t < time; ++t) { - padded[static_cast(pad + t)] = waveform.index(t, c, b); - } - - std::fill(conv_out.begin(), conv_out.end(), 0.0f); - for (int64_t t = 0; t < padded_time; ++t) { - const double sample = static_cast(padded[static_cast(t)]) * ratio; - const int64_t out_base = t * ratio; - for (int k = 0; k < kernel_size; ++k) { - conv_out[static_cast(out_base + k)] += static_cast(sample * filter[static_cast(k)]); - } - } - - for (int64_t t = 0; t < cropped_time; ++t) { - output.index(t, c, b) = conv_out[static_cast(t + pad_left)]; - } - } - } - - return output; - } - - static sd::Tensor crop_waveform_samples(const sd::Tensor& waveform_in, int64_t target_samples) { - auto waveform = normalize_waveform_for_host(waveform_in); - if (waveform.shape()[0] == target_samples) { - return waveform; - } - if (waveform.shape()[0] > target_samples) { - return sd::ops::slice(waveform, 0, 0, target_samples); - } - sd::Tensor output({target_samples, waveform.shape()[1], waveform.shape()[2]}); - sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform); - return output; - } - static ggml_type audio_conv_weight_type(ggml_type type) { return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type; } @@ -413,22 +314,101 @@ namespace LTXV { return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1); } + static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) { + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(filter != nullptr); + GGML_ASSERT(filter->ne[1] == 1); + GGML_ASSERT(filter->ne[2] == 1); + GGML_ASSERT(filter->ne[3] == 1); + + ggml_tensor* reversed = nullptr; + for (int64_t k = filter->ne[0] - 1; k >= 0; --k) { + auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1); + reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0); + } + return reversed; + } + static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx, ggml_tensor* x, ggml_tensor* filter, int stride) { GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1); GGML_ASSERT(filter->ne[1] == 1); + GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1); + + const int64_t time = x->ne[0]; + const int64_t channels = x->ne[1]; + const int64_t kernel_size = filter->ne[0]; + const int64_t out_time = (time - 1) * stride + kernel_size; + + auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels); + if (stride > 1) { + auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f); + auto zero_tail = zero_unit; + for (int i = 1; i < stride - 1; ++i) { + zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0); + } + x_flat = ggml_concat(ctx, x_flat, zero_tail, 0); + } + x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels); + + auto reversed_filter = reverse_1d_filter(ctx, filter); + auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast(kernel_size - 1), 1); + if (out->ne[0] > out_time) { + out = ggml_ext_slice(ctx, out, 0, 0, out_time); + } + GGML_ASSERT(out->ne[0] == out_time); + GGML_ASSERT(out->ne[1] == 1); + GGML_ASSERT(out->ne[2] == channels); + + out = ggml_ext_scale(ctx, out, static_cast(stride)); + return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1); + } + + static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx, + ggml_tensor* waveform, + ggml_tensor* filter, + int ratio) { + auto ctx = runner_ctx->ggml_ctx; + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + GGML_ASSERT(filter != nullptr); + GGML_ASSERT(waveform->ne[3] == 1); + if (ratio <= 1) { + return waveform; + } + + const int lowpass_filter_width = 6; + const double rolloff = 0.99; + const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); + const int kernel_size = 2 * width * ratio + 1; + const int pad = width; + const int pad_left = 2 * width * ratio; + const int pad_right = kernel_size - ratio; + const int64_t time = waveform->ne[0]; + const int64_t channels = waveform->ne[1]; + const int64_t batch = waveform->ne[2]; + + GGML_ASSERT(filter->ne[0] == kernel_size); + + auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1); + x = replicate_pad_1d(runner_ctx, x, pad, pad); + x = depthwise_conv_transpose1d(ctx, x, filter, ratio); + x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right); + return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch); + } - ggml_tensor* out = nullptr; - for (int64_t c = 0; c < x->ne[1]; ++c) { - auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1); - auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1); - yi = ggml_ext_scale(ctx, yi, static_cast(stride)); - yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1); - out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1); + static ggml_tensor* crop_waveform_samples(ggml_context* ctx, + ggml_tensor* waveform, + int64_t target_samples) { + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(waveform != nullptr); + if (waveform->ne[0] == target_samples) { + return waveform; } - return out; + GGML_ASSERT(waveform->ne[0] > target_samples); + return ggml_ext_slice(ctx, waveform, 0, 0, target_samples); } struct PixelNorm2D : public UnaryBlock { @@ -950,41 +930,66 @@ namespace LTXV { } } - ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx, - ggml_tensor* latent, - int target_time, - int target_freq) { - auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; - auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; - auto decoder = std::dynamic_pointer_cast(blocks["audio_vae.decoder"]); - return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq); - } + ggml_tensor* decode(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* bwe_skip_filter) { + int target_time = static_cast(latent->ne[1]) * config.latent_downsample_factor() - + (config.latent_downsample_factor() - 1); + int target_freq = config.mel_bins; - ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) { - auto vocoder = std::dynamic_pointer_cast(blocks["vocoder.vocoder"]); - return vocoder->forward(ctx, mel); - } + auto decoder = std::dynamic_pointer_cast(blocks["audio_vae.decoder"]); + auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; + auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; + auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq); + auto vocoder = std::dynamic_pointer_cast(blocks["vocoder.vocoder"]); + auto waveform = vocoder->forward(ctx, mel); - ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) { - GGML_ASSERT(config.has_bwe); - auto bwe_generator = std::dynamic_pointer_cast(blocks["vocoder.bwe_generator"]); - return bwe_generator->forward(ctx, mel); - } + if (config.has_bwe) { + GGML_ASSERT(bwe_skip_filter != nullptr); + const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate; + const int64_t low_time = waveform->ne[0]; + const int64_t out_time = low_time * bwe_ratio; + int64_t remainder = low_time % config.bwe_hop_length; + auto bwe_waveform = waveform; + if (remainder != 0) { + bwe_waveform = ggml_pad_ext(ctx->ggml_ctx, + bwe_waveform, + 0, + static_cast(config.bwe_hop_length - remainder), + 0, + 0, + 0, + 0, + 0, + 0); + } - ggml_tensor* mel_basis_tensor() const { - auto iter = params.find("vocoder.mel_stft.mel_basis"); - return iter == params.end() ? nullptr : iter->second; - } + auto mel_basis = params["vocoder.mel_stft.mel_basis"]; + auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"]; + GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr); + auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length); + auto bwe_generator = std::dynamic_pointer_cast(blocks["vocoder.bwe_generator"]); + auto residual = bwe_generator->forward(ctx, bwe_mel); + + auto skip = upsample_waveform_hann(ctx, + bwe_waveform, + bwe_skip_filter, + bwe_ratio); + waveform = ggml_clamp(ctx->ggml_ctx, + ggml_add(ctx->ggml_ctx, residual, skip), + -1.0f, + 1.0f); + waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time); + } - ggml_tensor* stft_forward_basis_tensor() const { - auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis"); - return iter == params.end() ? nullptr : iter->second; + return waveform; } }; struct LTXAudioVAERunner : public GGMLRunner { LTXAudioVAEConfig config; LTXAudioVAE model; + sd::Tensor bwe_skip_filter_tensor; LTXAudioVAERunner(ggml_backend_t backend, ggml_backend_t params_backend, @@ -994,6 +999,10 @@ namespace LTXV { config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)), model(config) { model.init(params_ctx, tensor_storage_map, prefix); + if (config.has_bwe) { + const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate; + bwe_skip_filter_tensor = sd::Tensor::from_vector(build_hann_resample_filter(bwe_ratio)); + } } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1008,77 +1017,22 @@ namespace LTXV { return "ltx_audio_vae"; } - ggml_cgraph* build_base_graph(const sd::Tensor& latent_tensor) { - auto latent = make_input(latent_tensor); - int target_time = static_cast(latent_tensor.shape()[1]) * config.latent_downsample_factor() - - (config.latent_downsample_factor() - 1); - int target_freq = config.mel_bins; - - ggml_cgraph* gf = new_graph_custom(655360); - auto runner_ctx = GGMLRunner::get_context(); - auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq); - auto waveform = model.run_vocoder(&runner_ctx, mel); - ggml_build_forward_expand(gf, waveform); - return gf; - } - - ggml_cgraph* build_bwe_graph(const sd::Tensor& mel_tensor) { - auto mel = make_input(mel_tensor); - ggml_cgraph* gf = new_graph_custom(655360); - auto runner_ctx = GGMLRunner::get_context(); - auto residual = model.run_bwe_generator(&runner_ctx, mel); - ggml_build_forward_expand(gf, residual); - return gf; - } - - sd::Tensor compute_base_waveform(int n_threads, - const sd::Tensor& latent_tensor) { - auto get_graph = [&]() -> ggml_cgraph* { - return build_base_graph(latent_tensor); - }; - return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); - } - - sd::Tensor compute_bwe_residual(int n_threads, - const sd::Tensor& mel_tensor) { - auto get_graph = [&]() -> ggml_cgraph* { - return build_bwe_graph(mel_tensor); - }; - return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); - } - sd::Tensor decode(int n_threads, const sd::Tensor& latent_tensor) { - auto waveform = compute_base_waveform(n_threads, latent_tensor); - if (!config.has_bwe || waveform.empty()) { - return waveform; - } - - auto waveform_host = normalize_waveform_for_host(waveform); - const int64_t low_time = waveform_host.shape()[0]; - const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate; - int64_t remainder = low_time % config.bwe_hop_length; - if (remainder != 0) { - sd::Tensor padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]}); - sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host); - waveform_host = std::move(padded); - } - - auto mel_basis_tensor = model.mel_basis_tensor(); - auto stft_basis_tensor = model.stft_forward_basis_tensor(); - GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr); - auto mel_basis = load_param_tensor_f32(mel_basis_tensor); - auto forward_basis = load_param_tensor_f32(stft_basis_tensor); - auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length); - auto residual_raw = compute_bwe_residual(n_threads, bwe_mel); - if (residual_raw.empty()) { - return waveform; - } - auto residual = normalize_waveform_for_host(residual_raw); - auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate); - auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f); - auto cropped = crop_waveform_samples(combined, out_time); - return restore_trailing_singleton_dims(cropped, 4); + int64_t t0 = ggml_time_ms(); + auto get_graph = [&]() -> ggml_cgraph* { + auto latent = make_input(latent_tensor); + ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr; + ggml_cgraph* gf = new_graph_custom(655360); + auto runner_ctx = GGMLRunner::get_context(); + auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter); + ggml_build_forward_expand(gf, waveform); + return gf; + }; + auto result = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); + int64_t t1 = ggml_time_ms(); + LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + return result; } void test(const std::string& input_path) { diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 8da6b489a..8d6806228 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -5218,14 +5218,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->diffusion_model->free_params_buffer(); } + int64_t latent_end = ggml_time_ms(); + LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); + sd_audio_t* generated_audio = nullptr; if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0 && sd_ctx->sd->audio_vae_model != nullptr) { + int64_t audio_latent_decode_start = ggml_time_ms(); + auto audio_latent = unpack_ltxav_audio_latent(final_latent, latents.audio_length, sd_ctx->sd->get_latent_channel()); if (!audio_latent.empty()) { + LOG_DEBUG("decode audio latent %dx%dx%dx%d", + (int)audio_latent.shape()[0], + (int)audio_latent.shape()[1], + (int)audio_latent.shape()[2], + (int)audio_latent.shape()[3]); auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent); if (!waveform.empty()) { generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform); @@ -5233,6 +5243,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_WARN("LTX audio latent decode failed; continuing with silent video output"); } } + int64_t audio_latent_decode_end = ggml_time_ms(); + LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000); } if (latents.video_conditioning_frame_count > 0) { @@ -5245,9 +5257,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); } - int64_t latent_end = ggml_time_ms(); - LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); - auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out); if (result == nullptr) { free_sd_audio(generated_audio);