Tamil-Chatbot / app.py
Nithish3115's picture
Update app.py
27bdc77 verified
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
# Define paths for storage - avoid persistent folder issues
MODEL_CACHE_DIR = "./model_cache"
HF_HOME_DIR = "./hf_home"
TRANSFORMERS_CACHE_DIR = "./transformers_cache"
# Set environment variables
os.environ["HF_HOME"] = HF_HOME_DIR
os.environ["TRANSFORMERS_CACHE"] = TRANSFORMERS_CACHE_DIR
# Create cache directories if they don't exist
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.makedirs(HF_HOME_DIR, exist_ok=True)
os.makedirs(TRANSFORMERS_CACHE_DIR, exist_ok=True)
# Initialize the model and tokenizer - only when explicitly requested
def initialize_model():
logger.info("Loading model and tokenizer... This may take a few minutes.")
try:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"abhinand/tamil-llama-7b-instruct-v0.2",
cache_dir=MODEL_CACHE_DIR
)
# CPU-friendly configuration
model = AutoModelForCausalLM.from_pretrained(
"abhinand/tamil-llama-7b-instruct-v0.2",
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
cache_dir=MODEL_CACHE_DIR
)
logger.info(f"Model device: {next(model.parameters()).device}")
logger.info("Model and tokenizer loaded successfully!")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {e}")
return None, None
# Generate response
def generate_response(model, tokenizer, user_input, chat_history, temperature=0.2, top_p=1.0, top_k=40):
# Check if model and tokenizer are loaded
if model is None or tokenizer is None:
return "மாதிரி ஏற்றப்படவில்லை. 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்." # Model not loaded
try:
logger.info(f"Generating response for input: {user_input[:50]}...")
# Simple prompt approach to test basic generation
prompt = f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs["attention_mask"].to(model.device)
# Debug info
logger.info(f"Input shape: {input_ids.shape}")
logger.info(f"Device: {input_ids.device}")
# Generate response with user-specified parameters
with torch.no_grad():
output_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=100, # Start with a smaller value for testing
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.eos_token_id
)
# Get only the generated part
new_tokens = output_ids[0, input_ids.shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
logger.info(f"Generated response (raw): {response}")
# Clean up response if needed
if "<|im_end|>" in response:
response = response.split("<|im_end|>")[0].strip()
logger.info(f"Final response: {response}")
# Fallback if empty response
if not response or response.isspace():
logger.warning("Empty response generated, returning fallback message")
return "வருந்துகிறேன், பதிலை உருவாக்குவதில் சிக்கல் உள்ளது. மீண்டும் முயற்சிக்கவும்." # Sorry, there was a problem generating a response
return response
except Exception as e:
logger.error(f"Error generating response: {e}", exc_info=True)
return f"பிழை ஏற்பட்டது: {str(e)}" # Error occurred
# Create the Gradio interface
def create_chatbot_interface():
with gr.Blocks() as demo:
title = "# தமிழ் உரையாடல் பொத்தான் (Tamil Chatbot)"
description = "Tamil LLaMA 7B Instruct model with user-controlled generation parameters."
gr.Markdown(title)
gr.Markdown(description)
# Add a direct testing area to debug the model
with gr.Tab("Debug Mode"):
with gr.Row():
debug_status = gr.Markdown("⚠️ Debug Mode - Model not loaded")
debug_load_model_btn = gr.Button("Load Model (Debug)")
debug_model = gr.State(None)
debug_tokenizer = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
debug_input = gr.Textbox(label="Input Text", lines=3)
debug_submit = gr.Button("Generate Response")
with gr.Column(scale=3):
debug_output = gr.Textbox(label="Raw Output", lines=8)
def debug_load_model_fn():
m, t = initialize_model()
if m is not None and t is not None:
return "✅ Debug Model loaded", m, t
else:
return "❌ Debug Model loading failed", None, None
def debug_generate(input_text, model, tokenizer):
if model is None:
return "Model not loaded yet. Please load the model first."
try:
# Simple direct generation for testing
prompt = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
inputs["input_ids"],
max_new_tokens=100,
temperature=0.2,
do_sample=True
)
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
response = full_output[len(prompt):]
# Log the full output for debugging
logger.info(f"Debug full output: {full_output}")
return f"FULL OUTPUT:\n{full_output}\n\nEXTRACTED:\n{response}"
except Exception as e:
logger.error(f"Debug error: {e}", exc_info=True)
return f"Error: {str(e)}"
debug_load_model_btn.click(
debug_load_model_fn,
outputs=[debug_status, debug_model, debug_tokenizer]
)
debug_submit.click(
debug_generate,
inputs=[debug_input, debug_model, debug_tokenizer],
outputs=[debug_output]
)
# Regular chatbot interface
with gr.Tab("Chatbot"):
# Model loading indicator
with gr.Row():
model_status = gr.Markdown("⚠️ மாதிரி ஏற்றப்படவில்லை (Model not loaded)")
load_model_btn = gr.Button("மாதிரியை ஏற்று (Load Model)")
# Model and tokenizer states
model = gr.State(None)
tokenizer = gr.State(None)
# Parameter sliders
with gr.Accordion("Generation Parameters", open=False):
temperature = gr.Slider(
label="temperature",
value=0.2,
minimum=0.0,
maximum=2.0,
step=0.05,
interactive=True
)
top_p = gr.Slider(
label="top_p",
value=1.0,
minimum=0.0,
maximum=1.0,
step=0.01,
interactive=True
)
top_k = gr.Slider(
label="top_k",
value=40,
minimum=0,
maximum=1000,
step=1,
interactive=True
)
# Function to load model on button click
def load_model_fn():
m, t = initialize_model()
if m is not None and t is not None:
return "✅ மாதிரி வெற்றிகரமாக ஏற்றப்பட்டது (Model loaded successfully)", m, t
else:
return "❌ மாதிரி ஏற்றுவதில் பிழை (Error loading model)", None, None
# Function to respond to user messages - with error handling
def chat_function(message, history, model_state, tokenizer_state, temp, tp, tk):
if not message.strip():
return "", history
try:
# Check if model is loaded
if model_state is None:
bot_message = "மாதிரி ஏற்றப்படவில்லை. முதலில் 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்."
else:
# Generate bot response with parameters
bot_message = generate_response(
model_state,
tokenizer_state,
message,
history,
temperature=temp,
top_p=tp,
top_k=tk
)
# Format for message-style chatbot
return "", history + [[message, bot_message]]
except Exception as e:
logger.error(f"Chat function error: {e}", exc_info=True)
return "", history + [[message, f"Error: {str(e)}"]]
# Create the chat interface with modern message format
chatbot = gr.Chatbot(type="messages")
msg = gr.TextArea(
placeholder="உங்கள் செய்தி இங்கே தட்டச்சு செய்யவும் (Type your message here...)",
lines=3
)
clear = gr.Button("அழி (Clear)")
# Set up the chat interface
msg.submit(
chat_function,
[msg, chatbot, model, tokenizer, temperature, top_p, top_k],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
# Connect the model loading button
load_model_btn.click(
load_model_fn,
outputs=[model_status, model, tokenizer]
)
return demo
# Create and launch the demo
demo = create_chatbot_interface()
# Launch the demo
if __name__ == "__main__":
demo.queue(max_size=5).launch()