# https://www.gradio.app/guides/using-hugging-face-integrations import gradio as gr import logging import html from pprint import pprint import time import torch from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer # Model model_name = "augmxnt/shisa-7b-v1" # UI Settings title = "Shisa 7B" description = "Test out Shisa 7B in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language. Note, we are running `load_in_4bit` to fit in 16GB of VRAM" placeholder = "Type Here / ここに入力してください" examples = [ ["What are the best slices of pizza in New York City?"], ["東京でおすすめのラーメン屋ってどこ?"], ['How do I program a simple "hello world" in Python?'], ["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"], ] # LLM Settings # Initial system_prompt = 'You are a helpful, bilingual assistant. Reply in the same language as the user.' default_prompt = system_prompt tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", # load_in_8bit=True, quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ), use_flash_attention_2=True ) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) def chat(message, history, system_prompt): print('---') pprint(history) if not system_prompt: system_prompt = default_prompt # Let's just rebuild every time it's easier chat_history = [{"role": "system", "content": system_prompt}] for h in history: chat_history.append({"role": "user", "content": h[0]}) chat_history.append({"role": "assistant", "content": h[1]}) chat_history.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt") # for multi-gpu, find the device of the first parameter of the model first_param_device = next(model.parameters()).device input_ids = input_ids.to(first_param_device) generate_kwargs = dict( inputs=input_ids, streamer=streamer, max_new_tokens=200, do_sample=True, temperature=0.7, repetition_penalty=1.15, top_p=0.95, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: partial_message += new_token # html.escape(new_token) yield partial_message chat_interface = gr.ChatInterface( chat, chatbot=gr.Chatbot(height=400), textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7), title=title, description=description, theme="soft", examples=examples, cache_examples=False, undo_btn="Delete Previous", clear_btn="Clear", additional_inputs=[ gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"), ], ) # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise with gr.Blocks() as demo: chat_interface.render() gr.Markdown("You can try asking this question in Japanese or English. We limit output to 200 tokens.") demo.queue().launch()