File size: 4,482 Bytes
f162197 cb4a824 f162197 dd61e08 f162197 959c892 f162197 959c892 f162197 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import time
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from mistralai.exceptions import MistralAPIException
from datasets import load_dataset
from pinecone import Pinecone, ServerlessSpec
# Environment variables
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "mistral-rag")
if not MISTRAL_API_KEY or not PINECONE_API_KEY:
raise RuntimeError("Environment variables MISTRAL_API_KEY and PINECONE_API_KEY must be set.")
# Initialize clients
mistral = MistralClient(api_key=MISTRAL_API_KEY)
pc = Pinecone(api_key=PINECONE_API_KEY)
# Prepare or connect to Pinecone index
def init_index(dimensions: int):
spec = ServerlessSpec(cloud="aws", region="us-east-1")
existing = [idx["name"] for idx in pc.list_indexes()]
if INDEX_NAME not in existing:
npc.create_index(
INDEX_NAME,
dimension=dimensions,
metric="dotproduct",
spec=spec
)
# wait until ready
while not pc.describe_index(INDEX_NAME).status["ready"]:
time.sleep(1)
return pc.Index(INDEX_NAME)
# Embedding model
EMBED_MODEL = "mistral-embed"
# Dynamic index client (populated on startup)
index = None
app = FastAPI(title="Mistral RAG API")
# Pydantic models
def chunk_to_upsert(metadata: List[dict]):
"""Helper to batch-embed and upsert metadata items into Pinecone."""
# Perform adaptive batching
batch_size = len(metadata)
last_exc = None
while batch_size >= 1:
try:
embeds = []
for i in range(0, len(metadata), batch_size):
batch = metadata[i : i + batch_size]
texts = [f"{m['title']}\n{m['content']}" for m in batch]
resp = mistral.embeddings(input=texts, model=EMBED_MODEL)
embeds.extend([d.embedding for d in resp.data])
# Prepare upsert list
# Each metadata item must include 'id', 'title', 'content'
to_upsert = [
(m["id"], embeds[idx], {"title": m["title"], "content": m["content"]})
for idx, m in enumerate(metadata)
]
index.upsert(vectors=to_upsert)
return
except MistralAPIException as e:
last_exc = e
batch_size = max(1, batch_size // 2)
raise last_exc
class Chunk(BaseModel):
id: str
title: str
content: str
class IngestRequest(BaseModel):
chunks: List[Chunk]
class QueryRequest(BaseModel):
query: str
top_k: int = 5
@app.on_event("startup")
async def startup_event():
global index
# Sanity check embed dimension from a sample
sample_embed = mistral.embeddings(input=["test"], model=EMBED_MODEL)
dims = len(sample_embed.data[0].embedding)
index = init_index(dimensions=dims)
@app.post("/ingest")
async def ingest(req: IngestRequest):
if not req.chunks:
raise HTTPException(status_code=400, detail="No chunks provided to ingest.")
metadata = [chunk.dict() for chunk in req.chunks]
try:
chunk_to_upsert(metadata)
return {"status": "success", "ingested": len(req.chunks)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def query(req: QueryRequest):
# Embed the query
embed_resp = mistral.embeddings(input=[req.query], model=EMBED_MODEL)
xq = embed_resp.data[0].embedding
# Query Pinecone
res = index.query(vector=xq, top_k=req.top_k, include_metadata=True)
# Retrieve docs
docs = [match['metadata']['content'] for match in res['matches']]
# Construct system message
system_content = "You are a helpful assistant that answers questions about AI using the context below.\n\nCONTEXT:\n" + "\n---\n".join(docs)
messages = [
ChatMessage(role="system", content=system_content),
ChatMessage(role="user", content=req.query)
]
chat_resp = mistral.chat(model="mistral-large-latest", messages=messages)
answer = chat_resp.choices[0].message.content
return {"query": req.query, "answers": answer, "docs": docs}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|