Spaces:
Runtime error
Runtime error
import spacy | |
import streamlit as st | |
import pandas as pd | |
from PyPDF2 import PdfReader | |
from io import StringIO | |
import json | |
import warnings | |
import os | |
import ast | |
#@st.cache_resource | |
def load_models(model_names: list, args: dict, model_names_dir: list)-> dict: | |
""" | |
Check if model name refers to fine tuned models that are located in the model_dir or | |
default models native to spacy. Load them according to required methods | |
Parameters: | |
model_names: list of model names for inference | |
args: dict, configuration parameters | |
model_names_dir: list of model that are from the model_names_dir which are fine tuned models | |
Returns: | |
model_dict: A dictionary of keys representing the model names and values containing the model. | |
""" | |
assert (model_names is not None) or (len(model_names)!=0), "No models avaliable" | |
model_dict = {} | |
for model_name in model_names: | |
print(model_name) | |
# loading model from directory | |
if model_name in model_names_dir: | |
try: | |
model_path = os.path.join(args['model_dir'], model_name) | |
model = spacy.load(model_path) | |
except: | |
warnings.warn(f"Path to {model_name} not found") | |
else: | |
try: | |
#load default models from spacy | |
model = spacy.load(model_name) | |
except: | |
warnings.warn(f'Model: {model_name} not found') | |
model_dict.update({model_name:model}) | |
print('Model loaded') | |
return model_dict | |
def process_text(doc: spacy, selected_entities: list,colors: list)-> list: | |
""" | |
This function is to process the tokens from the doc type output from spacy models such that tokens that | |
are grouped together by their corresponding entities. This allow the st-annotations to be processed | |
the tokens for visualization | |
Example: "Hi John, i am sick with cough and flu" | |
Entities: person , disease | |
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)] | |
Parameters: | |
doc : spacy document | |
selected_entities : list of entities | |
colors : list of colors | |
Returns: | |
tokens: list of tuples | |
""" | |
tokens = [] | |
span = '' | |
p_ent = None | |
last = len(doc) | |
for no, token in enumerate(doc): | |
add_span = False | |
for ent in selected_entities: | |
if (token.ent_type_ == ent) & (ent in selected_entities): | |
span += token.text + " " | |
p_ent = ent | |
add_span = True | |
if no+1 == last: | |
tokens.append((span, ent, colors[ent],'#464646')) | |
if (add_span is False) & (len(span) >1): | |
tokens.append((span, p_ent, colors[p_ent],'#464646')) | |
span = '' | |
p_ent = None | |
if add_span is False: | |
tokens.append(" " + token.text + " ") | |
return tokens | |
def process_text_compare(infer_input: dict, selected_entities: list, colors: list)-> list: | |
""" | |
This function is use when user is looking to compare the text annotations between the prediction and | |
labels. This function is to process the tokens from evaluation data such that tokens that | |
are grouped together by their corresponding entities. This allow the st-annotations to be processed | |
the tokens for visualization | |
Example: "Hi John, i am sick with cough and flu" | |
Entities: person , disease | |
Output: [(Hi)(John, 'person', blue)(i am sick)(cough, 'disease', red)(and)(flu, 'disease', red)] | |
Parameters: | |
infer_input : spacy document | |
selected_entities : list of entities | |
colors : list of colors | |
Returns: | |
tokens: list of tuples | |
""" | |
tokens = [] | |
start_=0 | |
end_= len(infer_input['text']) | |
for start, end, entities in infer_input['entities']: | |
if entities in selected_entities: | |
# get the span of words that match the entities detected | |
span = infer_input['text'][start:end+1] | |
# get the span of words that don't match the entities | |
if start_ != start: | |
b4_span = infer_input['text'][start_:start] | |
tokens.append(" " + b4_span + " ") | |
tokens.append((span, entities, colors[entities],'#464646')) | |
start_=end | |
if start_ <= end_: | |
span = infer_input['text'][start_:end_+1] | |
tokens.append(" " + span + " ") | |
return tokens | |
def process_files(uploaded_file, text_input): | |
""" | |
As the app allows uploading files of mutiple files types, at present | |
such as json, csv, pdf and txt format. | |
The function is to detect what kind of file has been uploaded and process | |
the files accordingly. | |
If file has been uplaoded it will replace existing text_input | |
Parameters: | |
uploaded_file: The UploadedFile class is a subclass of BytesIO, and therefore it is "file-like". | |
text_input: str / dict /list | |
Return: | |
text_input: list / dict / str | |
""" | |
if uploaded_file is not None: | |
if uploaded_file.name[-3:]=='csv': | |
# literal_eval to eval a string of list into actual list obj | |
text_input = pd.read_csv(uploaded_file, converters={'entities': ast.literal_eval}) | |
text_input = text_input.to_dict('records') | |
elif uploaded_file.name[-3:]=='son': | |
text_input = json.load(uploaded_file) | |
else: | |
try: | |
text_input = "" | |
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) | |
for line in stringio.readlines(): | |
text_input += line + "\n" | |
#text_input = text_input.decode("utf-8", errors='strict') | |
except: | |
text_input = [] | |
reader = PdfReader(uploaded_file) | |
count = len(reader.pages) | |
# read all the pages of a pdf | |
for i in range(count): | |
pages = reader.pages[i] | |
text_input.append(pages.extract_text()) | |
text_input = ''.join(text_input) | |
return text_input | |