Spaces:
Running
on
T4
Running
on
T4
File size: 8,278 Bytes
7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 12f7ab8 1503fc7 12f7ab8 1503fc7 12f7ab8 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 12f7ab8 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 7a44d8e 1503fc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import spaces
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
import gradio as gr
import gc
# Global variables for models
nllb_tokenizer = None
nllb_model = None
llama_tokenizer = None
llama_model = None
flores_dict = {}
def load_models():
"""Load all models once at startup"""
global nllb_tokenizer, nllb_model, llama_tokenizer, llama_model
print("Loading NLLB translation model...")
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
# load_in_4bit=True,
device_map="auto"
)
print("Loading Llama model...")
model_id = "meta-llama/Llama-3.2-3B-Instruct"
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
if llama_tokenizer.pad_token is None:
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_model = AutoModelForCausalLM.from_pretrained(
model_id,
# load_in_4bit=True,
device_map="auto",
torch_dtype=torch.bfloat16
)
print("Models loaded successfully!")
def load_language_keys():
"""Load FLORES language mappings"""
global flores_dict
try:
lang_keys = pd.read_csv('flores_200_keys.csv', header=None)
flores_dict = {}
for i in range(len(lang_keys)):
flores_dict[lang_keys.loc[i][0]] = lang_keys.loc[i][1]
except FileNotFoundError:
# Fallback with common languages if CSV not found
flores_dict = {
"English": "eng_Latn",
"Spanish": "spa_Latn",
"French": "fra_Latn",
"German": "deu_Latn",
"Italian": "ita_Latn",
"Portuguese": "por_Latn",
"Russian": "rus_Cyrl",
"Chinese (Simplified)": "zho_Hans",
"Japanese": "jpn_Jpan",
"Korean": "kor_Hang",
"Arabic": "arb_Arab",
"Hindi": "hin_Deva"
}
def translate_to_lang(input_str, target_lang):
"""
Efficient translation function without GPU decorator
"""
if target_lang not in nllb_tokenizer.additional_special_tokens:
return f"Error: {target_lang} is not a valid FLORES 200 language!"
# Move inputs to the same device as model
device = next(nllb_model.parameters()).device
inputs = nllb_tokenizer(input_str, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
translated_tokens = nllb_model.generate(
**inputs,
forced_bos_token_id=nllb_tokenizer.convert_tokens_to_ids(target_lang),
max_new_tokens=512,
do_sample=False,
num_beams=1
)
output_str = nllb_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return output_str
def llama_QA(input_question):
"""
Efficient Llama QA without pipeline overhead
"""
messages = [
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in."},
{"role": "user", "content": input_question},
]
# Format the conversation manually for better control
formatted_prompt = llama_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Move inputs to the same device as model
device = next(llama_model.parameters()).device
inputs = llama_tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = llama_model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=llama_tokenizer.eos_token_id
)
# Extract only the new tokens (response)
response_tokens = outputs[0][inputs['input_ids'].shape[1]:]
response = llama_tokenizer.decode(response_tokens, skip_special_tokens=True)
return response.strip()
@spaces.GPU
def process_multilang_qa(input_question, left_lang, right_lang):
"""
Single GPU-decorated function that handles the entire pipeline
"""
try:
# Get FLORES codes
left_flores = flores_dict.get(left_lang, left_lang)
right_flores = flores_dict.get(right_lang, right_lang)
# Process left language
if left_flores == 'eng_Latn':
left_translated_q = input_question
else:
left_translated_q = translate_to_lang(input_question, left_flores)
left_response = llama_QA(left_translated_q)
if left_flores == 'eng_Latn':
left_final = left_response
else:
left_final = translate_to_lang(left_response, 'eng_Latn')
# Clear some memory between operations
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Process right language
if right_flores == 'eng_Latn':
right_translated_q = input_question
else:
right_translated_q = translate_to_lang(input_question, right_flores)
right_response = llama_QA(right_translated_q)
if right_flores == 'eng_Latn':
right_final = right_response
else:
right_final = translate_to_lang(right_response, 'eng_Latn')
return left_final, right_final
except Exception as e:
error_msg = f"Error processing request: {str(e)}"
return error_msg, error_msg
def create_interface():
"""Create Gradio interface"""
language_choices = list(flores_dict.keys())
with gr.Blocks(title="Multi-language QA with Llama") as demo:
with gr.Row():
question_input = gr.Textbox(
label="Enter your question (in English)",
placeholder="What is the capital of France?",
lines=2
)
with gr.Row():
left_lang = gr.Dropdown(
choices=language_choices,
label="Language #1",
value=language_choices[0] if language_choices else None
)
right_lang = gr.Dropdown(
choices=language_choices,
label="Language #2",
value=language_choices[1] if len(language_choices) > 1 else language_choices[0]
)
with gr.Row():
submit_btn = gr.Button("Ask Llama!", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Row():
left_output = gr.Textbox(
label="Response via Language #1",
interactive=False,
lines=4
)
right_output = gr.Textbox(
label="Response via Language #2",
interactive=False,
lines=4
)
# Event handlers
submit_btn.click(
fn=process_multilang_qa,
inputs=[question_input, left_lang, right_lang],
outputs=[left_output, right_output]
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[question_input, left_output, right_output]
)
# # Examples
# gr.Examples(
# examples=[
# ["What is the meaning of life?", "Spanish", "French"],
# ["How do you cook pasta?", "Italian", "Japanese"],
# ["What is artificial intelligence?", "German", "Chinese (Simplified)"]
# ],
# inputs=[question_input, left_lang, right_lang]
# )
return demo
# Initialize everything
if __name__ == "__main__":
print("Initializing models and language mappings...")
load_language_keys()
load_models()
# Launch the app
demo = create_interface()
demo.launch(
)
|