anpigon commited on
Commit
c2c8656
โ€ข
1 Parent(s): 6cf0662

chore: Update device for HuggingFaceBgeEmbeddings to dynamic device selection

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
 
2
 
3
  import gradio as gr
4
  from dotenv import load_dotenv
5
  from langchain.callbacks.base import BaseCallbackHandler
6
  from langchain.embeddings import CacheBackedEmbeddings
7
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
8
  from langchain.storage import LocalFileStore
9
  from langchain_anthropic import ChatAnthropic
10
  from langchain_community.chat_models import ChatOllama
@@ -100,11 +101,25 @@ print(f"๋ถ„ํ• ๋œ .ipynb ํŒŒ์ผ์˜ ๊ฐœ์ˆ˜: {len(ipynb_docs)}")
100
  combined_documents = py_docs + mdx_docs + ipynb_docs
101
  print(f"์ด ๋„ํ๋จผํŠธ ๊ฐœ์ˆ˜: {len(combined_documents)}")
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Initialize embeddings and cache
104
  store = LocalFileStore("~/.cache/embedding")
105
  embeddings = HuggingFaceBgeEmbeddings(
106
  model_name="BAAI/bge-m3",
107
- model_kwargs={"device": "cuda:0"},
108
  encode_kwargs={"normalize_embeddings": True},
109
  )
110
  cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
 
1
  import os
2
+ import torch
3
 
4
  import gradio as gr
5
  from dotenv import load_dotenv
6
  from langchain.callbacks.base import BaseCallbackHandler
7
  from langchain.embeddings import CacheBackedEmbeddings
8
+ from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
9
  from langchain.storage import LocalFileStore
10
  from langchain_anthropic import ChatAnthropic
11
  from langchain_community.chat_models import ChatOllama
 
101
  combined_documents = py_docs + mdx_docs + ipynb_docs
102
  print(f"์ด ๋„ํ๋จผํŠธ ๊ฐœ์ˆ˜: {len(combined_documents)}")
103
 
104
+
105
+ # Define the device setting function
106
+ def get_device():
107
+ if torch.cuda.is_available():
108
+ return "cuda:0"
109
+ elif torch.backends.mps.is_available():
110
+ return "mps"
111
+ else:
112
+ return "cpu"
113
+
114
+
115
+ # Use the function to set the device in model_kwargs
116
+ device = get_device()
117
+
118
  # Initialize embeddings and cache
119
  store = LocalFileStore("~/.cache/embedding")
120
  embeddings = HuggingFaceBgeEmbeddings(
121
  model_name="BAAI/bge-m3",
122
+ model_kwargs={"device": device},
123
  encode_kwargs={"normalize_embeddings": True},
124
  )
125
  cached_embeddings = CacheBackedEmbeddings.from_bytes_store(