-
Notifications
You must be signed in to change notification settings - Fork 1.7k
chore: update fun_asr_nano batch inference model logic #2858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,9 @@ def __init__( | |
| else -1 | ||
| ) | ||
| audio_encoder = ( | ||
| model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder | ||
| model.model.model.encoder | ||
| if hasattr(model.model, "model") | ||
| else model.model.encoder | ||
| ) | ||
| else: | ||
| encoder_class = tables.encoder_classes.get(audio_encoder) | ||
|
|
@@ -135,7 +137,9 @@ def __init__( | |
| if init_param_path is not None: | ||
| src_state = torch.load(init_param_path, map_location="cpu") | ||
| flag = self.ctc_decoder.load_state_dict(src_state, strict=False) | ||
| logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}") | ||
| logging.info( | ||
| f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}" | ||
| ) | ||
| freeze = ctc_decoder_conf.get("freeze", False) | ||
| if freeze: | ||
| for _, param in self.ctc_decoder.named_parameters(): | ||
|
|
@@ -189,7 +193,9 @@ def forward( | |
| encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) | ||
|
|
||
| # audio_adaptor | ||
| encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) | ||
| encoder_out, encoder_out_lens = self.audio_adaptor( | ||
| encoder_out, encoder_out_lens | ||
| ) | ||
|
|
||
| batch_size, token_num, dims = inputs_embeds.shape | ||
| fake_token_len = kwargs.get("fake_token_len") | ||
|
|
@@ -228,7 +234,9 @@ def forward( | |
| stats["batch_size_speech"] = batch_size_speech | ||
| stats["batch_size_x_frames"] = frames * batch_size_speech | ||
| stats["batch_size_real_frames"] = speech_lengths.sum().item() | ||
| stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] | ||
| stats["padding_frames"] = ( | ||
| stats["batch_size_x_frames"] - stats["batch_size_real_frames"] | ||
| ) | ||
|
|
||
| device_type = next(self.parameters()).device.type | ||
| with torch.autocast( | ||
|
|
@@ -247,15 +255,19 @@ def forward( | |
|
|
||
| with torch.no_grad(): | ||
| preds = torch.argmax(model_outputs.logits, -1) | ||
| acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) | ||
| acc_att = compute_accuracy( | ||
| preds[:, :-1], labels_ids[:, 1:], ignore_label=-100 | ||
| ) | ||
| stats["acc"] = acc_att | ||
|
|
||
| stats["loss"] = torch.clone(loss.detach()) | ||
| stats["batch_size"] = batch_size | ||
|
|
||
| stats["batch_size_x_tokens"] = token_num * batch_size | ||
| stats["batch_size_real_tokens"] = attention_mask.sum().item() | ||
| stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] | ||
| stats["padding_tokens"] = ( | ||
| stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] | ||
| ) | ||
|
|
||
| dialog_turns = (fbank_beg > 0).sum(-1) | ||
| dialog_turns_max = torch.max(dialog_turns).int().item() | ||
|
|
@@ -305,7 +317,9 @@ def data_template(self, data): | |
|
|
||
| return contents | ||
|
|
||
| def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): | ||
| def data_load_speech( | ||
| self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs | ||
| ): | ||
| system = contents["system"] | ||
| user = contents["user"] | ||
| assistant = contents["assistant"] | ||
|
|
@@ -326,7 +340,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** | |
| [], | ||
| ) | ||
| input_source_ids = [] | ||
| for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): | ||
| for i, (system_prompt, user_prompt, target_out) in enumerate( | ||
| zip(system, user, assistant) | ||
| ): | ||
| if i >= kwargs.get("multiturn_num_max", 5): | ||
| break | ||
| if len(input_ids) > kwargs.get("max_token_length", 1500): | ||
|
|
@@ -341,16 +357,12 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** | |
| else: | ||
| source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
| if not sys_prompt: | ||
| source_input = ( | ||
| f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
| ) | ||
| source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
| else: | ||
| if kwargs.get("infer_with_assistant_input", False): | ||
| source_input = f"<|im_start|>user\n{user_prompt}" | ||
| else: | ||
| source_input = ( | ||
| f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
| ) | ||
| source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
| if not do_think: | ||
| source_input += "<think>\n\n</think>\n\n" | ||
| if kwargs.get("prev_text", None) is not None: | ||
|
|
@@ -383,7 +395,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** | |
| time2 = time.perf_counter() | ||
| meta_data["load_data"] = f"{time2 - time1:0.3f}" | ||
| except Exception as e: | ||
| logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") | ||
| logging.error( | ||
| f"Loading wav failed! {str(e)}, {traceback.format_exc()}" | ||
| ) | ||
|
|
||
| speech, speech_lengths = extract_fbank( | ||
| data_src, | ||
|
|
@@ -425,7 +439,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** | |
| fbank.append(speech[0, :, :]) | ||
| fbank_lens.append(speech_lengths) | ||
|
|
||
| input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] | ||
| input_ids = torch.tensor( | ||
| input_ids, dtype=torch.int64 | ||
| ) # [: self.max_token_length] | ||
| attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) | ||
| labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] | ||
|
|
||
|
|
@@ -436,7 +452,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** | |
| target_ids = torch.tensor(target_ids, dtype=torch.int64) | ||
|
|
||
| if len(fbank) > 0: | ||
| speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) | ||
| speech = torch.nn.utils.rnn.pad_sequence( | ||
| fbank, batch_first=True, padding_value=0.0 | ||
| ) | ||
| speech_lengths = torch.nn.utils.rnn.pad_sequence( | ||
| fbank_lens, batch_first=True, padding_value=-1 | ||
| ) | ||
|
|
@@ -469,11 +487,10 @@ def inference_prepare( | |
| ): | ||
| meta_data = {} | ||
|
|
||
| if kwargs.get("batch_size", 1) > 1: | ||
| raise NotImplementedError("batch decoding is not implemented") | ||
|
|
||
| contents = self.data_template(data_in[0]) | ||
| output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) | ||
| output = self.data_load_speech( | ||
| contents, tokenizer, frontend, meta_data=meta_data, **kwargs | ||
| ) | ||
| batch = to_device(output, kwargs["device"]) | ||
|
|
||
| # audio encoder | ||
|
|
@@ -494,7 +511,9 @@ def inference_prepare( | |
| encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) | ||
|
|
||
| # audio_adaptor | ||
| adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) | ||
| adaptor_out, adaptor_out_lens = self.audio_adaptor( | ||
| encoder_out, encoder_out_lens | ||
| ) | ||
| meta_data["encoder_out"] = encoder_out | ||
| meta_data["encoder_out_lens"] = encoder_out_lens | ||
| meta_data["audio_adaptor_out"] = adaptor_out | ||
|
|
@@ -566,7 +585,10 @@ def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]): | |
| if isinstance(data, str): | ||
| return [ | ||
| {"role": "system", "content": "You are a helpful assistant."}, | ||
| {"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"}, | ||
| { | ||
| "role": "user", | ||
| "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>", | ||
| }, | ||
| {"role": "assistant", "content": "null"}, | ||
| ] | ||
| elif isinstance(data, torch.Tensor): | ||
|
|
@@ -590,15 +612,45 @@ def inference( | |
| **kwargs, | ||
| ): | ||
| prompt = self.get_prompt( | ||
| kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True) | ||
| kwargs.get("hotwords", []), | ||
| kwargs.get("language", None), | ||
| kwargs.get("itn", True), | ||
| ) | ||
| data_in = [self.generate_chatml(prompt, data) for data in data_in] | ||
|
|
||
| if key is None: | ||
| key = [] | ||
| for _ in data_in: | ||
| chars = string.ascii_letters + string.digits | ||
| key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13))) | ||
| key.append( | ||
| "rand_key_" + "".join(random.choice(chars) for _ in range(13)) | ||
| ) | ||
|
|
||
| # 批量推理:LLM 自回归解码不支持跨样本 padding 批处理, | ||
| # 对每条音频独立推理后聚合结果,实现对 batch_size_s > 0 的支持。 | ||
| if len(data_in) > 1: | ||
| all_results = [] | ||
| last_meta = {} | ||
| for i, single_data in enumerate(data_in): | ||
| single_key = [key[i]] if i < len(key) else None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a References
|
||
| try: | ||
| res, meta = self.inference_llm( | ||
| [single_data], | ||
| data_lengths=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The References
|
||
| key=single_key, | ||
| tokenizer=tokenizer, | ||
| frontend=frontend, | ||
| **kwargs, | ||
| ) | ||
| all_results.extend(res) | ||
| last_meta = meta | ||
| except Exception as e: | ||
| logging.error( | ||
| f"batch item {i} inference failed: {str(e)}, {traceback.format_exc()}" | ||
| ) | ||
| if single_key: | ||
| all_results.append({"key": single_key[0], "text": ""}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dummy result dictionary used in the exception handler is missing several keys ( References
|
||
| return all_results, last_meta | ||
|
Comment on lines
+633
to
+653
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| return self.inference_llm( | ||
| data_in, | ||
|
|
@@ -626,7 +678,9 @@ def inference_llm( | |
| if self.ctc_decoder is not None: | ||
| encoder_out = meta_data["encoder_out"] | ||
| encoder_out_lens = meta_data["encoder_out_lens"] | ||
| decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens) | ||
| decoder_out, decoder_out_lens = self.ctc_decoder( | ||
| encoder_out, encoder_out_lens | ||
| ) | ||
| ctc_logits = self.ctc.log_softmax(decoder_out) | ||
|
|
||
| b, n, d = encoder_out.size() | ||
|
|
@@ -665,7 +719,8 @@ def inference_llm( | |
| inputs_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| max_new_tokens=kwargs.get("max_length", 512), | ||
| pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, | ||
| pad_token_id=self.llm.config.pad_token_id | ||
| or self.llm.config.eos_token_id, | ||
| **llm_kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -683,7 +738,8 @@ def inference_llm( | |
| inputs_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| labels=labels_ids, | ||
| pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, | ||
| pad_token_id=self.llm.config.pad_token_id | ||
| or self.llm.config.eos_token_id, | ||
| **llm_kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -722,8 +778,12 @@ def inference_llm( | |
| result["ctc_timestamps"] = forced_align( | ||
| ctc_result["ctc_logits"], target_ids, self.blank_id | ||
| ) | ||
| target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) | ||
| result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) | ||
| target_ids = torch.tensor( | ||
| self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64 | ||
| ) | ||
| result["timestamps"] = forced_align( | ||
| ctc_result["ctc_logits"], target_ids, self.blank_id | ||
| ) | ||
| for timestamps in [result["timestamps"], result["ctc_timestamps"]]: | ||
| for timestamp in timestamps: | ||
| timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]]) | ||
|
|
@@ -741,6 +801,8 @@ def inference_llm( | |
| def from_pretrained(model: str = None, **kwargs): | ||
| from funasr import AutoModel | ||
|
|
||
| model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) | ||
| model, kwargs = AutoModel.build_model( | ||
| model=model, trust_remote_code=True, **kwargs | ||
| ) | ||
|
|
||
| return model, kwargs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of the
NotImplementedErrorforbatch_size > 1is problematic becauseinference_prepare(and consequentlyinference_llm) still only processes the first element ofdata_in(data_in[0]). While theinferencemethod now handles multiple inputs via a loop, callinginference_llmdirectly with a list of multiple items will result in only the first item being processed without any warning or error. It is recommended to keep a check on the length ofdata_inwithininference_llmorinference_prepareto prevent silent data loss.