writing-assistant / utils.py
aseifert's picture
Update utils.py
288b53e
from pathlib import Path
import errant
import requests
import spacy
import streamlit as st
from fastT5 import export_and_get_onnx_model, get_onnx_model
from transformers import AutoTokenizer # type: ignore
_ERRANT_TOKENIZER_MAPPINGs = [
(" .", "."),
(" ,", ","),
(" ?", "?"),
(" !", "!"),
(" :", ":"),
(" ;", ";"),
(" n't", "n't"), # do n't
(" 're ", "'re "), # we 're
(" 'm ", "'m "), # I 'm
(" 've ", "'ve "), # I 've
(" 'll ", "'ll "), # I 'll
(" 's ", "'s "), # Laura 's (singular possive)
("s ' ", "s' "), # years ' (plural possessive)
# (" `` ", ' "'),
# (" '' ", '" '),
]
# @st.cache
# def download_quantized_model(model_name: str):
# for fname in (
# f"{model_name}-decoder-quantized.onnx",
# f"{model_name}-encoder-quantized.onnx",
# f"{model_name}-init-decoder-quantized.onnx",
# ):
# out_path = Path(f"models/{fname}")
# if not out_path.exists():
# s3.download_file(
# "alex.apollo.ai", f"hf-writing-assistant/{model_name}/{fname}", str(out_path)
# )
@st.cache
def download_quantized_model(model_name):
s3_base = "https://s3.eu-west-1.amazonaws.com/alex.apollo.ai/hf-writing-assistant"
model_name = model_name.split("/")[-1]
Path("models/").mkdir(exist_ok=True)
for fname in (
f"{model_name}-decoder-quantized.onnx",
f"{model_name}-encoder-quantized.onnx",
f"{model_name}-init-decoder-quantized.onnx",
):
url = f"{s3_base}/{model_name}/{fname}"
response = requests.get(url)
with open(f"models/{fname}", "wb") as fp:
fp.write(response.content)
@st.cache
def errant_tokenize(text):
for replacement, orig in _ERRANT_TOKENIZER_MAPPINGs:
text = text.replace(orig, replacement)
text = (
text.replace(". . .", "...")
.replace("etc .", "etc.")
.replace("Mr .", "Mr.")
.replace("U .S .A", "U.S.A")
)
return text
@st.cache
def download_spacy_model(model="en"):
try:
spacy.load(model)
except OSError:
spacy.cli.download(model) # type: ignore
return True
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer(model_name: str):
model = None
try:
model = get_onnx_model(model_name)
except AssertionError as e:
model = export_and_get_onnx_model(model_name)
assert model
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
@st.cache(allow_output_mutation=True)
def get_annotator(lang: str):
return errant.load(lang)