Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 146 additions & 14 deletions frontends/tgapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@
_QUEUE_WAIT_SECONDS = 1
_ASK_USER_HOOK_KEY = "telegram_ask_user_menu"
_ASK_CALLBACK_PREFIX = "ask:"
_LLM_CALLBACK_PREFIX = "llm:"
_ASK_CANCEL_ACTION = "none"
_ASK_MULTI_DONE_ACTION = "done"
_ASK_TOGGLE_ACTION = "toggle"
_ASK_CANCEL_LABEL = "none of these above"
_ASK_CANCEL_PROMPT = "已取消选择,请直接发送下一步操作。"
_ASK_MULTI_HINT = "可多选:点选项目后点击 Done 提交。"
_ASK_MULTI_EMPTY_HINT = "请至少选择一项,或选择 none of these above。"
_LLM_MENU_PROMPT = "请选择要切换的 LLM:"
_ask_menu_events = Q.Queue()
_ask_menu_store = {}
_llm_menu_store = {}
_MULTI_SELECT_RE = re.compile(r"\[?(?:多选|multi(?:[-_ ]?select)?|select all)\]?", re.IGNORECASE)
_QUOTE_OPEN_TAG = "<_quote_>"
_QUOTE_CLOSE_TAG = "</_quote_>"
_QUOTE_TOKEN_PATTERN = re.escape(_QUOTE_OPEN_TAG) + r"([\s\S]*?)" + re.escape(_QUOTE_CLOSE_TAG)
Expand Down Expand Up @@ -265,7 +273,11 @@ def _extract_ask_user_event(ctx):
if not candidates:
return None
question = str(data.get("question") or "请选择下一步操作:").strip() or "请选择下一步操作:"
return {"question": question, "candidates": candidates}
return {
"question": question,
"candidates": candidates,
"multi": bool(_MULTI_SELECT_RE.search(question)),
}

def _register_ask_user_hook():
if not hasattr(agent, "_turn_end_hooks"):
Expand All @@ -285,25 +297,49 @@ def _drain_latest_ask_user_event():
break
return latest

def _build_ask_user_markup(menu_id, candidates):
rows = [
[InlineKeyboardButton(candidate, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{idx}")]
for idx, candidate in enumerate(candidates)
]
def _build_ask_user_markup(menu_id, candidates, multi=False, selected_indexes=None):
selected_indexes = set(selected_indexes or [])
rows = []
for idx, candidate in enumerate(candidates):
if multi:
label = f"✓ {candidate}" if idx in selected_indexes else candidate
action = f"{_ASK_TOGGLE_ACTION}:{idx}"
else:
label = candidate
action = str(idx)
rows.append([
InlineKeyboardButton(label, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{action}")
])
if multi:
rows.append([
InlineKeyboardButton("Done", callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{_ASK_MULTI_DONE_ACTION}")
])
rows.append([
InlineKeyboardButton(_ASK_CANCEL_LABEL, callback_data=f"{_ASK_CALLBACK_PREFIX}{menu_id}:{_ASK_CANCEL_ACTION}")
])
return InlineKeyboardMarkup(rows)

def _parse_ask_callback_data(data):
if not (data or "").startswith(_ASK_CALLBACK_PREFIX):
def _build_llm_markup(menu_id, llms):
rows = []
for idx, name, current in llms:
label = f"→ [{idx}] {name}" if current else f"[{idx}] {name}"
rows.append([
InlineKeyboardButton(label, callback_data=f"{_LLM_CALLBACK_PREFIX}{menu_id}:{idx}")
])
return InlineKeyboardMarkup(rows)

def _parse_menu_callback_data(data, prefix):
if not (data or "").startswith(prefix):
return None, None
payload = data[len(_ASK_CALLBACK_PREFIX):]
payload = data[len(prefix):]
menu_id, sep, action = payload.partition(":")
if not sep or not menu_id or not action:
return None, None
return menu_id, action

def _parse_ask_callback_data(data):
return _parse_menu_callback_data(data, _ASK_CALLBACK_PREFIX)

def _build_text_prompt(text):
return f"{FILE_HINT}\n\n{text}"

Expand All @@ -313,11 +349,15 @@ def _normalize_ask_menu_event(stored):
return {
"question": str(stored.get("question") or "请选择下一步操作:").strip() or "请选择下一步操作:",
"candidates": [str(candidate).strip() for candidate in candidates if str(candidate).strip()],
"multi": bool(stored.get("multi")),
"selected": [int(idx) for idx in stored.get("selected", []) if isinstance(idx, int)],
}
if isinstance(stored, (list, tuple)):
return {
"question": "请选择下一步操作:",
"candidates": [str(candidate).strip() for candidate in stored if str(candidate).strip()],
"multi": False,
"selected": [],
}
return None

Expand Down Expand Up @@ -357,11 +397,18 @@ async def _edit_ask_user_result(query, event, selected=None, cancelled=False):
async def _send_ask_user_menu(root_msg, event):
menu_id = uuid.uuid4().hex[:16]
candidates = event["candidates"]
_ask_menu_store[menu_id] = {"question": event["question"], "candidates": list(candidates)}
multi = bool(event.get("multi"))
_ask_menu_store[menu_id] = {
"question": event["question"],
"candidates": list(candidates),
"multi": multi,
"selected": [],
}
prompt = f"{event['question']}\n\n{_ASK_MULTI_HINT}" if multi else event["question"]
try:
await root_msg.reply_text(
event["question"],
reply_markup=_build_ask_user_markup(menu_id, candidates),
prompt,
reply_markup=_build_ask_user_markup(menu_id, candidates, multi=multi),
)
except Exception as exc:
_ask_menu_store.pop(menu_id, None)
Expand Down Expand Up @@ -845,6 +892,45 @@ async def handle_ask_callback(update, ctx):
await query.answer("菜单已过期")
return await _clear_ask_reply_markup(query)
candidates = event["candidates"]
if event.get("multi") and action.startswith(f"{_ASK_TOGGLE_ACTION}:"):
try:
selected_idx = int(action.split(":", 1)[1])
if selected_idx < 0 or selected_idx >= len(candidates):
raise ValueError
except ValueError:
return await query.answer("菜单无效")
stored = _ask_menu_store.get(menu_id)
if not isinstance(stored, dict):
return await query.answer("菜单已过期")
selected = set(stored.get("selected", []))
if selected_idx in selected:
selected.remove(selected_idx)
else:
selected.add(selected_idx)
stored["selected"] = sorted(selected)
await query.answer()
return await query.edit_message_reply_markup(
reply_markup=_build_ask_user_markup(
menu_id,
candidates,
multi=True,
selected_indexes=stored["selected"],
)
)
if event.get("multi") and action == _ASK_MULTI_DONE_ACTION:
selected_indexes = event.get("selected") or []
if not selected_indexes:
return await query.answer(_ASK_MULTI_EMPTY_HINT, show_alert=True)
selected = "; ".join(candidates[idx] for idx in selected_indexes)
_ask_menu_store.pop(menu_id, None)
await query.answer()
await _edit_ask_user_result(query, event, selected=selected)
if query.message is None:
return
dq = agent.put_task(_build_text_prompt(selected), source="telegram")
task = asyncio.create_task(_stream(dq, query.message))
ctx.user_data['stream_task'] = task
return
if action == _ASK_CANCEL_ACTION:
_ask_menu_store.pop(menu_id, None)
await query.answer()
Expand All @@ -865,6 +951,52 @@ async def handle_ask_callback(update, ctx):
task = asyncio.create_task(_stream(dq, query.message))
ctx.user_data['stream_task'] = task

async def _send_llm_menu(message):
llms = agent.list_llms()
if not llms:
return await message.reply_text("没有可用模型。")
menu_id = uuid.uuid4().hex[:16]
_llm_menu_store[menu_id] = [idx for idx, _, _ in llms]
lines = [f"{'→' if cur else ' '} [{idx}] {name}" for idx, name, cur in llms]
try:
await message.reply_text(
_LLM_MENU_PROMPT,
reply_markup=_build_llm_markup(menu_id, llms),
)
except Exception as exc:
_llm_menu_store.pop(menu_id, None)
print(f"[TG llm menu error] {type(exc).__name__}: {exc}", flush=True)
await message.reply_text("LLMs:\n" + "\n".join(lines))

async def handle_llm_callback(update, ctx):
query = update.callback_query
if query is None:
return
uid = update.effective_user.id if update.effective_user else None
if ALLOWED and uid not in ALLOWED:
return await query.answer("no", show_alert=True)
menu_id, action = _parse_menu_callback_data(query.data, _LLM_CALLBACK_PREFIX)
if not menu_id:
return await query.answer("菜单无效")
valid_indexes = _llm_menu_store.get(menu_id)
if valid_indexes is None:
await query.answer("菜单已过期")
return await _clear_ask_reply_markup(query)
try:
selected_idx = int(action)
except (TypeError, ValueError):
return await query.answer("菜单无效")
if selected_idx not in valid_indexes:
return await query.answer("菜单已过期", show_alert=True)
try:
agent.next_llm(selected_idx)
selected_name = agent.get_llm_name()
except Exception as exc:
return await query.answer(f"切换失败: {exc}", show_alert=True)
_llm_menu_store.pop(menu_id, None)
await query.answer(f"已切换到 [{selected_idx}] {selected_name}")
await query.edit_message_text(f"✅ 已切换到 [{selected_idx}] {selected_name}")

async def cmd_abort(update, ctx):
_cancel_stream_task(ctx)
agent.abort()
Expand All @@ -880,8 +1012,7 @@ async def cmd_llm(update, ctx):
except (ValueError, IndexError):
await update.message.reply_text(f"用法: /llm <0-{len(agent.list_llms())-1}>")
else:
lines = [f"{'→' if cur else ' '} [{i}] {name}" for i, name, cur in agent.list_llms()]
await update.message.reply_text("LLMs:\n" + "\n".join(lines))
await _send_llm_menu(update.message)

async def handle_photo(update, ctx):
uid = update.effective_user.id
Expand Down Expand Up @@ -969,6 +1100,7 @@ async def _error_handler(update, context: ContextTypes.DEFAULT_TYPE):
app = (ApplicationBuilder().token(mykeys['tg_bot_token'])
.request(request).get_updates_request(request).post_init(_sync_commands).build())
app.add_handler(CallbackQueryHandler(handle_ask_callback, pattern=r"^ask:"))
app.add_handler(CallbackQueryHandler(handle_llm_callback, pattern=r"^llm:"))
app.add_handler(MessageHandler(filters.COMMAND, handle_command))
app.add_handler(MessageHandler(filters.PHOTO, handle_photo))
app.add_handler(MessageHandler(filters.Document.ALL, handle_photo))
Expand Down
Loading