|
import streamlit as st |
|
import textract |
|
import tempfile |
|
import spacy |
|
from spacy.tokens import DocBin, Doc, Span |
|
from collections import Counter |
|
import srsly |
|
from spacy.matcher import PhraseMatcher |
|
|
|
|
|
with open("style.css") as f: |
|
st.markdown("<style>" + f.read() + "</style>", unsafe_allow_html=True) |
|
|
|
st.title('Index and Search a Collection of Documents') |
|
if 'query' not in st.session_state: |
|
st.session_state['query'] = '' |
|
|
|
@st.cache |
|
def download_model(select_model:str): |
|
try: |
|
spacy.cli.download(select_model) |
|
return True |
|
except Exception as e: |
|
return False |
|
|
|
def search_docs(query:str, documents:list[Doc], nlp) -> list[Span]: |
|
terms = query.split('|') |
|
patterns = [nlp.make_doc(text) for text in terms] |
|
matcher = PhraseMatcher(nlp.vocab) |
|
matcher.add(query, patterns) |
|
|
|
results = [] |
|
for doc in documents: |
|
matches = matcher(doc) |
|
for match in matches: |
|
results.append(doc[match[1]:match[2]]) |
|
|
|
return results |
|
|
|
def update_query(arg:str): |
|
st.session_state.query = arg |
|
|
|
models = srsly.read_json('models.json') |
|
models[''] = [] |
|
languages = models.keys() |
|
language = st.selectbox("Language", languages, index=len(models.keys())-1, help="Select the language of your materials.") |
|
if language: |
|
select_model = st.selectbox("Model", models[language], help="spaCy model") |
|
if select_model: |
|
model_downloaded = download_model(select_model) |
|
|
|
if model_downloaded: |
|
|
|
nlp = spacy.load(select_model) |
|
|
|
nlp.max_length = 1200000 |
|
|
|
|
|
uploaded_files = st.file_uploader("Select files to process", accept_multiple_files=True) |
|
st.session_state.query = st.sidebar.text_input(label="Enter your query (use | to separate search terms)", value="...") |
|
|
|
documents = [] |
|
all_ents = [] |
|
for uploaded_file in uploaded_files: |
|
file_type = uploaded_file.type |
|
file_suffix = '.' + uploaded_file.name.split('.')[-1] |
|
temp = tempfile.NamedTemporaryFile(suffix=file_suffix,) |
|
temp.write(uploaded_file.getvalue()) |
|
try: |
|
text = textract.process(temp.name) |
|
text = text.decode('utf-8') |
|
doc = nlp(text) |
|
doc.user_data['filename'] = uploaded_file.name |
|
documents.append(doc) |
|
for ent in doc.ents: |
|
all_ents.append(ent) |
|
|
|
|
|
except Exception as e: |
|
st.error(e) |
|
|
|
ents_container = st.container() |
|
label_freq = Counter([ent.label_ for ent in all_ents]) |
|
for key, value in label_freq.items(): |
|
if st.sidebar.button(key, key=key): |
|
st.sidebar.write(value) |
|
text_freq = Counter([ent.text for ent in all_ents if ent.label_ == key]) |
|
for text in text_freq.keys(): |
|
st.sidebar.button(f'{text} ({text_freq[text]})', on_click=update_query, args=(text, )) |
|
|
|
results_container = st.container() |
|
results = search_docs(st.session_state.query, documents,nlp) |
|
for result in results: |
|
doc = result.doc |
|
sent_before = doc[result.sent.start:result.start] |
|
sent_after = doc[result.end:result.sent.end] |
|
results_container.markdown(f""" |
|
<div style="border: 2px solid #202d89;border-radius: 15px;"><p>{result.doc.user_data['filename']}</p> |
|
<div class='text'>{sent_before.text} <span class="text_mark"> {result.text}</span>{sent_after.text}</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|