rosetta / generator.py
yhavinga's picture
Add app
46ffa30
raw
history blame
No virus
4.05 kB
import os
import streamlit as st
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name, task):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("netherator")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(
model_name, from_flax=True, use_auth_token=access_token
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
auto_model_class = (
AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
)
model = auto_model_class.from_pretrained(
model_name, from_flax=True, use_auth_token=access_token
)
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc):
self.model_name = model_name
self.task = task
self.desc = desc
self.tokenizer = None
self.model = None
self.prefix = ""
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name, self.task)
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
except TypeError:
pass
def generate(self, text: str, **generate_kwargs) -> str:
#
# import pydevd_pycharm
# pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True)
#
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
decoded_preds = [
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
for pred in decoded_preds
]
return decoded_preds
# return self.pipeline(text, **generate_kwargs)
def __str__(self):
return self.desc
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def gpt_descs(self):
return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]