p208p2002 commited on
Commit
b947701
β€’
1 Parent(s): b1e857a
Files changed (4) hide show
  1. .gitignore +3 -0
  2. README.md +2 -1
  3. app.py +92 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pyproject.toml
2
+ .venv/
3
+ flagged/
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Question Group Generator
3
- emoji: 🏒
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Question Group Generator
3
+ emoji: πŸ§‘β€πŸ«
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.8.9
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import BartTokenizerFast, BartForConditionalGeneration
3
+ import torch
4
+ import re
5
+ from qgg_utils.optim import GAOptimizer # https://github.com/p208p2002/qgg-utils.git
6
+
7
+ MAX_LENGTH=512
8
+
9
+ default_context = "Facebook is an online social media and social networking service owned by American company Meta Platforms. Founded in 2004 by Mark Zuckerberg with fellow Harvard College students and roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz, and Chris Hughes, its name comes from the face book directories often given to American university students. Membership was initially limited to Harvard students, gradually expanding to other North American universities and, since 2006, anyone over 13 years old. As of July 2022, Facebook claimed 2.93 billion monthly active users,[6] and ranked third worldwide among the most visited websites as of July 2022. It was the most downloaded mobile app of the 2010s."
10
+
11
+ model=BartForConditionalGeneration.from_pretrained("p208p2002/qmst-qgg")
12
+ tokenizer=BartTokenizerFast.from_pretrained("p208p2002/qmst-qgg")
13
+
14
+ def feedback_generation(model, tokenizer, input_ids, feedback_times = 3):
15
+ outputs = []
16
+ device = 'cpu'
17
+ for i in range(feedback_times):
18
+ gened_text = tokenizer.bos_token * (len(outputs)+1)
19
+ gened_ids = tokenizer(gened_text,add_special_tokens=False)['input_ids']
20
+ input_ids = gened_ids + input_ids
21
+ input_ids = input_ids[:MAX_LENGTH]
22
+
23
+ sample_outputs = model.generate(
24
+ input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device),
25
+ attention_mask=torch.LongTensor([1]*len(input_ids)).unsqueeze(0).to(device),
26
+ max_length=50,
27
+ early_stopping=True,
28
+ temperature=1.0,
29
+ do_sample=True,
30
+ top_p=0.9,
31
+ top_k=10,
32
+ num_beams=1,
33
+ no_repeat_ngram_size=5,
34
+ num_return_sequences=1,
35
+ )
36
+ sample_output = sample_outputs[0]
37
+ decode_question = tokenizer.decode(sample_output, skip_special_tokens=False)
38
+ decode_question = re.sub(re.escape(tokenizer.pad_token),'',decode_question)
39
+ decode_question = re.sub(re.escape(tokenizer.eos_token),'',decode_question)
40
+ if tokenizer.bos_token is not None:
41
+ decode_question = re.sub(re.escape(tokenizer.bos_token),'',decode_question)
42
+ decode_question = decode_question.strip()
43
+ decode_question = decode_question.replace("[Q:]","")
44
+ outputs.append(decode_question)
45
+ return outputs
46
+
47
+ def gen_quesion_group(context,question_group_size):
48
+ question_group_size = int(question_group_size)
49
+ print(context,question_group_size)
50
+ candidate_pool_size = question_group_size*2
51
+ tokenize_result = tokenizer.batch_encode_plus(
52
+ [context],
53
+ stride=MAX_LENGTH - int(MAX_LENGTH*0.7),
54
+ max_length=MAX_LENGTH,
55
+ truncation=True,
56
+ add_special_tokens=False,
57
+ return_overflowing_tokens=True,
58
+ return_length=True,
59
+ )
60
+ candidate_questions = []
61
+
62
+ if len(tokenize_result.input_ids)>=10:
63
+ tokenize_result.input_ids = tokenize_result.input_ids[:10]
64
+
65
+ for input_ids in tokenize_result.input_ids:
66
+ candidate_questions += feedback_generation(
67
+ model=model,
68
+ tokenizer=tokenizer,
69
+ input_ids=input_ids,
70
+ feedback_times=candidate_pool_size
71
+ )
72
+
73
+ while len(candidate_questions) > question_group_size:
74
+ qgg_optim = GAOptimizer(len(candidate_questions),question_group_size)
75
+ candidate_questions = qgg_optim.optimize(candidate_questions,context)
76
+
77
+ # format
78
+ candidate_questions = [f" - {q}" for q in candidate_questions]
79
+ return '\n'.join(candidate_questions)
80
+
81
+ demo = gr.Interface(
82
+ fn=gen_quesion_group,
83
+ inputs=[
84
+ gr.Textbox(lines=10, value=default_context, label="Context",placeholder="Paste some context here"),
85
+ gr.Slider(3, 8,step=1,label="Group Size")
86
+ ],
87
+ outputs=gr.Textbox(
88
+ lines = 8,
89
+ label = "Generation Question Group"
90
+ ),
91
+ )
92
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.4
2
+ torch==1.12.1
3
+ transformers==4.22.2
4
+ git+https://github.com/p208p2002/qgg-utils.git
5
+ git+https://github.com/voidful/nlg-eval.git@master
6
+ stanza