File size: 4,482 Bytes
f162197
 
 
cb4a824
f162197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd61e08
f162197
 
 
 
 
 
 
 
959c892
f162197
959c892
f162197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import time
from typing import List

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from mistralai.exceptions import MistralAPIException
from datasets import load_dataset
from pinecone import Pinecone, ServerlessSpec

# Environment variables
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "mistral-rag")

if not MISTRAL_API_KEY or not PINECONE_API_KEY:
    raise RuntimeError("Environment variables MISTRAL_API_KEY and PINECONE_API_KEY must be set.")

# Initialize clients
mistral = MistralClient(api_key=MISTRAL_API_KEY)
pc = Pinecone(api_key=PINECONE_API_KEY)

# Prepare or connect to Pinecone index
def init_index(dimensions: int):
    spec = ServerlessSpec(cloud="aws", region="us-east-1")
    existing = [idx["name"] for idx in pc.list_indexes()]
    if INDEX_NAME not in existing:
        npc.create_index(
            INDEX_NAME,
            dimension=dimensions,
            metric="dotproduct",
            spec=spec
        )
        # wait until ready
        while not pc.describe_index(INDEX_NAME).status["ready"]:
            time.sleep(1)
    return pc.Index(INDEX_NAME)

# Embedding model
EMBED_MODEL = "mistral-embed"

# Dynamic index client (populated on startup)
index = None

app = FastAPI(title="Mistral RAG API")

# Pydantic models
def chunk_to_upsert(metadata: List[dict]):
    """Helper to batch-embed and upsert metadata items into Pinecone."""
    # Perform adaptive batching
    batch_size = len(metadata)
    last_exc = None
    while batch_size >= 1:
        try:
            embeds = []
            for i in range(0, len(metadata), batch_size):
                batch = metadata[i : i + batch_size]
                texts = [f"{m['title']}\n{m['content']}" for m in batch]
                resp = mistral.embeddings(input=texts, model=EMBED_MODEL)
                embeds.extend([d.embedding for d in resp.data])
            # Prepare upsert list
            # Each metadata item must include 'id', 'title', 'content'
            to_upsert = [
                (m["id"], embeds[idx], {"title": m["title"], "content": m["content"]})
                for idx, m in enumerate(metadata)
            ]
            index.upsert(vectors=to_upsert)
            return
        except MistralAPIException as e:
            last_exc = e
            batch_size = max(1, batch_size // 2)
    raise last_exc

class Chunk(BaseModel):
    id: str
    title: str
    content: str

class IngestRequest(BaseModel):
    chunks: List[Chunk]

class QueryRequest(BaseModel):
    query: str
    top_k: int = 5

@app.on_event("startup")
async def startup_event():
    global index
    # Sanity check embed dimension from a sample
    sample_embed = mistral.embeddings(input=["test"], model=EMBED_MODEL)
    dims = len(sample_embed.data[0].embedding)
    index = init_index(dimensions=dims)

@app.post("/ingest")
async def ingest(req: IngestRequest):
    if not req.chunks:
        raise HTTPException(status_code=400, detail="No chunks provided to ingest.")
    metadata = [chunk.dict() for chunk in req.chunks]
    try:
        chunk_to_upsert(metadata)
        return {"status": "success", "ingested": len(req.chunks)}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/query")
async def query(req: QueryRequest):
    # Embed the query
    embed_resp = mistral.embeddings(input=[req.query], model=EMBED_MODEL)
    xq = embed_resp.data[0].embedding
    # Query Pinecone
    res = index.query(vector=xq, top_k=req.top_k, include_metadata=True)
    # Retrieve docs
    docs = [match['metadata']['content'] for match in res['matches']]
    # Construct system message
    system_content = "You are a helpful assistant that answers questions about AI using the context below.\n\nCONTEXT:\n" + "\n---\n".join(docs)
    messages = [
        ChatMessage(role="system", content=system_content),
        ChatMessage(role="user", content=req.query)
    ]
    chat_resp = mistral.chat(model="mistral-large-latest", messages=messages)
    answer = chat_resp.choices[0].message.content
    return {"query": req.query, "answers": answer, "docs": docs}

@app.get("/health")
async def health():
    return {"status": "ok"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)