from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Optional import logging import os try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel ML_DEPENDENCIES_AVAILABLE = True except ImportError as e: ML_DEPENDENCIES_AVAILABLE = False MISSING_DEPENDENCY_ERROR = str(e) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) app = FastAPI( title="Astra - Ayurvedic AI Assistant", description="Meet Astra, your intelligent Ayurvedic Assistant powered by Llama 3.2 11B with specialized Ayurveda knowledge. Astra provides complete, thorough information about Ayurvedic medicine, herbs, wellness practices, and holistic health.", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = None tokenizer = None model_loaded = False # Get Hugging Face token from environment (required for gated models) HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") # Models configuration - Unsloth optimized for 2x faster inference BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" LORA_MODEL = "ayureasehealthcare/llama3-ayurveda-lora-v3" # Astra - Ayurvedic Assistant System Prompt ASTRA_SYSTEM_PROMPT = """You are Astra, a knowledgeable and compassionate Ayurvedic Assistant. Your purpose is to provide complete, accurate, and helpful information about Ayurveda, traditional wellness practices, herbs, treatments, and holistic health. Guidelines for your responses: 1. Always provide COMPLETE information - never give partial or incomplete answers 2. Be thorough and comprehensive in your explanations 3. Use clear, accessible language while maintaining accuracy 4. Include relevant details about benefits, usage, precautions, and traditional wisdom 5. When discussing herbs or treatments, provide complete information including: - Traditional uses and benefits - Properties according to Ayurveda (doshas, qualities) - Preparation methods when relevant - Any important precautions or considerations 6. If a question requires a detailed answer, provide all necessary information 7. Be warm, supportive, and encouraging in your tone 8. Always complete your response - do not leave answers unfinished Remember: You are Astra, here to share the ancient wisdom of Ayurveda completely and thoroughly.""" class TextGenerationRequest(BaseModel): prompt: str = Field(..., description="Your question or message to Astra") max_length: Optional[int] = Field(1024, description="Maximum length of generated text (default: 1024 for complete responses)") temperature: Optional[float] = Field(0.7, description="Sampling temperature (0.0-2.0)") top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter") top_k: Optional[int] = Field(50, description="Top-k sampling parameter") include_system_prompt: Optional[bool] = Field(True, description="Include Astra's system instructions") class GenerationResponse(BaseModel): generated_text: str prompt: str model_info: dict class ModelStatus(BaseModel): loaded: bool base_model: str lora_model: str device: str @app.get("/") async def root(): """Root endpoint with API information""" return { "assistant_name": "Astra", "message": "Welcome to Astra - Your Ayurvedic AI Assistant", "version": "1.0.0", "description": "Astra provides complete and thorough information about Ayurvedic medicine, herbs, wellness practices, and holistic health using advanced AI with specialized Ayurveda knowledge.", "capabilities": [ "Comprehensive information about Ayurvedic herbs and their benefits", "Detailed explanations of Ayurvedic principles and practices", "Complete guidance on doshas, body types, and balance", "Traditional wellness practices and natural remedies", "Holistic health and nutrition advice based on Ayurveda" ], "endpoints": { "health": "/health", "status": "/status", "load_model": "/load-model (POST)", "generate": "/generate (POST) - Chat with Astra", "docs": "/docs" }, "note": "Visit /docs for interactive API documentation. Astra always provides complete, thorough responses." } @app.get("/health") async def health_check(): """Health check endpoint""" device_info = "unknown" if ML_DEPENDENCIES_AVAILABLE: device_info = "cuda" if torch.cuda.is_available() else "cpu" return { "status": "healthy", "model_loaded": model_loaded, "ml_dependencies_available": ML_DEPENDENCIES_AVAILABLE, "device": device_info } @app.get("/status", response_model=ModelStatus) async def get_status(): """Get current model status""" device_info = "unknown" if ML_DEPENDENCIES_AVAILABLE: device_info = "cuda" if torch.cuda.is_available() else "cpu" return ModelStatus( loaded=model_loaded, base_model=BASE_MODEL, lora_model=LORA_MODEL, device=device_info ) @app.post("/load-model") async def load_model(): """Load the base model with LoRA adapters""" global model, tokenizer, model_loaded if not ML_DEPENDENCIES_AVAILABLE: raise HTTPException( status_code=503, detail={ "error": "ML dependencies not installed", "message": "Please install the required ML libraries: pip install torch transformers peft accelerate bitsandbytes sentencepiece protobuf", "missing_dependency": MISSING_DEPENDENCY_ERROR } ) try: logger.info("Starting model loading process...") if model_loaded: return { "message": "Model already loaded", "base_model": BASE_MODEL, "lora_model": LORA_MODEL } logger.info(f"Loading base model: {BASE_MODEL}") # Check for HF token (required for gated models like Llama) if not HF_TOKEN: logger.warning("HF_TOKEN not found. You may need it for gated models.") logger.info("Set HF_TOKEN in Secrets (Replit) or Space Settings (HF)") # Load tokenizer with authentication tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL, token=HF_TOKEN, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load base model (Unsloth models are already 4-bit quantized) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", torch_dtype=torch.bfloat16, # Unsloth works best with bfloat16 trust_remote_code=True, token=HF_TOKEN # Authentication for gated models (if needed) ) # Load LoRA adapter if specified if LORA_MODEL: logger.info(f"Loading LoRA adapters: {LORA_MODEL}") model = PeftModel.from_pretrained( base_model, LORA_MODEL, token=HF_TOKEN # Authentication for private adapters ) else: logger.info("No LoRA adapter specified, using base model") model = base_model model.eval() model_loaded = True logger.info("Model loaded successfully") return { "message": "Model loaded successfully", "base_model": BASE_MODEL, "lora_model": LORA_MODEL, "device": str(next(model.parameters()).device) } except Exception as e: logger.error(f"Error loading model: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") @app.post("/generate", response_model=GenerationResponse) async def generate_text(request: TextGenerationRequest): """Chat with Astra - Get complete Ayurvedic information and guidance""" global model, tokenizer, model_loaded if not model_loaded: raise HTTPException( status_code=400, detail="Model not loaded. Please call /load-model first" ) try: logger.info(f"Astra processing query: {request.prompt[:50]}...") # Format the prompt with Astra's system instructions for complete responses if request.include_system_prompt: formatted_prompt = f"""{ASTRA_SYSTEM_PROMPT} User Question: {request.prompt} Astra's Complete Response:""" else: formatted_prompt = f"""Question: {request.prompt} Answer:""" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Generate with settings optimized for complete responses with torch.no_grad(): outputs = model.generate( **inputs, max_length=request.max_length, min_length=100, # Ensure responses have substance temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, do_sample=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, # Reduce repetition no_repeat_ngram_size=3 # Prevent repetitive phrases ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only Astra's response (remove the system prompt from output) if request.include_system_prompt and "Astra's Complete Response:" in generated_text: astra_response = generated_text.split("Astra's Complete Response:")[-1].strip() elif "Answer:" in generated_text: astra_response = generated_text.split("Answer:")[-1].strip() else: astra_response = generated_text logger.info("Astra response generated successfully") return GenerationResponse( generated_text=astra_response, prompt=request.prompt, model_info={ "assistant": "Astra - Ayurvedic AI Assistant", "base_model": BASE_MODEL, "lora_model": LORA_MODEL, "parameters": { "max_length": request.max_length, "min_length": 100, "temperature": request.temperature, "top_p": request.top_p, "top_k": request.top_k } } ) except Exception as e: logger.error(f"Error during text generation: {str(e)}") raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}") if __name__ == "__main__": import uvicorn import os # Use port 7860 for Hugging Face Spaces, otherwise use port 5000 for Replit # Set PORT=7860 environment variable for HF Spaces deployment port = int(os.getenv("PORT", "5000")) uvicorn.run(app, host="0.0.0.0", port=port)