From d7be3125725084973e37a724d7c9c1c8c2eb5350 Mon Sep 17 00:00:00 2001 From: szdtzpj <167522563+szdtzpj@users.noreply.github.com> Date: Tue, 19 May 2026 16:06:37 +0800 Subject: [PATCH] feat(tg): add inline selection menus --- frontends/tgapp.py | 160 ++++++++++++++++-- tests/test_tgapp_inline_selection.py | 234 +++++++++++++++++++++++++++ 2 files changed, 380 insertions(+), 14 deletions(-) create mode 100644 tests/test_tgapp_inline_selection.py diff --git a/frontends/tgapp.py b/frontends/tgapp.py index fddab2f59..0716af94b 100644 --- a/frontends/tgapp.py +++ b/frontends/tgapp.py @@ -42,11 +42,19 @@ _QUEUE_WAIT_SECONDS = 1 _ASK_USER_HOOK_KEY = "telegram_ask_user_menu" _ASK_CALLBACK_PREFIX = "ask:" +_LLM_CALLBACK_PREFIX = "llm:" _ASK_CANCEL_ACTION = "none" +_ASK_MULTI_DONE_ACTION = "done" +_ASK_TOGGLE_ACTION = "toggle" _ASK_CANCEL_LABEL = "none of these above" _ASK_CANCEL_PROMPT = "已取消选择,请直接发送下一步操作。" +_ASK_MULTI_HINT = "可多选:点选项目后点击 Done 提交。" +_ASK_MULTI_EMPTY_HINT = "请至少选择一项,或选择 none of these above。" +_LLM_MENU_PROMPT = "请选择要切换的 LLM:" _ask_menu_events = Q.Queue() _ask_menu_store = {} +_llm_menu_store = {} +_MULTI_SELECT_RE = re.compile(r"\[?(?:多选|multi(?:[-_ ]?select)?|select all)\]?", re.IGNORECASE) _QUOTE_OPEN_TAG = "<_quote_>" _QUOTE_CLOSE_TAG = "" _QUOTE_TOKEN_PATTERN = re.escape(_QUOTE_OPEN_TAG) + r"([\s\S]*?)" + re.escape(_QUOTE_CLOSE_TAG) @@ -265,7 +273,11 @@ def _extract_ask_user_event(ctx): if not candidates: return None question = str(data.get("question") or "请选择下一步操作:").strip() or "请选择下一步操作:" - return {"question": question, "candidates": candidates} + return { + "question": question, + "candidates": candidates, + "multi": bool(_MULTI_SELECT_RE.search(question)), + } def _register_ask_user_hook(): if not hasattr(agent, "_turn_end_hooks"): @@ -285,25 +297,49 @@ def _drain_latest_ask_user_event(): break return latest -def _build_ask_user_markup(menu_id, candidates): - rows = [ - [InlineKeyboardButton(candidate, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{idx}")] - for idx, candidate in enumerate(candidates) - ] +def _build_ask_user_markup(menu_id, candidates, multi=False, selected_indexes=None): + selected_indexes = set(selected_indexes or []) + rows = [] + for idx, candidate in enumerate(candidates): + if multi: + label = f"✓ {candidate}" if idx in selected_indexes else candidate + action = f"{_ASK_TOGGLE_ACTION}:{idx}" + else: + label = candidate + action = str(idx) + rows.append([ + InlineKeyboardButton(label, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{action}") + ]) + if multi: + rows.append([ + InlineKeyboardButton("Done", callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{_ASK_MULTI_DONE_ACTION}") + ]) rows.append([ InlineKeyboardButton(_ASK_CANCEL_LABEL, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{_ASK_CANCEL_ACTION}") ]) return InlineKeyboardMarkup(rows) -def _parse_ask_callback_data(data): - if not (data or "").startswith(_ASK_CALLBACK_PREFIX): +def _build_llm_markup(menu_id, llms): + rows = [] + for idx, name, current in llms: + label = f"→ [{idx}] {name}" if current else f"[{idx}] {name}" + rows.append([ + InlineKeyboardButton(label, callback_data=f"{_LLM_CALLBACK_PREFIX}{menu_id}:{idx}") + ]) + return InlineKeyboardMarkup(rows) + +def _parse_menu_callback_data(data, prefix): + if not (data or "").startswith(prefix): return None, None - payload = data[len(_ASK_CALLBACK_PREFIX):] + payload = data[len(prefix):] menu_id, sep, action = payload.partition(":") if not sep or not menu_id or not action: return None, None return menu_id, action +def _parse_ask_callback_data(data): + return _parse_menu_callback_data(data, _ASK_CALLBACK_PREFIX) + def _build_text_prompt(text): return f"{FILE_HINT}\n\n{text}" @@ -313,11 +349,15 @@ def _normalize_ask_menu_event(stored): return { "question": str(stored.get("question") or "请选择下一步操作:").strip() or "请选择下一步操作:", "candidates": [str(candidate).strip() for candidate in candidates if str(candidate).strip()], + "multi": bool(stored.get("multi")), + "selected": [int(idx) for idx in stored.get("selected", []) if isinstance(idx, int)], } if isinstance(stored, (list, tuple)): return { "question": "请选择下一步操作:", "candidates": [str(candidate).strip() for candidate in stored if str(candidate).strip()], + "multi": False, + "selected": [], } return None @@ -357,11 +397,18 @@ async def _edit_ask_user_result(query, event, selected=None, cancelled=False): async def _send_ask_user_menu(root_msg, event): menu_id = uuid.uuid4().hex[:16] candidates = event["candidates"] - _ask_menu_store[menu_id] = {"question": event["question"], "candidates": list(candidates)} + multi = bool(event.get("multi")) + _ask_menu_store[menu_id] = { + "question": event["question"], + "candidates": list(candidates), + "multi": multi, + "selected": [], + } + prompt = f"{event['question']}\n\n{_ASK_MULTI_HINT}" if multi else event["question"] try: await root_msg.reply_text( - event["question"], - reply_markup=_build_ask_user_markup(menu_id, candidates), + prompt, + reply_markup=_build_ask_user_markup(menu_id, candidates, multi=multi), ) except Exception as exc: _ask_menu_store.pop(menu_id, None) @@ -845,6 +892,45 @@ async def handle_ask_callback(update, ctx): await query.answer("菜单已过期") return await _clear_ask_reply_markup(query) candidates = event["candidates"] + if event.get("multi") and action.startswith(f"{_ASK_TOGGLE_ACTION}:"): + try: + selected_idx = int(action.split(":", 1)[1]) + if selected_idx < 0 or selected_idx >= len(candidates): + raise ValueError + except ValueError: + return await query.answer("菜单无效") + stored = _ask_menu_store.get(menu_id) + if not isinstance(stored, dict): + return await query.answer("菜单已过期") + selected = set(stored.get("selected", [])) + if selected_idx in selected: + selected.remove(selected_idx) + else: + selected.add(selected_idx) + stored["selected"] = sorted(selected) + await query.answer() + return await query.edit_message_reply_markup( + reply_markup=_build_ask_user_markup( + menu_id, + candidates, + multi=True, + selected_indexes=stored["selected"], + ) + ) + if event.get("multi") and action == _ASK_MULTI_DONE_ACTION: + selected_indexes = event.get("selected") or [] + if not selected_indexes: + return await query.answer(_ASK_MULTI_EMPTY_HINT, show_alert=True) + selected = "; ".join(candidates[idx] for idx in selected_indexes) + _ask_menu_store.pop(menu_id, None) + await query.answer() + await _edit_ask_user_result(query, event, selected=selected) + if query.message is None: + return + dq = agent.put_task(_build_text_prompt(selected), source="telegram") + task = asyncio.create_task(_stream(dq, query.message)) + ctx.user_data['stream_task'] = task + return if action == _ASK_CANCEL_ACTION: _ask_menu_store.pop(menu_id, None) await query.answer() @@ -865,6 +951,52 @@ async def handle_ask_callback(update, ctx): task = asyncio.create_task(_stream(dq, query.message)) ctx.user_data['stream_task'] = task +async def _send_llm_menu(message): + llms = agent.list_llms() + if not llms: + return await message.reply_text("没有可用模型。") + menu_id = uuid.uuid4().hex[:16] + _llm_menu_store[menu_id] = [idx for idx, _, _ in llms] + lines = [f"{'→' if cur else ' '} [{idx}] {name}" for idx, name, cur in llms] + try: + await message.reply_text( + _LLM_MENU_PROMPT, + reply_markup=_build_llm_markup(menu_id, llms), + ) + except Exception as exc: + _llm_menu_store.pop(menu_id, None) + print(f"[TG llm menu error] {type(exc).__name__}: {exc}", flush=True) + await message.reply_text("LLMs:\n" + "\n".join(lines)) + +async def handle_llm_callback(update, ctx): + query = update.callback_query + if query is None: + return + uid = update.effective_user.id if update.effective_user else None + if ALLOWED and uid not in ALLOWED: + return await query.answer("no", show_alert=True) + menu_id, action = _parse_menu_callback_data(query.data, _LLM_CALLBACK_PREFIX) + if not menu_id: + return await query.answer("菜单无效") + valid_indexes = _llm_menu_store.get(menu_id) + if valid_indexes is None: + await query.answer("菜单已过期") + return await _clear_ask_reply_markup(query) + try: + selected_idx = int(action) + except (TypeError, ValueError): + return await query.answer("菜单无效") + if selected_idx not in valid_indexes: + return await query.answer("菜单已过期", show_alert=True) + try: + agent.next_llm(selected_idx) + selected_name = agent.get_llm_name() + except Exception as exc: + return await query.answer(f"切换失败: {exc}", show_alert=True) + _llm_menu_store.pop(menu_id, None) + await query.answer(f"已切换到 [{selected_idx}] {selected_name}") + await query.edit_message_text(f"✅ 已切换到 [{selected_idx}] {selected_name}") + async def cmd_abort(update, ctx): _cancel_stream_task(ctx) agent.abort() @@ -880,8 +1012,7 @@ async def cmd_llm(update, ctx): except (ValueError, IndexError): await update.message.reply_text(f"用法: /llm <0-{len(agent.list_llms())-1}>") else: - lines = [f"{'→' if cur else ' '} [{i}] {name}" for i, name, cur in agent.list_llms()] - await update.message.reply_text("LLMs:\n" + "\n".join(lines)) + await _send_llm_menu(update.message) async def handle_photo(update, ctx): uid = update.effective_user.id @@ -969,6 +1100,7 @@ async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE): app = (ApplicationBuilder().token(mykeys['tg_bot_token']) .request(request).get_updates_request(request).post_init(_sync_commands).build()) app.add_handler(CallbackQueryHandler(handle_ask_callback, pattern=r"^ask:")) + app.add_handler(CallbackQueryHandler(handle_llm_callback, pattern=r"^llm:")) app.add_handler(MessageHandler(filters.COMMAND, handle_command)) app.add_handler(MessageHandler(filters.PHOTO, handle_photo)) app.add_handler(MessageHandler(filters.Document.ALL, handle_photo)) diff --git a/tests/test_tgapp_inline_selection.py b/tests/test_tgapp_inline_selection.py new file mode 100644 index 000000000..a800a3b7a --- /dev/null +++ b/tests/test_tgapp_inline_selection.py @@ -0,0 +1,234 @@ +import importlib +import sys +import types +import unittest +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +FRONTENDS = REPO_ROOT / "frontends" +for path in (str(REPO_ROOT), str(FRONTENDS)): + if path not in sys.path: + sys.path.insert(0, path) + + +def _install_import_stubs(): + telegram = types.ModuleType("telegram") + + class InlineKeyboardButton: + def __init__(self, text, callback_data=None): + self.text = text + self.callback_data = callback_data + + class InlineKeyboardMarkup: + def __init__(self, rows): + self.inline_keyboard = rows + + telegram.BotCommand = object + telegram.InlineKeyboardButton = InlineKeyboardButton + telegram.InlineKeyboardMarkup = InlineKeyboardMarkup + + constants = types.ModuleType("telegram.constants") + constants.ChatType = types.SimpleNamespace(PRIVATE="private") + constants.MessageLimit = types.SimpleNamespace(MAX_TEXT_LENGTH=4096) + constants.ParseMode = types.SimpleNamespace(MARKDOWN_V2="MarkdownV2") + + error = types.ModuleType("telegram.error") + error.RetryAfter = type("RetryAfter", (Exception,), {}) + + ext = types.ModuleType("telegram.ext") + ext.ApplicationBuilder = object + ext.CallbackQueryHandler = object + ext.MessageHandler = object + ext.ContextTypes = types.SimpleNamespace(DEFAULT_TYPE=object) + ext.filters = types.SimpleNamespace( + COMMAND=object(), + PHOTO=object(), + TEXT=object(), + Document=types.SimpleNamespace(ALL=object()), + ) + + helpers = types.ModuleType("telegram.helpers") + helpers.escape_markdown = lambda text, version=2, entity_type=None: text or "" + + request = types.ModuleType("telegram.request") + request.HTTPXRequest = object + + class FakeAgent: + def __init__(self): + self.verbose = False + self.inc_out = False + self.llm_no = 0 + self.prompts = [] + + def list_llms(self): + return [ + (0, "gpt-4o", self.llm_no == 0), + (1, "claude-sonnet", self.llm_no == 1), + ] + + def next_llm(self, n): + if n not in (0, 1): + raise IndexError(n) + self.llm_no = n + + def get_llm_name(self): + return self.list_llms()[self.llm_no][1] + + def put_task(self, prompt, source=None): + self.prompts.append((prompt, source)) + return object() + + agentmain = types.ModuleType("agentmain") + agentmain.GeneraticAgent = FakeAgent + + chatapp_common = types.ModuleType("chatapp_common") + chatapp_common.FILE_HINT = "FILE_HINT" + chatapp_common.HELP_TEXT = "" + chatapp_common.TELEGRAM_MENU_COMMANDS = [] + chatapp_common.clean_reply = lambda text: text + chatapp_common.ensure_single_instance = lambda *args, **kwargs: None + chatapp_common.extract_files = lambda text: [] + chatapp_common.format_restore = lambda: (([], "", 0), None) + chatapp_common.redirect_log = lambda *args, **kwargs: None + chatapp_common.require_runtime = lambda *args, **kwargs: None + chatapp_common.split_text = lambda text, limit: [text] + + continue_cmd = types.ModuleType("continue_cmd") + continue_cmd.handle_frontend_command = lambda *args, **kwargs: "" + continue_cmd.reset_conversation = lambda *args, **kwargs: "" + + btw_cmd = types.ModuleType("btw_cmd") + btw_cmd.handle_frontend_command = lambda *args, **kwargs: "" + + llmcore = types.ModuleType("llmcore") + llmcore.mykeys = {} + + sys.modules.update( + { + "telegram": telegram, + "telegram.constants": constants, + "telegram.error": error, + "telegram.ext": ext, + "telegram.helpers": helpers, + "telegram.request": request, + "agentmain": agentmain, + "chatapp_common": chatapp_common, + "continue_cmd": continue_cmd, + "btw_cmd": btw_cmd, + "llmcore": llmcore, + } + ) + + +_install_import_stubs() +tgapp = importlib.import_module("tgapp") + + +class FakeMessage: + text = "/llm" + + def __init__(self): + self.replies = [] + + async def reply_text(self, text, reply_markup=None, **kwargs): + self.replies.append(types.SimpleNamespace(text=text, reply_markup=reply_markup)) + return self.replies[-1] + + +class FakeQuery: + def __init__(self, data): + self.data = data + self.message = FakeMessage() + self.answers = [] + self.edited_text = None + self.edited_markup = None + + async def answer(self, text=None, show_alert=False): + self.answers.append((text, show_alert)) + + async def edit_message_text(self, text, reply_markup=None): + self.edited_text = text + self.edited_markup = reply_markup + + async def edit_message_reply_markup(self, reply_markup=None): + self.edited_markup = reply_markup + + +class FakeUpdate: + effective_user = types.SimpleNamespace(id=1) + + def __init__(self, query=None, message=None): + self.callback_query = query + self.message = message + + +class TelegramInlineSelectionTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + tgapp._ask_menu_store.clear() + tgapp._llm_menu_store.clear() + tgapp.agent = tgapp.GeneraticAgent() + + async def fake_stream(*args, **kwargs): + return None + + self._original_stream = tgapp._stream + tgapp._stream = fake_stream + + async def asyncTearDown(self): + tgapp._stream = self._original_stream + + def test_multi_ask_markup_tracks_selected_items(self): + markup = tgapp._build_ask_user_markup( + "menu", + ["Python", "Go"], + multi=True, + selected_indexes=[0], + ) + + rows = markup.inline_keyboard + self.assertEqual(rows[0][0].text, "✓ Python") + self.assertEqual(rows[0][0].callback_data, "ask:menu:toggle:0") + self.assertEqual(rows[1][0].callback_data, "ask:menu:toggle:1") + self.assertEqual(rows[2][0].callback_data, "ask:menu:done") + + async def test_multi_ask_done_submits_joined_selection(self): + tgapp._ask_menu_store["menu"] = { + "question": "Pick [多选]", + "candidates": ["Python", "JavaScript", "Go"], + "multi": True, + "selected": [0, 2], + } + query = FakeQuery("ask:menu:done") + ctx = types.SimpleNamespace(user_data={}) + + await tgapp.handle_ask_callback(FakeUpdate(query=query), ctx) + + self.assertNotIn("menu", tgapp._ask_menu_store) + self.assertIn("Python; Go", query.edited_text) + self.assertEqual(tgapp.agent.prompts[-1][0], "FILE_HINT\n\nPython; Go") + self.assertEqual(tgapp.agent.prompts[-1][1], "telegram") + + async def test_llm_command_sends_inline_keyboard_and_callback_switches(self): + message = FakeMessage() + + await tgapp.cmd_llm(FakeUpdate(message=message), types.SimpleNamespace()) + + self.assertEqual(message.replies[0].text, tgapp._LLM_MENU_PROMPT) + menu_id = next(iter(tgapp._llm_menu_store)) + rows = message.replies[0].reply_markup.inline_keyboard + self.assertEqual(rows[1][0].callback_data, f"llm:{menu_id}:1") + + query = FakeQuery(f"llm:{menu_id}:1") + await tgapp.handle_llm_callback( + FakeUpdate(query=query), + types.SimpleNamespace(user_data={}), + ) + + self.assertEqual(tgapp.agent.llm_no, 1) + self.assertIn("claude-sonnet", query.edited_text) + self.assertNotIn(menu_id, tgapp._llm_menu_store) + + +if __name__ == "__main__": + unittest.main()