test3 / app.py
EzekielMW's picture
Update app.py
604964b verified
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
import torch
import os
import yaml
import transformers
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this as needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")
# Where should output files be stored locally
drive_folder = "./serverlogs"
if not os.path.exists(drive_folder):
os.makedirs(drive_folder)
# Large batch sizes generally give good results for translation
effective_train_batch_size = 480
train_batch_size = 6
eval_batch_size = train_batch_size
gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)
# Everything in one yaml string, so that it can all be logged.
yaml_config = '''
training_args:
output_dir: "{drive_folder}"
eval_strategy: steps
eval_steps: 100
save_steps: 100
gradient_accumulation_steps: {gradient_accumulation_steps}
learning_rate: 3.0e-4 # Include decimal point to parse as float
# optim: adafactor
per_device_train_batch_size: {train_batch_size}
per_device_eval_batch_size: {eval_batch_size}
weight_decay: 0.01
save_total_limit: 3
max_steps: 500
predict_with_generate: True
fp16: True
logging_dir: "{drive_folder}"
load_best_model_at_end: True
metric_for_best_model: loss
seed: 123
push_to_hub: False
max_input_length: 128
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .
# Use a 600M parameter model here, which is easier to train on a free Colab
# instance. Bigger models work better, however: results will be improved
# if able to train on nllb-200-1.3B instead.
model_checkpoint: facebook/nllb-200-distilled-600M
datasets:
train:
huggingface_load:
# We will load two datasets here: English/KSL Gloss, and also SALT
# Swahili/English, so that we can try out multi-way translation.
- path: EzekielMW/Eksl_dataset
split: train[:-1000]
- path: sunbird/salt
name: text-all
split: train
source:
# This is a text translation only, no audio.
type: text
# The source text can be any of English, KSL or Swahili.
language: [eng,ksl,swa]
preprocessing:
# The models are case sensitive, so if the training text is all
# capitals, then it will only learn to translate capital letters and
# won't understand lower case. Make everything lower case for now.
- lower_case
# We can also augment the spelling of the input text, which makes the
# model more robust to spelling errors.
- augment_characters
target:
type: text
# The target text with any of English, KSL or Swahili.
language: [eng,ksl,swa]
# The models are case sensitive: make everything lower case for now.
preprocessing:
- lower_case
shuffle: True
allow_same_src_and_tgt_language: False
validation:
huggingface_load:
# Use the last 500 of the KSL examples for validation.
- path: EzekielMW/Eksl_dataset
split: train[-1000:]
# Add some Swahili validation text.
- path: sunbird/salt
name: text-all
split: dev
source:
type: text
language: [swa,ksl,eng]
preprocessing:
- lower_case
target:
type: text
language: [swa,ksl,eng]
preprocessing:
- lower_case
allow_same_src_and_tgt_language: False
'''
yaml_config = yaml_config.format(
drive_folder=drive_folder,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
)
config = yaml.safe_load(yaml_config)
training_settings = transformers.Seq2SeqTrainingArguments(
**config["training_args"])
# The pre-trained model that we use has support for some African languages, but
# we need to adapt the tokenizer to languages that it wasn't trained with,
# such as KSL. Here we reuse the token from a different language.
LANGUAGE_CODES = ["eng", "swa", "ksl"]
code_mapping = {
# Exact/close mapping
'eng': 'eng_Latn',
'swa': 'swh_Latn',
# Random mapping
'ksl': 'ace_Latn',
}
tokenizer = transformers.NllbTokenizer.from_pretrained(
config['model_checkpoint'],
src_lang='eng_Latn',
tgt_lang='eng_Latn')
offset = tokenizer.sp_model_size + tokenizer.fairseq_offset
for code in LANGUAGE_CODES:
i = tokenizer.convert_tokens_to_ids(code_mapping[code])
tokenizer._added_tokens_encoder[code] = i
# Define a translation function
def translate(text, source_language, target_language):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
translated_tokens = model.to(device).generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
max_length=100,
num_beams=5,
)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
if target_language == 'ksl':
result = result.upper()
return result
@app.post("/translate")
async def translate_text(request: Request):
data = await request.json()
text = data.get("text")
source_language = data.get("source_language")
target_language = data.get("target_language")
translation = translate(text, source_language, target_language)
return {"translation": translation}
@app.get("/")
async def root():
return {"message": "Welcome to the translation API!"}