Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- app.py +15 -1
- config.py +1 -1
- rag_service.py +7 -20
app.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
from huggingface_hub import login
|
|
|
|
|
|
|
|
|
|
| 2 |
from config import HF_TOKEN, GAME_KNOWLEDGE_DATA, EMBEDDING_MODEL_ID, LLM_MODEL_ID
|
| 3 |
from rag_service import RAGService
|
| 4 |
from ui import build_gradio_ui
|
|
@@ -8,8 +11,19 @@ def main():
|
|
| 8 |
print("Logging into Hugging Face Hub...")
|
| 9 |
login(token=HF_TOKEN)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# 1. Create the single service instance. This loads all models and data.
|
| 12 |
-
rag_service = RAGService(GAME_KNOWLEDGE_DATA,
|
| 13 |
|
| 14 |
# 2. Build the UI, passing the service instance to it.
|
| 15 |
demo = build_gradio_ui(rag_service)
|
|
|
|
| 1 |
from huggingface_hub import login
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
|
| 5 |
from config import HF_TOKEN, GAME_KNOWLEDGE_DATA, EMBEDDING_MODEL_ID, LLM_MODEL_ID
|
| 6 |
from rag_service import RAGService
|
| 7 |
from ui import build_gradio_ui
|
|
|
|
| 11 |
print("Logging into Hugging Face Hub...")
|
| 12 |
login(token=HF_TOKEN)
|
| 13 |
|
| 14 |
+
print("Initializing embedding model...")
|
| 15 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID)
|
| 16 |
+
|
| 17 |
+
print("Initializing language model...")
|
| 18 |
+
llm_pipeline = pipeline(
|
| 19 |
+
"text-generation",
|
| 20 |
+
model=LLM_MODEL_ID,
|
| 21 |
+
device_map="auto",
|
| 22 |
+
dtype="auto",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
# 1. Create the single service instance. This loads all models and data.
|
| 26 |
+
rag_service = RAGService(GAME_KNOWLEDGE_DATA, embedding_model, llm_pipeline)
|
| 27 |
|
| 28 |
# 2. Build the UI, passing the service instance to it.
|
| 29 |
demo = build_gradio_ui(rag_service)
|
config.py
CHANGED
|
@@ -9,7 +9,7 @@ from web_helper import get_html, find_wiki_links, get_markdown_from_html, get_ma
|
|
| 9 |
# --- Hugging Face & Model Configuration ---
|
| 10 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 11 |
EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
|
| 12 |
-
LLM_MODEL_ID = "google/gemma-3-
|
| 13 |
|
| 14 |
# --- Data Source Configuration ---
|
| 15 |
BASE_URL = "https://hollowknight.wiki"
|
|
|
|
| 9 |
# --- Hugging Face & Model Configuration ---
|
| 10 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 11 |
EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
|
| 12 |
+
LLM_MODEL_ID = "google/gemma-3-12B-it"
|
| 13 |
|
| 14 |
# --- Data Source Configuration ---
|
| 15 |
BASE_URL = "https://hollowknight.wiki"
|
rag_service.py
CHANGED
|
@@ -1,33 +1,20 @@
|
|
| 1 |
import spaces
|
| 2 |
import torch
|
| 3 |
-
from sentence_transformers import
|
| 4 |
-
from transformers import
|
| 5 |
from threading import Thread
|
| 6 |
|
| 7 |
# Import project-specific modules
|
| 8 |
from config import BASE_URL, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data
|
| 9 |
from chat_context import ChatContext
|
| 10 |
|
| 11 |
-
embedding_model = None
|
| 12 |
-
|
| 13 |
class RAGService:
|
| 14 |
"""Manages model loading, data processing, and chat generation logic."""
|
| 15 |
-
def __init__(self, data_config: list[dict],
|
| 16 |
print("Initializing RAG Service...")
|
| 17 |
self.data_config = data_config
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
global embedding_model
|
| 21 |
-
embedding_model = SentenceTransformer(embedding_model_id)
|
| 22 |
-
|
| 23 |
-
print("Initializing language model...")
|
| 24 |
-
self.llm_pipeline = pipeline(
|
| 25 |
-
"text-generation",
|
| 26 |
-
model=llm_model_id,
|
| 27 |
-
device_map="auto",
|
| 28 |
-
dtype="auto",
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
self.knowledge_base: dict[str, list[dict]] = get_all_game_data(embedding_model)
|
| 32 |
|
| 33 |
def _select_content(self, title: str) -> list[dict]:
|
|
@@ -40,11 +27,11 @@ class RAGService:
|
|
| 40 |
if not query or not contents:
|
| 41 |
return -1
|
| 42 |
|
| 43 |
-
query_embedding = embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(embedding_model.device)
|
| 44 |
|
| 45 |
try:
|
| 46 |
# Stack pre-computed tensors from our knowledge base
|
| 47 |
-
contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(embedding_model.device)
|
| 48 |
except (RuntimeError, IndexError, TypeError) as e:
|
| 49 |
print(f"Warning: Could not stack content embeddings. Error: {e}")
|
| 50 |
return -1
|
|
|
|
| 1 |
import spaces
|
| 2 |
import torch
|
| 3 |
+
from sentence_transformers import util
|
| 4 |
+
from transformers import TextIteratorStreamer
|
| 5 |
from threading import Thread
|
| 6 |
|
| 7 |
# Import project-specific modules
|
| 8 |
from config import BASE_URL, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data
|
| 9 |
from chat_context import ChatContext
|
| 10 |
|
|
|
|
|
|
|
| 11 |
class RAGService:
|
| 12 |
"""Manages model loading, data processing, and chat generation logic."""
|
| 13 |
+
def __init__(self, data_config: list[dict], embedding_model, llm_pipeline):
|
| 14 |
print("Initializing RAG Service...")
|
| 15 |
self.data_config = data_config
|
| 16 |
+
self.embedding_model = embedding_model
|
| 17 |
+
self.llm_pipeline = llm_pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.knowledge_base: dict[str, list[dict]] = get_all_game_data(embedding_model)
|
| 19 |
|
| 20 |
def _select_content(self, title: str) -> list[dict]:
|
|
|
|
| 27 |
if not query or not contents:
|
| 28 |
return -1
|
| 29 |
|
| 30 |
+
query_embedding = self.embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(self.embedding_model.device)
|
| 31 |
|
| 32 |
try:
|
| 33 |
# Stack pre-computed tensors from our knowledge base
|
| 34 |
+
contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(self.embedding_model.device)
|
| 35 |
except (RuntimeError, IndexError, TypeError) as e:
|
| 36 |
print(f"Warning: Could not stack content embeddings. Error: {e}")
|
| 37 |
return -1
|