CMLL's picture
Update app.py
4ed0b9b verified
raw
history blame
4.37 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import pipeline, AutoTokenizer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# ZhongJing 2 1.8B Merge
This Space demonstrates model [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) for text generation. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
"""
LICENSE = """
<p/>
---
As a derivative work of [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge),
this demo is governed by the original [license](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/LICENSE).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "CMLL/ZhongJing-2-1_8b-merge"
pipe = pipeline("text-generation", model=model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_text = "\n".join([f"{entry['role']}: {entry['content']}" for entry in conversation])
inputs = tokenizer(input_text, return_tensors="pt")
if inputs.input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
inputs = {k: v[:, -MAX_INPUT_TOKEN_LENGTH:] for k, v in inputs.items()}
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
inputs = inputs.to(pipe.device)
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
def run_generation():
return pipe(inputs.input_ids, **generate_kwargs)
t = Thread(target=run_generation)
t.start()
outputs = []
for text in run_generation():
outputs.append(text['generated_text'])
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["请问气虚体质有哪些症状表现?"],
["简单介绍一下中医的五行学说。"],
["桑螵蛸是什么?有什么功效作用?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()