|
from fastapi import FastAPI, Request |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import torch |
|
import os |
|
import yaml |
|
import transformers |
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset") |
|
|
|
|
|
drive_folder = "./serverlogs" |
|
|
|
if not os.path.exists(drive_folder): |
|
os.makedirs(drive_folder) |
|
|
|
|
|
|
|
effective_train_batch_size = 480 |
|
train_batch_size = 6 |
|
eval_batch_size = train_batch_size |
|
|
|
gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size) |
|
|
|
|
|
yaml_config = ''' |
|
training_args: |
|
output_dir: "{drive_folder}" |
|
eval_strategy: steps |
|
eval_steps: 100 |
|
save_steps: 100 |
|
gradient_accumulation_steps: {gradient_accumulation_steps} |
|
learning_rate: 3.0e-4 # Include decimal point to parse as float |
|
# optim: adafactor |
|
per_device_train_batch_size: {train_batch_size} |
|
per_device_eval_batch_size: {eval_batch_size} |
|
weight_decay: 0.01 |
|
save_total_limit: 3 |
|
max_steps: 500 |
|
predict_with_generate: True |
|
fp16: True |
|
logging_dir: "{drive_folder}" |
|
load_best_model_at_end: True |
|
metric_for_best_model: loss |
|
seed: 123 |
|
push_to_hub: False |
|
|
|
max_input_length: 128 |
|
eval_pretrained_model: False |
|
early_stopping_patience: 4 |
|
data_dir: . |
|
|
|
# Use a 600M parameter model here, which is easier to train on a free Colab |
|
# instance. Bigger models work better, however: results will be improved |
|
# if able to train on nllb-200-1.3B instead. |
|
model_checkpoint: facebook/nllb-200-distilled-600M |
|
|
|
datasets: |
|
train: |
|
huggingface_load: |
|
# We will load two datasets here: English/KSL Gloss, and also SALT |
|
# Swahili/English, so that we can try out multi-way translation. |
|
|
|
- path: EzekielMW/Eksl_dataset |
|
split: train[:-1000] |
|
- path: sunbird/salt |
|
name: text-all |
|
split: train |
|
source: |
|
# This is a text translation only, no audio. |
|
type: text |
|
# The source text can be any of English, KSL or Swahili. |
|
language: [eng,ksl,swa] |
|
preprocessing: |
|
# The models are case sensitive, so if the training text is all |
|
# capitals, then it will only learn to translate capital letters and |
|
# won't understand lower case. Make everything lower case for now. |
|
- lower_case |
|
# We can also augment the spelling of the input text, which makes the |
|
# model more robust to spelling errors. |
|
- augment_characters |
|
target: |
|
type: text |
|
# The target text with any of English, KSL or Swahili. |
|
language: [eng,ksl,swa] |
|
# The models are case sensitive: make everything lower case for now. |
|
preprocessing: |
|
- lower_case |
|
|
|
shuffle: True |
|
allow_same_src_and_tgt_language: False |
|
|
|
validation: |
|
huggingface_load: |
|
# Use the last 500 of the KSL examples for validation. |
|
- path: EzekielMW/Eksl_dataset |
|
split: train[-1000:] |
|
# Add some Swahili validation text. |
|
- path: sunbird/salt |
|
name: text-all |
|
split: dev |
|
source: |
|
type: text |
|
language: [swa,ksl,eng] |
|
preprocessing: |
|
- lower_case |
|
target: |
|
type: text |
|
language: [swa,ksl,eng] |
|
preprocessing: |
|
- lower_case |
|
allow_same_src_and_tgt_language: False |
|
''' |
|
|
|
yaml_config = yaml_config.format( |
|
drive_folder=drive_folder, |
|
train_batch_size=train_batch_size, |
|
eval_batch_size=eval_batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
) |
|
|
|
config = yaml.safe_load(yaml_config) |
|
|
|
training_settings = transformers.Seq2SeqTrainingArguments( |
|
**config["training_args"]) |
|
|
|
|
|
|
|
LANGUAGE_CODES = ["eng", "swa", "ksl"] |
|
|
|
code_mapping = { |
|
|
|
'eng': 'eng_Latn', |
|
'swa': 'swh_Latn', |
|
|
|
'ksl': 'ace_Latn', |
|
} |
|
tokenizer = transformers.NllbTokenizer.from_pretrained( |
|
config['model_checkpoint'], |
|
src_lang='eng_Latn', |
|
tgt_lang='eng_Latn') |
|
|
|
offset = tokenizer.sp_model_size + tokenizer.fairseq_offset |
|
|
|
for code in LANGUAGE_CODES: |
|
i = tokenizer.convert_tokens_to_ids(code_mapping[code]) |
|
tokenizer._added_tokens_encoder[code] = i |
|
|
|
|
|
def translate(text, source_language, target_language): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inputs = tokenizer(text.lower(), return_tensors="pt").to(device) |
|
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language) |
|
translated_tokens = model.to(device).generate( |
|
**inputs, |
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language), |
|
max_length=100, |
|
num_beams=5, |
|
) |
|
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
|
if target_language == 'ksl': |
|
result = result.upper() |
|
|
|
return result |
|
|
|
@app.post("/translate") |
|
async def translate_text(request: Request): |
|
data = await request.json() |
|
text = data.get("text") |
|
source_language = data.get("source_language") |
|
target_language = data.get("target_language") |
|
|
|
translation = translate(text, source_language, target_language) |
|
return {"translation": translation} |
|
|
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Welcome to the translation API!"} |
|
|
|
|
|
|
|
|
|
|
|
|