Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed docs/images/.DS_Store
Binary file not shown.
124 changes: 93 additions & 31 deletions funasr/models/fun_asr_nano/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __init__(
else -1
)
audio_encoder = (
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
model.model.model.encoder
if hasattr(model.model, "model")
else model.model.encoder
)
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
Expand Down Expand Up @@ -135,7 +137,9 @@ def __init__(
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
logging.info(
f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}"
)
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
Expand Down Expand Up @@ -189,7 +193,9 @@ def forward(
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
encoder_out, encoder_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens
)

batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
Expand Down Expand Up @@ -228,7 +234,9 @@ def forward(
stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["padding_frames"] = (
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
)

device_type = next(self.parameters()).device.type
with torch.autocast(
Expand All @@ -247,15 +255,19 @@ def forward(

with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
acc_att = compute_accuracy(
preds[:, :-1], labels_ids[:, 1:], ignore_label=-100
)
stats["acc"] = acc_att

stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size

stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["padding_tokens"] = (
stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
)

dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item()
Expand Down Expand Up @@ -305,7 +317,9 @@ def data_template(self, data):

return contents

def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
def data_load_speech(
self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs
):
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
Expand All @@ -326,7 +340,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
Expand All @@ -341,16 +357,12 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **
else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
else:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not do_think:
source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None:
Expand Down Expand Up @@ -383,7 +395,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
logging.error(
f"Loading wav failed! {str(e)}, {traceback.format_exc()}"
)

speech, speech_lengths = extract_fbank(
data_src,
Expand Down Expand Up @@ -425,7 +439,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)

input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
input_ids = torch.tensor(
input_ids, dtype=torch.int64
) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]

Expand All @@ -436,7 +452,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **
target_ids = torch.tensor(target_ids, dtype=torch.int64)

if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
speech = torch.nn.utils.rnn.pad_sequence(
fbank, batch_first=True, padding_value=0.0
)
speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1
)
Expand Down Expand Up @@ -469,11 +487,10 @@ def inference_prepare(
):
meta_data = {}

if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")

contents = self.data_template(data_in[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of the NotImplementedError for batch_size > 1 is problematic because inference_prepare (and consequently inference_llm) still only processes the first element of data_in (data_in[0]). While the inference method now handles multiple inputs via a loop, calling inference_llm directly with a list of multiple items will result in only the first item being processed without any warning or error. It is recommended to keep a check on the length of data_in within inference_llm or inference_prepare to prevent silent data loss.

output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
output = self.data_load_speech(
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
)
batch = to_device(output, kwargs["device"])

# audio encoder
Expand All @@ -494,7 +511,9 @@ def inference_prepare(
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

# audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
adaptor_out, adaptor_out_lens = self.audio_adaptor(
encoder_out, encoder_out_lens
)
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
Expand Down Expand Up @@ -566,7 +585,10 @@ def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>",
},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
Expand All @@ -590,15 +612,45 @@ def inference(
**kwargs,
):
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
kwargs.get("hotwords", []),
kwargs.get("language", None),
kwargs.get("itn", True),
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]

if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
key.append(
"rand_key_" + "".join(random.choice(chars) for _ in range(13))
)

# 批量推理:LLM 自回归解码不支持跨样本 padding 批处理,
# 对每条音频独立推理后聚合结果,实现对 batch_size_s > 0 的支持。
if len(data_in) > 1:
all_results = []
last_meta = {}
for i, single_data in enumerate(data_in):
single_key = [key[i]] if i < len(key) else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a key list is provided by the user but is shorter than data_in, single_key will be set to None. This will cause a TypeError in inference_llm at line 687 (or line 764) when it attempts to subscript key. It is safer to validate the length of key at the beginning of the inference method or provide a fallback key in the loop.

References
  1. Defensive programming: ensure that invalid inputs or states are safely handled. (link)
  2. Ensure appropriate null/nil/None checks or other language-idiomatic guards exist before object property accesses.

try:
res, meta = self.inference_llm(
[single_data],
data_lengths=None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The data_lengths argument is hardcoded to None in the batch loop, which means any lengths provided by the user to the inference method are ignored for batch items. This may affect model components that rely on explicit length information.

References
  1. Verify code functionality and ensure alignment between function descriptions and implementations. (link)

key=single_key,
tokenizer=tokenizer,
frontend=frontend,
**kwargs,
)
all_results.extend(res)
last_meta = meta
except Exception as e:
logging.error(
f"batch item {i} inference failed: {str(e)}, {traceback.format_exc()}"
)
if single_key:
all_results.append({"key": single_key[0], "text": ""})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dummy result dictionary used in the exception handler is missing several keys (text_tn, label, ctc_text, timestamps, etc.) that are present in a standard result. Downstream code expecting these keys will encounter a KeyError. Ensure the error result structure is consistent with the successful one.

References
  1. Ensure that invalid inputs or states are safely handled in all cases. (link)

return all_results, last_meta
Comment on lines +633 to +653
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The last_meta variable only captures the metadata from the final iteration of the loop. This leads to the loss of metadata (such as encoder outputs, internal states, or specific timing information) for all other items in the batch. Consider returning a list of metadata dictionaries or merging them to maintain consistency with the all_results list, especially if this model is used in pipelines where per-sample metadata is required.


return self.inference_llm(
data_in,
Expand Down Expand Up @@ -626,7 +678,9 @@ def inference_llm(
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
decoder_out, decoder_out_lens = self.ctc_decoder(
encoder_out, encoder_out_lens
)
ctc_logits = self.ctc.log_softmax(decoder_out)

b, n, d = encoder_out.size()
Expand Down Expand Up @@ -665,7 +719,8 @@ def inference_llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_length", 512),
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id
or self.llm.config.eos_token_id,
**llm_kwargs,
)

Expand All @@ -683,7 +738,8 @@ def inference_llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id
or self.llm.config.eos_token_id,
**llm_kwargs,
)

Expand Down Expand Up @@ -722,8 +778,12 @@ def inference_llm(
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
target_ids = torch.tensor(
self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64
)
result["timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
Expand All @@ -741,6 +801,8 @@ def inference_llm(
def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel

model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
model, kwargs = AutoModel.build_model(
model=model, trust_remote_code=True, **kwargs
)

return model, kwargs