bebechien commited on
Commit
1eca919
·
verified ·
1 Parent(s): c7ffebe

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +15 -1
  2. config.py +1 -1
  3. rag_service.py +7 -20
app.py CHANGED
@@ -1,4 +1,7 @@
1
  from huggingface_hub import login
 
 
 
2
  from config import HF_TOKEN, GAME_KNOWLEDGE_DATA, EMBEDDING_MODEL_ID, LLM_MODEL_ID
3
  from rag_service import RAGService
4
  from ui import build_gradio_ui
@@ -8,8 +11,19 @@ def main():
8
  print("Logging into Hugging Face Hub...")
9
  login(token=HF_TOKEN)
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  # 1. Create the single service instance. This loads all models and data.
12
- rag_service = RAGService(GAME_KNOWLEDGE_DATA, EMBEDDING_MODEL_ID, LLM_MODEL_ID)
13
 
14
  # 2. Build the UI, passing the service instance to it.
15
  demo = build_gradio_ui(rag_service)
 
1
  from huggingface_hub import login
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import pipeline
4
+
5
  from config import HF_TOKEN, GAME_KNOWLEDGE_DATA, EMBEDDING_MODEL_ID, LLM_MODEL_ID
6
  from rag_service import RAGService
7
  from ui import build_gradio_ui
 
11
  print("Logging into Hugging Face Hub...")
12
  login(token=HF_TOKEN)
13
 
14
+ print("Initializing embedding model...")
15
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID)
16
+
17
+ print("Initializing language model...")
18
+ llm_pipeline = pipeline(
19
+ "text-generation",
20
+ model=LLM_MODEL_ID,
21
+ device_map="auto",
22
+ dtype="auto",
23
+ )
24
+
25
  # 1. Create the single service instance. This loads all models and data.
26
+ rag_service = RAGService(GAME_KNOWLEDGE_DATA, embedding_model, llm_pipeline)
27
 
28
  # 2. Build the UI, passing the service instance to it.
29
  demo = build_gradio_ui(rag_service)
config.py CHANGED
@@ -9,7 +9,7 @@ from web_helper import get_html, find_wiki_links, get_markdown_from_html, get_ma
9
  # --- Hugging Face & Model Configuration ---
10
  HF_TOKEN = os.getenv('HF_TOKEN')
11
  EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
12
- LLM_MODEL_ID = "google/gemma-3-1B-it"
13
 
14
  # --- Data Source Configuration ---
15
  BASE_URL = "https://hollowknight.wiki"
 
9
  # --- Hugging Face & Model Configuration ---
10
  HF_TOKEN = os.getenv('HF_TOKEN')
11
  EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
12
+ LLM_MODEL_ID = "google/gemma-3-12B-it"
13
 
14
  # --- Data Source Configuration ---
15
  BASE_URL = "https://hollowknight.wiki"
rag_service.py CHANGED
@@ -1,33 +1,20 @@
1
  import spaces
2
  import torch
3
- from sentence_transformers import SentenceTransformer, util
4
- from transformers import pipeline, TextIteratorStreamer
5
  from threading import Thread
6
 
7
  # Import project-specific modules
8
  from config import BASE_URL, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data
9
  from chat_context import ChatContext
10
 
11
- embedding_model = None
12
-
13
  class RAGService:
14
  """Manages model loading, data processing, and chat generation logic."""
15
- def __init__(self, data_config: list[dict], embedding_model_id: str, llm_model_id: str):
16
  print("Initializing RAG Service...")
17
  self.data_config = data_config
18
-
19
- print("Initializing embedding model...")
20
- global embedding_model
21
- embedding_model = SentenceTransformer(embedding_model_id)
22
-
23
- print("Initializing language model...")
24
- self.llm_pipeline = pipeline(
25
- "text-generation",
26
- model=llm_model_id,
27
- device_map="auto",
28
- dtype="auto",
29
- )
30
-
31
  self.knowledge_base: dict[str, list[dict]] = get_all_game_data(embedding_model)
32
 
33
  def _select_content(self, title: str) -> list[dict]:
@@ -40,11 +27,11 @@ class RAGService:
40
  if not query or not contents:
41
  return -1
42
 
43
- query_embedding = embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(embedding_model.device)
44
 
45
  try:
46
  # Stack pre-computed tensors from our knowledge base
47
- contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(embedding_model.device)
48
  except (RuntimeError, IndexError, TypeError) as e:
49
  print(f"Warning: Could not stack content embeddings. Error: {e}")
50
  return -1
 
1
  import spaces
2
  import torch
3
+ from sentence_transformers import util
4
+ from transformers import TextIteratorStreamer
5
  from threading import Thread
6
 
7
  # Import project-specific modules
8
  from config import BASE_URL, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data
9
  from chat_context import ChatContext
10
 
 
 
11
  class RAGService:
12
  """Manages model loading, data processing, and chat generation logic."""
13
+ def __init__(self, data_config: list[dict], embedding_model, llm_pipeline):
14
  print("Initializing RAG Service...")
15
  self.data_config = data_config
16
+ self.embedding_model = embedding_model
17
+ self.llm_pipeline = llm_pipeline
 
 
 
 
 
 
 
 
 
 
 
18
  self.knowledge_base: dict[str, list[dict]] = get_all_game_data(embedding_model)
19
 
20
  def _select_content(self, title: str) -> list[dict]:
 
27
  if not query or not contents:
28
  return -1
29
 
30
+ query_embedding = self.embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(self.embedding_model.device)
31
 
32
  try:
33
  # Stack pre-computed tensors from our knowledge base
34
+ contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(self.embedding_model.device)
35
  except (RuntimeError, IndexError, TypeError) as e:
36
  print(f"Warning: Could not stack content embeddings. Error: {e}")
37
  return -1