import pandas as pd import numpy as np import datetime, time import pickle import glob import json from pandas.io.json import json_normalize from nltk.tokenize import sent_tokenize import nltk import scipy.spatial from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering from sentence_transformers import models, SentenceTransformer import torch import spacy import subprocess import streamlit as st from utils import * @st.cache(allow_output_mutation=True) def load_spacy_model(): subprocess.call(['python', '-m','spacy', 'download', 'en_core_web_sm']) @st.cache(allow_output_mutation=True) def load_prep_data(): with open('listfile_3.data', 'rb') as filehandle: articles = pickle.load(filehandle) for article in range(len(articles)): if articles[article][1] != []: articles[article][1] = sent_tokenize(articles[article][1]) return articles @st.cache(allow_output_mutation=True) def build_sent_trans_model(): word_embedding_model = models.BERT('./') # Add the pooling strategy of Mean pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) return model @st.cache(allow_output_mutation=True) def load_embedded_articles(): with open('list_of_articles.pkl', 'rb') as f: list_of_articles = pickle.load(f) return list_of_articles @st.cache(allow_output_mutation=True) def load_comprehension_model(): # device is set to -1 to use the available gpu comprehension_model = pipeline("question-answering", model=AutoModelForQuestionAnswering.\ from_pretrained("graviraja/covidbert_squad"), tokenizer=AutoTokenizer.\ from_pretrained("graviraja/covidbert_squad"), device=-1) return comprehension_model def main(): nltk.download('punkt') load_spacy_model() spacy_nlp = spacy.load('en_core_web_sm') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') embeddings = load_prep_data() model = build_sent_trans_model() model.to(device) list_of_articles = load_embedded_articles() comprehension_model = load_comprehension_model() st.title('Co-Search') query = st.text_input("Enter Query",'What are the corona viruses?', key="query") st.write('Using device type: {}'.format(device)) with st.spinner('Please wait...'): dt1 = datetime.datetime.now() query_embedding, results1 = fetch_stage1(query, model, list_of_articles) results2 = fetch_stage2(results1, model, embeddings, query_embedding) results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp) dt2 = datetime.datetime.now() tdelta = dt2-dt1 st.write('Time taken in minutes: %.2f' % (tdelta.seconds/60)) if results3: count = 1 for res in results3: st.write('{}> {}'.format(count, res[2])) st.write('Score: %.4f' % (res[1])) st.write("From the article with title:\n{}".format(embeddings[res[0]][0])) st.write("\n") if count > 3: break count += 1 else: st.info("There isn't any answer") st.success('Done!') if __name__ == '__main__': main()