import gradio as gr from transformers import BartTokenizerFast, BartForConditionalGeneration import torch import re from qgg_utils.optim import GAOptimizer # https://github.com/p208p2002/qgg-utils.git MAX_LENGTH=512 default_context = "Facebook is an online social media and social networking service owned by American company Meta Platforms. Founded in 2004 by Mark Zuckerberg with fellow Harvard College students and roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz, and Chris Hughes, its name comes from the face book directories often given to American university students. Membership was initially limited to Harvard students, gradually expanding to other North American universities and, since 2006, anyone over 13 years old. As of July 2022, Facebook claimed 2.93 billion monthly active users,[6] and ranked third worldwide among the most visited websites as of July 2022. It was the most downloaded mobile app of the 2010s." model=BartForConditionalGeneration.from_pretrained("p208p2002/qmst-qgg") tokenizer=BartTokenizerFast.from_pretrained("p208p2002/qmst-qgg") def feedback_generation(model, tokenizer, input_ids, feedback_times = 3): outputs = [] device = 'cpu' for i in range(feedback_times): gened_text = tokenizer.bos_token * (len(outputs)+1) gened_ids = tokenizer(gened_text,add_special_tokens=False)['input_ids'] input_ids = gened_ids + input_ids input_ids = input_ids[:MAX_LENGTH] sample_outputs = model.generate( input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device), attention_mask=torch.LongTensor([1]*len(input_ids)).unsqueeze(0).to(device), max_length=50, early_stopping=True, temperature=1.0, do_sample=True, top_p=0.9, top_k=10, num_beams=1, no_repeat_ngram_size=5, num_return_sequences=1, ) sample_output = sample_outputs[0] decode_question = tokenizer.decode(sample_output, skip_special_tokens=False) decode_question = re.sub(re.escape(tokenizer.pad_token),'',decode_question) decode_question = re.sub(re.escape(tokenizer.eos_token),'',decode_question) if tokenizer.bos_token is not None: decode_question = re.sub(re.escape(tokenizer.bos_token),'',decode_question) decode_question = decode_question.strip() decode_question = decode_question.replace("[Q:]","") outputs.append(decode_question) return outputs def gen_quesion_group(context,question_group_size): question_group_size = int(question_group_size) print(context,question_group_size) candidate_pool_size = question_group_size*2 tokenize_result = tokenizer.batch_encode_plus( [context], stride=MAX_LENGTH - int(MAX_LENGTH*0.7), max_length=MAX_LENGTH, truncation=True, add_special_tokens=False, return_overflowing_tokens=True, return_length=True, ) candidate_questions = [] if len(tokenize_result.input_ids)>=10: tokenize_result.input_ids = tokenize_result.input_ids[:10] for input_ids in tokenize_result.input_ids: candidate_questions += feedback_generation( model=model, tokenizer=tokenizer, input_ids=input_ids, feedback_times=candidate_pool_size ) while len(candidate_questions) > question_group_size: qgg_optim = GAOptimizer(len(candidate_questions),question_group_size) candidate_questions = qgg_optim.optimize(candidate_questions,context) # format candidate_questions = [f" - {q}" for q in candidate_questions] return '\n'.join(candidate_questions) demo = gr.Interface( fn=gen_quesion_group, inputs=[ gr.Textbox(lines=10, value=default_context, label="Context",placeholder="Paste some context here"), gr.Slider(3, 8,step=1,label="Group Size") ], outputs=gr.Textbox( lines = 8, label = "Generation Question Group" ), ) demo.launch()