text_embedder / main.py
ex510's picture
Update main.py
03cc16d verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import uvicorn
import asyncio
from typing import List
import numpy as np
from contextlib import asynccontextmanager
import httpx
import os
import sqlite3
import json
# Globals
model = None
tokenizer = None
model_id = 'Qwen/Qwen3-Embedding-0.6B'
MAX_TOKENS = 32000
DB_PATH = "/data/embeddings.db"
is_processing = False
def init_database():
"""Initialize the SQLite database"""
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS embedding_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
request_id TEXT,
text TEXT NOT NULL,
embedding TEXT,
status TEXT DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processed_at TIMESTAMP,
webhook_sent BOOLEAN DEFAULT 0,
error_message TEXT
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_status
ON embedding_requests(status)
''')
conn.commit()
conn.close()
print("✅ Database initialized successfully")
def save_request_to_db(text: str, request_id: str = None) -> int:
"""Save the incoming request to database"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO embedding_requests (request_id, text, status)
VALUES (?, ?, 'pending')
''', (request_id, text))
row_id = cursor.lastrowid
conn.commit()
conn.close()
print(f"✅ Request saved to DB with ID: {row_id}")
return row_id
def get_next_pending_request():
"""Get the next pending request from database"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
SELECT id, request_id, text
FROM embedding_requests
WHERE status = 'pending'
ORDER BY id ASC
LIMIT 1
''')
result = cursor.fetchone()
conn.close()
return result
def update_request_processing(row_id: int):
"""Mark request as processing"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
UPDATE embedding_requests
SET status = 'processing'
WHERE id = ?
''', (row_id,))
conn.commit()
conn.close()
def update_embedding_in_db(row_id: int, embedding: List[float]):
"""Update the embedding in database"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
embedding_json = json.dumps(embedding)
cursor.execute('''
UPDATE embedding_requests
SET embedding = ?,
status = 'completed',
processed_at = CURRENT_TIMESTAMP
WHERE id = ?
''', (embedding_json, row_id))
conn.commit()
conn.close()
print(f"✅ Embedding saved for ID: {row_id}")
def get_request_data(row_id: int):
"""Get full request data including embedding"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
SELECT id, request_id, text, embedding
FROM embedding_requests
WHERE id = ?
''', (row_id,))
result = cursor.fetchone()
conn.close()
return result
def mark_webhook_sent_and_delete(row_id: int):
"""Mark webhook as sent and delete from DB"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# First mark as sent
cursor.execute('''
UPDATE embedding_requests
SET webhook_sent = 1
WHERE id = ?
''', (row_id,))
# Then delete
cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,))
conn.commit()
conn.close()
print(f"🗑️ Request deleted from DB: {row_id}")
def mark_request_failed(row_id: int, error_message: str):
"""Mark request as failed"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
UPDATE embedding_requests
SET status = 'failed',
error_message = ?,
processed_at = CURRENT_TIMESTAMP
WHERE id = ?
''', (error_message, row_id))
conn.commit()
conn.close()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database
init_database()
# Load the model
global model, tokenizer
print(f"Loading model: {model_id}...")
model = SentenceTransformer(model_id)
tokenizer = model.tokenizer
print("✅ Model loaded successfully")
# Start background processor
asyncio.create_task(process_queue())
yield
# Cleanup
print("Cleaning up...")
model = None
tokenizer = None
app = FastAPI(
title="Text Embedding API with Queue",
lifespan=lifespan
)
class TextRequest(BaseModel):
text: str = Field(..., min_length=1, description="Text to embed")
request_id: str | None = Field(None, description="Optional request identifier")
def chunk_and_embed(text: str) -> List[float]:
"""Generate embedding with chunking if needed"""
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= MAX_TOKENS:
return model.encode(text, normalize_embeddings=True).tolist()
# Chunking
chunks = []
overlap = 50
start = 0
while start < len(tokens):
end = start + MAX_TOKENS
chunk_tokens = tokens[start:end]
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
chunks.append(chunk_text)
if end >= len(tokens):
break
start = end - overlap
chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
final_embedding = np.mean(chunk_embeddings, axis=0).tolist()
return final_embedding
async def send_to_webhook(webhook_url: str, row_id: int, request_id: str, text: str, embedding: List[float]):
"""Send complete data to webhook after embedding is ready"""
try:
payload = {
"db_id": row_id,
"request_id": request_id,
"text": text,
"embedding": embedding,
"status": "completed"
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(webhook_url, json=payload)
response.raise_for_status()
print(f"✅ Webhook sent successfully for ID: {row_id}")
# Delete from DB after successful webhook
mark_webhook_sent_and_delete(row_id)
except Exception as e:
print(f"❌ Webhook error for ID {row_id}: {e}")
# Don't delete if webhook failed
async def process_queue():
"""Background processor - processes one request at a time"""
global is_processing
print("🚀 Queue processor started")
while True:
try:
pending = get_next_pending_request()
if pending:
row_id, request_id, text = pending
is_processing = True
update_request_processing(row_id)
print(f"⚙️ Processing request ID: {row_id}")
try:
# Generate embedding
embedding = await asyncio.to_thread(chunk_and_embed, text)
# Save to DB
update_embedding_in_db(row_id, embedding)
# Send to webhook with ALL data
webhook_url = os.environ.get("WEBHOOK_URL")
if webhook_url:
await send_to_webhook(webhook_url, row_id, request_id, text, embedding)
else:
# No webhook, just delete
mark_webhook_sent_and_delete(row_id)
except Exception as e:
print(f"❌ Error processing {row_id}: {e}")
mark_request_failed(row_id, str(e))
is_processing = False
else:
# No pending requests
await asyncio.sleep(2)
except Exception as e:
print(f"❌ Queue error: {e}")
is_processing = False
await asyncio.sleep(5)
@app.get("/")
def home():
return {
"status": "online",
"model": model_id,
"processing": is_processing
}
@app.post("/embed/text")
async def embed_text(request: TextRequest):
"""
Fast response - just queue the request
Processing happens in background
"""
try:
# Save to DB immediately
db_row_id = save_request_to_db(request.text, request.request_id)
# Return immediately
return {
"success": True,
"message": "Request queued successfully",
"db_id": db_row_id,
"status": "pending"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
def get_status():
"""Get queue statistics"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "pending"')
pending = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "processing"')
processing = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "completed"')
completed = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"')
failed = cursor.fetchone()[0]
# Get next in queue
cursor.execute('''
SELECT id, created_at
FROM embedding_requests
WHERE status = "pending"
ORDER BY id ASC
LIMIT 1
''')
next_request = cursor.fetchone()
conn.close()
return {
"queue": {
"pending": pending,
"processing": processing,
"completed": completed,
"failed": failed
},
"is_processing": is_processing,
"next_request": {
"id": next_request[0] if next_request else None,
"created_at": next_request[1] if next_request else None
} if next_request else None
}
@app.get("/request/{db_id}")
def get_request_info(db_id: int):
"""Check specific request status"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
SELECT id, request_id, status, created_at, processed_at, webhook_sent, error_message
FROM embedding_requests
WHERE id = ?
''', (db_id,))
result = cursor.fetchone()
conn.close()
if not result:
raise HTTPException(status_code=404, detail="Request not found or already deleted")
return {
"db_id": result[0],
"request_id": result[1],
"status": result[2],
"created_at": result[3],
"processed_at": result[4],
"webhook_sent": bool(result[5]),
"error_message": result[6]
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)