w601sxs's picture
Update app.py
ecfed14
raw
history blame contribute delete
No virus
999 Bytes
import gradio as gr
import torch
from peft import PeftModel, PeftConfig, LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
peft_model_id = "w601sxs/b1ade-1b-orca-chkpt-506k"
config = PeftConfig.from_pretrained(peft_model_id)
model = PeftModel.from_pretrained(ref_model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model.eval()
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128)
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
return out_text.split(text)[-1]
demo = gr.Interface(
fn=predict,
inputs='text',
outputs='text',
)
demo.launch()