rosetta / generator.py
yhavinga's picture
Fix double model load
f3e8368
raw
history blame contribute delete
No virus
5.71 kB
import os
import re
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = torch.cuda.device_count() - 1
def get_access_token():
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("babel")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
return access_token
@st.cache_resource
def load_model(model_name):
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=("ul2" not in model_name),
use_auth_token=get_access_token(),
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
for framework in [None, "flax", "tf"]:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
from_flax=(framework == "flax"),
from_tf=(framework == "tf"),
use_auth_token=get_access_token(),
)
break
except EnvironmentError:
if framework == "tf":
raise
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc, split_sentences):
self.model_name = model_name
self.task = task
self.desc = desc
self.split_sentences = split_sentences
self.tokenizer = None
self.model = None
self.prefix = ""
self.gen_kwargs = {
"max_length": 128,
"num_beams": 6,
"num_beam_groups": 3,
"no_repeat_ngram_size": 0,
"early_stopping": True,
"num_return_sequences": 1,
"length_penalty": 1.0,
}
self.load()
def load(self):
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name)
for key in self.gen_kwargs:
if key in self.model.config.__dict__:
self.gen_kwargs[key] = self.model.config.__dict__[key]
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
for key in self.gen_kwargs:
if key in task_specific_params:
self.gen_kwargs[key] = task_specific_params[key]
except TypeError:
pass
def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
# Replace two or more newlines with a single newline in text
text = re.sub(r"\n{2,}", "\n", text)
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
# if there are newlines in the text, and the model needs line-splitting, split the text and recurse
if re.search(r"\n", text) and self.split_sentences:
lines = text.splitlines()
translated = [
self.generate(line, streamer, **generate_kwargs)[0] for line in lines
]
return "\n".join(translated), generate_kwargs
# if self.tokenizer has a newline_token attribute, replace \n with it
if hasattr(self.tokenizer, "newline_token"):
text = re.sub(r"\n", self.tokenizer.newline_token, text)
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
streamer=streamer,
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
def replace_tokens(pred):
pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
if hasattr(self.tokenizer, "newline_token"):
pred = pred.replace(self.tokenizer.newline_token, "\n")
return pred
decoded_preds = list(map(replace_tokens, decoded_preds))
return decoded_preds[0], generate_kwargs
def __str__(self):
return self.model_name
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc, split_sentences):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc, split_sentences)
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def filter(self, **kwargs):
return [
g
for g in self.generators
if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
]