File size: 4,292 Bytes
4c042d9
7cdaf14
4c042d9
 
 
26e1bea
 
0126f24
6941336
4c042d9
3dfae67
4c042d9
ee202f8
b16d274
ee202f8
 
4c042d9
7136222
 
4c042d9
 
26e1bea
 
 
 
3dfae67
26e1bea
 
 
 
 
 
3dfae67
7cdaf14
3dfae67
 
4c6bd94
3dfae67
 
9496b6e
 
 
3dfae67
 
 
0126f24
6941336
7136222
 
 
4c042d9
 
 
 
 
 
 
26e1bea
4c042d9
 
 
 
 
 
4c6bd94
4c042d9
 
7136222
 
9496b6e
7136222
4c042d9
 
6941336
 
4c042d9
 
 
 
 
0126f24
9496b6e
7136222
 
 
54cc41f
4c042d9
 
7136222
 
 
 
 
 
 
 
 
 
4c6bd94
7136222
4c6bd94
 
 
 
 
0126f24
 
4c6bd94
 
 
4c042d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
from typing import List
import textract
import tempfile
import spacy
import subprocess
import scispacy
from spacy.tokens import DocBin, Doc, Span
from collections import Counter
import srsly
from spacy.matcher import PhraseMatcher

# Import CSS file
with open("style.css") as f:
    st.markdown("<style>" + f.read() + "</style>", unsafe_allow_html=True)
    
st.title('Index and Search a Collection of Documents')
if 'query' not in st.session_state:
    st.session_state['query'] = ''

@st.cache
def download_model(language:str, select_model:str):
    if language == 'Science':
        urls = srsly.read_json('scispacy.json')
        subprocess.run(['pip', 'install', f'{urls[select_model]}'])
        return True
    else:
        try:
            spacy.cli.download(select_model)
            return True
        except Exception as e:
            return False

def search_docs(query:str, documents:List[Doc], nlp) -> List[Span]:
    terms = query.split('|')
    patterns = [nlp.make_doc(text) for text in terms]
    matcher = PhraseMatcher(nlp.vocab)
    matcher.add(query, patterns)
    
    results = []
    for doc in documents:
        matches = matcher(doc) #List[(match_id, start, end)]
        for match in matches:
            results.append(doc[match[1]:match[2]])
            
    return results

def update_query(arg:str):
    st.session_state.query = arg

models = srsly.read_json('models.json')
models[''] = [] #require the user to choose a language
languages = models.keys()
language = st.selectbox("Language", languages, index=len(models.keys())-1, help="Select the language of your materials.")
if language:
    select_model = st.selectbox("Model", models[language], help="spaCy model")
    if select_model:
        model_downloaded = download_model(language, select_model)

        if model_downloaded:

            nlp = spacy.load(select_model)

            nlp.max_length = 1200000
            

            uploaded_files = st.file_uploader("Select files to process", accept_multiple_files=True)
            st.session_state.query = st.sidebar.text_input(label="Enter your query (use | to separate search terms)", value="...")
            
            documents = []
            all_ents = []
            for uploaded_file in uploaded_files:
                file_type = uploaded_file.type
                file_suffix = '.' + uploaded_file.name.split('.')[-1]
                temp = tempfile.NamedTemporaryFile(suffix=file_suffix,)
                temp.write(uploaded_file.getvalue())
                try:
                    text = textract.process(temp.name)
                    text = text.decode('utf-8')
                    doc = nlp(text)
                    doc.user_data['filename'] = uploaded_file.name
                    documents.append(doc)
                    for ent in doc.ents:
                        all_ents.append(ent)
                    
                        
                except Exception as e:
                    st.error(e)

            ents_container = st.container()   
            label_freq = Counter([ent.label_ for ent in all_ents])
            for key, value in label_freq.items():
                if st.sidebar.button(key, key=key):
                    st.sidebar.write(value)
                    text_freq = Counter([ent.text for ent in all_ents if ent.label_ == key])
                    for text in text_freq.keys():
                        st.sidebar.button(f'{text} ({text_freq[text]})', on_click=update_query, args=(text, ))
                        
            results_container = st.container()   
            results = search_docs(st.session_state.query, documents,nlp)
            for result in results:
                doc = result.doc
                sent_before = doc[result.sent.start:result.start]
                sent_after = doc[result.end:result.sent.end]
                results_container.markdown(f"""
                <div style="border: 2px solid #202d89;border-radius: 15px;"><p>{result.doc.user_data['filename']}</p>
                <div class='text'>{sent_before.text} <span class="text_mark"> {result.text}</span>{sent_after.text}</div>
                </div>
                """, unsafe_allow_html=True)
            
            #st.download_button('Download', '', 'text/plain')