Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import subprocess | |
| from threading import Thread | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| MODEL_ID = os.environ.get("MODEL_ID") | |
| CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE") | |
| MODEL_NAME = MODEL_ID.split("/")[-1] | |
| CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH")) | |
| COLOR = os.environ.get("COLOR") | |
| EMOJI = os.environ.get("EMOJI") | |
| DESCRIPTION = os.environ.get("DESCRIPTION") | |
| def predict( | |
| message, | |
| history, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_k, | |
| repetition_penalty, | |
| top_p, | |
| ): | |
| # Format history with a given chat template | |
| if CHAT_TEMPLATE == "ChatML": | |
| stop_tokens = ["<|endoftext|>", "<|im_end|>"] | |
| instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n" | |
| for human, assistant in history: | |
| instruction += ( | |
| "<|im_start|>user\n" | |
| + human | |
| + "\n<|im_end|>\n<|im_start|>assistant\n" | |
| + assistant | |
| ) | |
| instruction += ( | |
| "\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n" | |
| ) | |
| elif CHAT_TEMPLATE == "Mistral Instruct": | |
| stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "] | |
| instruction = "<s>[INST] " + system_prompt | |
| for human, assistant in history: | |
| instruction += human + " [/INST] " + assistant + "</s>[INST]" | |
| instruction += " " + message + " [/INST]" | |
| else: | |
| raise Exception( | |
| "Incorrect chat template, select 'ChatML' or 'Mistral Instruct'" | |
| ) | |
| print(instruction) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) | |
| input_ids, attention_mask = enc.input_ids, enc.attention_mask | |
| if input_ids.shape[1] > CONTEXT_LENGTH: | |
| input_ids = input_ids[:, -CONTEXT_LENGTH:] | |
| generate_kwargs = dict( | |
| { | |
| "input_ids": input_ids.to(device), | |
| "attention_mask": attention_mask.to(device), | |
| }, | |
| streamer=streamer, | |
| do_sample=True, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| top_p=top_p, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for new_token in streamer: | |
| outputs.append(new_token) | |
| if new_token in stop_tokens: | |
| break | |
| yield "".join(outputs) | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=False, bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| quantization_config=quantization_config, | |
| attn_implementation="flash_attention_2", | |
| ) | |
| # Create Gradio interface | |
| gr.ChatInterface( | |
| predict, | |
| title=EMOJI + " " + MODEL_NAME, | |
| description=DESCRIPTION, | |
| examples=[ | |
| ["Can you solve the equation 2x + 3 = 11 for x?"], | |
| ["Write an epic poem about Ancient Rome."], | |
| ["Who was the first person to walk on the Moon?"], | |
| [ | |
| "Use a list comprehension to create a list of squares for numbers from 1 to 10." | |
| ], | |
| ["Recommend some popular science fiction books."], | |
| ["Can you write a short story about a time-traveling detective?"], | |
| ], | |
| additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), | |
| additional_inputs=[ | |
| gr.Textbox( | |
| "Perform the task to the best of your ability.", label="System prompt" | |
| ), | |
| gr.Slider(0, 1, 0.8, label="Temperature"), | |
| gr.Slider(128, 4096, 1024, label="Max new tokens"), | |
| gr.Slider(1, 80, 40, label="Top K sampling"), | |
| gr.Slider(0, 2, 1.1, label="Repetition penalty"), | |
| gr.Slider(0, 1, 0.95, label="Top P sampling"), | |
| ], | |
| theme=gr.themes.Soft(primary_hue=COLOR), | |
| ).queue().launch() | |