import os import gradio as gr from transformers import T5ForConditionalGeneration, T5Tokenizer #question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1') #question_tokenizer = T5Tokenizer.from_pretrained('t5-base') question_model = T5ForConditionalGeneration.from_pretrained('t5-small') question_tokenizer = T5Tokenizer.from_pretrained('t5-small') def get_question(context, answer): text = "context: {} answer: {}".format(context, answer) #max_len = 512 #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt") encoding = question_tokenizer.encode_plus(text, return_tensors="pt") input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] outs = question_model.generate(input_ids=input_ids, attention_mask=attention_mask, early_stopping=True, num_beams=3, # Use fewer beams to generate fewer but higher-quality questions num_return_sequences=3, no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions max_length=256) # Use a shorter max length to focus on generating more relevant questions dec = [question_tokenizer.decode(ids) for ids in outs] questions = "" for i, question in enumerate(dec): question = question.replace("question:", "").replace("", "").replace("", "") question = question.strip() questions += question if i != len(dec)-1: questions += "ยง" return questions input_context = gr.Textbox() input_answer = gr.Textbox() output_question = gr.Textbox() interface = gr.Interface( fn=get_question, inputs=[input_context, input_answer], outputs=output_question ) interface.launch()