-
Notifications
You must be signed in to change notification settings - Fork 3
Yhl/llm example #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Yhl/llm example #12
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # RUNPOD_API_KEY=your_api_key_here | ||
| # FLASH_HOST=localhost | ||
| # FLASH_PORT=8888 | ||
| # LOG_LEVEL=INFO | ||
| # HF_TOKEN=your_huggingface_token |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Flash Build Ignore Patterns | ||
|
|
||
| # Python cache | ||
| __pycache__/ | ||
| *.pyc | ||
|
|
||
| # Virtual environments | ||
| venv/ | ||
| .venv/ | ||
| env/ | ||
|
|
||
| # IDE | ||
| .vscode/ | ||
| .idea/ | ||
|
|
||
| # Environment files | ||
| .env | ||
| .env.local | ||
|
|
||
| # Git | ||
| .git/ | ||
| .gitignore | ||
|
|
||
| # Build artifacts | ||
| dist/ | ||
| build/ | ||
| *.egg-info/ | ||
|
|
||
| # Flash resources | ||
| .tetra_resources.pkl | ||
|
|
||
| # Tests | ||
| tests/ | ||
| test_*.py | ||
| *_test.py | ||
|
|
||
| # Documentation | ||
| docs/ | ||
| *.md | ||
| !README.md |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """LLM chat inference on a serverless GPU example.""" |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,170 @@ | ||||||
| ## LLM chat inference on a serverless GPU | ||||||
| # This example runs a small chat LLM (Llama 3.2 1B Instruct) on Runpod serverless GPUs | ||||||
| # using `transformers.pipeline`. | ||||||
| # | ||||||
| # Call it via the FastAPI endpoint (`POST /gpu/llm`) or run this module directly for | ||||||
| # a quick smoke test. | ||||||
| # | ||||||
| # Scaling behavior is controlled by the `LiveServerless` config below. | ||||||
| import os | ||||||
|
|
||||||
| from fastapi import APIRouter | ||||||
| from pydantic import BaseModel | ||||||
|
|
||||||
| from tetra_rp import ( | ||||||
| GpuGroup, | ||||||
| LiveServerless, | ||||||
| remote, | ||||||
| ) | ||||||
|
|
||||||
| # Here, we'll define several variables that change the | ||||||
| # default behavior of our serverless endpoint. `workersMin` sets our endpoint | ||||||
| # to scale to 0 active containers; `workersMax` will allow our endpoint to run | ||||||
| # up to 3 workers in parallel as the endpoint receives more work. We also set | ||||||
| # an idle timeout of 5 minutes so that any active worker stays alive for 5 | ||||||
| # minutes after completing a request. | ||||||
| # | ||||||
| # Hugging Face auth: | ||||||
| # Many `meta-llama/*` models are gated on Hugging Face. Local shell env vars are NOT | ||||||
| # automatically forwarded into serverless containers, so we pass `HF_TOKEN` via `env=...` | ||||||
| # so the remote worker can download the model. | ||||||
| _hf_token = os.getenv("HF_TOKEN") | ||||||
| _worker_env = {"HF_TOKEN": _hf_token} if _hf_token else {} | ||||||
| gpu_config = LiveServerless( | ||||||
| name="02_01_text_generation_gpu_worker", | ||||||
| gpus=[GpuGroup.ANY], # Run on any GPU | ||||||
| env=_worker_env, | ||||||
| workersMin=0, | ||||||
| workersMax=3, | ||||||
| idleTimeout=5, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # Decorating our function with `remote` will package up the function code and | ||||||
| # deploy it on the infrastructure according to the passed input config. The | ||||||
| # results from the worker will be returned to your terminal. In this example | ||||||
| # the function will return a greeting to the input string passed in the `name` | ||||||
| # key. The code itself will run on a GPU worker, and information about the GPU | ||||||
| # the worker has access to will be included in the response. | ||||||
| # Declare worker dependencies so they're installed in the remote execution environment. | ||||||
| # (Local `requirements.txt` is not automatically shipped to the worker.) | ||||||
| @remote( | ||||||
| resource_config=gpu_config, | ||||||
| dependencies=[ | ||||||
| "torch", | ||||||
| "transformers", | ||||||
| "accelerate", | ||||||
| ], | ||||||
| ) | ||||||
| async def gpu_hello( | ||||||
| input_data: dict, | ||||||
| ) -> dict: | ||||||
| """Generate one chat response using Llama 3.2 1B Instruct on a serverless GPU.""" | ||||||
| import os | ||||||
| import platform | ||||||
| from datetime import datetime | ||||||
|
|
||||||
| import torch | ||||||
| from transformers import pipeline | ||||||
|
|
||||||
| gpu_available = torch.cuda.is_available() | ||||||
| gpu_name = torch.cuda.get_device_name(0) | ||||||
| gpu_count = torch.cuda.device_count() | ||||||
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | ||||||
|
|
||||||
| # Inputs: | ||||||
| # - Simple: {"message": "...", "system_prompt": "...", "max_new_tokens": 512} | ||||||
| # - Full chat: {"messages": [{"role": "...", "content": "..."}, ...], "max_new_tokens": 512} | ||||||
| system_prompt = input_data.get( | ||||||
| "system_prompt", | ||||||
| "You are a helpful assistant chatbot who always responds in a friendly and helpful manner!", | ||||||
| ) | ||||||
| message = input_data.get("message", "What is gpu?") | ||||||
| messages = input_data.get("messages") or [ | ||||||
| {"role": "system", "content": system_prompt}, | ||||||
| {"role": "user", "content": message}, | ||||||
| ] | ||||||
|
|
||||||
| model_id = "meta-llama/Llama-3.2-1B-Instruct" | ||||||
|
|
||||||
| # Hugging Face auth for gated repos: | ||||||
| hf_token = os.getenv("HF_TOKEN") | ||||||
| if not hf_token: | ||||||
| raise RuntimeError("HF_TOKEN is required to download gated models (e.g. meta-llama/*).") | ||||||
|
|
||||||
| pipe = pipeline( | ||||||
| "text-generation", | ||||||
| model=model_id, | ||||||
| torch_dtype=torch.bfloat16, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [🟠 High] [🔵 Bug] The worker is explicitly schedulable on any GPU, but the pipeline forces # 02_ml_inference/01_text_generation/gpu_worker.py
gpu_config = LiveServerless(
gpus=[GpuGroup.ANY], # Run on any GPU
)
...
torch_dtype=torch.bfloat16,
Suggested change
|
||||||
| device_map="auto", | ||||||
| token=hf_token, | ||||||
| ) | ||||||
|
|
||||||
| outputs = pipe( | ||||||
| messages, | ||||||
| max_new_tokens=int(input_data.get("max_new_tokens", 512)), | ||||||
| ) | ||||||
| generated = outputs[0]["generated_text"] | ||||||
| last = generated[-1] if isinstance(generated, list) and generated else generated | ||||||
| assistant_message = last.get("content") if isinstance(last, dict) else str(last) | ||||||
| print(assistant_message) | ||||||
|
|
||||||
| return { | ||||||
| "status": "success", | ||||||
| "message": assistant_message, | ||||||
| "worker_type": "GPU", | ||||||
| "gpu_info": { | ||||||
| "available": gpu_available, | ||||||
| "name": gpu_name, | ||||||
| "count": gpu_count, | ||||||
| "memory_gb": round( | ||||||
| gpu_memory, | ||||||
| 2, | ||||||
| ), | ||||||
| }, | ||||||
| "timestamp": datetime.now().isoformat(), | ||||||
| "platform": platform.system(), | ||||||
| "python_version": platform.python_version(), | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| # We define a subrouter for our gpu worker so that our main router in `main.py` | ||||||
| # can attach it for routing gpu-specific requests. | ||||||
| gpu_router = APIRouter() | ||||||
|
|
||||||
|
|
||||||
| class MessageRequest(BaseModel): | ||||||
| """Request model for GPU worker.""" | ||||||
|
|
||||||
| message: str = "What is gpu?" | ||||||
| system_prompt: str = ( | ||||||
| "You are a helpful assistant chatbot who always responds in a friendly and helpful manner!" | ||||||
| ) | ||||||
| max_new_tokens: int = 512 | ||||||
|
|
||||||
|
|
||||||
| @gpu_router.post("/llm") | ||||||
| async def llm( | ||||||
| request: MessageRequest, | ||||||
| ): | ||||||
| """Simple GPU worker endpoint.""" | ||||||
| result = await gpu_hello( | ||||||
| { | ||||||
| "message": request.message, | ||||||
| "system_prompt": request.system_prompt, | ||||||
| "max_new_tokens": request.max_new_tokens, | ||||||
| } | ||||||
| ) | ||||||
| return result | ||||||
|
|
||||||
|
|
||||||
| # This code is packaged up as a "worker" that will handle requests sent to the | ||||||
| # endpoint at /gpu/llm, but you can also trigger it directly by running | ||||||
| # python -m workers.gpu.endpoint | ||||||
| if __name__ == "__main__": | ||||||
| import asyncio | ||||||
|
|
||||||
| test_payload = {"message": "Testing GPU worker"} | ||||||
| print(f"Testing GPU worker with payload: {test_payload}") | ||||||
| result = asyncio.run(gpu_hello(test_payload)) | ||||||
| print(f"Result: {result}") | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| ## LLM demo: FastAPI router + serverless GPU worker | ||
| # This example exposes a simple local FastAPI app (this file) with a single LLM endpoint | ||
| # backed by a Runpod serverless GPU worker defined in `gpu_worker.py`. | ||
| # | ||
| # - Local API: runs on your machine via `flash run` (default: http://localhost:8888) | ||
| # - Remote compute: executed on Runpod serverless GPUs via `tetra_rp.remote` | ||
| # | ||
| # Main endpoint: | ||
| # - POST /gpu/llm -> runs Llama chat inference on the remote GPU worker | ||
| # | ||
| # Note: The Llama model used in the worker is gated on Hugging Face, so you must provide | ||
| # `HF_TOKEN` (the worker reads it from the serverless env). | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| from fastapi import FastAPI | ||
| from gpu_worker import gpu_router | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # We define a simple FastAPI app to serve requests from localhost. | ||
| app = FastAPI( | ||
| title="Flash Application", | ||
| description="Distributed GPU computing with Runpod Flash", | ||
| version="0.1.0", | ||
| ) | ||
|
|
||
| # Attach gpu worker subrouters - this will route any requests to our | ||
| # app with the prefix /gpu to the gpu subrouter. To see the subrouter in action, | ||
| # start the app and execute the following command in another terminal window: | ||
| # curl -X POST http://localhost:8888/gpu/llm -d '{"message": "hello"}' -H "Content-Type: application/json" | ||
| app.include_router( | ||
| gpu_router, | ||
| prefix="/gpu", | ||
| tags=["GPU Workers"], | ||
| ) | ||
|
|
||
|
|
||
| # The homepage for our main endpoint will just return a plaintext json object | ||
| # containing the endpoints defined in this app. | ||
| @app.get("/") | ||
| def home(): | ||
| return { | ||
| "message": "Flash Application", | ||
| "docs": "/docs", | ||
| "endpoints": { | ||
| "gpu_hello": "/gpu/llm", | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @app.get("/ping") | ||
| def ping(): | ||
| return {"status": "healthy"} | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import uvicorn | ||
|
|
||
| host = os.getenv("FLASH_HOST", "localhost") | ||
| port = int(os.getenv("FLASH_PORT", 8888)) | ||
| logger.info(f"Starting Flash server on {host}:{port}") | ||
|
|
||
| uvicorn.run( | ||
| app, | ||
| host=host, | ||
| port=port, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| tetra_rp | ||
| torch | ||
| transformers | ||
| accelerate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[🟠 High] [🟡 Investigate]
This new example imports
tetra_rpdirectly, but the repository’s documented/base install contract is centered onrunpod-flash(uv syncfrom the repo root installs that dependency set), so a clean environment may hitModuleNotFoundErrorwhenflash runimports this module. Verify by running the repo quick-start flow in a fresh venv; if it fails, migrate this example to the supportedrunpod_flashAPI or add/document project-level dependency installation fortetra_rp.