dbmoradi60 commited on
Commit
7265081
·
verified ·
1 Parent(s): b7e4714

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ import os
6
+
7
+ app = FastAPI(title="GPT-OSS-20B API")
8
+
9
+ # Set environment variable for faster model downloads
10
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
11
+
12
+ # Model ID
13
+ MODEL_ID = "openai/gpt-oss-20b"
14
+
15
+ # Load tokenizer
16
+ print("Loading tokenizer...")
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
+
19
+ # Load model with CPU offloading
20
+ print("Loading model (this may take several minutes)...")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_ID,
23
+ device_map="auto", # Automatically place on available devices (CPU)
24
+ torch_dtype="auto", # Automatic precision
25
+ offload_folder="./offload", # Offload weights to disk
26
+ max_memory={0: "15GB", "cpu": "30GB"}, # Memory constraints
27
+ trust_remote_code=True # Required for custom model code
28
+ )
29
+ print(f"Model loaded on: {model.device}")
30
+ print(f"Model dtype: {model.dtype}")
31
+
32
+ # Enable gradient checkpointing to reduce memory usage
33
+ model.gradient_checkpointing_enable()
34
+
35
+ class ChatRequest(BaseModel):
36
+ message: str
37
+ max_tokens: int = 256
38
+ temperature: float = 0.7
39
+
40
+ @app.post("/chat")
41
+ async def chat_endpoint(request: ChatRequest):
42
+ try:
43
+ # Prepare input
44
+ messages = [{"role": "user", "content": request.message}]
45
+ inputs = tokenizer.apply_chat_template(
46
+ messages,
47
+ add_generation_prompt=True,
48
+ return_tensors="pt",
49
+ return_dict=True
50
+ ).to("cpu")
51
+
52
+ # Generate response
53
+ with torch.no_grad():
54
+ generated = model.generate(
55
+ **inputs,
56
+ max_new_tokens=request.max_tokens,
57
+ temperature=request.temperature,
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.eos_token_id,
60
+ repetition_penalty=1.1
61
+ )
62
+
63
+ # Decode response
64
+ response = tokenizer.decode(
65
+ generated[0][inputs["input_ids"].shape[-1]:],
66
+ skip_special_tokens=True
67
+ )
68
+ return {"response": response}
69
+ except Exception as e:
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+ # Clear cache regularly to manage memory
73
+ torch.cuda.empty_cache()
74
+
75
+ if __name__ == "__main__":
76
+ import uvicorn
77
+ uvicorn.run(app, host="0.0.0.0", port=8000)