import os import time from typing import List from fastapi import FastAPI, HTTPException from pydantic import BaseModel from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage from mistralai.exceptions import MistralAPIException from datasets import load_dataset from pinecone import Pinecone, ServerlessSpec # Environment variables MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "mistral-rag") if not MISTRAL_API_KEY or not PINECONE_API_KEY: raise RuntimeError("Environment variables MISTRAL_API_KEY and PINECONE_API_KEY must be set.") # Initialize clients mistral = MistralClient(api_key=MISTRAL_API_KEY) pc = Pinecone(api_key=PINECONE_API_KEY) # Prepare or connect to Pinecone index def init_index(dimensions: int): spec = ServerlessSpec(cloud="aws", region="us-east-1") existing = [idx["name"] for idx in pc.list_indexes()] if INDEX_NAME not in existing: npc.create_index( INDEX_NAME, dimension=dimensions, metric="dotproduct", spec=spec ) # wait until ready while not pc.describe_index(INDEX_NAME).status["ready"]: time.sleep(1) return pc.Index(INDEX_NAME) # Embedding model EMBED_MODEL = "mistral-embed" # Dynamic index client (populated on startup) index = None app = FastAPI(title="Mistral RAG API") # Pydantic models def chunk_to_upsert(metadata: List[dict]): """Helper to batch-embed and upsert metadata items into Pinecone.""" # Perform adaptive batching batch_size = len(metadata) last_exc = None while batch_size >= 1: try: embeds = [] for i in range(0, len(metadata), batch_size): batch = metadata[i : i + batch_size] texts = [f"{m['title']}\n{m['content']}" for m in batch] resp = mistral.embeddings(input=texts, model=EMBED_MODEL) embeds.extend([d.embedding for d in resp.data]) # Prepare upsert list # Each metadata item must include 'id', 'title', 'content' to_upsert = [ (m["id"], embeds[idx], {"title": m["title"], "content": m["content"]}) for idx, m in enumerate(metadata) ] index.upsert(vectors=to_upsert) return except MistralAPIException as e: last_exc = e batch_size = max(1, batch_size // 2) raise last_exc class Chunk(BaseModel): id: str title: str content: str class IngestRequest(BaseModel): chunks: List[Chunk] class QueryRequest(BaseModel): query: str top_k: int = 5 @app.on_event("startup") async def startup_event(): global index # Sanity check embed dimension from a sample sample_embed = mistral.embeddings(input=["test"], model=EMBED_MODEL) dims = len(sample_embed.data[0].embedding) index = init_index(dimensions=dims) @app.post("/ingest") async def ingest(req: IngestRequest): if not req.chunks: raise HTTPException(status_code=400, detail="No chunks provided to ingest.") metadata = [chunk.dict() for chunk in req.chunks] try: chunk_to_upsert(metadata) return {"status": "success", "ingested": len(req.chunks)} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/query") async def query(req: QueryRequest): # Embed the query embed_resp = mistral.embeddings(input=[req.query], model=EMBED_MODEL) xq = embed_resp.data[0].embedding # Query Pinecone res = index.query(vector=xq, top_k=req.top_k, include_metadata=True) # Retrieve docs docs = [match['metadata']['content'] for match in res['matches']] # Construct system message system_content = "You are a helpful assistant that answers questions about AI using the context below.\n\nCONTEXT:\n" + "\n---\n".join(docs) messages = [ ChatMessage(role="system", content=system_content), ChatMessage(role="user", content=req.query) ] chat_resp = mistral.chat(model="mistral-large-latest", messages=messages) answer = chat_resp.choices[0].message.content return {"query": req.query, "answers": answer, "docs": docs} @app.get("/health") async def health(): return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)