salti commited on
Commit
1ef9f8e
1 Parent(s): 8961013
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +60 -6
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ .mypy_cache
app.py CHANGED
@@ -3,17 +3,71 @@ import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
 
6
- tokenizer = AutoTokenizer.from_pretrained("salti/arabic-t5-small-question-paraphrasing", use_fast=True)
7
- model = AutoModelForSeq2SeqLM.from_pretrained("salti/arabic-t5-small-question-paraphrasing").eval();
 
 
 
 
 
 
8
  prompt = "أعد صياغة: "
9
 
 
10
  @torch.inference_mode()
11
- def paraphrase(question):
12
  question = prompt + question
13
  input_ids = tokenizer(question, return_tensors="pt").input_ids
14
- generated_tokens = model.generate(input_ids).squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
15
  return tokenizer.decode(generated_tokens, skip_special_tokens=True)
16
 
17
-
18
- iface = gr.Interface(fn=paraphrase, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  iface.launch()
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained(
7
+ "salti/arabic-t5-small-question-paraphrasing", use_fast=True
8
+ )
9
+
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(
11
+ "salti/arabic-t5-small-question-paraphrasing"
12
+ ).eval()
13
+
14
  prompt = "أعد صياغة: "
15
 
16
+
17
  @torch.inference_mode()
18
+ def paraphrase(question, num_beams, encoder_no_repeat_ngram_size):
19
  question = prompt + question
20
  input_ids = tokenizer(question, return_tensors="pt").input_ids
21
+ generated_tokens = (
22
+ model.generate(
23
+ input_ids,
24
+ num_beams=num_beams,
25
+ encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
26
+ )
27
+ .squeeze()
28
+ .cpu()
29
+ .numpy()
30
+ )
31
  return tokenizer.decode(generated_tokens, skip_special_tokens=True)
32
 
33
+
34
+ question = gr.inputs.Textbox(label="اكتب سؤالاً باللغة العربية")
35
+ num_beams = gr.inputs.Slider(1, 10, step=1, default=1, label="Beam size")
36
+ encoder_no_repeat_ngram_size = gr.inputs.Slider(
37
+ 0,
38
+ 10,
39
+ step=1,
40
+ default=3,
41
+ label="Ngrams of this size won't be copied from the input (forces more diverse outputs)",
42
+ )
43
+
44
+ outputs = gr.outputs.Textbox(label="السؤال بصيغة مختلفة")
45
+
46
+ examples = [
47
+ [
48
+ "متى تم اختراع الكتابة؟",
49
+ 5,
50
+ 3,
51
+ ],
52
+ [
53
+ "ما عدد حروف اللغة العربية؟",
54
+ 5,
55
+ 3,
56
+ ],
57
+ [
58
+ "ما هو الذكاء الصنعي؟",
59
+ 5,
60
+ 3,
61
+ ],
62
+ ]
63
+
64
+ iface = gr.Interface(
65
+ fn=paraphrase,
66
+ inputs=[question, num_beams, encoder_no_repeat_ngram_size],
67
+ outputs=outputs,
68
+ examples=examples,
69
+ title="Arabic question paraphrasing",
70
+ theme="huggingface",
71
+ )
72
+
73
  iface.launch()