laserbeam2045 commited on
Commit
7f80d8c
·
1 Parent(s): b2b7327
Files changed (2) hide show
  1. app.py +71 -41
  2. requirements.txt +4 -8
app.py CHANGED
@@ -1,47 +1,77 @@
 
1
  import os
2
  import torch
3
- from fastapi import FastAPI
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from pydantic import BaseModel
6
- import logging
7
-
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- app = FastAPI()
12
-
13
- model_name = "google/gemma-2-2b-it"
14
- tokenizer = None
15
- model = None
16
-
17
- try:
18
- logger.info(f"Loading model: {model_name}")
19
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN"))
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- torch_dtype=torch.float16, # メモリ削減
23
- device_map="cpu", # GPU利用不可
24
- token=os.getenv("HF_TOKEN"),
25
- low_cpu_mem_usage=True
26
- )
27
- logger.info("Model loaded successfully")
28
- except Exception as e:
29
- logger.error(f"Model load error: {e}")
30
- raise
31
 
32
- class TextInput(BaseModel):
33
- text: str
34
- max_length: int = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  @app.post("/generate")
37
- async def generate_text(input: TextInput):
38
- try:
39
- logger.info(f"Generating text for input: {input.text}")
40
- inputs = tokenizer(input.text, return_tensors="pt", max_length=512, truncation=True).to("cpu")
41
- outputs = model.generate(**inputs, max_length=input.max_length)
42
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- logger.info(f"Generated text: {result}")
44
- return {"generated_text": result}
45
- except Exception as e:
46
- logger.error(f"Generation error: {e}")
47
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import os
3
  import torch
4
+ from fastapi import FastAPI, HTTPException
 
5
  from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # 設定
10
+ # -----------------------------------------------------------------------------
11
+ MODEL_ID = "google/gemma-3-4b-it"
12
+ # Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット
13
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
+
15
+ # -----------------------------------------------------------------------------
16
+ # デバイス設定(Spaces の無料枠では CPU のみです)
17
+ # -----------------------------------------------------------------------------
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # -----------------------------------------------------------------------------
21
+ # トークナイザーとモデルのロード
22
+ # -----------------------------------------------------------------------------
23
+ tokenizer = AutoTokenizer.from_pretrained(
24
+ MODEL_ID,
25
+ use_auth_token=HF_TOKEN,
26
+ trust_remote_code=True
27
+ )
28
+
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_ID,
31
+ use_auth_token=HF_TOKEN,
32
+ torch_dtype=torch.float32, # CPU 環境では float32
33
+ device_map="auto" if torch.cuda.is_available() else None
34
+ )
35
+ model.to(device)
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # FastAPI 定義
39
+ # -----------------------------------------------------------------------------
40
+ app = FastAPI(title="Gemma3-4B-IT API")
41
+
42
+ class GenerationRequest(BaseModel):
43
+ prompt: str
44
+ max_new_tokens: int = 128
45
+ temperature: float = 0.8
46
+ top_p: float = 0.95
47
 
48
  @app.post("/generate")
49
+ async def generate(req: GenerationRequest):
50
+ if not req.prompt:
51
+ raise HTTPException(status_code=400, detail="prompt は必須です。")
52
+ # トークナイズ
53
+ inputs = tokenizer(
54
+ req.prompt,
55
+ return_tensors="pt",
56
+ padding=True,
57
+ truncation=True,
58
+ ).to(device)
59
+ # 生成
60
+ generation_output = model.generate(
61
+ **inputs,
62
+ max_new_tokens=req.max_new_tokens,
63
+ temperature=req.temperature,
64
+ top_p=req.top_p,
65
+ do_sample=True,
66
+ pad_token_id=tokenizer.eos_token_id
67
+ )
68
+ text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
69
+ return {"generated_text": text}
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # ローカル起動用
73
+ # -----------------------------------------------------------------------------
74
+ if __name__ == "__main__":
75
+ import uvicorn
76
+ port = int(os.environ.get("PORT", 8000))
77
+ uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
requirements.txt CHANGED
@@ -1,9 +1,5 @@
1
  huggingface_hub==0.25.2
2
- torch==2.1.0
3
- numpy<2.0
4
- transformers==4.44.2
5
- bitsandbytes==0.42.0
6
- accelerate==0.26.1
7
- fastapi==0.115.0
8
- uvicorn==0.30.6
9
- gradio==4.15.0
 
1
  huggingface_hub==0.25.2
2
+ fastapi
3
+ uvicorn[standard]
4
+ transformers>=4.50.0.dev0
5
+ torch