Hieu-Pham's picture
Update app.py
24074bc
raw
history blame contribute delete
No virus
1.67 kB
from transformers import pipeline
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("Hieu-Pham/Llama2-7B-QLoRA-cooking-text-gen-merged")
stop_token_ids = tokenizer.convert_tokens_to_ids(["\n", "#", "\\", "`", "###", "##", "Question", "Comment", "Answer", "Context"])
# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
pipe = pipeline(
task="text-generation",
model="Hieu-Pham/Llama2-7B-QLoRA-cooking-text-gen-merged",
tokenizer=tokenizer,
return_full_text=False,
stopping_criteria=stopping_criteria,
temperature=0.1,
top_p=0.15,
top_k=0,
max_new_tokens=100,
repetition_penalty=1.1
)
import gradio as gr
def predict(question, context):
input = f"Question: {question} Context: {context} Answer:"
predictions = pipe(input)
output = predictions[0]["generated_text"].replace("Question", "")
output = output.replace("Answer", "")
return output
demo = gr.Interface(
predict,
inputs=[gr.Textbox(lines=2, placeholder="Please provide your question", label="Question"), gr.Textbox(lines=2, placeholder="Please provide your context", label="Context")],
outputs=gr.Textbox(lines=2, placeholder="Answers...", label="Output"),
title="Cooking Recipe MRC",
)
demo.launch()