MATRIX / app.py
laserbeam2045
fix
6dd176e
raw
history blame
2.96 kB
# app.py
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from pyllamacpp.model import Model
# -----------------------------------------------------------------------------
# Hugging Face Hub の設定
# -----------------------------------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN") # 必要に応じて Secrets にセット
REPO_ID = "google/gemma-3-12b-it-qat-q4_0-gguf"
# 実際にリポジトリに置かれている GGUF ファイル名を確認してください。
# 例: "gemma-3-12b-it-qat-q4_0-gguf.gguf"
GGUF_FILENAME = "gemma-3-12b-it-qat-q4_0-gguf.gguf"
# キャッシュ先のパス(リポジトリ直下に置く場合)
MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME)
# -----------------------------------------------------------------------------
# 起動時に一度だけダウンロード
# -----------------------------------------------------------------------------
if not os.path.exists(MODEL_PATH):
print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …")
hf_hub_download(
repo_id=REPO_ID,
filename=GGUF_FILENAME,
token=HF_TOKEN,
repo_type="model", # 明示的にモデルリポジトリを指定
local_dir=os.getcwd(), # カレントディレクトリに保存
local_dir_use_symlinks=False
)
# -----------------------------------------------------------------------------
# llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード
# -----------------------------------------------------------------------------
llm = Model(
model_path=MODEL_PATH,
n_ctx=512, # 必要に応じて調整
n_threads=4, # 実マシンのコア数に合わせて
)
# -----------------------------------------------------------------------------
# FastAPI 定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-12B-IT Q4_0 GGUF API")
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
temperature: float = 0.8
top_p: float = 0.95
@app.post("/generate")
async def generate(req: GenerationRequest):
if not req.prompt:
raise HTTPException(status_code=400, detail="`prompt` は必須です。")
# llama.cpp の generate を呼び出し
text = llm.generate(
req.prompt,
top_p=req.top_p,
temp=req.temperature,
n_predict=req.max_new_tokens,
repeat_last_n=64,
repeat_penalty=1.1
)
return {"generated_text": text}
# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")