reach-vb commited on
Commit
fe65e71
·
verified ·
1 Parent(s): 30a3eaa
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import List, Dict
4
+
5
+ import gradio as gr
6
+ from transformers import pipeline
7
+ import spaces
8
+
9
+ # === Config (override via Space secrets/env vars) ===
10
+ MODEL_ID = os.environ.get("MODEL_ID", "tlhv/osb-minier")
11
+ DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
12
+ DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
13
+ DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95))
14
+ DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
15
+ ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
16
+
17
+ # Cached pipeline (created after GPU is granted)
18
+ _pipe = None
19
+
20
+
21
+ def _to_messages(user_prompt: str) -> List[Dict[str, str]]:
22
+ # The provided model expects chat-style messages
23
+ return [{"role": "user", "content": user_prompt}]
24
+
25
+
26
+ @spaces.GPU(duration=ZGPU_DURATION)
27
+ def generate_long_prompt(
28
+ prompt: str,
29
+ max_new_tokens: int,
30
+ temperature: float,
31
+ top_p: float,
32
+ repetition_penalty: float,
33
+ ):
34
+ """Runs on a ZeroGPU-allocated GPU thanks to the decorator above."""
35
+ global _pipe
36
+ start = time.time()
37
+
38
+ # Create the pipeline lazily once the GPU is available
39
+ if _pipe is None:
40
+ _pipe = pipeline(
41
+ "text-generation",
42
+ model=MODEL_ID,
43
+ torch_dtype="auto",
44
+ device_map="auto", # let HF accelerate map to the GPU we just got
45
+ )
46
+
47
+ messages = _to_messages(prompt)
48
+
49
+ outputs = _pipe(
50
+ messages,
51
+ max_new_tokens=max_new_tokens,
52
+ do_sample=True,
53
+ temperature=temperature,
54
+ top_p=top_p,
55
+ repetition_penalty=repetition_penalty,
56
+ )
57
+
58
+ # Robust extraction for different pipeline return shapes
59
+ text = None
60
+ if isinstance(outputs, list) and outputs:
61
+ res = outputs[0]
62
+ if isinstance(res, dict):
63
+ gt = res.get("generated_text")
64
+ if isinstance(gt, list) and gt and isinstance(gt[-1], dict):
65
+ text = gt[-1].get("content") or gt[-1].get("text")
66
+ elif isinstance(gt, str):
67
+ text = gt
68
+ if text is None:
69
+ text = str(res)
70
+ else:
71
+ text = str(outputs)
72
+
73
+ elapsed = time.time() - start
74
+ meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
75
+ return text, meta
76
+
77
+
78
+ with gr.Blocks(css=".wrap textarea {font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;}") as demo:
79
+ gr.Markdown("# ZeroGPU: Long-Prompt Text Generation\nPaste a long prompt and generate text with a Transformers model. Set `MODEL_ID` in Space secrets to switch models.")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ prompt = gr.Textbox(
84
+ label="Prompt",
85
+ lines=20,
86
+ placeholder="Paste a long prompt here…",
87
+ elem_id="wrap",
88
+ )
89
+ with gr.Accordion("Advanced settings", open=False):
90
+ max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens")
91
+ temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature")
92
+ top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p")
93
+ repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty")
94
+ generate = gr.Button("Generate", variant="primary")
95
+ with gr.Column():
96
+ output = gr.Textbox(label="Output", lines=20)
97
+ meta = gr.Markdown()
98
+
99
+ generate.click(
100
+ fn=generate_long_prompt,
101
+ inputs=[prompt, max_new_tokens, temperature, top_p, repetition_penalty],
102
+ outputs=[output, meta],
103
+ concurrency_limit=1,
104
+ api_name="generate",
105
+ )
106
+
107
+ gr.Examples(
108
+ examples=[
109
+ ["Summarize the following 3 pages of notes into a crisp plan of action…"],
110
+ ["Write a 1200-word blog post about the history of transformers and attention…"],
111
+ ],
112
+ inputs=[prompt],
113
+ )
114
+
115
+ # Important for ZeroGPU: use a queue so calls are serialized & resumable
116
+ if __name__ == "__main__":
117
+ demo.queue(concurrency_count=1, max_size=32).launch()