bolete / app.py
apjanco
type error
7cdaf14
import streamlit as st
from typing import List
import textract
import tempfile
import spacy
import subprocess
import scispacy
from spacy.tokens import DocBin, Doc, Span
from collections import Counter
import srsly
from spacy.matcher import PhraseMatcher
# Import CSS file
with open("style.css") as f:
st.markdown("<style>" + f.read() + "</style>", unsafe_allow_html=True)
st.title('Index and Search a Collection of Documents')
if 'query' not in st.session_state:
st.session_state['query'] = ''
@st.cache
def download_model(language:str, select_model:str):
if language == 'Science':
urls = srsly.read_json('scispacy.json')
subprocess.run(['pip', 'install', f'{urls[select_model]}'])
return True
else:
try:
spacy.cli.download(select_model)
return True
except Exception as e:
return False
def search_docs(query:str, documents:List[Doc], nlp) -> List[Span]:
terms = query.split('|')
patterns = [nlp.make_doc(text) for text in terms]
matcher = PhraseMatcher(nlp.vocab)
matcher.add(query, patterns)
results = []
for doc in documents:
matches = matcher(doc) #List[(match_id, start, end)]
for match in matches:
results.append(doc[match[1]:match[2]])
return results
def update_query(arg:str):
st.session_state.query = arg
models = srsly.read_json('models.json')
models[''] = [] #require the user to choose a language
languages = models.keys()
language = st.selectbox("Language", languages, index=len(models.keys())-1, help="Select the language of your materials.")
if language:
select_model = st.selectbox("Model", models[language], help="spaCy model")
if select_model:
model_downloaded = download_model(language, select_model)
if model_downloaded:
nlp = spacy.load(select_model)
nlp.max_length = 1200000
uploaded_files = st.file_uploader("Select files to process", accept_multiple_files=True)
st.session_state.query = st.sidebar.text_input(label="Enter your query (use | to separate search terms)", value="...")
documents = []
all_ents = []
for uploaded_file in uploaded_files:
file_type = uploaded_file.type
file_suffix = '.' + uploaded_file.name.split('.')[-1]
temp = tempfile.NamedTemporaryFile(suffix=file_suffix,)
temp.write(uploaded_file.getvalue())
try:
text = textract.process(temp.name)
text = text.decode('utf-8')
doc = nlp(text)
doc.user_data['filename'] = uploaded_file.name
documents.append(doc)
for ent in doc.ents:
all_ents.append(ent)
except Exception as e:
st.error(e)
ents_container = st.container()
label_freq = Counter([ent.label_ for ent in all_ents])
for key, value in label_freq.items():
if st.sidebar.button(key, key=key):
st.sidebar.write(value)
text_freq = Counter([ent.text for ent in all_ents if ent.label_ == key])
for text in text_freq.keys():
st.sidebar.button(f'{text} ({text_freq[text]})', on_click=update_query, args=(text, ))
results_container = st.container()
results = search_docs(st.session_state.query, documents,nlp)
for result in results:
doc = result.doc
sent_before = doc[result.sent.start:result.start]
sent_after = doc[result.end:result.sent.end]
results_container.markdown(f"""
<div style="border: 2px solid #202d89;border-radius: 15px;"><p>{result.doc.user_data['filename']}</p>
<div class='text'>{sent_before.text} <span class="text_mark"> {result.text}</span>{sent_after.text}</div>
</div>
""", unsafe_allow_html=True)
#st.download_button('Download', '', 'text/plain')