bolete / app.py
apjanco
working on ent buttons to update query
7136222
raw
history blame
4.01 kB
import streamlit as st
import textract
import tempfile
import spacy
from spacy.tokens import DocBin, Doc, Span
from collections import Counter
import srsly
from spacy.matcher import PhraseMatcher
# Import CSS file
with open("style.css") as f:
st.markdown("<style>" + f.read() + "</style>", unsafe_allow_html=True)
st.title('Index and Search a Collection of Documents')
if 'query' not in st.session_state:
st.session_state['query'] = ''
@st.cache
def download_model(select_model:str):
try:
spacy.cli.download(select_model)
return True
except Exception as e:
return False
def search_docs(query:str, documents:list[Doc], nlp) -> list[Span]:
terms = query.split('|')
patterns = [nlp.make_doc(text) for text in terms]
matcher = PhraseMatcher(nlp.vocab)
matcher.add(query, patterns)
results = []
for doc in documents:
matches = matcher(doc) #List[(match_id, start, end)]
for match in matches:
results.append(doc[match[1]:match[2]])
return results
def update_query(arg:str):
st.session_state.query = arg
models = srsly.read_json('models.json')
models[''] = [] #require the user to choose a language
languages = models.keys()
language = st.selectbox("Language", languages, index=len(models.keys())-1, help="Select the language of your materials.")
if language:
select_model = st.selectbox("Model", models[language], help="spaCy model")
if select_model:
model_downloaded = download_model(select_model)
if model_downloaded:
nlp = spacy.load(select_model)
nlp.max_length = 1200000
uploaded_files = st.file_uploader("Select files to process", accept_multiple_files=True)
st.session_state.query = st.sidebar.text_input(label="Enter your query (use | to separate search terms)", value="...")
documents = []
all_ents = []
for uploaded_file in uploaded_files:
file_type = uploaded_file.type
file_suffix = '.' + uploaded_file.name.split('.')[-1]
temp = tempfile.NamedTemporaryFile(suffix=file_suffix,)
temp.write(uploaded_file.getvalue())
try:
text = textract.process(temp.name)
text = text.decode('utf-8')
doc = nlp(text)
doc.user_data['filename'] = uploaded_file.name
documents.append(doc)
for ent in doc.ents:
all_ents.append(ent)
except Exception as e:
st.error(e)
ents_container = st.container()
label_freq = Counter([ent.label_ for ent in all_ents])
for key, value in label_freq.items():
if st.sidebar.button(key, key=key):
st.sidebar.write(value)
text_freq = Counter([ent.text for ent in all_ents if ent.label_ == key])
for text in text_freq.keys():
st.sidebar.button(f'{text} ({text_freq[text]})', on_click=update_query, args=(text, ))
results_container = st.container()
results = search_docs(st.session_state.query, documents,nlp)
for result in results:
doc = result.doc
sent_before = doc[result.sent.start:result.start]
sent_after = doc[result.end:result.sent.end]
results_container.markdown(f"""
<div style="border: 2px solid #202d89;border-radius: 15px;"><p>{result.doc.user_data['filename']}</p>
<div class='text'>{sent_before.text} <span class="text_mark"> {result.text}</span>{sent_after.text}</div>
</div>
""", unsafe_allow_html=True)
#st.download_button('Download', '', 'text/plain')