Teapack1's picture
Update fast_app.py
9201e2b verified
# backend/main.py
import os
import json
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form, Response
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFacePipeline # NEW
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
# -------- optional OpenAI imports (kept, but disabled) ----------
# from langchain.llms import OpenAI
# from langchain.embeddings import OpenAIEmbeddings
# ---------------------------------------------------------------
from ingest import Ingest
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# ------------------------------------------------------------------
# 1. ENVIRONMENT
# ------------------------------------------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional
# if OPENAI_API_KEY is None:
# print("OpenAI key missing – OpenAI path disabled.")
# ------------------------------------------------------------------
# 2. LLM & EMBEDDINGS CONFIGURATION
# ------------------------------------------------------------------
DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired
EMB_EN = "sentence-transformers/all-MiniLM-L6-v2"
EMB_CZ = "Seznam/retromae-small-cs"
def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
"""
Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
Works on CPU; uses half precision automatically when CUDA is available.
"""
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
token = HF_TOKEN,
torch_dtype = dtype,
device_map = "auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
gen_pipe = pipeline(
task = "text-generation",
model = model,
tokenizer = tokenizer,
max_new_tokens = 512,
temperature = 0.2,
top_p = 0.95,
)
return HuggingFacePipeline(pipeline=gen_pipe)
HF_LLM = build_hf_llm() # Initialise once; reuse in every request
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional
# ------------------------------------------------------------------
# 3. FASTAPI PLUMBING
# ------------------------------------------------------------------
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Embedding stores
CZECH_STORE = "stores/czech_512"
ENGLISH_STORE = "stores/english_512"
ingestor = Ingest(
# openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds
chunk = 512,
overlap = 256,
czech_store = CZECH_STORE,
english_store = ENGLISH_STORE,
czech_embedding_model = EMB_CZ,
english_embedding_model = EMB_EN,
)
# ------------------------------------------------------------------
# 4. PROMPTS
# ------------------------------------------------------------------
def prompt_en() -> PromptTemplate:
tmpl = """You are an electrical engineer and you answer users' ###Question.
# Your answer must be helpful, relevant and closely related to the user's ###Question.
# Quote literally from the ###Context wherever possible.
# Use your own words only to connect or clarify. If you don't know, say so.
###Context: {context}
###Question: {question}
Helpful answer:
"""
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
def prompt_cz() -> PromptTemplate:
tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
# Citujte co nejvíce doslovně z ###Kontextu.
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
###Kontext: {context}
###Otázka: {question}
Užitečná odpověď:
"""
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
# ------------------------------------------------------------------
# 5. ROUTES
# ------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/ingest_data")
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
if language.lower() == "czech":
ingestor.data_czech = folderPath
ingestor.ingest_czech()
return {"message": "Czech data ingestion complete."}
ingestor.data_english = folderPath
ingestor.ingest_english()
return {"message": "English data ingestion complete."}
@app.post("/get_response")
async def get_response(query: str = Form(...), language: str = Form(...)):
is_czech = language.lower() == "czech"
prompt = prompt_cz() if is_czech else prompt_en()
store_path = CZECH_STORE if is_czech else ENGLISH_STORE
embed_name = EMB_CZ if is_czech else EMB_EN
embeddings = HuggingFaceEmbeddings(
model_name = embed_name,
model_kwargs = {"device": "cpu"},
encode_kwargs= {"normalize_embeddings": False}
)
vectordb = FAISS.load_local(store_path, embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
qa_chain = RetrievalQA.from_chain_type(
llm = HF_LLM, # <- default open-source model
# llm = OPENAI_LLM, # <- optional paid model
chain_type = "stuff",
retriever = retriever,
return_source_documents= True,
chain_type_kwargs = {"prompt": prompt},
verbose = True,
)
result = qa_chain(query)
answer = result["result"]
src_doc = result["source_documents"][0].page_content
src_path = result["source_documents"][0].metadata["source"]
payload = jsonable_encoder(json.dumps({
"answer" : answer,
"source_document" : src_doc,
"doc" : src_path
}))
return Response(payload)