Mr-Help commited on
Commit
6867f65
Β·
verified Β·
1 Parent(s): 16cb566

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -20
main.py CHANGED
@@ -1,31 +1,61 @@
 
1
  import torch
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def main():
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
 
11
- # Pick dtype Ω…Ω†Ψ§Ψ³Ψ¨: bfloat16 Ω„Ωˆ GPU Ω…Ψͺاح، غير ΩƒΨ―Ω‡ float32 ΨΉΩ„Ω‰ CPU
12
  has_cuda = torch.cuda.is_available()
13
  dtype = torch.bfloat16 if has_cuda else torch.float32
14
 
15
- # Load model (device_map="auto" يوزع ΨͺΩ„Ω‚Ψ§Ψ¦ΩŠ)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
18
  torch_dtype=dtype,
19
  device_map="auto"
20
  )
21
 
22
- # Prompt: explain Past Simple in simple English
 
 
 
 
 
 
 
 
 
 
 
23
  messages = [
24
- {"role": "system", "content": "You are a friendly English teacher. Explain clearly and simply."},
25
- {"role": "user", "content": "Explain the Past Simple tense in very simple English. Give rules and 8 short examples. Keep it clear for A2 learners."}
26
  ]
27
 
28
- # Convert chat messages to model input
29
  text = tokenizer.apply_chat_template(
30
  messages,
31
  tokenize=False,
@@ -34,23 +64,24 @@ def main():
34
 
35
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
36
 
37
- # Generate
 
 
 
38
  with torch.no_grad():
39
  generated_ids = model.generate(
40
  **model_inputs,
41
- max_new_tokens=400,
42
- do_sample=True,
43
- temperature=0.7,
44
- top_p=0.9
45
  )
46
 
47
- # Keep only the newly generated tokens (remove the prompt tokens)
48
  new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:]
49
- response = tokenizer.decode(new_tokens, skip_special_tokens=True)
50
 
51
- print("\n=== Model Response ===\n")
52
- print(response.strip())
53
- print("\n======================\n")
54
 
55
- if __name__ == "__main__":
56
- main()
 
1
+ import os
2
  import torch
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct")
8
 
9
+ app = FastAPI(title="Qwen FastAPI")
10
+
11
+ tokenizer = None
12
+ model = None
13
+
14
+
15
+ class GenerateRequest(BaseModel):
16
+ system_prompt: str
17
+ user_prompt: str
18
+ max_new_tokens: int = 400
19
+ temperature: float = 0.7
20
+ top_p: float = 0.9
21
+ do_sample: bool = True
22
+
23
+
24
+ @app.on_event("startup")
25
+ def startup_event():
26
+ global tokenizer, model
27
 
 
28
  # Load tokenizer
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
30
 
31
+ # dtype: bfloat16 on CUDA, float32 on CPU
32
  has_cuda = torch.cuda.is_available()
33
  dtype = torch.bfloat16 if has_cuda else torch.float32
34
 
35
+ # Load model (auto device placement)
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_NAME,
38
  torch_dtype=dtype,
39
  device_map="auto"
40
  )
41
 
42
+ print("Model ready") # βœ… Ω…Ψ·Ω„ΩˆΨ¨ Ω…Ω†Ωƒ
43
+
44
+
45
+ @app.get("/health")
46
+ def health():
47
+ return {"status": "ok", "model": MODEL_NAME}
48
+
49
+
50
+ @app.post("/generate")
51
+ def generate(req: GenerateRequest):
52
+ global tokenizer, model
53
+
54
  messages = [
55
+ {"role": "system", "content": req.system_prompt},
56
+ {"role": "user", "content": req.user_prompt}
57
  ]
58
 
 
59
  text = tokenizer.apply_chat_template(
60
  messages,
61
  tokenize=False,
 
64
 
65
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
66
 
67
+ print("\n=== Incoming Request ===")
68
+ print("SYSTEM:", req.system_prompt)
69
+ print("USER:", req.user_prompt)
70
+
71
  with torch.no_grad():
72
  generated_ids = model.generate(
73
  **model_inputs,
74
+ max_new_tokens=req.max_new_tokens,
75
+ do_sample=req.do_sample,
76
+ temperature=req.temperature,
77
+ top_p=req.top_p,
78
  )
79
 
 
80
  new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:]
81
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
82
 
83
+ print("\n=== Model Response ===")
84
+ print(response)
85
+ print("======================\n")
86
 
87
+ return {"response": response}