File size: 12,078 Bytes
377b4e3
 
 
 
e0f9a34
 
 
 
 
 
 
377b4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f9a34
1d587af
377b4e3
 
 
 
 
 
 
 
 
 
 
e0f9a34
377b4e3
 
 
 
e0f9a34
 
377b4e3
 
e0f9a34
377b4e3
 
 
 
 
 
e0f9a34
0d297db
377b4e3
e0f9a34
 
 
 
 
377b4e3
e0f9a34
1659873
 
377b4e3
e0f9a34
 
 
 
377b4e3
 
e0f9a34
1659873
 
e0f9a34
377b4e3
 
 
 
e0f9a34
377b4e3
 
e0f9a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1659873
377b4e3
e0f9a34
 
377b4e3
 
 
e0f9a34
377b4e3
 
 
 
 
 
e0f9a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fceb31b
e0f9a34
 
 
377b4e3
fceb31b
e0f9a34
 
 
 
377b4e3
 
e0f9a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377b4e3
2520afc
e0f9a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27bdc77
 
e0f9a34
 
 
27bdc77
e0f9a34
27bdc77
 
e0f9a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377b4e3
105046d
377b4e3
 
105046d
377b4e3
 
27bdc77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)

# Define paths for storage - avoid persistent folder issues
MODEL_CACHE_DIR = "./model_cache"
HF_HOME_DIR = "./hf_home"
TRANSFORMERS_CACHE_DIR = "./transformers_cache"

# Set environment variables
os.environ["HF_HOME"] = HF_HOME_DIR
os.environ["TRANSFORMERS_CACHE"] = TRANSFORMERS_CACHE_DIR

# Create cache directories if they don't exist
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.makedirs(HF_HOME_DIR, exist_ok=True)
os.makedirs(TRANSFORMERS_CACHE_DIR, exist_ok=True)

# Initialize the model and tokenizer - only when explicitly requested
def initialize_model():
    logger.info("Loading model and tokenizer... This may take a few minutes.")
    
    try:
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            "abhinand/tamil-llama-7b-instruct-v0.2",
            cache_dir=MODEL_CACHE_DIR
        )
        
        # CPU-friendly configuration
        model = AutoModelForCausalLM.from_pretrained(
            "abhinand/tamil-llama-7b-instruct-v0.2",
            device_map="auto",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
            cache_dir=MODEL_CACHE_DIR
        )
        
        logger.info(f"Model device: {next(model.parameters()).device}")
        logger.info("Model and tokenizer loaded successfully!")
        return model, tokenizer
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return None, None

# Generate response
def generate_response(model, tokenizer, user_input, chat_history, temperature=0.2, top_p=1.0, top_k=40):
    # Check if model and tokenizer are loaded
    if model is None or tokenizer is None:
        return "மாதிரி ஏற்றப்படவில்லை. 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்."  # Model not loaded
    
    try:
        logger.info(f"Generating response for input: {user_input[:50]}...")
        
        # Simple prompt approach to test basic generation
        prompt = f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n"
        
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(model.device)
        attention_mask = inputs["attention_mask"].to(model.device)
        
        # Debug info
        logger.info(f"Input shape: {input_ids.shape}")
        logger.info(f"Device: {input_ids.device}")
        
        # Generate response with user-specified parameters
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,  # Start with a smaller value for testing
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Get only the generated part
        new_tokens = output_ids[0, input_ids.shape[1]:]
        response = tokenizer.decode(new_tokens, skip_special_tokens=True)
        
        logger.info(f"Generated response (raw): {response}")
        
        # Clean up response if needed
        if "<|im_end|>" in response:
            response = response.split("<|im_end|>")[0].strip()
        
        logger.info(f"Final response: {response}")
        
        # Fallback if empty response
        if not response or response.isspace():
            logger.warning("Empty response generated, returning fallback message")
            return "வருந்துகிறேன், பதிலை உருவாக்குவதில் சிக்கல் உள்ளது. மீண்டும் முயற்சிக்கவும்."  # Sorry, there was a problem generating a response
        
        return response
        
    except Exception as e:
        logger.error(f"Error generating response: {e}", exc_info=True)
        return f"பிழை ஏற்பட்டது: {str(e)}"  # Error occurred

# Create the Gradio interface
def create_chatbot_interface():
    with gr.Blocks() as demo:
        title = "# தமிழ் உரையாடல் பொத்தான் (Tamil Chatbot)"
        description = "Tamil LLaMA 7B Instruct model with user-controlled generation parameters."
        
        gr.Markdown(title)
        gr.Markdown(description)
        
        # Add a direct testing area to debug the model
        with gr.Tab("Debug Mode"):
            with gr.Row():
                debug_status = gr.Markdown("⚠️ Debug Mode - Model not loaded")
                debug_load_model_btn = gr.Button("Load Model (Debug)")
            
            debug_model = gr.State(None)
            debug_tokenizer = gr.State(None)
            
            with gr.Row():
                with gr.Column(scale=3):
                    debug_input = gr.Textbox(label="Input Text", lines=3)
                    debug_submit = gr.Button("Generate Response")
                with gr.Column(scale=3):
                    debug_output = gr.Textbox(label="Raw Output", lines=8)
            
            def debug_load_model_fn():
                m, t = initialize_model()
                if m is not None and t is not None:
                    return "✅ Debug Model loaded", m, t
                else:
                    return "❌ Debug Model loading failed", None, None
            
            def debug_generate(input_text, model, tokenizer):
                if model is None:
                    return "Model not loaded yet. Please load the model first."
                
                try:
                    # Simple direct generation for testing
                    prompt = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"
                    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
                    
                    with torch.no_grad():
                        output_ids = model.generate(
                            inputs["input_ids"],
                            max_new_tokens=100,
                            temperature=0.2,
                            do_sample=True
                        )
                    
                    full_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
                    response = full_output[len(prompt):]
                    
                    # Log the full output for debugging
                    logger.info(f"Debug full output: {full_output}")
                    
                    return f"FULL OUTPUT:\n{full_output}\n\nEXTRACTED:\n{response}"
                except Exception as e:
                    logger.error(f"Debug error: {e}", exc_info=True)
                    return f"Error: {str(e)}"
            
            debug_load_model_btn.click(
                debug_load_model_fn,
                outputs=[debug_status, debug_model, debug_tokenizer]
            )
            
            debug_submit.click(
                debug_generate,
                inputs=[debug_input, debug_model, debug_tokenizer],
                outputs=[debug_output]
            )
        
        # Regular chatbot interface
        with gr.Tab("Chatbot"):
            # Model loading indicator
            with gr.Row():
                model_status = gr.Markdown("⚠️ மாதிரி ஏற்றப்படவில்லை (Model not loaded)")
                load_model_btn = gr.Button("மாதிரியை ஏற்று (Load Model)")
            
            # Model and tokenizer states
            model = gr.State(None)
            tokenizer = gr.State(None)
            
            # Parameter sliders
            with gr.Accordion("Generation Parameters", open=False):
                temperature = gr.Slider(
                    label="temperature", 
                    value=0.2, 
                    minimum=0.0, 
                    maximum=2.0, 
                    step=0.05, 
                    interactive=True
                )
                
                top_p = gr.Slider(
                    label="top_p", 
                    value=1.0, 
                    minimum=0.0, 
                    maximum=1.0, 
                    step=0.01, 
                    interactive=True
                )
                
                top_k = gr.Slider(
                    label="top_k", 
                    value=40, 
                    minimum=0, 
                    maximum=1000, 
                    step=1, 
                    interactive=True
                )
            
            # Function to load model on button click
            def load_model_fn():
                m, t = initialize_model()
                if m is not None and t is not None:
                    return "✅ மாதிரி வெற்றிகரமாக ஏற்றப்பட்டது (Model loaded successfully)", m, t
                else:
                    return "❌ மாதிரி ஏற்றுவதில் பிழை (Error loading model)", None, None
            
            # Function to respond to user messages - with error handling
            def chat_function(message, history, model_state, tokenizer_state, temp, tp, tk):
                if not message.strip():
                    return "", history
                
                try:
                    # Check if model is loaded
                    if model_state is None:
                        bot_message = "மாதிரி ஏற்றப்படவில்லை. முதலில் 'மாதிரியை ஏற்று' பொத்தானைக் கிளிக் செய்யவும்."
                    else:
                        # Generate bot response with parameters
                        bot_message = generate_response(
                            model_state, 
                            tokenizer_state, 
                            message, 
                            history,
                            temperature=temp, 
                            top_p=tp, 
                            top_k=tk
                        )
                    
                    # Format for message-style chatbot
                    return "", history + [[message, bot_message]]
                    
                except Exception as e:
                    logger.error(f"Chat function error: {e}", exc_info=True)
                    return "", history + [[message, f"Error: {str(e)}"]]
            
            # Create the chat interface with modern message format
            chatbot = gr.Chatbot(type="messages")
            msg = gr.TextArea(
                placeholder="உங்கள் செய்தி இங்கே தட்டச்சு செய்யவும் (Type your message here...)",
                lines=3
            )
            clear = gr.Button("அழி (Clear)")
            
            # Set up the chat interface
            msg.submit(
                chat_function,
                [msg, chatbot, model, tokenizer, temperature, top_p, top_k],
                [msg, chatbot]
            )
            
            clear.click(lambda: None, None, chatbot, queue=False)
            
            # Connect the model loading button
            load_model_btn.click(
                load_model_fn, 
                outputs=[model_status, model, tokenizer]
            )
    
    return demo

# Create and launch the demo
demo = create_chatbot_interface()

# Launch the demo
if __name__ == "__main__":
    demo.queue(max_size=5).launch()