reach-vb's picture
Create app.py
18c7b07 verified
raw
history blame
4.22 kB
import os
import time
from typing import List, Dict
import gradio as gr
from transformers import pipeline
import spaces
# === Config (override via Space secrets/env vars) ===
MODEL_ID = os.environ.get("MODEL_ID", "tlhv/osb-minier")
DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95))
DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
# Cached pipeline (created after GPU is granted)
_pipe = None
def _to_messages(user_prompt: str) -> List[Dict[str, str]]:
# The provided model expects chat-style messages
return [{"role": "user", "content": user_prompt}]
@spaces.GPU(duration=ZGPU_DURATION)
def generate_long_prompt(
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
):
"""Runs on a ZeroGPU-allocated GPU thanks to the decorator above."""
global _pipe
start = time.time()
# Create the pipeline lazily once the GPU is available
if _pipe is None:
_pipe = pipeline(
"text-generation",
model=MODEL_ID,
torch_dtype="auto",
device_map="auto", # let HF accelerate map to the GPU we just got
)
messages = _to_messages(prompt)
outputs = _pipe(
messages,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
# Robust extraction for different pipeline return shapes
text = None
if isinstance(outputs, list) and outputs:
res = outputs[0]
if isinstance(res, dict):
gt = res.get("generated_text")
if isinstance(gt, list) and gt and isinstance(gt[-1], dict):
text = gt[-1].get("content") or gt[-1].get("text")
elif isinstance(gt, str):
text = gt
if text is None:
text = str(res)
else:
text = str(outputs)
elapsed = time.time() - start
meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
return text, meta
with gr.Blocks(css=".wrap textarea {font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;}") as demo:
gr.Markdown("# ZeroGPU: Long-Prompt Text Generation\nPaste a long prompt and generate text with a Transformers model. Set `MODEL_ID` in Space secrets to switch models.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=20,
placeholder="Paste a long prompt here…",
elem_id="wrap",
)
with gr.Accordion("Advanced settings", open=False):
max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p")
repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty")
generate = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Output", lines=20)
meta = gr.Markdown()
generate.click(
fn=generate_long_prompt,
inputs=[prompt, max_new_tokens, temperature, top_p, repetition_penalty],
outputs=[output, meta],
concurrency_limit=1,
api_name="generate",
)
gr.Examples(
examples=[
["Summarize the following 3 pages of notes into a crisp plan of action…"],
["Write a 1200-word blog post about the history of transformers and attention…"],
],
inputs=[prompt],
)
# Important for ZeroGPU: use a queue so calls are serialized & resumable
if __name__ == "__main__":
demo.queue(concurrency_count=1, max_size=32).launch()