import gradio as gr from gradio.components import Textbox, Checkbox from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration from peft import PeftModel import torch import datasets from sentence_transformers import CrossEncoder import math import re from nltk import sent_tokenize, word_tokenize import nltk nltk.download('punkt') # Load cross encoder top_k = 10 cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Load your fine-tuned model and tokenizer model_name = "google/flan-t5-large" peft_name = "legacy107/flan-t5-large-ia3-covidqa" tokenizer = AutoTokenizer.from_pretrained(model_name) pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large") model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large") model = PeftModel.from_pretrained(model, peft_name) peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase" paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name) max_length = 512 max_target_length = 200 # Load your dataset dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train") dataset = dataset.shuffle() dataset = dataset.select(range(10)) # Context chunking min_sentences_per_chunk = 3 chunk_size = 64 window_size = math.ceil(min_sentences_per_chunk * 0.25) over_lap_chunk_size = chunk_size * 0.25 def chunk_splitter(context): sentences = sent_tokenize(context) chunks = [] current_chunk = [] for sentence in sentences: if len(current_chunk) < min_sentences_per_chunk: current_chunk.append(sentence) continue elif len(word_tokenize(' '.join(current_chunk) + " " + sentence)) < chunk_size: current_chunk.append(sentence) continue chunks.append(' '.join(current_chunk)) new_chunk = current_chunk[-window_size:] new_window = window_size buffer_new_chunk = new_chunk while len(word_tokenize(' '.join(new_chunk))) <= over_lap_chunk_size: buffer_new_chunk = new_chunk new_window += 1 new_chunk = current_chunk[-new_window:] if new_window >= len(current_chunk): break current_chunk = buffer_new_chunk current_chunk.append(sentence) if current_chunk: chunks.append(' '.join(current_chunk)) return chunks def clean_data(text): # Extract abstract content index = text.find("\nAbstract: ") if index != -1: cleaned_text = text[index + len("\nAbstract: "):] else: cleaned_text = text # If "\nAbstract: " is not found, keep the original text # Remove both http and https links using a regular expression cleaned_text = re.sub(r'(http(s|)\/\/:( |)\S+)|(http(s|):\/\/( |)\S+)', '', cleaned_text) # Remove DOI patterns like "doi:10.1371/journal.pone.0007211.s003" cleaned_text = re.sub(r'doi:( |)\w+', '', cleaned_text) # Remove the "(0.11 MB DOC)" pattern cleaned_text = re.sub(r'\(0\.\d+ MB DOC\)', '', cleaned_text) cleaned_text = re.sub(r'www\.\w+(.org|)', '', cleaned_text) return cleaned_text def paraphrase_answer(question, answer, use_pretrained=False): # Combine question and context input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}" # Tokenize the input text input_ids = tokenizer( input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ).input_ids # Generate the answer with torch.no_grad(): if use_pretrained: generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) else: generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) # Decode and return the generated answer paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return paraphrased_answer def retrieve_context(question, contexts): # cross-encoder hits = [{"corpus_id": i} for i in range(len(contexts))] cross_inp = [[question, contexts[hit["corpus_id"]]] for hit in hits] cross_scores = cross_encoder.predict(cross_inp, show_progress_bar=False) for idx in range(len(cross_scores)): hits[idx]["cross-score"] = cross_scores[idx] hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True) return " ".join( [contexts[hit["corpus_id"]] for hit in hits[0:top_k]] ).replace("\n", " ") # Define your function to generate answers def generate_answer(question, context, ground, do_pretrained, do_natural, do_pretrained_natural): contexts = chunk_splitter(clean_data(context)) context = retrieve_context(question, contexts) # Combine question and context input_text = f"question: {question} context: {context}" # Tokenize the input text input_ids = tokenizer( input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ).input_ids # Generate the answer with torch.no_grad(): generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length) # Decode and return the generated answer generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Paraphrase answer paraphrased_answer = "" if do_natural: paraphrased_answer = paraphrase_answer(question, generated_answer) # Get pretrained model's answer pretrained_answer = "" if do_pretrained: with torch.no_grad(): pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True) # Get pretrained model's natural answer pretrained_paraphrased_answer = "" if do_pretrained_natural: pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True) return generated_answer, context, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer # Define a function to list examples from the dataset def list_examples(): examples = [] for example in dataset: context = example["context"] question = example["question"] answer = example["answer"] examples.append([question, context, answer, True, True, True]) return examples # Create a Gradio interface iface = gr.Interface( fn=generate_answer, inputs=[ Textbox(label="Question"), Textbox(label="Context"), Textbox(label="Ground truth"), Checkbox(label="Include pretrained model's result"), Checkbox(label="Include natural answer"), Checkbox(label="Include pretrained model's natural answer") ], outputs=[ Textbox(label="Generated Answer"), Textbox(label="Retrieved Context"), Textbox(label="Natural Answer"), Textbox(label="Pretrained Model's Answer"), Textbox(label="Pretrained Model's Natural Answer") ], examples=list_examples(), examples_per_page=1, ) # Launch the Gradio interface iface.launch()