madlad400-3b-mt / app.py
jbochi's picture
Model was not trained on more than 128 tokens
5e4ae11
raw
history blame
No virus
2.01 kB
import time
from transformers import T5ForConditionalGeneration, T5Tokenizer, GenerationConfig
import gradio as gr
MODEL_NAME = "jbochi/madlad400-3b-mt"
print(f"Loading {MODEL_NAME} tokenizer...")
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
print(f"Loading {MODEL_NAME} model...")
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
def inference(input_text, target_language, max_length):
global model, tokenizer
start_time = time.time()
input_ids = tokenizer(
f"<2{target_language}> {input_text}", return_tensors="pt"
).input_ids
outputs = model.generate(
input_ids=input_ids.to(model.device),
generation_config=GenerationConfig(max_length=max_length),
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
end_time = time.time()
result = {
'result': result,
'inference_time': end_time - start_time,
'input_token_ids': input_ids[0].tolist(),
'output_token_ids': outputs[0].tolist(),
}
return result
def run():
tokens = [tokenizer.decode(i) for i in range(500)]
lang_codes = [token[2:-1] for token in tokens if token.startswith("<2")]
inputs = [
gr.components.Textbox(lines=5, label="Input text"),
gr.components.Dropdown(lang_codes, value="en", label="Target Language"),
gr.components.Slider(
minimum=5,
maximum=128,
value=50,
label="Max length",
),
]
examples = [
["I'm a mad lad!", "es", 50],
["千里之行,始於足下", "en", 50],
]
outputs = gr.components.JSON()
title = f"{MODEL_NAME} demo"
demo_status = "Demo is running on CPU"
description = f"Details: https://huggingface.co/{MODEL_NAME}. {demo_status}"
gr.Interface(
inference,
inputs,
outputs,
title=title,
description=description,
examples=examples,
).launch()
if __name__ == "__main__":
run()