Spaces:
Sleeping
Sleeping
File size: 4,878 Bytes
402e33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import requests
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub
from ..utils.vector_store import get_vector_store
app = FastAPI()
class QueryRequest(BaseModel):
query: str
# Option 1: Hugging Face Inference API (Free with rate limits)
def setup_hf_api_model():
"""Uses Hugging Face's free Inference API"""
# Get free token from https://huggingface.co/settings/tokens
# Set this in your deployment
hf_token = os.getenv("HUGGINGFACE_API_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_API_TOKEN environment variable required")
llm = HuggingFaceHub(
repo_id="microsoft/DialoGPT-medium", # Free model
model_kwargs={
"temperature": 0.1,
"max_length": 512
},
huggingfacehub_api_token=hf_token
)
return llm
# Option 2: Cohere API (Free tier: 100 API calls/month)
def setup_cohere_model():
"""Uses Cohere's free tier"""
from langchain.llms import Cohere
cohere_api_key = os.getenv("COHERE_API_KEY")
if not cohere_api_key:
raise ValueError("COHERE_API_KEY required")
llm = Cohere(
cohere_api_key=cohere_api_key,
model="command-light", # Free tier model
temperature=0.1
)
return llm
# Option 3: Together AI (Free credits)
def setup_together_model():
"""Uses Together AI's free credits"""
from langchain.llms import Together
together_api_key = os.getenv("TOGETHER_API_KEY")
if not together_api_key:
raise ValueError("TOGETHER_API_KEY required")
llm = Together(
together_api_key=together_api_key,
model="meta-llama/Llama-2-7b-chat-hf",
temperature=0.1
)
return llm
# Initialize model (try different options in order of preference)
llm = None
model_used = "none"
try:
llm = setup_hf_api_model()
model_used = "huggingface"
print("β
Using Hugging Face Inference API")
except:
try:
llm = setup_cohere_model()
model_used = "cohere"
print("β
Using Cohere API")
except:
try:
llm = setup_together_model()
model_used = "together"
print("β
Using Together AI")
except Exception as e:
print(f"β Failed to initialize any model: {e}")
# Setup QA chain
qa_chain = None
if llm:
try:
retriever = get_vector_store().as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=True
)
print("β
QA chain ready")
except Exception as e:
print(f"β QA chain failed: {e}")
@app.get("/")
def root():
return {
"status": "running",
"model": model_used,
"qa_ready": qa_chain is not None
}
@app.post("/ask")
def ask_question(request: QueryRequest):
if qa_chain is None:
raise HTTPException(status_code=500, detail="Service not ready")
try:
result = qa_chain({"query": request.query})
return {
"answer": result["result"],
"model_used": model_used,
"sources": [
{
"content": doc.page_content[:200] + "...",
"metadata": doc.metadata
}
for doc in result["source_documents"][:3]
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.post("/ask/{store_type}")
def ask_specific_store(store_type: str, request: QueryRequest):
if llm is None:
raise HTTPException(status_code=500, detail="LLM not available")
store_paths = {
"mes": "./vector_stores/mes_db",
"general": "./vector_stores/general_db",
"tech": "./vector_stores/tech_db"
}
if store_type not in store_paths:
raise HTTPException(status_code=400, detail="Invalid store type")
try:
vector_store = get_vector_store(
persist_directory=store_paths[store_type])
retriever = vector_store.as_retriever()
store_qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=True
)
result = store_qa_chain({"query": request.query})
return {
"answer": result["result"],
"store_used": store_type,
"model_used": model_used,
"sources": [doc.page_content[:200] + "..." for doc in result["source_documents"][:3]]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
|