gemma-2-9b-it / app.py
hysts's picture
hysts HF Staff
Update deps and tidy repo
132f71a
raw
history blame contribute delete
5.75 kB
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, TextIteratorStreamer
DESCRIPTION = """\
# Gemma 2 9B IT
Gemma 2 is Google's latest iteration of open LLMs.
This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/gemma2).
👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it).
"""
MAX_NEW_TOKENS_LIMIT = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096"))
MODEL_ID = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
class StopOnSignal(StoppingCriteria):
def __init__(self) -> None:
self.stopped = False
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs: object) -> bool: # noqa: ARG002
return self.stopped
@spaces.GPU(duration=90)
def _generate_on_gpu(
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
stop_criteria = StopOnSignal()
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"stopping_criteria": [stop_criteria],
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"disable_compile": True,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
try:
for text in streamer:
chunks.append(text)
yield "".join(chunks)
except GeneratorExit:
stop_criteria.stopped = True
for _ in streamer:
pass
thread.join()
raise
thread.join()
if exception_holder:
msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(msg)
def validate_input(message: str) -> dict:
return gr.validate(bool(message and message.strip()), "Please enter a message.")
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for msg in chat_history:
content = msg["content"]
if isinstance(content, list):
text = "".join(part["text"] for part in content if part.get("type") == "text")
else:
text = content
conversation.append({"role": msg["role"], "content": text})
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
).input_ids
n_input_tokens = input_ids.shape[1]
if n_input_tokens > MAX_INPUT_TOKENS:
err_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(err_msg)
max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
if max_new_tokens <= 0:
raise gr.Error("Input uses the entire context window. No room to generate new tokens.")
yield from _generate_on_gpu(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_NEW_TOKENS_LIMIT,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
description=DESCRIPTION,
fill_height=True,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")