mistralai / main.py
mominah's picture
Update main.py
959c892 verified
import os
import time
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from mistralai.exceptions import MistralAPIException
from datasets import load_dataset
from pinecone import Pinecone, ServerlessSpec
# Environment variables
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "mistral-rag")
if not MISTRAL_API_KEY or not PINECONE_API_KEY:
raise RuntimeError("Environment variables MISTRAL_API_KEY and PINECONE_API_KEY must be set.")
# Initialize clients
mistral = MistralClient(api_key=MISTRAL_API_KEY)
pc = Pinecone(api_key=PINECONE_API_KEY)
# Prepare or connect to Pinecone index
def init_index(dimensions: int):
spec = ServerlessSpec(cloud="aws", region="us-east-1")
existing = [idx["name"] for idx in pc.list_indexes()]
if INDEX_NAME not in existing:
npc.create_index(
INDEX_NAME,
dimension=dimensions,
metric="dotproduct",
spec=spec
)
# wait until ready
while not pc.describe_index(INDEX_NAME).status["ready"]:
time.sleep(1)
return pc.Index(INDEX_NAME)
# Embedding model
EMBED_MODEL = "mistral-embed"
# Dynamic index client (populated on startup)
index = None
app = FastAPI(title="Mistral RAG API")
# Pydantic models
def chunk_to_upsert(metadata: List[dict]):
"""Helper to batch-embed and upsert metadata items into Pinecone."""
# Perform adaptive batching
batch_size = len(metadata)
last_exc = None
while batch_size >= 1:
try:
embeds = []
for i in range(0, len(metadata), batch_size):
batch = metadata[i : i + batch_size]
texts = [f"{m['title']}\n{m['content']}" for m in batch]
resp = mistral.embeddings(input=texts, model=EMBED_MODEL)
embeds.extend([d.embedding for d in resp.data])
# Prepare upsert list
# Each metadata item must include 'id', 'title', 'content'
to_upsert = [
(m["id"], embeds[idx], {"title": m["title"], "content": m["content"]})
for idx, m in enumerate(metadata)
]
index.upsert(vectors=to_upsert)
return
except MistralAPIException as e:
last_exc = e
batch_size = max(1, batch_size // 2)
raise last_exc
class Chunk(BaseModel):
id: str
title: str
content: str
class IngestRequest(BaseModel):
chunks: List[Chunk]
class QueryRequest(BaseModel):
query: str
top_k: int = 5
@app.on_event("startup")
async def startup_event():
global index
# Sanity check embed dimension from a sample
sample_embed = mistral.embeddings(input=["test"], model=EMBED_MODEL)
dims = len(sample_embed.data[0].embedding)
index = init_index(dimensions=dims)
@app.post("/ingest")
async def ingest(req: IngestRequest):
if not req.chunks:
raise HTTPException(status_code=400, detail="No chunks provided to ingest.")
metadata = [chunk.dict() for chunk in req.chunks]
try:
chunk_to_upsert(metadata)
return {"status": "success", "ingested": len(req.chunks)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def query(req: QueryRequest):
# Embed the query
embed_resp = mistral.embeddings(input=[req.query], model=EMBED_MODEL)
xq = embed_resp.data[0].embedding
# Query Pinecone
res = index.query(vector=xq, top_k=req.top_k, include_metadata=True)
# Retrieve docs
docs = [match['metadata']['content'] for match in res['matches']]
# Construct system message
system_content = "You are a helpful assistant that answers questions about AI using the context below.\n\nCONTEXT:\n" + "\n---\n".join(docs)
messages = [
ChatMessage(role="system", content=system_content),
ChatMessage(role="user", content=req.query)
]
chat_resp = mistral.chat(model="mistral-large-latest", messages=messages)
answer = chat_resp.choices[0].message.content
return {"query": req.query, "answers": answer, "docs": docs}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)