chore: update fun_asr_nano batch inference model logic#2858
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces batch inference support for the fun_asr_nano model by implementing a sequential processing loop within the inference method and removes the previous restriction on batch sizes greater than one. The changes also include extensive code reformatting for better readability. Feedback highlights several logic issues in the new batch implementation: the error handler's dummy result is missing required keys, metadata is lost for all but the last item in a batch, and there are potential TypeError risks if input keys are mismatched. Additionally, the implementation currently ignores user-provided data lengths and may lead to silent data loss if internal methods are called directly with multiple items.
| f"batch item {i} inference failed: {str(e)}, {traceback.format_exc()}" | ||
| ) | ||
| if single_key: | ||
| all_results.append({"key": single_key[0], "text": ""}) |
There was a problem hiding this comment.
The dummy result dictionary used in the exception handler is missing several keys (text_tn, label, ctc_text, timestamps, etc.) that are present in a standard result. Downstream code expecting these keys will encounter a KeyError. Ensure the error result structure is consistent with the successful one.
References
- Ensure that invalid inputs or states are safely handled in all cases. (link)
| if kwargs.get("batch_size", 1) > 1: | ||
| raise NotImplementedError("batch decoding is not implemented") | ||
|
|
||
| contents = self.data_template(data_in[0]) |
There was a problem hiding this comment.
The removal of the NotImplementedError for batch_size > 1 is problematic because inference_prepare (and consequently inference_llm) still only processes the first element of data_in (data_in[0]). While the inference method now handles multiple inputs via a loop, calling inference_llm directly with a list of multiple items will result in only the first item being processed without any warning or error. It is recommended to keep a check on the length of data_in within inference_llm or inference_prepare to prevent silent data loss.
| last_meta = {} | ||
| for i, single_data in enumerate(data_in): | ||
| single_key = [key[i]] if i < len(key) else None | ||
| try: | ||
| res, meta = self.inference_llm( | ||
| [single_data], | ||
| data_lengths=None, | ||
| key=single_key, | ||
| tokenizer=tokenizer, | ||
| frontend=frontend, | ||
| **kwargs, | ||
| ) | ||
| all_results.extend(res) | ||
| last_meta = meta | ||
| except Exception as e: | ||
| logging.error( | ||
| f"batch item {i} inference failed: {str(e)}, {traceback.format_exc()}" | ||
| ) | ||
| if single_key: | ||
| all_results.append({"key": single_key[0], "text": ""}) | ||
| return all_results, last_meta |
There was a problem hiding this comment.
The last_meta variable only captures the metadata from the final iteration of the loop. This leads to the loss of metadata (such as encoder outputs, internal states, or specific timing information) for all other items in the batch. Consider returning a list of metadata dictionaries or merging them to maintain consistency with the all_results list, especially if this model is used in pipelines where per-sample metadata is required.
| all_results = [] | ||
| last_meta = {} | ||
| for i, single_data in enumerate(data_in): | ||
| single_key = [key[i]] if i < len(key) else None |
There was a problem hiding this comment.
If a key list is provided by the user but is shorter than data_in, single_key will be set to None. This will cause a TypeError in inference_llm at line 687 (or line 764) when it attempts to subscript key. It is safer to validate the length of key at the beginning of the inference method or provide a fallback key in the loop.
References
- Defensive programming: ensure that invalid inputs or states are safely handled. (link)
- Ensure appropriate null/nil/None checks or other language-idiomatic guards exist before object property accesses.
| try: | ||
| res, meta = self.inference_llm( | ||
| [single_data], | ||
| data_lengths=None, |
There was a problem hiding this comment.
The data_lengths argument is hardcoded to None in the batch loop, which means any lengths provided by the user to the inference method are ignored for batch items. This may affect model components that rely on explicit length information.
References
- Verify code functionality and ensure alignment between function descriptions and implementations. (link)
No description provided.