Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import pandas as pd | |
import gradio as gr | |
import gc | |
# Global variables for models | |
nllb_tokenizer = None | |
nllb_model = None | |
llama_tokenizer = None | |
llama_model = None | |
flores_dict = {} | |
def load_models(): | |
"""Load all models once at startup""" | |
global nllb_tokenizer, nllb_model, llama_tokenizer, llama_model | |
print("Loading NLLB translation model...") | |
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
# load_in_4bit=True, | |
device_map="auto" | |
) | |
print("Loading Llama model...") | |
model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
llama_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if llama_tokenizer.pad_token is None: | |
llama_tokenizer.pad_token = llama_tokenizer.eos_token | |
llama_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
# load_in_4bit=True, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
print("Models loaded successfully!") | |
def load_language_keys(): | |
"""Load FLORES language mappings""" | |
global flores_dict | |
try: | |
lang_keys = pd.read_csv('flores_200_keys.csv', header=None) | |
flores_dict = {} | |
for i in range(len(lang_keys)): | |
flores_dict[lang_keys.loc[i][0]] = lang_keys.loc[i][1] | |
except FileNotFoundError: | |
# Fallback with common languages if CSV not found | |
flores_dict = { | |
"English": "eng_Latn", | |
"Spanish": "spa_Latn", | |
"French": "fra_Latn", | |
"German": "deu_Latn", | |
"Italian": "ita_Latn", | |
"Portuguese": "por_Latn", | |
"Russian": "rus_Cyrl", | |
"Chinese (Simplified)": "zho_Hans", | |
"Japanese": "jpn_Jpan", | |
"Korean": "kor_Hang", | |
"Arabic": "arb_Arab", | |
"Hindi": "hin_Deva" | |
} | |
def translate_to_lang(input_str, target_lang): | |
""" | |
Efficient translation function without GPU decorator | |
""" | |
if target_lang not in nllb_tokenizer.additional_special_tokens: | |
return f"Error: {target_lang} is not a valid FLORES 200 language!" | |
# Move inputs to the same device as model | |
device = next(nllb_model.parameters()).device | |
inputs = nllb_tokenizer(input_str, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
translated_tokens = nllb_model.generate( | |
**inputs, | |
forced_bos_token_id=nllb_tokenizer.convert_tokens_to_ids(target_lang), | |
max_new_tokens=512, | |
do_sample=False, | |
num_beams=1 | |
) | |
output_str = nllb_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
return output_str | |
def llama_QA(input_question): | |
""" | |
Efficient Llama QA without pipeline overhead | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in."}, | |
{"role": "user", "content": input_question}, | |
] | |
# Format the conversation manually for better control | |
formatted_prompt = llama_tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Move inputs to the same device as model | |
device = next(llama_model.parameters()).device | |
inputs = llama_tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=2048 | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = llama_model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=llama_tokenizer.eos_token_id | |
) | |
# Extract only the new tokens (response) | |
response_tokens = outputs[0][inputs['input_ids'].shape[1]:] | |
response = llama_tokenizer.decode(response_tokens, skip_special_tokens=True) | |
return response.strip() | |
def process_multilang_qa(input_question, left_lang, right_lang): | |
""" | |
Single GPU-decorated function that handles the entire pipeline | |
""" | |
try: | |
# Get FLORES codes | |
left_flores = flores_dict.get(left_lang, left_lang) | |
right_flores = flores_dict.get(right_lang, right_lang) | |
# Process left language | |
if left_flores == 'eng_Latn': | |
left_translated_q = input_question | |
else: | |
left_translated_q = translate_to_lang(input_question, left_flores) | |
left_response = llama_QA(left_translated_q) | |
if left_flores == 'eng_Latn': | |
left_final = left_response | |
else: | |
left_final = translate_to_lang(left_response, 'eng_Latn') | |
# Clear some memory between operations | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Process right language | |
if right_flores == 'eng_Latn': | |
right_translated_q = input_question | |
else: | |
right_translated_q = translate_to_lang(input_question, right_flores) | |
right_response = llama_QA(right_translated_q) | |
if right_flores == 'eng_Latn': | |
right_final = right_response | |
else: | |
right_final = translate_to_lang(right_response, 'eng_Latn') | |
return left_final, right_final | |
except Exception as e: | |
error_msg = f"Error processing request: {str(e)}" | |
return error_msg, error_msg | |
def create_interface(): | |
"""Create Gradio interface""" | |
language_choices = list(flores_dict.keys()) | |
with gr.Blocks(title="Multi-language QA with Llama") as demo: | |
with gr.Row(): | |
question_input = gr.Textbox( | |
label="Enter your question (in English)", | |
placeholder="What is the capital of France?", | |
lines=2 | |
) | |
with gr.Row(): | |
left_lang = gr.Dropdown( | |
choices=language_choices, | |
label="Language #1", | |
value=language_choices[0] if language_choices else None | |
) | |
right_lang = gr.Dropdown( | |
choices=language_choices, | |
label="Language #2", | |
value=language_choices[1] if len(language_choices) > 1 else language_choices[0] | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Ask Llama!", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
with gr.Row(): | |
left_output = gr.Textbox( | |
label="Response via Language #1", | |
interactive=False, | |
lines=4 | |
) | |
right_output = gr.Textbox( | |
label="Response via Language #2", | |
interactive=False, | |
lines=4 | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=process_multilang_qa, | |
inputs=[question_input, left_lang, right_lang], | |
outputs=[left_output, right_output] | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", ""), | |
outputs=[question_input, left_output, right_output] | |
) | |
# # Examples | |
# gr.Examples( | |
# examples=[ | |
# ["What is the meaning of life?", "Spanish", "French"], | |
# ["How do you cook pasta?", "Italian", "Japanese"], | |
# ["What is artificial intelligence?", "German", "Chinese (Simplified)"] | |
# ], | |
# inputs=[question_input, left_lang, right_lang] | |
# ) | |
return demo | |
# Initialize everything | |
if __name__ == "__main__": | |
print("Initializing models and language mappings...") | |
load_language_keys() | |
load_models() | |
# Launch the app | |
demo = create_interface() | |
demo.launch( | |
) | |