Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from typing import List, Union | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_LENGTH = 512 | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the Qwen3 embedding model and tokenizer""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading Qwen3 embedding model on device: {DEVICE}") | |
| # Load tokenizer and model for Qwen3 embedding | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| if DEVICE == "cpu": | |
| model = model.to(DEVICE) | |
| model.eval() | |
| logger.info("Qwen3 embedding model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading Qwen3 model: {str(e)}") | |
| # Try fallback to a simpler approach | |
| try: | |
| logger.info("Trying fallback model loading...") | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| tokenizer = None | |
| logger.info("Fallback model loaded successfully") | |
| return True | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback model loading also failed: {str(fallback_error)}") | |
| return False | |
| def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: | |
| """Generate embeddings for input text(s) using Qwen3 or fallback model""" | |
| global model, tokenizer | |
| try: | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| single_text = True | |
| else: | |
| single_text = False | |
| # Truncate texts if too long | |
| texts = [text[:MAX_LENGTH] for text in texts] | |
| embeddings = [] | |
| for text in texts: | |
| try: | |
| # Method 1: Try using the Qwen model directly | |
| if model and tokenizer: | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_LENGTH | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Use mean pooling of last hidden state | |
| embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| embeddings.append(embedding.tolist()) | |
| elif model and hasattr(model, 'encode'): | |
| # Method 2: Using sentence transformer fallback | |
| embedding = model.encode(text) | |
| embeddings.append(embedding.tolist()) | |
| else: | |
| raise Exception("No model available") | |
| except Exception as e: | |
| logger.warning(f"Error generating embedding for text: {str(e)}") | |
| # Return zero vector as last resort | |
| embeddings.append([0.0] * 384) # Standard dimension for fallback | |
| return embeddings[0] if single_text else embeddings | |
| except Exception as e: | |
| logger.error(f"Error in generate_embeddings: {str(e)}") | |
| # Return zero vectors as fallback | |
| if single_text: | |
| return [0.0] * 384 | |
| else: | |
| return [[0.0] * 384] * len(texts) | |
| def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float: | |
| """Compute cosine similarity between two embeddings""" | |
| try: | |
| # Convert to numpy arrays | |
| emb1 = np.array(embedding1) | |
| emb2 = np.array(embedding2) | |
| # Compute cosine similarity | |
| dot_product = np.dot(emb1, emb2) | |
| norm1 = np.linalg.norm(emb1) | |
| norm2 = np.linalg.norm(emb2) | |
| if norm1 == 0 or norm2 == 0: | |
| return 0.0 | |
| similarity = dot_product / (norm1 * norm2) | |
| return float(similarity) | |
| except Exception as e: | |
| logger.error(f"Error computing similarity: {str(e)}") | |
| return 0.0 | |
| def batch_embedding_interface(texts: str) -> str: | |
| """Interface for batch embedding generation""" | |
| try: | |
| # Split texts by newlines | |
| text_list = [text.strip() for text in texts.split('\n') if text.strip()] | |
| if not text_list: | |
| return json.dumps([]) | |
| # Generate embeddings | |
| embeddings = generate_embeddings(text_list) | |
| # Return as JSON string | |
| return json.dumps(embeddings) | |
| except Exception as e: | |
| logger.error(f"Error in batch_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def single_embedding_interface(text: str) -> str: | |
| """Interface for single embedding generation""" | |
| try: | |
| if not text.strip(): | |
| return json.dumps([]) | |
| # Generate embedding | |
| embedding = generate_embeddings(text) | |
| # Return as JSON string | |
| return json.dumps(embedding) | |
| except Exception as e: | |
| logger.error(f"Error in single_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def similarity_interface(embedding1: str, embedding2: str) -> float: | |
| """Interface for computing similarity between two embeddings""" | |
| try: | |
| # Parse embeddings from JSON strings | |
| emb1 = json.loads(embedding1) | |
| emb2 = json.loads(embedding2) | |
| # Compute similarity | |
| similarity = compute_similarity(emb1, emb2) | |
| return similarity | |
| except Exception as e: | |
| logger.error(f"Error in similarity_interface: {str(e)}") | |
| return 0.0 | |
| def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": model is not None} | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title="Qwen3 Embedding API", | |
| description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # FastAPI endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Qwen3 Embedding API", | |
| "version": "1.0.0", | |
| "model": "Qwen3-Embedding-0.6B", | |
| "endpoints": { | |
| "health": "/health", | |
| "predict": "/api/predict", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| return health_check() | |
| async def predict(data: dict): | |
| """Main prediction endpoint for embeddings""" | |
| try: | |
| if "data" not in data: | |
| raise HTTPException(status_code=400, detail="Missing 'data' field in request") | |
| input_data = data["data"] | |
| # Handle single text or batch texts | |
| if isinstance(input_data, str): | |
| # Single text | |
| embeddings = generate_embeddings(input_data) | |
| return {"data": [embeddings]} | |
| elif isinstance(input_data, list): | |
| if len(input_data) > 0 and isinstance(input_data[0], str): | |
| # Single text in list | |
| embeddings = generate_embeddings(input_data[0]) | |
| return {"data": [embeddings]} | |
| elif len(input_data) > 0 and isinstance(input_data[0], list): | |
| # Batch texts | |
| embeddings = generate_embeddings(input_data[0]) | |
| return {"data": [embeddings]} | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid data format") | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid data type") | |
| except Exception as e: | |
| logger.error(f"Error in predict endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def similarity(data: dict): | |
| """Compute similarity between two embeddings""" | |
| try: | |
| if "embedding1" not in data or "embedding2" not in data: | |
| raise HTTPException(status_code=400, detail="Missing embedding1 or embedding2 field") | |
| emb1 = data["embedding1"] | |
| emb2 = data["embedding2"] | |
| if not isinstance(emb1, list) or not isinstance(emb2, list): | |
| raise HTTPException(status_code=400, detail="Embeddings must be lists") | |
| sim = compute_similarity(emb1, emb2) | |
| return {"similarity": sim} | |
| except Exception as e: | |
| logger.error(f"Error in similarity endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| def main(): | |
| """Main function to run the application""" | |
| logger.info("Starting Qwen3 Embedding Model API...") | |
| # Load model | |
| if not load_model(): | |
| logger.error("Failed to load model. Exiting...") | |
| return | |
| logger.info("Model loaded successfully. Starting FastAPI server...") | |
| # Run with uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |