|
from dataclasses import dataclass
|
|
from typing import Dict, List, Sequence
|
|
import stanza
|
|
import transformers
|
|
import json
|
|
import streamlit as st
|
|
from .web_utilities import st_cache_data_if, st_cache_resource_if, supported_cache
|
|
from .anonymize import add_space_to_comma_endpoint, change_name_patient_abbreviations
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SentenceBoundary:
|
|
text: str
|
|
prefix: str
|
|
|
|
def __str__(self):
|
|
return self.prefix + self.text
|
|
|
|
|
|
@dataclass
|
|
class SentenceBoundaries:
|
|
def __init__(self) -> None:
|
|
self._sentence_boundaries: List[SentenceBoundary] = []
|
|
|
|
@property
|
|
def sentence_boundaries(self):
|
|
return self._sentence_boundaries
|
|
|
|
def update_sentence_boundaries(
|
|
self, sentence_boundaries_list: List[SentenceBoundary]
|
|
):
|
|
self._sentence_boundaries = sentence_boundaries_list
|
|
return self
|
|
|
|
def from_doc(self, doc: stanza.Document):
|
|
start_idx = 0
|
|
for sent in doc.sentences:
|
|
self.sentence_boundaries.append(
|
|
SentenceBoundary(
|
|
text=sent.text,
|
|
prefix=doc.text[start_idx : sent.tokens[0].start_char],
|
|
)
|
|
)
|
|
start_idx = sent.tokens[-1].end_char
|
|
self.sentence_boundaries.append(
|
|
SentenceBoundary(text="", prefix=doc.text[start_idx:])
|
|
)
|
|
return self
|
|
|
|
@property
|
|
def nonempty_sentences(self) -> List[str]:
|
|
return [item.text for item in self.sentence_boundaries if item.text]
|
|
|
|
def map_sentence_boundaries(self, d: Dict[str, str]) -> List:
|
|
return SentenceBoundaries().update_sentence_boundaries(
|
|
[
|
|
SentenceBoundary(text=d.get(sb.text, sb.text), prefix=sb.prefix)
|
|
for sb in self.sentence_boundaries
|
|
]
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
return "".join(map(str, self.sentence_boundaries))
|
|
|
|
|
|
@st_cache_resource_if(supported_cache, max_entries=5, ttl=3600)
|
|
def minibatch(seq, size):
|
|
items = []
|
|
for x in seq:
|
|
items.append(x)
|
|
if len(items) >= size:
|
|
yield items
|
|
items = []
|
|
if items:
|
|
yield items
|
|
|
|
|
|
|
|
class Translator:
|
|
def __init__(self, source_lang: str, dest_lang: str, use_gpu: bool = False) -> None:
|
|
|
|
self.model_name = "Helsinki-NLP/opus-mt-" + source_lang + "-" + dest_lang
|
|
self.model = transformers.MarianMTModel.from_pretrained(self.model_name)
|
|
|
|
|
|
self.tokenizer = transformers.MarianTokenizer.from_pretrained(self.model_name)
|
|
self.sentencizer = stanza.Pipeline(
|
|
source_lang, processors="tokenize", verbose=False, use_gpu=use_gpu
|
|
)
|
|
|
|
def sentencize(self, texts: Sequence[str]) -> List[SentenceBoundaries]:
|
|
return [
|
|
SentenceBoundaries().from_doc(doc=self.sentencizer.process(text))
|
|
for text in texts
|
|
]
|
|
|
|
def translate(
|
|
self, texts: Sequence[str], batch_size: int = 10, truncation=True
|
|
) -> Sequence[str]:
|
|
if isinstance(texts, str):
|
|
raise ValueError("Expected a sequence of texts")
|
|
text_sentences = self.sentencize(texts)
|
|
translations = {
|
|
sent: None for text in text_sentences for sent in text.nonempty_sentences
|
|
}
|
|
|
|
for text_batch in minibatch(
|
|
sorted(translations, key=len, reverse=True), batch_size
|
|
):
|
|
tokens = self.tokenizer(
|
|
text_batch, return_tensors="pt", padding=True, truncation=truncation
|
|
)
|
|
|
|
|
|
translate_tokens = self.model.generate(**tokens)
|
|
translate_batch = [
|
|
self.tokenizer.decode(t, skip_special_tokens=True)
|
|
for t in translate_tokens
|
|
]
|
|
for text, translated in zip(text_batch, translate_batch):
|
|
translations[text] = translated
|
|
|
|
return [
|
|
str(text.map_sentence_boundaries(translations)) for text in text_sentences
|
|
]
|
|
|
|
|
|
|
|
@st_cache_data_if(supported_cache, max_entries=5, ttl=3600)
|
|
def translate_report(
|
|
Report, Last_name, First_name, _nlp, _marian_fr_en, dict_correction, abbreviation_dict
|
|
):
|
|
Report_name, list_replaced_abb_name = change_name_patient_abbreviations(
|
|
Report, Last_name, First_name, abbreviation_dict
|
|
)
|
|
MarianText_raw = translate_marian(Report_name, _nlp, _marian_fr_en)
|
|
MarianText_space = add_space_to_comma_endpoint(MarianText_raw, _nlp)
|
|
MarianText, list_replaced = correct_marian(
|
|
MarianText_space, dict_correction, Last_name, First_name
|
|
)
|
|
del MarianText_raw
|
|
del MarianText_space
|
|
return MarianText, list_replaced, list_replaced_abb_name
|
|
|
|
|
|
|
|
@st_cache_resource_if(supported_cache, max_entries=5, ttl=3600)
|
|
def translate_marian(Report_name, _nlp, _marian_fr_en):
|
|
list_of_sentence = []
|
|
for sentence in _nlp.process(Report_name).sentences:
|
|
list_of_sentence.append(sentence.text)
|
|
MarianText_raw = "\n".join(_marian_fr_en.translate(list_of_sentence))
|
|
del list_of_sentence
|
|
return MarianText_raw
|
|
|
|
|
|
@st_cache_data_if(supported_cache, max_entries=5, ttl=3600)
|
|
def correct_marian(MarianText_space, dict_correction, Last_name, First_name):
|
|
MarianText = MarianText_space
|
|
list_replaced = []
|
|
for key, value in dict_correction.items():
|
|
if key in MarianText:
|
|
list_replaced.append(
|
|
{
|
|
"name": Last_name,
|
|
"surname": First_name,
|
|
"type": "marian_correction",
|
|
"value": key,
|
|
"correction": value,
|
|
"lf_detected": True,
|
|
"manual_validation": True,
|
|
}
|
|
)
|
|
MarianText = MarianText.replace(key, value)
|
|
return MarianText, list_replaced
|
|
|
|
|
|
|
|
@st_cache_resource_if(supported_cache, max_entries=5, ttl=3600)
|
|
def get_translation_dict_correction():
|
|
dict_correction_FRspec = {
|
|
"PC": "head circumference",
|
|
"palatine slot": "cleft palate",
|
|
"ASD": "autism",
|
|
"ADHD": "attention deficit hyperactivity disorder",
|
|
"IUGR": "intrauterin growth retardation",
|
|
"QI": "IQ ",
|
|
"QIT": "FSIQ ",
|
|
"ITQ": "FSIQ ",
|
|
"DS": "SD",
|
|
"FOP": "patent foramen ovale",
|
|
"PFO": "patent foramen ovale",
|
|
"ARCF": "fetal distress",
|
|
"\n": " ",
|
|
"associated": "with",
|
|
"Mr.": "Mr",
|
|
"Mrs.": "Mrs",
|
|
}
|
|
|
|
dict_correction = {}
|
|
for key, value in dict_correction_FRspec.items():
|
|
dict_correction[" " + key + " "] = " " + value + " "
|
|
|
|
with open("data/hp_fr_en_translated_marian_review_lwg.json", "r") as outfile:
|
|
hpo_translated = json.load(outfile)
|
|
|
|
for key, value in hpo_translated.items():
|
|
dict_correction[" " + key + " "] = " " + value + " "
|
|
|
|
with open("data/fr_abbreviations_translation.json", "r") as outfile:
|
|
hpo_translated_abbreviations = json.load(outfile)
|
|
|
|
for key, value in hpo_translated_abbreviations.items():
|
|
dict_correction[" " + key + " "] = " " + value + " "
|
|
|
|
del hpo_translated
|
|
del hpo_translated_abbreviations
|
|
return dict_correction
|
|
|
|
|
|
|