Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()]) | |
logger = logging.getLogger(__name__) | |
# Define paths for storage - avoid persistent folder issues | |
MODEL_CACHE_DIR = "./model_cache" | |
HF_HOME_DIR = "./hf_home" | |
TRANSFORMERS_CACHE_DIR = "./transformers_cache" | |
# Set environment variables | |
os.environ["HF_HOME"] = HF_HOME_DIR | |
os.environ["TRANSFORMERS_CACHE"] = TRANSFORMERS_CACHE_DIR | |
# Create cache directories if they don't exist | |
os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
os.makedirs(HF_HOME_DIR, exist_ok=True) | |
os.makedirs(TRANSFORMERS_CACHE_DIR, exist_ok=True) | |
# Initialize the model and tokenizer - only when explicitly requested | |
def initialize_model(): | |
logger.info("Loading model and tokenizer... This may take a few minutes.") | |
try: | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
"abhinand/tamil-llama-7b-instruct-v0.2", | |
cache_dir=MODEL_CACHE_DIR | |
) | |
# CPU-friendly configuration | |
model = AutoModelForCausalLM.from_pretrained( | |
"abhinand/tamil-llama-7b-instruct-v0.2", | |
device_map="auto", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True, | |
cache_dir=MODEL_CACHE_DIR | |
) | |
logger.info(f"Model device: {next(model.parameters()).device}") | |
logger.info("Model and tokenizer loaded successfully!") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return None, None | |
# Generate response | |
def generate_response(model, tokenizer, user_input, chat_history, temperature=0.2, top_p=1.0, top_k=40): | |
# Check if model and tokenizer are loaded | |
if model is None or tokenizer is None: | |
return "மாதிரி ஏற்றப்படவில்லை. 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்." # Model not loaded | |
try: | |
logger.info(f"Generating response for input: {user_input[:50]}...") | |
# Simple prompt approach to test basic generation | |
prompt = f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n" | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(model.device) | |
attention_mask = inputs["attention_mask"].to(model.device) | |
# Debug info | |
logger.info(f"Input shape: {input_ids.shape}") | |
logger.info(f"Device: {input_ids.device}") | |
# Generate response with user-specified parameters | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=100, # Start with a smaller value for testing | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Get only the generated part | |
new_tokens = output_ids[0, input_ids.shape[1]:] | |
response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
logger.info(f"Generated response (raw): {response}") | |
# Clean up response if needed | |
if "<|im_end|>" in response: | |
response = response.split("<|im_end|>")[0].strip() | |
logger.info(f"Final response: {response}") | |
# Fallback if empty response | |
if not response or response.isspace(): | |
logger.warning("Empty response generated, returning fallback message") | |
return "வருந்துகிறேன், பதிலை உருவாக்குவதில் சிக்கல் உள்ளது. மீண்டும் முயற்சிக்கவும்." # Sorry, there was a problem generating a response | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}", exc_info=True) | |
return f"பிழை ஏற்பட்டது: {str(e)}" # Error occurred | |
# Create the Gradio interface | |
def create_chatbot_interface(): | |
with gr.Blocks() as demo: | |
title = "# தமிழ் உரையாடல் பொத்தான் (Tamil Chatbot)" | |
description = "Tamil LLaMA 7B Instruct model with user-controlled generation parameters." | |
gr.Markdown(title) | |
gr.Markdown(description) | |
# Add a direct testing area to debug the model | |
with gr.Tab("Debug Mode"): | |
with gr.Row(): | |
debug_status = gr.Markdown("⚠️ Debug Mode - Model not loaded") | |
debug_load_model_btn = gr.Button("Load Model (Debug)") | |
debug_model = gr.State(None) | |
debug_tokenizer = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
debug_input = gr.Textbox(label="Input Text", lines=3) | |
debug_submit = gr.Button("Generate Response") | |
with gr.Column(scale=3): | |
debug_output = gr.Textbox(label="Raw Output", lines=8) | |
def debug_load_model_fn(): | |
m, t = initialize_model() | |
if m is not None and t is not None: | |
return "✅ Debug Model loaded", m, t | |
else: | |
return "❌ Debug Model loading failed", None, None | |
def debug_generate(input_text, model, tokenizer): | |
if model is None: | |
return "Model not loaded yet. Please load the model first." | |
try: | |
# Simple direct generation for testing | |
prompt = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=100, | |
temperature=0.2, | |
do_sample=True | |
) | |
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=False) | |
response = full_output[len(prompt):] | |
# Log the full output for debugging | |
logger.info(f"Debug full output: {full_output}") | |
return f"FULL OUTPUT:\n{full_output}\n\nEXTRACTED:\n{response}" | |
except Exception as e: | |
logger.error(f"Debug error: {e}", exc_info=True) | |
return f"Error: {str(e)}" | |
debug_load_model_btn.click( | |
debug_load_model_fn, | |
outputs=[debug_status, debug_model, debug_tokenizer] | |
) | |
debug_submit.click( | |
debug_generate, | |
inputs=[debug_input, debug_model, debug_tokenizer], | |
outputs=[debug_output] | |
) | |
# Regular chatbot interface | |
with gr.Tab("Chatbot"): | |
# Model loading indicator | |
with gr.Row(): | |
model_status = gr.Markdown("⚠️ மாதிரி ஏற்றப்படவில்லை (Model not loaded)") | |
load_model_btn = gr.Button("மாதிரியை ஏற்று (Load Model)") | |
# Model and tokenizer states | |
model = gr.State(None) | |
tokenizer = gr.State(None) | |
# Parameter sliders | |
with gr.Accordion("Generation Parameters", open=False): | |
temperature = gr.Slider( | |
label="temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True | |
) | |
top_p = gr.Slider( | |
label="top_p", | |
value=1.0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
interactive=True | |
) | |
top_k = gr.Slider( | |
label="top_k", | |
value=40, | |
minimum=0, | |
maximum=1000, | |
step=1, | |
interactive=True | |
) | |
# Function to load model on button click | |
def load_model_fn(): | |
m, t = initialize_model() | |
if m is not None and t is not None: | |
return "✅ மாதிரி வெற்றிகரமாக ஏற்றப்பட்டது (Model loaded successfully)", m, t | |
else: | |
return "❌ மாதிரி ஏற்றுவதில் பிழை (Error loading model)", None, None | |
# Function to respond to user messages - with error handling | |
def chat_function(message, history, model_state, tokenizer_state, temp, tp, tk): | |
if not message.strip(): | |
return "", history | |
try: | |
# Check if model is loaded | |
if model_state is None: | |
bot_message = "மாதிரி ஏற்றப்படவில்லை. முதலில் 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்." | |
else: | |
# Generate bot response with parameters | |
bot_message = generate_response( | |
model_state, | |
tokenizer_state, | |
message, | |
history, | |
temperature=temp, | |
top_p=tp, | |
top_k=tk | |
) | |
# Format for message-style chatbot | |
return "", history + [[message, bot_message]] | |
except Exception as e: | |
logger.error(f"Chat function error: {e}", exc_info=True) | |
return "", history + [[message, f"Error: {str(e)}"]] | |
# Create the chat interface with modern message format | |
chatbot = gr.Chatbot(type="messages") | |
msg = gr.TextArea( | |
placeholder="உங்கள் செய்தி இங்கே தட்டச்சு செய்யவும் (Type your message here...)", | |
lines=3 | |
) | |
clear = gr.Button("அழி (Clear)") | |
# Set up the chat interface | |
msg.submit( | |
chat_function, | |
[msg, chatbot, model, tokenizer, temperature, top_p, top_k], | |
[msg, chatbot] | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# Connect the model loading button | |
load_model_btn.click( | |
load_model_fn, | |
outputs=[model_status, model, tokenizer] | |
) | |
return demo | |
# Create and launch the demo | |
demo = create_chatbot_interface() | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.queue(max_size=5).launch() |