lterriel's picture
clean & refactor components + add doc
74e2066
# -*- coding:utf-8 -*-
from io import BytesIO
import re
from zipfile import ZipFile
import os
from pathlib import Path
import streamlit as st
from cassis import load_typesystem, load_cas_from_xmi
def st_pb(method):
"""streamlit decorator to display
progress bar
"""
def progress_bar(ref):
container = st.empty()
bar = st.progress(0)
pg_gen = method(ref)
try:
while True:
progress = next(pg_gen)
bar.progress(progress[0])
if progress[2]:
container.write("βœ… Processing... " + progress[1])
else:
container.write("❌️ Errror with..." + progress[1])
except StopIteration as result:
return result.value
return progress_bar
class Project:
def __init__(self, zip_project, type, remote):
# zip container that contains XMI and typesystem
self.zip_project = zip_project
self.remote = remote
# 'iaa' or 'global'
self.type = type
# store source filename
self.documents = []
# store XMI representation
self.xmi_documents = []
# store typesystem file
self.typesystem = None # cassis.load_typesystem(BytesIO(annotation_zip.read('TypeSystem.xml')))
# set annotators
self.annotators = []
# set annotations
"""
{
"Filename.xmi": {
mentions: [],
labels: []
}, ...
}
"""
self.annotations = {}
if isinstance(self.zip_project, ZipFile) and self.remote and self.type == "global":
for fp in self.zip_project.namelist():
if self.typesystem is None:
self.typesystem = load_typesystem(BytesIO(self.zip_project.open('TypeSystem.xml').read()))
if fp.endswith('.xmi'):
self.documents.append(fp)
self.xmi_documents.append(str(self.zip_project.open(fp).read().decode("utf-8")))
else:
with ZipFile(self.zip_project) as project_zip:
if self.type == "global":
regex = re.compile('.*curation/.*/(?!\._).*zip$')
elif self.type == "iaa":
regex = re.compile('.*xm[il]$')
annotation_fps = (fp for fp in project_zip.namelist() if regex.match(fp))
for fp in annotation_fps:
if self.type == "global":
with ZipFile(BytesIO(project_zip.read(fp))) as annotation_zip:
if self.typesystem is None:
self.typesystem = load_typesystem(BytesIO(annotation_zip.read('TypeSystem.xml')))
for f in annotation_zip.namelist():
if f.endswith('.xmi'):
# store source filename
self.documents.append(Path(fp).parent.name)
# annotators = []
# store XMI representation
self.xmi_documents.append(str(annotation_zip.read(f).decode("utf-8")))
elif self.type == "iaa":
if self.typesystem is None and fp.endswith('.xml'):
self.typesystem = load_typesystem(BytesIO(project_zip.read('TypeSystem.xml')))
else:
if fp.endswith('.xmi'):
# store source filename
self.documents.append(fp)
# set annotators
self.annotators.append(os.path.splitext(fp)[0])
# store XMI representation
self.xmi_documents.append(str(project_zip.read(fp).decode("utf-8")))
self.extract_ne()
@st_pb
def extract_ne(self):
count = 0
for xmi, src in zip(self.xmi_documents, self.documents):
doc_flag = True
try:
cas = load_cas_from_xmi(xmi, typesystem=self.typesystem)
self.annotations[src] = {
"mentions": [],
"labels": []
}
for ne in cas.select('de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity'):
self.annotations[src]["mentions"].append(ne.get_covered_text())
self.annotations[src]["labels"].append(ne.value)
except:
doc_flag = False
count += 1
yield (count / len(self.documents)) * 1.0, src, doc_flag