Mario Vignieri commited on
Commit
b02de85
·
1 Parent(s): f5099cd

fix hf_embed torch device use MPS or CPU when CUDA is not available -macos users

Browse files
Files changed (1) hide show
  1. lightrag/llm/hf.py +16 -1
lightrag/llm/hf.py CHANGED
@@ -138,16 +138,31 @@ async def hf_model_complete(
138
 
139
 
140
  async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
141
- device = next(embed_model.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
142
  encoded_texts = tokenizer(
143
  texts, return_tensors="pt", padding=True, truncation=True
144
  ).to(device)
 
 
145
  with torch.no_grad():
146
  outputs = embed_model(
147
  input_ids=encoded_texts["input_ids"],
148
  attention_mask=encoded_texts["attention_mask"],
149
  )
150
  embeddings = outputs.last_hidden_state.mean(dim=1)
 
 
151
  if embeddings.dtype == torch.bfloat16:
152
  return embeddings.detach().to(torch.float32).cpu().numpy()
153
  else:
 
138
 
139
 
140
  async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
141
+ # Detect the appropriate device
142
+ if torch.cuda.is_available():
143
+ device = next(embed_model.parameters()).device # Use CUDA if available
144
+ elif torch.backends.mps.is_available():
145
+ device = torch.device("mps") # Use MPS for Apple Silicon
146
+ else:
147
+ device = torch.device("cpu") # Fallback to CPU
148
+
149
+ # Move the model to the detected device
150
+ embed_model = embed_model.to(device)
151
+
152
+ # Tokenize the input texts and move them to the same device
153
  encoded_texts = tokenizer(
154
  texts, return_tensors="pt", padding=True, truncation=True
155
  ).to(device)
156
+
157
+ # Perform inference
158
  with torch.no_grad():
159
  outputs = embed_model(
160
  input_ids=encoded_texts["input_ids"],
161
  attention_mask=encoded_texts["attention_mask"],
162
  )
163
  embeddings = outputs.last_hidden_state.mean(dim=1)
164
+
165
+ # Convert embeddings to NumPy
166
  if embeddings.dtype == torch.bfloat16:
167
  return embeddings.detach().to(torch.float32).cpu().numpy()
168
  else: