Hieu-Pham commited on
Commit
b8f45c1
1 Parent(s): 5f7897f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install -q -U git+https://github.com/huggingface/transformers.git
2
+ !pip install -q gradio
3
+
4
+ from huggingface_hub import notebook_login
5
+
6
+ notebook_login()
7
+
8
+ from transformers import pipeline
9
+ from transformers import StoppingCriteria, StoppingCriteriaList
10
+ from transformers import AutoTokenizer
11
+ import torch
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("Hieu-Pham/Llama2-7B-QLoRA-cooking-text-gen-merged")
14
+
15
+ stop_token_ids = tokenizer.convert_tokens_to_ids(["\n", "#", "\\", "`", "###", "##", "Question", "Comment", "Answer"])
16
+
17
+ # define custom stopping criteria object
18
+ class StopOnTokens(StoppingCriteria):
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ for stop_id in stop_token_ids:
21
+ if input_ids[0][-1] == stop_id:
22
+ return True
23
+ return False
24
+
25
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
26
+
27
+ pipe = pipeline(
28
+ task="text-generation",
29
+ model="Hieu-Pham/Llama2-7B-QLoRA-cooking-text-gen-merged",
30
+ tokenizer=tokenizer,
31
+ return_full_text=False,
32
+ stopping_criteria=stopping_criteria,
33
+ temperature=0.1,
34
+ top_p=0.15,
35
+ top_k=0,
36
+ max_new_tokens=100,
37
+ repetition_penalty=1.1
38
+ )
39
+
40
+ import gradio as gr
41
+
42
+ def predict(question, context):
43
+ input = f"Question: {question} Context: {context} Answer:"
44
+ predictions = pipe(input)
45
+ output = predictions["generated_text"].replace("Question", "")
46
+ return output
47
+
48
+ demo = gr.Interface(
49
+ predict,
50
+ inputs=[gr.Textbox(lines=2, placeholder="Please provide your question", label="Question"), gr.Textbox(lines=2, placeholder="Please provide your context", label="Context")],
51
+ outputs=gr.Textbox(lines=2, placeholder="Predicted Answers..."),
52
+ title="Question Answering",
53
+ )
54
+
55
+ demo.launch()