|
import os |
|
import time |
|
from typing import List |
|
|
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from mistralai.client import MistralClient |
|
from mistralai.models.chat_completion import ChatMessage |
|
from mistralai.exceptions import MistralAPIException |
|
from datasets import load_dataset |
|
from pinecone import Pinecone, ServerlessSpec |
|
|
|
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
|
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") |
|
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "mistral-rag") |
|
|
|
if not MISTRAL_API_KEY or not PINECONE_API_KEY: |
|
raise RuntimeError("Environment variables MISTRAL_API_KEY and PINECONE_API_KEY must be set.") |
|
|
|
|
|
mistral = MistralClient(api_key=MISTRAL_API_KEY) |
|
pc = Pinecone(api_key=PINECONE_API_KEY) |
|
|
|
|
|
def init_index(dimensions: int): |
|
spec = ServerlessSpec(cloud="aws", region="us-east-1") |
|
existing = [idx["name"] for idx in pc.list_indexes()] |
|
if INDEX_NAME not in existing: |
|
npc.create_index( |
|
INDEX_NAME, |
|
dimension=dimensions, |
|
metric="dotproduct", |
|
spec=spec |
|
) |
|
|
|
while not pc.describe_index(INDEX_NAME).status["ready"]: |
|
time.sleep(1) |
|
return pc.Index(INDEX_NAME) |
|
|
|
|
|
EMBED_MODEL = "mistral-embed" |
|
|
|
|
|
index = None |
|
|
|
app = FastAPI(title="Mistral RAG API") |
|
|
|
|
|
def chunk_to_upsert(metadata: List[dict]): |
|
"""Helper to batch-embed and upsert metadata items into Pinecone.""" |
|
|
|
batch_size = len(metadata) |
|
last_exc = None |
|
while batch_size >= 1: |
|
try: |
|
embeds = [] |
|
for i in range(0, len(metadata), batch_size): |
|
batch = metadata[i : i + batch_size] |
|
texts = [f"{m['title']}\n{m['content']}" for m in batch] |
|
resp = mistral.embeddings(input=texts, model=EMBED_MODEL) |
|
embeds.extend([d.embedding for d in resp.data]) |
|
|
|
|
|
to_upsert = [ |
|
(m["id"], embeds[idx], {"title": m["title"], "content": m["content"]}) |
|
for idx, m in enumerate(metadata) |
|
] |
|
index.upsert(vectors=to_upsert) |
|
return |
|
except MistralAPIException as e: |
|
last_exc = e |
|
batch_size = max(1, batch_size // 2) |
|
raise last_exc |
|
|
|
class Chunk(BaseModel): |
|
id: str |
|
title: str |
|
content: str |
|
|
|
class IngestRequest(BaseModel): |
|
chunks: List[Chunk] |
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
top_k: int = 5 |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global index |
|
|
|
sample_embed = mistral.embeddings(input=["test"], model=EMBED_MODEL) |
|
dims = len(sample_embed.data[0].embedding) |
|
index = init_index(dimensions=dims) |
|
|
|
@app.post("/ingest") |
|
async def ingest(req: IngestRequest): |
|
if not req.chunks: |
|
raise HTTPException(status_code=400, detail="No chunks provided to ingest.") |
|
metadata = [chunk.dict() for chunk in req.chunks] |
|
try: |
|
chunk_to_upsert(metadata) |
|
return {"status": "success", "ingested": len(req.chunks)} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/query") |
|
async def query(req: QueryRequest): |
|
|
|
embed_resp = mistral.embeddings(input=[req.query], model=EMBED_MODEL) |
|
xq = embed_resp.data[0].embedding |
|
|
|
res = index.query(vector=xq, top_k=req.top_k, include_metadata=True) |
|
|
|
docs = [match['metadata']['content'] for match in res['matches']] |
|
|
|
system_content = "You are a helpful assistant that answers questions about AI using the context below.\n\nCONTEXT:\n" + "\n---\n".join(docs) |
|
messages = [ |
|
ChatMessage(role="system", content=system_content), |
|
ChatMessage(role="user", content=req.query) |
|
] |
|
chat_resp = mistral.chat(model="mistral-large-latest", messages=messages) |
|
answer = chat_resp.choices[0].message.content |
|
return {"query": req.query, "answers": answer, "docs": docs} |
|
|
|
@app.get("/health") |
|
async def health(): |
|
return {"status": "ok"} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|