|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForQuestionAnswering |
|
from langchain_community.llms import HuggingFacePipeline |
|
from qdrant_client import QdrantClient |
|
from langchain_qdrant import QdrantVectorStore |
|
import os |
|
from pydantic import BaseModel |
|
from langchain.chains import RetrievalQA |
|
from langchain.schema import Document |
|
import time |
|
import torch |
|
|
|
model = None |
|
tokenizer = None |
|
dolly_pipeline_hf = None |
|
embed_model = None |
|
qdrant = None |
|
model_name_hf = None |
|
text_generation_pipeline = None |
|
qa_pipeline = None |
|
|
|
class Item(BaseModel): |
|
query: str |
|
|
|
app = FastAPI() |
|
|
|
app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder") |
|
os.makedirs("./cache", exist_ok=True) |
|
os.makedirs("./offload", exist_ok=True) |
|
os.makedirs("./models", exist_ok=True) |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline |
|
|
|
print("π Loading model....") |
|
|
|
sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2" |
|
start_time = time.perf_counter() |
|
|
|
embed_model = HuggingFaceEmbeddings( |
|
model_name=sentence_embedding_model_path, |
|
cache_folder="./models", |
|
model_kwargs={"device": "cpu"}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
) |
|
|
|
try: |
|
qdrant_client = QdrantClient(path="qdrant/") |
|
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot") |
|
except Exception as e: |
|
print(f"β Error initializing Qdrant: {e}") |
|
|
|
model_path = "distilbert-base-cased-distilled-squad" |
|
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir="./models") |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="./models") |
|
qa_pipeline = pipeline( |
|
"question-answering", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
end_time = time.perf_counter() |
|
print(f"β
Dolly model loaded successfully in {end_time - start_time:.2f} seconds.") |
|
|
|
app.on_event("shutdown") |
|
async def shutdown_event(): |
|
global model, tokenizer, dolly_pipeline_hf |
|
print("πͺ Shutting down the API and releasing model memory.") |
|
del model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "Welcome to FastAPI"} |
|
|
|
@app.post("/search") |
|
def search(Item:Item): |
|
print("Search endpoint") |
|
query = Item.query |
|
|
|
search_result = qdrant.similarity_search( |
|
query=query, k=10 |
|
) |
|
i = 0 |
|
list_res = [] |
|
for res in search_result: |
|
list_res.append({"id":i,"path":res.metadata.get("path"),"content":res.page_content}) |
|
|
|
|
|
return list_res |
|
|
|
@app.post("/ask_localai") |
|
async def ask_localai(item: Item): |
|
query = item.query |
|
|
|
search_result = qdrant.similarity_search(query=query, k=3) |
|
if not search_result: |
|
return {"error": "No relevant results found for the query."} |
|
|
|
context = " ".join([res.page_content for res in search_result]) |
|
if not context.strip(): |
|
return {"error": "No relevant context found."} |
|
|
|
try: |
|
prompt = ( |
|
f"Context: {context}\n\n" |
|
f"Question: {query}\n" |
|
f"Answer concisely and only based on the context provided. Do not repeat the context or the question.\n" |
|
f"Answer:" |
|
) |
|
qa_result = qa_pipeline(question=query, context=context) |
|
answer = qa_result["answer"] |
|
|
|
return { |
|
"question": query, |
|
"answer": answer |
|
} |
|
except Exception as e: |
|
return {"error": "Failed to generate an answer."} |
|
|
|
@app.get("/items/{item_id}") |
|
def read_item(item_id: int, q: str = None): |
|
return {"item_id": item_id, "q": q} |
|
|
|
@app.post("/items/") |
|
def create_item(item: Item): |
|
return {"item": item, "total_price": item.price + (item.tax or 0)} |
|
|