davda54's picture
Update app.py
9c65c2d verified
import os
import json
import torch
import shutil
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
from transformers.generation import LogitsProcessor
import huggingface_hub
from huggingface_hub import Repository
from threading import Thread
import gradio as gr
print(f"Starting to load the model to memory")
tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SYSTEM: Running on {device}", flush=True)
model = model.to(device)
model.eval()
print(f"Sucessfully loaded the model to the memory")
LANGUAGES = [
"🇬🇧 English",
"🇳🇴 Norwegian (Bokmål)",
"🇳🇴 Norwegian (Nynorsk)"
]
LANGUAGE_IDS = {
"🇬🇧 English": eng_index,
"🇳🇴 Norwegian (Bokmål)": nob_index,
"🇳🇴 Norwegian (Nynorsk)": nno_index
}
STATS_REPO = "https://huggingface.co/datasets/ltg/usage_statistics"
HF_TOKEN = os.environ.get("HF_TOKEN")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
# log the timestamp of the query
def add_anonymous_usage_log(path):
global dataset
try:
dataset.git_pull()
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
except:
shutil.rmtree("data")
dataset = Repository(
local_dir="data", clone_from=STATS_REPO, use_auth_token=HF_TOKEN
)
with open(path, "a") as f:
line = json.dumps(str(datetime.now()), ensure_ascii=False)
f.write(f"{line}\n")
dataset.push_to_hub(blocking=False)
class BatchStreamer(TextIteratorStreamer):
def put(self, value):
print(value.shape)
#if value.size(0) == 1:
# return super().put(value)
if len(self.token_cache) == 0:
self.token_cache = [[] for _ in range(value.size(0))]
value = value.tolist()
# Add the new token to the cache and decodes the entire thing.
for c, v in zip(self.token_cache, value):
c += [v] if isinstance(v, int) else v
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
text = '\n'.join(paragraphs)
self.on_finalized_text(text)
def end(self):
if len(self.token_cache) > 0:
paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
printable_text = '\n'.join(paragraphs)
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float, model):
last_bias = model.classifier.nonlinearity[-1].bias.data
last_bias = torch.nn.functional.log_softmax(last_bias)
self.penalty = penalty * (last_bias - last_bias.max())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
scores.scatter_(1, input_ids, penalized_score)
return scores
def translate(source, source_language, target_language):
if source_language == target_language:
yield source.strip()
return source.strip()
source = [s.strip() for s in source.split('\n')]
source_subwords = tokenizer(source).input_ids
source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
source_subwords = [torch.tensor(s) for s in source_subwords]
source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
source_subwords = source_subwords[:, :512].to(device)
streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
def generate(model, **kwargs):
with torch.inference_mode():
with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
return model.generate(**kwargs)
generate_kwargs = dict(
streamer=streamer,
input_ids=source_subwords,
attention_mask=(source_subwords != pad_index).long(),
max_new_tokens = 512-1,
#top_k=64,
#top_p=0.95,
#do_sample=True,
#temperature=0.3,
num_beams=1,
#use_cache=True,
logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)],
# num_beams=4,
# early_stopping=True,
do_sample=False,
use_cache=True
)
t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
t.start()
for new_text in streamer:
yield new_text.strip()
add_anonymous_usage_log("data/no-en-translation.jsonl")
return new_text.strip()
def switch_inputs(source, target, source_language, target_language):
return target, source, target_language, source_language
with gr.Blocks() as demo:
# with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown("# Norwegian-English translation")
with gr.Row():
with gr.Column(scale=7, variant="panel"):
source_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[1], show_label=False
)
source = gr.Textbox(
label="Source text", placeholder="What do you want to translate?", show_label=False, lines=7, max_lines=100, autofocus=True
) # .style(container=False)
submit = gr.Button("Submit", variant="primary") # .style(full_width=True)
with gr.Column(scale=7, variant="panel"):
target_language = gr.Dropdown(
LANGUAGES, value=LANGUAGES[0], show_label=False
)
target = gr.Textbox(
label="Translation", show_label=False, interactive=False, lines=7, max_lines=100
)
def update_state_after_user():
return {
source: gr.update(interactive=False),
submit: gr.update(interactive=False),
source_language: gr.update(interactive=False),
target_language: gr.update(interactive=False)
}
def update_state_after_return():
return {
source: gr.update(interactive=True),
submit: gr.update(interactive=True),
source_language: gr.update(interactive=True),
target_language: gr.update(interactive=True)
}
submit_event = source.submit(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
submit_click_event = submit.click(
fn=update_state_after_user, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
).then(
fn=translate, inputs=[source, source_language, target_language], outputs=[target], queue=True
).then(
fn=update_state_after_return, inputs=None, outputs=[source, submit, source_language, target_language], queue=False
)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()