from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import torch | |
import os | |
from huggingface_hub import login | |
login(os.getenv('HF_KEY')) | |
def init_model(): | |
system_prompt = "You are a pirate chatbot who always responds in pirate speak!" | |
# system_prompt = "### System:\nYou are StableBeluga, an AI that follows instructions extremely well. Help as much as you can.\n\n" | |
# model = AutoModelForCausalLM.from_pretrained( | |
# "stabilityai/StableBeluga2", | |
# torch_dtype=torch.float16, | |
# low_cpu_mem_usage=True, | |
# device_map="auto") | |
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# "stabilityai/stablelm-2-12b", | |
# torch_dtype="auto", | |
# ) | |
# model.cuda() | |
# model = AutoModelForCausalLM.from_pretrained( | |
# 'stabilityai/stablelm-2-12b-chat', | |
# device_map="auto", | |
# ) | |
# tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-12b") | |
# tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-2-12b-chat') | |
# print(tokenizer.decode(output[0], skip_special_tokens=True)) | |
# tokenizer = AutoTokenizer.from_pretrained("stabilityai/StableBeluga-7B", use_fast=True) | |
# model = AutoModelForCausalLM.from_pretrained("stabilityai/StableBeluga-7B", load_in_8bit=True, low_cpu_mem_usage=True, device_map=0) | |
# model = AutoModelForCausalLM.from_pretrained("stabilityai/StableBeluga-7B", low_cpu_mem_usage=True, device_map=0) | |
# model = AutoModelForCausalLM.from_pretrained("stabilityai/StableBeluga-7B", device_map=0) | |
return system_prompt, tokenizer, model | |
system_prompt, tokenizer, model = init_model() | |
# def make_prompt(user, syst=system_prompt): | |
# # return f"{syst}### User: {user}\n\n### Assistant:\n" | |
# ] return [{'role': 'user', 'content': user}] | |
def ask_assistant(prompt, token=tokenizer, md=model, system_prompt=system_prompt): | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(md.device) | |
terminators = [ | |
token.eos_token_id, | |
token.convert_tokens_to_ids("<|eot_id|>") | |
] | |
outputs = md.generate( | |
input_ids, | |
max_new_tokens=256, | |
eos_token_id=terminators, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9, | |
) | |
response = outputs[0][input_ids.shape[-1]:] | |
return tokenizer.decode(response, skip_special_tokens=True) | |
def ask(prompt): | |
return ask_assistant(prompt) | |
demo = gr.Interface( | |
gr.Radio(["LLaMa-3", "StableBeluga-2-12b", "Falcon-11b"], label="Model",), | |
fn=ask, | |
inputs=["text"], | |
outputs=["text"], | |
) | |
demo.launch() |