fix hf embedding to support loading to different device
Browse files- lightrag/llm.py +6 -2
lightrag/llm.py
CHANGED
|
@@ -693,13 +693,17 @@ async def bedrock_embedding(
|
|
| 693 |
|
| 694 |
|
| 695 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
|
|
|
| 696 |
input_ids = tokenizer(
|
| 697 |
texts, return_tensors="pt", padding=True, truncation=True
|
| 698 |
-
).input_ids
|
| 699 |
with torch.no_grad():
|
| 700 |
outputs = embed_model(input_ids)
|
| 701 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
|
| 705 |
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
|
|
|
| 693 |
|
| 694 |
|
| 695 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
| 696 |
+
device = next(embed_model.parameters()).device
|
| 697 |
input_ids = tokenizer(
|
| 698 |
texts, return_tensors="pt", padding=True, truncation=True
|
| 699 |
+
).input_ids.to(device)
|
| 700 |
with torch.no_grad():
|
| 701 |
outputs = embed_model(input_ids)
|
| 702 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 703 |
+
if embeddings.dtype == torch.bfloat16:
|
| 704 |
+
return embeddings.detach().to(torch.float32).cpu().numpy()
|
| 705 |
+
else:
|
| 706 |
+
return embeddings.detach().cpu().numpy()
|
| 707 |
|
| 708 |
|
| 709 |
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|