hayas's picture
Fix
b3a69a4
raw
history blame contribute delete
No virus
5.21 kB
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """# Swallow-13B instruct"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_name = "tokyotech-llm/Swallow-13b-instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, load_in_8bit=True, low_cpu_mem_usage=True, device_map="auto"
)
MAX_INPUT_TOKENS = 2048
PROMPT_DICT = {
"prompt_input": (
"以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
"リクエストを適切に完了するための回答を記述してください。\n\n"
"### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"
),
"prompt_no_input": (
"以下に、あるタスクを説明する指示があります。" "リクエストを適切に完了するための回答を記述してください。\n\n" "### 指示:\n{instruction}\n\n### 応答:"
),
}
def create_prompt(instruction: str, input_text: str | None = None) -> str:
"""Generates a prompt based on the given instruction and an optional input.
If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
If no input is provided, it uses the 'prompt_no_input' template.
Args:
instruction (str): The instruction describing the task.
input_text (str, optional): Additional input providing context for the task. Default is None.
Returns:
str: The generated prompt.
"""
if input_text:
# Use the 'prompt_input' template when additional input is provided
return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text)
else:
# Use the 'prompt_no_input' template when no additional input is provided
return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
@spaces.GPU
@torch.inference_mode()
def run(
instruction: str,
input_text: str | None = None,
max_new_tokens: int = 256,
temperature: float = 0.99,
top_p: float = 0.95,
) -> Iterator[str]:
if input_text == "":
input_text = None
prompt = create_prompt(instruction, input_text)
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
if input_ids.shape[-1] > MAX_INPUT_TOKENS:
raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids.to(model.device)},
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def process_example(instruction: str, input_text: str) -> Iterator[str]:
yield from run(instruction, input_text)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(label="Instruction", lines=5)
input_text = gr.Textbox(label="Input (optional)", lines=5)
run_button = gr.Button()
with gr.Accordion(label="Advanced Options", open=False):
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
with gr.Column():
output = gr.Textbox(label="Output", lines=10)
run_button.click(
fn=run,
inputs=[instruction, input_text, max_new_tokens, temperature, top_p],
outputs=output,
api_name="run",
)
gr.Examples(
examples=[
["以下のトピックに関する詳細な情報を提供してください。", "東京工業大学の主なキャンパスについて教えてください。"],
["以下のトピックに関する詳細な情報を提供してください。", "夢オチとは何かについて教えてください。"],
["暴れん坊将軍って誰のことですか?", ""],
],
inputs=[instruction, input_text],
outputs=output,
fn=process_example,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
api_name=False,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()