llm_multilingual_demo / llm_translate_gradio.py
willsh1997's picture
:clown_face: remove load in 4 bit, change dtype to bfloat16
12f7ab8
import spaces
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
import gradio as gr
import gc
# Global variables for models
nllb_tokenizer = None
nllb_model = None
llama_tokenizer = None
llama_model = None
flores_dict = {}
def load_models():
"""Load all models once at startup"""
global nllb_tokenizer, nllb_model, llama_tokenizer, llama_model
print("Loading NLLB translation model...")
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
# load_in_4bit=True,
device_map="auto"
)
print("Loading Llama model...")
model_id = "meta-llama/Llama-3.2-3B-Instruct"
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
if llama_tokenizer.pad_token is None:
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_model = AutoModelForCausalLM.from_pretrained(
model_id,
# load_in_4bit=True,
device_map="auto",
torch_dtype=torch.bfloat16
)
print("Models loaded successfully!")
def load_language_keys():
"""Load FLORES language mappings"""
global flores_dict
try:
lang_keys = pd.read_csv('flores_200_keys.csv', header=None)
flores_dict = {}
for i in range(len(lang_keys)):
flores_dict[lang_keys.loc[i][0]] = lang_keys.loc[i][1]
except FileNotFoundError:
# Fallback with common languages if CSV not found
flores_dict = {
"English": "eng_Latn",
"Spanish": "spa_Latn",
"French": "fra_Latn",
"German": "deu_Latn",
"Italian": "ita_Latn",
"Portuguese": "por_Latn",
"Russian": "rus_Cyrl",
"Chinese (Simplified)": "zho_Hans",
"Japanese": "jpn_Jpan",
"Korean": "kor_Hang",
"Arabic": "arb_Arab",
"Hindi": "hin_Deva"
}
def translate_to_lang(input_str, target_lang):
"""
Efficient translation function without GPU decorator
"""
if target_lang not in nllb_tokenizer.additional_special_tokens:
return f"Error: {target_lang} is not a valid FLORES 200 language!"
# Move inputs to the same device as model
device = next(nllb_model.parameters()).device
inputs = nllb_tokenizer(input_str, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
translated_tokens = nllb_model.generate(
**inputs,
forced_bos_token_id=nllb_tokenizer.convert_tokens_to_ids(target_lang),
max_new_tokens=512,
do_sample=False,
num_beams=1
)
output_str = nllb_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return output_str
def llama_QA(input_question):
"""
Efficient Llama QA without pipeline overhead
"""
messages = [
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in."},
{"role": "user", "content": input_question},
]
# Format the conversation manually for better control
formatted_prompt = llama_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Move inputs to the same device as model
device = next(llama_model.parameters()).device
inputs = llama_tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = llama_model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=llama_tokenizer.eos_token_id
)
# Extract only the new tokens (response)
response_tokens = outputs[0][inputs['input_ids'].shape[1]:]
response = llama_tokenizer.decode(response_tokens, skip_special_tokens=True)
return response.strip()
@spaces.GPU
def process_multilang_qa(input_question, left_lang, right_lang):
"""
Single GPU-decorated function that handles the entire pipeline
"""
try:
# Get FLORES codes
left_flores = flores_dict.get(left_lang, left_lang)
right_flores = flores_dict.get(right_lang, right_lang)
# Process left language
if left_flores == 'eng_Latn':
left_translated_q = input_question
else:
left_translated_q = translate_to_lang(input_question, left_flores)
left_response = llama_QA(left_translated_q)
if left_flores == 'eng_Latn':
left_final = left_response
else:
left_final = translate_to_lang(left_response, 'eng_Latn')
# Clear some memory between operations
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Process right language
if right_flores == 'eng_Latn':
right_translated_q = input_question
else:
right_translated_q = translate_to_lang(input_question, right_flores)
right_response = llama_QA(right_translated_q)
if right_flores == 'eng_Latn':
right_final = right_response
else:
right_final = translate_to_lang(right_response, 'eng_Latn')
return left_final, right_final
except Exception as e:
error_msg = f"Error processing request: {str(e)}"
return error_msg, error_msg
def create_interface():
"""Create Gradio interface"""
language_choices = list(flores_dict.keys())
with gr.Blocks(title="Multi-language QA with Llama") as demo:
with gr.Row():
question_input = gr.Textbox(
label="Enter your question (in English)",
placeholder="What is the capital of France?",
lines=2
)
with gr.Row():
left_lang = gr.Dropdown(
choices=language_choices,
label="Language #1",
value=language_choices[0] if language_choices else None
)
right_lang = gr.Dropdown(
choices=language_choices,
label="Language #2",
value=language_choices[1] if len(language_choices) > 1 else language_choices[0]
)
with gr.Row():
submit_btn = gr.Button("Ask Llama!", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Row():
left_output = gr.Textbox(
label="Response via Language #1",
interactive=False,
lines=4
)
right_output = gr.Textbox(
label="Response via Language #2",
interactive=False,
lines=4
)
# Event handlers
submit_btn.click(
fn=process_multilang_qa,
inputs=[question_input, left_lang, right_lang],
outputs=[left_output, right_output]
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[question_input, left_output, right_output]
)
# # Examples
# gr.Examples(
# examples=[
# ["What is the meaning of life?", "Spanish", "French"],
# ["How do you cook pasta?", "Italian", "Japanese"],
# ["What is artificial intelligence?", "German", "Chinese (Simplified)"]
# ],
# inputs=[question_input, left_lang, right_lang]
# )
return demo
# Initialize everything
if __name__ == "__main__":
print("Initializing models and language mappings...")
load_language_keys()
load_models()
# Launch the app
demo = create_interface()
demo.launch(
)