| """ |
| STANLEY AI - Optimized Flask Backend |
| Deploy on Hugging Face Spaces with fast, smaller models |
| """ |
|
|
| from flask import Flask, request, jsonify, send_file |
| from flask_cors import CORS |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| import torch |
| import time |
| import re |
| import logging |
| from threading import Thread |
| import queue |
| import io |
| import base64 |
| import random |
| from PIL import Image, ImageDraw, ImageFont |
| import os |
| import gc |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| |
| |
| MODEL_CONFIG = { |
| "primary": "Qwen/Qwen2.5-1.8B-Instruct", |
| "fallback": "microsoft/Phi-3-mini-4k-instruct", |
| "tiny": "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| } |
|
|
| model = None |
| tokenizer = None |
| model_loaded = False |
| current_model_name = None |
|
|
| |
| response_cache = {} |
| CACHE_SIZE = 200 |
|
|
| |
| STANLEY_AI_SYSTEM = """You are STANLEY AI - an advanced assistant with Kiswahili cultural knowledge. |
| Provide helpful, concise responses. Integrate Kiswahili phrases naturally when relevant. |
| |
| Key capabilities: |
| - Answer questions knowledgeably |
| - Use Kiswahili for greetings, proverbs, and cultural references |
| - Explain concepts clearly |
| - Be efficient and to the point |
| |
| Format: Use **bold** for emphasis. Keep responses under 300 words unless detailed explanation is needed.""" |
|
|
| |
| KISWAHILI_KNOWLEDGE = { |
| "greetings": { |
| "hello": "Jambo / Habari", |
| "how_are_you": "Habari yako?", |
| "goodbye": "Kwaheri / Tuonane tena", |
| "thank_you": "Asante sana", |
| "welcome": "Karibu / Karibuni" |
| }, |
| "proverbs": [ |
| "Mwenye pupa hadiriki kula tamu - The impatient one misses sweet things.", |
| "Asiyefunzwa na mamae hufunzwa na ulimwengu - He who is not taught by his mother is taught by the world.", |
| "Haraka haraka haina baraka - Hurry hurry has no blessing.", |
| "Ukitaka kwenda haraka, nenda peke yako. Ukitaka kwenda mbali, nenda na wenzako - If you want to go fast, go alone. If you want to go far, go together." |
| ], |
| "lion_king": { |
| "simba": "Lion (the main character)", |
| "rafiki": "Friend (the wise baboon)", |
| "hakuna_matata": "No worries / No problems", |
| "mufasa": "Simba's father, the king", |
| "nala": "Simba's childhood friend and queen" |
| } |
| } |
|
|
| def load_model_optimized(model_name=None): |
| """Load model with optimizations for Hugging Face Spaces""" |
| global model, tokenizer, model_loaded, current_model_name |
| |
| if model_loaded and model_name == current_model_name: |
| return |
| |
| |
| if not model_name: |
| model_name = MODEL_CONFIG["primary"] |
| |
| logger.info(f"🚀 Loading model: {model_name}") |
| |
| try: |
| |
| if model is not None: |
| del model |
| del tokenizer |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| use_fast=True |
| ) |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| load_in_4bit=True, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
| |
| model.eval() |
| model_loaded = True |
| current_model_name = model_name |
| |
| |
| prewarm_model() |
| |
| logger.info(f"✅ Model loaded successfully: {model_name}") |
| logger.info(f"📊 Model device: {model.device}") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ Error loading model: {e}") |
| |
| |
| if model_name != MODEL_CONFIG["fallback"]: |
| logger.info("🔄 Trying fallback model...") |
| return load_model_optimized(MODEL_CONFIG["fallback"]) |
| else: |
| logger.error("❌ All models failed to load") |
| model_loaded = False |
| return False |
|
|
| def prewarm_model(): |
| """Generate a dummy response to warm up the model""" |
| try: |
| dummy_input = "Hello, STANLEY AI!" |
| messages = [ |
| {"role": "system", "content": "Say hello briefly."}, |
| {"role": "user", "content": dummy_input} |
| ] |
| |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| _ = model.generate( |
| **inputs, |
| max_new_tokens=10, |
| do_sample=False |
| ) |
| |
| logger.info("✅ Model pre-warmed successfully!") |
| except Exception as e: |
| logger.warning(f"Pre-warm failed: {e}") |
|
|
| def detect_kiswahili_context(text): |
| """Detect if text contains Kiswahili or cultural references""" |
| if not text: |
| return False |
| |
| text_lower = text.lower() |
| kiswahili_keywords = [ |
| 'swahili', 'kiswahili', 'hakuna', 'matata', 'asante', 'rafiki', |
| 'jambo', 'mambo', 'pole', 'sawa', 'karibu', 'kwaheri', 'simba', |
| 'lion king', 'mufasa', 'nala', 'kenya', 'tanzania', 'africa', |
| 'habari', 'nze', 'pumbaa', 'timon', 'safari', 'ujamaa' |
| ] |
| |
| return any(keyword in text_lower for keyword in kiswahili_keywords) |
|
|
| def enhance_with_kiswahili(response, user_message): |
| """Add Kiswahili elements to response""" |
| if detect_kiswahili_context(user_message): |
| |
| greetings = list(KISWAHILI_KNOWLEDGE["greetings"].values()) |
| greeting = random.choice(greetings) |
| |
| |
| if any(word in user_message.lower() for word in ['advice', 'wisdom', 'lesson', 'teach']): |
| proverb = random.choice(KISWAHILI_KNOWLEDGE["proverbs"]) |
| enhanced = f"{greeting}! {response}\n\n**🔥 Kiswahili Proverb:** {proverb}" |
| else: |
| enhanced = f"{greeting}! {response}" |
| |
| |
| if any(word in user_message.lower() for word in ['lion', 'simba', 'mufasa', 'disney']): |
| lion_fact = "Did you know? 'Simba' means lion in Kiswahili, and 'Rafiki' means friend!" |
| enhanced += f"\n\n{lion_fact}" |
| |
| return enhanced |
| |
| return response |
|
|
| def get_cached_response(user_message): |
| """Get response from cache""" |
| cache_key = user_message.lower().strip()[:80] |
| return response_cache.get(cache_key) |
|
|
| def set_cached_response(user_message, response): |
| """Cache response""" |
| cache_key = user_message.lower().strip()[:80] |
| if len(response_cache) >= CACHE_SIZE: |
| |
| random_key = random.choice(list(response_cache.keys())) |
| del response_cache[random_key] |
| response_cache[cache_key] = response |
|
|
| def generate_response(user_message, max_tokens=512): |
| """Generate optimized response""" |
| |
| |
| cached = get_cached_response(user_message) |
| if cached: |
| logger.info("📦 Using cached response") |
| return cached |
| |
| |
| if not model_loaded: |
| success = load_model_optimized() |
| if not success: |
| return "I'm still initializing. Please try again in a moment." |
| |
| |
| messages = [ |
| {"role": "system", "content": STANLEY_AI_SYSTEM}, |
| {"role": "user", "content": user_message} |
| ] |
| |
| try: |
| |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=0.7, |
| top_p=0.9, |
| top_k=40, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| repetition_penalty=1.1, |
| no_repeat_ngram_size=3, |
| early_stopping=True |
| ) |
| |
| |
| response = tokenizer.decode( |
| outputs[0][inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True |
| ).strip() |
| |
| |
| enhanced_response = enhance_with_kiswahili(response, user_message) |
| |
| |
| set_cached_response(user_message, enhanced_response) |
| |
| return enhanced_response |
| |
| except Exception as e: |
| logger.error(f"Generation error: {e}") |
| return f"Pole! I encountered an error: {str(e)[:100]}" |
|
|
| def generate_image_simple(prompt, width=512, height=512): |
| """Simple image generation using PIL (no external dependencies)""" |
| try: |
| |
| img = Image.new('RGB', (width, height), color='white') |
| draw = ImageDraw.Draw(img) |
| |
| |
| for i in range(height): |
| r = int(100 + 155 * i / height) |
| g = int(150 + 105 * i / height) |
| b = int(200 + 55 * i / height) |
| draw.line([(0, i), (width, i)], fill=(r, g, b)) |
| |
| |
| prompt_lower = prompt.lower() |
| |
| if any(word in prompt_lower for word in ['sun', 'bright', 'light']): |
| draw.ellipse([width//3, height//3, 2*width//3, 2*height//3], |
| fill=(255, 255, 0), outline=(255, 200, 0)) |
| |
| if any(word in prompt_lower for word in ['tree', 'nature']): |
| draw.rectangle([width//2-15, height//2, width//2+15, height-50], |
| fill=(101, 67, 33)) |
| for i in range(5): |
| y_offset = i * 30 |
| draw.ellipse([width//2-60, height//2-100+y_offset, |
| width//2+60, height//2-40+y_offset], |
| fill=(34, 139, 34)) |
| |
| if any(word in prompt_lower for word in ['water', 'ocean', 'river']): |
| for i in range(0, width, 40): |
| draw.arc([i, height-80, i+80, height], 0, 180, |
| fill=(64, 164, 223), width=3) |
| |
| |
| try: |
| |
| font_size = min(width // 25, 20) |
| try: |
| font = ImageFont.truetype("arial.ttf", font_size) |
| except: |
| font = ImageFont.load_default() |
| |
| |
| display_text = prompt[:50] + "..." if len(prompt) > 50 else prompt |
| text = f"STANLEY AI: {display_text}" |
| |
| |
| bbox = draw.textbbox((0, 0), text, font=font) |
| text_width = bbox[2] - bbox[0] |
| text_height = bbox[3] - bbox[1] |
| |
| x = (width - text_width) // 2 |
| y = 20 |
| |
| |
| draw.rectangle([x-10, y-5, x+text_width+10, y+text_height+5], |
| fill=(0, 0, 0, 180)) |
| |
| |
| draw.text((x, y), text, fill=(255, 255, 255), font=font) |
| |
| except Exception as font_error: |
| logger.warning(f"Could not add text: {font_error}") |
| |
| |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG", optimize=True) |
| img_str = base64.b64encode(buffered.getvalue()).decode() |
| |
| return f"data:image/png;base64,{img_str}" |
| |
| except Exception as e: |
| logger.error(f"Image generation error: {e}") |
| |
| img = Image.new('RGB', (width, height), |
| color=(random.randint(50, 200), |
| random.randint(50, 200), |
| random.randint(50, 200))) |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode() |
| return f"data:image/png;base64,{img_str}" |
|
|
| |
| |
| |
|
|
| @app.route('/') |
| def home(): |
| return jsonify({ |
| "message": "🚀 STANLEY AI API is running!", |
| "version": "3.0", |
| "status": "active", |
| "model": current_model_name or "Loading...", |
| "optimized": "true", |
| "cache_size": len(response_cache), |
| "endpoints": [ |
| "/api/chat - Main chat endpoint", |
| "/api/chat-fast - Faster responses", |
| "/api/generate-image - Simple image generation", |
| "/api/health - System health check", |
| "/api/cache/clear - Clear response cache" |
| ] |
| }) |
|
|
| @app.route('/api/health') |
| def health_check(): |
| return jsonify({ |
| "status": "healthy" if model_loaded else "loading", |
| "model_loaded": model_loaded, |
| "model": current_model_name, |
| "cache_entries": len(response_cache), |
| "timestamp": time.time() |
| }) |
|
|
| @app.route('/api/chat', methods=['POST']) |
| def chat(): |
| try: |
| start_time = time.time() |
| data = request.get_json() |
| user_message = data.get('message', '') |
| |
| if not user_message: |
| return jsonify({"error": "Tafadhali provide a message"}), 400 |
| |
| logger.info(f"💬 Processing: {user_message[:60]}...") |
| |
| |
| response = generate_response(user_message) |
| response_time = round(time.time() - start_time, 2) |
| |
| |
| has_kiswahili = detect_kiswahili_context(response) |
| |
| return jsonify({ |
| "response": response, |
| "status": "success", |
| "response_time": f"{response_time}s", |
| "model": current_model_name, |
| "cultural_context": has_kiswahili, |
| "language": "en+sw" if has_kiswahili else "en", |
| "word_count": len(response.split()) |
| }) |
| |
| except Exception as e: |
| logger.error(f"Chat error: {e}") |
| return jsonify({ |
| "error": f"Pole! Error: {str(e)[:100]}", |
| "status": "error" |
| }), 500 |
|
|
| @app.route('/api/chat-fast', methods=['POST']) |
| def chat_fast(): |
| """Faster endpoint with shorter responses""" |
| try: |
| data = request.get_json() |
| user_message = data.get('message', '') |
| |
| if not user_message: |
| return jsonify({"error": "Please provide a message"}), 400 |
| |
| |
| response = generate_response(user_message, max_tokens=256) |
| |
| return jsonify({ |
| "response": response, |
| "status": "success", |
| "model": f"{current_model_name} (fast mode)", |
| "response_type": "concise" |
| }) |
| |
| except Exception as e: |
| return jsonify({"error": "Quick response failed"}), 500 |
|
|
| @app.route('/api/generate-image', methods=['POST']) |
| def generate_image_endpoint(): |
| """Simple image generation endpoint""" |
| try: |
| data = request.get_json() |
| prompt = data.get('prompt', 'A beautiful landscape') |
| width = min(data.get('width', 512), 1024) |
| height = min(data.get('height', 512), 1024) |
| |
| logger.info(f"🎨 Generating image: {prompt[:40]}...") |
| |
| image_data = generate_image_simple(prompt, width, height) |
| |
| if image_data: |
| return jsonify({ |
| "image": image_data, |
| "prompt": prompt, |
| "status": "success", |
| "method": "PIL generated", |
| "dimensions": f"{width}x{height}" |
| }) |
| else: |
| return jsonify({"error": "Could not generate image"}), 500 |
| |
| except Exception as e: |
| return jsonify({"error": f"Image error: {str(e)[:80]}"}), 500 |
|
|
| @app.route('/api/cache/clear', methods=['POST']) |
| def clear_cache(): |
| """Clear response cache""" |
| cache_size = len(response_cache) |
| response_cache.clear() |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| return jsonify({ |
| "status": "success", |
| "cleared_entries": cache_size, |
| "message": "Cache cleared" |
| }) |
|
|
| @app.route('/api/switch-model', methods=['POST']) |
| def switch_model(): |
| """Switch between available models""" |
| try: |
| data = request.get_json() |
| model_choice = data.get('model', 'primary') |
| |
| model_name = MODEL_CONFIG.get(model_choice, MODEL_CONFIG["primary"]) |
| |
| success = load_model_optimized(model_name) |
| |
| if success: |
| return jsonify({ |
| "status": "success", |
| "message": f"Switched to {model_name}", |
| "current_model": current_model_name |
| }) |
| else: |
| return jsonify({"error": "Failed to switch model"}), 500 |
| |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| |
| |
| |
|
|
| def initialize_app(): |
| """Initialize the application""" |
| logger.info("🚀 Initializing STANLEY AI...") |
| |
| |
| def load_model_background(): |
| load_model_optimized() |
| |
| background_thread = Thread(target=load_model_background, daemon=True) |
| background_thread.start() |
| |
| logger.info("✅ STANLEY AI initialized and ready!") |
|
|
| |
| initialize_app() |
|
|
| if __name__ == '__main__': |
| port = int(os.environ.get('PORT', 7860)) |
| app.run(debug=False, host='0.0.0.0', port=port, threaded=True) |