import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr from sklearn.ensemble import RandomForestClassifier from sklearn.feature_extraction.text import TfidfVectorizer import pickle vectorizer = pickle.load(open("tfidf.pickle", "rb")) # clf = pickle.load(open("classifier.pickle", "rb")) example_context = "ফলস্বরূপ, ১৯৭৯ সালে, সনি এবং ফিলিপস একটি নতুন ডিজিটাল অডিও ডিস্ক ডিজাইন করার জন্য প্রকৌশলীদের একটি যৌথ টাস্ক ফোর্স গঠন করে। ইঞ্জিনিয়ার কিস শুহামার ইমমিনক এবং তোশিতাদা দোই এর নেতৃত্বে, গবেষণাটি লেজার এবং অপটিক্যাল ডিস্ক প্রযুক্তিকে এগিয়ে নিয়ে যায়। এক বছর পরীক্ষা-নিরীক্ষা ও আলোচনার পর টাস্ক ফোর্স রেড বুক সিডি-ডিএ স্ট্যান্ডার্ড তৈরি করে। প্রথম প্রকাশিত হয় ১৯৮০ সালে। আইইসি কর্তৃক ১৯৮৭ সালে আন্তর্জাতিক মান হিসেবে আনুষ্ঠানিকভাবে এই মান গৃহীত হয় এবং ১৯৯৬ সালে বিভিন্ন সংশোধনী মানের অংশ হয়ে ওঠে।'" example_answer = "১৯৮০" def choose_model(model_choice): if model_choice=="mt5-small": return "jannatul17/squad-bn-qgen-mt5-small-v1" elif model_choice=="mt5-base": return "Tahsin-Mayeesha/squad-bn-mt5-base2" else : return "jannatul17/squad-bn-qgen-banglat5-v1" def generate_questions(model_choice,context,answer,numReturnSequences=1,num_beams=None,do_sample=False,top_p=None,top_k=None,temperature=None): model_name = choose_model(model_choice) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) text='answer: '+answer + ' context: ' + context text_encoding = tokenizer.encode_plus( text,return_tensors="pt" ) model.eval() generated_ids = model.generate( input_ids=text_encoding['input_ids'], attention_mask=text_encoding['attention_mask'], max_length=120, num_beams=num_beams, do_sample=do_sample, top_k = top_k, top_p = top_p, temperature = temperature, num_return_sequences=numReturnSequences ) text = [] for id in generated_ids: text.append(tokenizer.decode(id,skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')) question = " ".join(text) #correctness_pred = clf.predict(vectorizer.transform([question]))[0] #if correctness_pred == 1: # correctness = "Correct" #else : # correctness = "Incorrect" return question demo = gr.Interface(fn=generate_questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"), gr.Textbox(label='Context'), gr.Textbox(label='Answer'), # hyperparameters gr.Slider(1, 3, 1, step=1, label="Num return Sequences"), # beam search gr.Slider(1, 10,value=None, step=1, label="Beam width"), # top-k/top-p gr.Checkbox(label="Do Random Sample",value=False), gr.Slider(0, 50, value=None, step=1, label="Top K"), gr.Slider(0, 1, value=None, label="Top P/Nucleus Sampling"), gr.Slider(0, 1, value=None, label="Temperature") ] , # output outputs=[gr.Textbox(label='Question')], examples=[["banglat5",example_context,example_answer]], cache_examples=False, title="Bangla Question Generation") demo.launch()