try: import flash_attn except: import subprocess print("Installing flash-attn...") subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import flash_attn print("flash-attn installed.") import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, ) from threading import Thread import gradio as gr import spaces MODEL_NAME_MAP = { "150M": "llm-jp/llm-jp-3-150m-instruct3", "440M": "llm-jp/llm-jp-3-440m-instruct3", "980M": "llm-jp/llm-jp-3-980m-instruct3", "1.8B": "llm-jp/llm-jp-3-1.8b-instruct3", "3.7B": "llm-jp/llm-jp-3-3.7b-instruct3", "7.2B": "llm-jp/llm-jp-3-7.2b-instruct3", "13B": "llm-jp/llm-jp-3-13b-instruct3", } quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) MODELS = { key: AutoModelForCausalLM.from_pretrained( repo_id, quantization_config=quantization_config, device_map="auto", attn_implementation="flash_attention_2", ) for key, repo_id in MODEL_NAME_MAP.items() } TOKENIZERS = { key: AutoTokenizer.from_pretrained(repo_id) for key, repo_id in MODEL_NAME_MAP.items() } print("Compiling model...") for key, model in MODELS.items(): MODELS[key] = torch.compile(model) print("Model compiled.") @spaces.GPU(duration=45) def generate( model_name: str, message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): if not message or message.strip() == "": return "", history messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) tokenized_input = TOKENIZERS[model_name].apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( TOKENIZERS[model_name], timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=tokenized_input, streamer=streamer, max_new_tokens=int(max_tokens), do_sample=True, temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), num_beams=1, ) t = Thread(target=MODELS[model_name].generate, kwargs=generate_kwargs) t.start() # 返す値を初期化 partial_message = "" for new_token in streamer: partial_message += new_token new_history = history + [(message, partial_message)] # 入力テキストをクリアする yield "", new_history def respond( model_name: str, message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): for stream in generate( model_name, message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield (*stream,) def retry( model_name: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): # 最後のメッセージを削除 last_conversation = history[-1] user_message = last_conversation[0] history = history[:-1] for stream in generate( model_name, user_message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield (*stream,) def demo(): with gr.Blocks() as ui: gr.Markdown( """\ # (unofficial) llm-jp/llm-jp-3 instruct3 モデルデモ モデルは bitsandbytes を用いて 4bit (NF4) 量子化されています コレクション: https://huggingface.co/collections/llm-jp/llm-jp-3-fine-tuned-models-672c621db852a01eae939731 """ ) model_name_radio = gr.Radio(label="モデル", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) chat_history = gr.Chatbot(value=[]) with gr.Row(): retry_btn = gr.Button(value="🔄 再生成", scale=1) clear_btn = gr.ClearButton( components=[chat_history], value="🗑️ 削除", scale=1, ) with gr.Row(): input_text = gr.Textbox( value="", placeholder="質問を入力してください...", show_label=False, scale=8, ) start_btn = gr.Button( value="送信", variant="primary", scale=2, ) with gr.Accordion(label="詳細設定", open=False): system_prompt_text = gr.Textbox( label="システムプロンプト", value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", ) max_new_tokens_slider = gr.Slider( minimum=1, maximum=2048, value=256, step=1, label="Max new tokens" ) temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", ) top_k_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Top-k" ) gr.Examples( examples=[ ["情けは人の為ならずとはどういう意味ですか?"], ["まどマギで一番可愛いのは誰?"], ], inputs=[input_text], cache_examples=False, ) gr.on( triggers=[start_btn.click, input_text.submit], fn=respond, inputs=[ model_name_radio, input_text, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, ], outputs=[input_text, chat_history], ) retry_btn.click( retry, inputs=[ model_name_radio, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, ], outputs=[input_text, chat_history], ) ui.launch() if __name__ == "__main__": demo()