COVID_NLI / app.py
hitz02's picture
Update app.py
124530a
import pandas as pd
import numpy as np
import datetime, time
import pickle
import glob
import json
from pandas.io.json import json_normalize
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import subprocess
import streamlit as st
from utils import *
@st.cache(allow_output_mutation=True)
def load_spacy_model():
subprocess.call(['python', '-m','spacy', 'download', 'en_core_web_sm'])
@st.cache(allow_output_mutation=True)
def load_prep_data():
with open('listfile_3.data', 'rb') as filehandle:
articles = pickle.load(filehandle)
for article in range(len(articles)):
if articles[article][1] != []:
articles[article][1] = sent_tokenize(articles[article][1])
return articles
@st.cache(allow_output_mutation=True)
def build_sent_trans_model():
word_embedding_model = models.BERT('./')
# Add the pooling strategy of Mean
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
return model
@st.cache(allow_output_mutation=True)
def load_embedded_articles():
with open('list_of_articles.pkl', 'rb') as f:
list_of_articles = pickle.load(f)
return list_of_articles
@st.cache(allow_output_mutation=True)
def load_comprehension_model():
# device is set to -1 to use the available gpu
comprehension_model = pipeline("question-answering",
model=AutoModelForQuestionAnswering.\
from_pretrained("graviraja/covidbert_squad"),
tokenizer=AutoTokenizer.\
from_pretrained("graviraja/covidbert_squad"),
device=-1)
return comprehension_model
def main():
nltk.download('punkt')
load_spacy_model()
spacy_nlp = spacy.load('en_core_web_sm')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
embeddings = load_prep_data()
model = build_sent_trans_model()
model.to(device)
list_of_articles = load_embedded_articles()
comprehension_model = load_comprehension_model()
st.title('Co-Search')
query = st.text_input("Enter Query",'What are the corona viruses?', key="query")
st.write('Using device type: {}'.format(device))
with st.spinner('Please wait...'):
dt1 = datetime.datetime.now()
query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
results2 = fetch_stage2(results1, model, embeddings, query_embedding)
results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
dt2 = datetime.datetime.now()
tdelta = dt2-dt1
st.write('Time taken in minutes: %.2f' % (tdelta.seconds/60))
if results3:
count = 1
for res in results3:
st.write('{}> {}'.format(count, res[2]))
st.write('Score: %.4f' % (res[1]))
st.write("From the article with title:\n{}".format(embeddings[res[0]][0]))
st.write("\n")
if count > 3:
break
count += 1
else:
st.info("There isn't any answer")
st.success('Done!')
if __name__ == '__main__':
main()