davidlzs commited on
Commit
ccc7d21
·
1 Parent(s): ddcc625

fix hf embedding to support loading to different device

Browse files
Files changed (1) hide show
  1. lightrag/llm.py +6 -2
lightrag/llm.py CHANGED
@@ -693,13 +693,17 @@ async def bedrock_embedding(
693
 
694
 
695
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
 
696
  input_ids = tokenizer(
697
  texts, return_tensors="pt", padding=True, truncation=True
698
- ).input_ids
699
  with torch.no_grad():
700
  outputs = embed_model(input_ids)
701
  embeddings = outputs.last_hidden_state.mean(dim=1)
702
- return embeddings.detach().numpy()
 
 
 
703
 
704
 
705
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
 
693
 
694
 
695
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
696
+ device = next(embed_model.parameters()).device
697
  input_ids = tokenizer(
698
  texts, return_tensors="pt", padding=True, truncation=True
699
+ ).input_ids.to(device)
700
  with torch.no_grad():
701
  outputs = embed_model(input_ids)
702
  embeddings = outputs.last_hidden_state.mean(dim=1)
703
+ if embeddings.dtype == torch.bfloat16:
704
+ return embeddings.detach().to(torch.float32).cpu().numpy()
705
+ else:
706
+ return embeddings.detach().cpu().numpy()
707
 
708
 
709
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: