import gradio as gr import os import torch import pickle import gzip from torch.nn.functional import cosine_similarity from model import create_semantic_ranking_model from timeit import default_timer as timer from typing import Tuple, Dict ### Load example texts ### questions_texts = [] with open("questions_texts.txt", "r") as file: questions_texts = [line.strip() for line in file.readlines()] answers_texts = [] with open("answers_texts.txt", "r") as file: answers_texts = [line.strip() for line in file.readlines()] ### Model and transforms preparation ### # Create model and tokenizer model, tokenizer = create_semantic_ranking_model() # Load saved weights model.load_state_dict( torch.load(f="all-MiniLM-L6-v2.pth", map_location=torch.device("cpu")) # load to CPU ) # Load the embeddings with gzip.open('response_embeddings.pkl.gz', 'rb') as f: response_embeddings = pickle.load(f) # Load the response list with gzip.open('response_list.pkl.gz', 'rb') as f: response_list = pickle.load(f) ### Predict function ### def predict(text) -> Tuple[Dict, float]: # Start a timer start_time = timer() # Set the model to eval model.eval() # Set up the inputs tokenized_inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True, padding="max_length") # Get input_embeddings with torch.inference_mode(): input_embeddings = model(**tokenized_inputs) # Compute similarity scores similarity_scores = cosine_similarity(input_embeddings.unsqueeze(1), response_embeddings.unsqueeze(0), dim=2) top_responses_indices = torch.topk(similarity_scores, k=5, dim=1).indices.squeeze() # Retrieve the actual response texts top_responses = [response_list[idx] for idx in top_responses_indices] # Get actual response actual_response = None for i, question in enumerate(questions_texts): if text.strip() == question.strip(): actual_response = answers_texts[i] break # Calculate pred time end_time = timer() pred_time = round(end_time - start_time, 4) # Return pred dict and pred time return {"Top Responses": top_responses}, actual_response, pred_time ### 4. Gradio app ### # Create title, description and article title = "Semantic Ranking with MiniLM-L6-v2" description = "[A MiniLM-L6-H384-uncased MiniLM based model](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) sentence embedding model trained to rank results from [HuggingFace 🤗 Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3). [Source Code Found Here](https://colab.research.google.com/drive/1o5a9zH1TxzaxLKV5AFUhZE8L8yMnO9Jw?usp=sharing)" article = "Built with [Gradio](https://github.com/gradio-app/gradio) and [PyTorch](https://pytorch.org/). [Source Code Found Here](https://colab.research.google.com/drive/1o5a9zH1TxzaxLKV5AFUhZE8L8yMnO9Jw?usp=sharing)" # Create the Gradio demo demo = gr.Interface(fn=predict, inputs=gr.Textbox(lines=2, placeholder="Type your text here..."), outputs=[gr.JSON(label="Top Responses"), gr.Textbox(label="Actual Response"), gr.Number(label="Prediction time (s)")], examples=questions_texts, title=title, description=description, article=article) # Launch the demo demo.launch()