Spaces:
Runtime error
Runtime error
# backend/main.py | |
import os | |
import json | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request, Form, Response | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.encoders import jsonable_encoder | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms import HuggingFacePipeline # NEW | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain import PromptTemplate | |
# -------- optional OpenAI imports (kept, but disabled) ---------- | |
# from langchain.llms import OpenAI | |
# from langchain.embeddings import OpenAIEmbeddings | |
# --------------------------------------------------------------- | |
from ingest import Ingest | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
# ------------------------------------------------------------------ | |
# 1. ENVIRONMENT | |
# ------------------------------------------------------------------ | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN is None: | |
raise ValueError("HUGGINGFACE_TOKEN not set in the environment.") | |
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional | |
# if OPENAI_API_KEY is None: | |
# print("OpenAI key missing – OpenAI path disabled.") | |
# ------------------------------------------------------------------ | |
# 2. LLM & EMBEDDINGS CONFIGURATION | |
# ------------------------------------------------------------------ | |
DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired | |
EMB_EN = "sentence-transformers/all-MiniLM-L6-v2" | |
EMB_CZ = "Seznam/retromae-small-cs" | |
def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline: | |
""" | |
Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface. | |
Works on CPU; uses half precision automatically when CUDA is available. | |
""" | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token = HF_TOKEN, | |
torch_dtype = dtype, | |
device_map = "auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) | |
gen_pipe = pipeline( | |
task = "text-generation", | |
model = model, | |
tokenizer = tokenizer, | |
max_new_tokens = 512, | |
temperature = 0.2, | |
top_p = 0.95, | |
) | |
return HuggingFacePipeline(pipeline=gen_pipe) | |
HF_LLM = build_hf_llm() # Initialise once; reuse in every request | |
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional | |
# ------------------------------------------------------------------ | |
# 3. FASTAPI PLUMBING | |
# ------------------------------------------------------------------ | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Embedding stores | |
CZECH_STORE = "stores/czech_512" | |
ENGLISH_STORE = "stores/english_512" | |
ingestor = Ingest( | |
# openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds | |
chunk = 512, | |
overlap = 256, | |
czech_store = CZECH_STORE, | |
english_store = ENGLISH_STORE, | |
czech_embedding_model = EMB_CZ, | |
english_embedding_model = EMB_EN, | |
) | |
# ------------------------------------------------------------------ | |
# 4. PROMPTS | |
# ------------------------------------------------------------------ | |
def prompt_en() -> PromptTemplate: | |
tmpl = """You are an electrical engineer and you answer users' ###Question. | |
# Your answer must be helpful, relevant and closely related to the user's ###Question. | |
# Quote literally from the ###Context wherever possible. | |
# Use your own words only to connect or clarify. If you don't know, say so. | |
###Context: {context} | |
###Question: {question} | |
Helpful answer: | |
""" | |
return PromptTemplate(template=tmpl, input_variables=["context", "question"]) | |
def prompt_cz() -> PromptTemplate: | |
tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku. | |
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou. | |
# Citujte co nejvíce doslovně z ###Kontextu. | |
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to. | |
###Kontext: {context} | |
###Otázka: {question} | |
Užitečná odpověď: | |
""" | |
return PromptTemplate(template=tmpl, input_variables=["context", "question"]) | |
# ------------------------------------------------------------------ | |
# 5. ROUTES | |
# ------------------------------------------------------------------ | |
def home(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)): | |
if language.lower() == "czech": | |
ingestor.data_czech = folderPath | |
ingestor.ingest_czech() | |
return {"message": "Czech data ingestion complete."} | |
ingestor.data_english = folderPath | |
ingestor.ingest_english() | |
return {"message": "English data ingestion complete."} | |
async def get_response(query: str = Form(...), language: str = Form(...)): | |
is_czech = language.lower() == "czech" | |
prompt = prompt_cz() if is_czech else prompt_en() | |
store_path = CZECH_STORE if is_czech else ENGLISH_STORE | |
embed_name = EMB_CZ if is_czech else EMB_EN | |
embeddings = HuggingFaceEmbeddings( | |
model_name = embed_name, | |
model_kwargs = {"device": "cpu"}, | |
encode_kwargs= {"normalize_embeddings": False} | |
) | |
vectordb = FAISS.load_local(store_path, embeddings) | |
retriever = vectordb.as_retriever(search_kwargs={"k": 2}) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = HF_LLM, # <- default open-source model | |
# llm = OPENAI_LLM, # <- optional paid model | |
chain_type = "stuff", | |
retriever = retriever, | |
return_source_documents= True, | |
chain_type_kwargs = {"prompt": prompt}, | |
verbose = True, | |
) | |
result = qa_chain(query) | |
answer = result["result"] | |
src_doc = result["source_documents"][0].page_content | |
src_path = result["source_documents"][0].metadata["source"] | |
payload = jsonable_encoder(json.dumps({ | |
"answer" : answer, | |
"source_document" : src_doc, | |
"doc" : src_path | |
})) | |
return Response(payload) | |