Spaces:
Build error
Build error
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
import torch | |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer | |
import gradio as gr | |
tokenizer = AutoTokenizer.from_pretrained("Llama-2-13b-chat-hf", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained('./results/final_checkpoint', device_map={"":0}, torch_dtype=torch.bfloat16) | |
def answer(my_input): | |
text = "<s>[INST] "+my_input+" [/INST]" | |
inputs = tokenizer(text, return_tensors="pt").to("cuda") | |
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=256, pad_token_id=tokenizer.eos_token_id) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = result.split('[/INST]')[1].strip() | |
return response | |
def expand(my_input, response): | |
text = "<s>[INST] "+my_input+" [/INST]" | |
text += " " + response | |
inputs = tokenizer(text, return_tensors="pt").to("cuda") | |
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=256, pad_token_id=tokenizer.eos_token_id) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = result.split('[/INST]')[1].strip() | |
return response | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox(label = 'Prompt') | |
response = gr.Textbox(label = 'Response') | |
generate_btn = gr.Button('Generate') | |
generate_btn.click(fn = answer, inputs = prompt, outputs = response, api_name = 'answer') | |
expand_btn = gr.Button('Continue') | |
expand_btn.click(fn = expand, inputs = [prompt, response], outputs = response, api_name = 'continue') | |
demo.launch(share = True) | |