Spaces:
Runtime error
Runtime error
| # backend/main.py | |
| import os | |
| import json | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Request, Form, Response | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.encoders import jsonable_encoder | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms import HuggingFacePipeline # NEW | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import RetrievalQA | |
| from langchain import PromptTemplate | |
| # -------- optional OpenAI imports (kept, but disabled) ---------- | |
| # from langchain.llms import OpenAI | |
| # from langchain.embeddings import OpenAIEmbeddings | |
| # --------------------------------------------------------------- | |
| from ingest import Ingest | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| # ------------------------------------------------------------------ | |
| # 1. ENVIRONMENT | |
| # ------------------------------------------------------------------ | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN is None: | |
| raise ValueError("HUGGINGFACE_TOKEN not set in the environment.") | |
| # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional | |
| # if OPENAI_API_KEY is None: | |
| # print("OpenAI key missing – OpenAI path disabled.") | |
| # ------------------------------------------------------------------ | |
| # 2. LLM & EMBEDDINGS CONFIGURATION | |
| # ------------------------------------------------------------------ | |
| DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired | |
| EMB_EN = "sentence-transformers/all-MiniLM-L6-v2" | |
| EMB_CZ = "Seznam/retromae-small-cs" | |
| def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline: | |
| """ | |
| Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface. | |
| Works on CPU; uses half precision automatically when CUDA is available. | |
| """ | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token = HF_TOKEN, | |
| torch_dtype = dtype, | |
| device_map = "auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) | |
| gen_pipe = pipeline( | |
| task = "text-generation", | |
| model = model, | |
| tokenizer = tokenizer, | |
| max_new_tokens = 512, | |
| temperature = 0.2, | |
| top_p = 0.95, | |
| ) | |
| return HuggingFacePipeline(pipeline=gen_pipe) | |
| HF_LLM = build_hf_llm() # Initialise once; reuse in every request | |
| # OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional | |
| # ------------------------------------------------------------------ | |
| # 3. FASTAPI PLUMBING | |
| # ------------------------------------------------------------------ | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Embedding stores | |
| CZECH_STORE = "stores/czech_512" | |
| ENGLISH_STORE = "stores/english_512" | |
| ingestor = Ingest( | |
| # openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds | |
| chunk = 512, | |
| overlap = 256, | |
| czech_store = CZECH_STORE, | |
| english_store = ENGLISH_STORE, | |
| czech_embedding_model = EMB_CZ, | |
| english_embedding_model = EMB_EN, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 4. PROMPTS | |
| # ------------------------------------------------------------------ | |
| def prompt_en() -> PromptTemplate: | |
| tmpl = """You are an electrical engineer and you answer users' ###Question. | |
| # Your answer must be helpful, relevant and closely related to the user's ###Question. | |
| # Quote literally from the ###Context wherever possible. | |
| # Use your own words only to connect or clarify. If you don't know, say so. | |
| ###Context: {context} | |
| ###Question: {question} | |
| Helpful answer: | |
| """ | |
| return PromptTemplate(template=tmpl, input_variables=["context", "question"]) | |
| def prompt_cz() -> PromptTemplate: | |
| tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku. | |
| # Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou. | |
| # Citujte co nejvíce doslovně z ###Kontextu. | |
| # Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to. | |
| ###Kontext: {context} | |
| ###Otázka: {question} | |
| Užitečná odpověď: | |
| """ | |
| return PromptTemplate(template=tmpl, input_variables=["context", "question"]) | |
| # ------------------------------------------------------------------ | |
| # 5. ROUTES | |
| # ------------------------------------------------------------------ | |
| def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def ingest_data(folderPath: str = Form(...), language: str = Form(...)): | |
| if language.lower() == "czech": | |
| ingestor.data_czech = folderPath | |
| ingestor.ingest_czech() | |
| return {"message": "Czech data ingestion complete."} | |
| ingestor.data_english = folderPath | |
| ingestor.ingest_english() | |
| return {"message": "English data ingestion complete."} | |
| async def get_response(query: str = Form(...), language: str = Form(...)): | |
| is_czech = language.lower() == "czech" | |
| prompt = prompt_cz() if is_czech else prompt_en() | |
| store_path = CZECH_STORE if is_czech else ENGLISH_STORE | |
| embed_name = EMB_CZ if is_czech else EMB_EN | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name = embed_name, | |
| model_kwargs = {"device": "cpu"}, | |
| encode_kwargs= {"normalize_embeddings": False} | |
| ) | |
| vectordb = FAISS.load_local(store_path, embeddings) | |
| retriever = vectordb.as_retriever(search_kwargs={"k": 2}) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm = HF_LLM, # <- default open-source model | |
| # llm = OPENAI_LLM, # <- optional paid model | |
| chain_type = "stuff", | |
| retriever = retriever, | |
| return_source_documents= True, | |
| chain_type_kwargs = {"prompt": prompt}, | |
| verbose = True, | |
| ) | |
| result = qa_chain(query) | |
| answer = result["result"] | |
| src_doc = result["source_documents"][0].page_content | |
| src_path = result["source_documents"][0].metadata["source"] | |
| payload = jsonable_encoder(json.dumps({ | |
| "answer" : answer, | |
| "source_document" : src_doc, | |
| "doc" : src_path | |
| })) | |
| return Response(payload) | |