iamthewalrus67 commited on
Commit
d5dc5cf
·
1 Parent(s): f52933f

Try to make zero gpu work

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -10,19 +10,23 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
10
 
11
  MODEL_ID = "le-llm/gemma-3-12b-it-reasoning"
12
 
13
- # Load model & tokenizer
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
- model = AutoModelForCausalLM.from_pretrained(
17
- MODEL_ID,
18
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
19
- device_map="auto" if device == "cuda" else None, # helps if multiple GPUs
20
- )
21
-
22
  SYSTEM_PROMPT = (
23
  "You are a helpful, concise assistant. Only write replies as the Assistant. Do not invent or continue User messages."
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def respond(
27
  message,
28
  history: list[dict[str, str]],
@@ -31,6 +35,9 @@ def respond(
31
  temperature,
32
  top_p,
33
  ):
 
 
 
34
  # Build conversation
35
  messages = [{"role": "system", "content": system_message}] + history + [
36
  {"role": "user", "content": message}
@@ -67,7 +74,7 @@ def respond(
67
  partial_output = ""
68
  for new_text in streamer:
69
  partial_output += new_text
70
- yield partial_output # <- streams to Gradio frontend
71
 
72
 
73
  chatbot = gr.ChatInterface(
@@ -87,4 +94,5 @@ chatbot = gr.ChatInterface(
87
  ],
88
  )
89
 
90
- chatbot.launch()
 
 
10
 
11
  MODEL_ID = "le-llm/gemma-3-12b-it-reasoning"
12
 
 
 
 
 
 
 
 
 
 
13
  SYSTEM_PROMPT = (
14
  "You are a helpful, concise assistant. Only write replies as the Assistant. Do not invent or continue User messages."
15
  )
16
 
17
+
18
+ def load_model():
19
+ """Lazy-load model & tokenizer (for zeroGPU)."""
20
+ device = "cuda"# if torch.cuda.is_available() else "cpu"
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_ID,
24
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
25
+ device_map="auto" if device == "cuda" else None,
26
+ )
27
+ return model, tokenizer, device
28
+
29
+
30
  def respond(
31
  message,
32
  history: list[dict[str, str]],
 
35
  temperature,
36
  top_p,
37
  ):
38
+ # Load model/tokenizer each request → allows zeroGPU to cold start & then release
39
+ model, tokenizer, device = load_model()
40
+
41
  # Build conversation
42
  messages = [{"role": "system", "content": system_message}] + history + [
43
  {"role": "user", "content": message}
 
74
  partial_output = ""
75
  for new_text in streamer:
76
  partial_output += new_text
77
+ yield partial_output
78
 
79
 
80
  chatbot = gr.ChatInterface(
 
94
  ],
95
  )
96
 
97
+ if __name__ == "__main__":
98
+ chatbot.launch()