test / test.py
yapzanan's picture
Upload 4 files
0ee3b61
from flask import Flask, request, jsonify
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes
from flask_cors import CORS
app = Flask(__name__)
model_dict = {}
CORS(app, origins=["http://localhost:3000"])
def load_models():
# build model and tokenizer
model_name_dict = {
'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
# 'nllb-1.3B': 'facebook/nllb-200-1.3B',
# 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
# 'nllb-3.3B': 'facebook/nllb-200-3.3B',
}
for call_name, real_name in model_name_dict.items():
print('\tLoading model: %s' % call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained(real_name)
model_dict[call_name+'_model'] = model
model_dict[call_name+'_tokenizer'] = tokenizer
def translation(source, target, text):
source = flores_codes[source]
target = flores_codes[target]
model_name = 'nllb-distilled-600M'
model = model_dict[model_name + '_model']
tokenizer = model_dict[model_name + '_tokenizer']
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
output = translator(text, max_length=400)
full_output = output
output = output[0]['translation_text']
result = {
'source': source,
'target': target,
'result': output,
'full_output': full_output
}
return result
@app.route("/translate", methods=["POST"])
def translate():
data = request.get_json()
source = data.get("source")
target = data.get("target")
text = data.get("text")
print(source, target, text)
result = translation(source, target, text)
return jsonify(result)
@app.route("/languages", methods=["GET"])
def getlanguages():
return jsonify(list(flores_codes.keys()))
if __name__ == '__main__':
print('\tinit models')
load_models()
app.run()