b10902118
commited on
Commit
·
292daea
1
Parent(s):
65da904
support JSON output for ollama and openai
Browse files- lightrag/llm.py +38 -14
- lightrag/operate.py +3 -3
- lightrag/utils.py +1 -1
lightrag/llm.py
CHANGED
|
@@ -29,7 +29,11 @@ import torch
|
|
| 29 |
from pydantic import BaseModel, Field
|
| 30 |
from typing import List, Dict, Callable, Any
|
| 31 |
from .base import BaseKVStorage
|
| 32 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 35 |
|
|
@@ -301,7 +305,7 @@ async def ollama_model_if_cache(
|
|
| 301 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 302 |
) -> str:
|
| 303 |
kwargs.pop("max_tokens", None)
|
| 304 |
-
kwargs.pop("response_format", None)
|
| 305 |
host = kwargs.pop("host", None)
|
| 306 |
timeout = kwargs.pop("timeout", None)
|
| 307 |
|
|
@@ -345,9 +349,9 @@ def initialize_lmdeploy_pipeline(
|
|
| 345 |
backend_config=TurbomindEngineConfig(
|
| 346 |
tp=tp, model_format=model_format, quant_policy=quant_policy
|
| 347 |
),
|
| 348 |
-
chat_template_config=
|
| 349 |
-
|
| 350 |
-
|
| 351 |
log_level="WARNING",
|
| 352 |
)
|
| 353 |
return lmdeploy_pipe
|
|
@@ -458,9 +462,16 @@ async def lmdeploy_model_if_cache(
|
|
| 458 |
return response
|
| 459 |
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
async def gpt_4o_complete(
|
| 462 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 463 |
) -> str:
|
|
|
|
|
|
|
| 464 |
return await openai_complete_if_cache(
|
| 465 |
"gpt-4o",
|
| 466 |
prompt,
|
|
@@ -471,8 +482,10 @@ async def gpt_4o_complete(
|
|
| 471 |
|
| 472 |
|
| 473 |
async def gpt_4o_mini_complete(
|
| 474 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 475 |
) -> str:
|
|
|
|
|
|
|
| 476 |
return await openai_complete_if_cache(
|
| 477 |
"gpt-4o-mini",
|
| 478 |
prompt,
|
|
@@ -483,45 +496,56 @@ async def gpt_4o_mini_complete(
|
|
| 483 |
|
| 484 |
|
| 485 |
async def azure_openai_complete(
|
| 486 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 487 |
) -> str:
|
| 488 |
-
|
| 489 |
"conversation-4o-mini",
|
| 490 |
prompt,
|
| 491 |
system_prompt=system_prompt,
|
| 492 |
history_messages=history_messages,
|
| 493 |
**kwargs,
|
| 494 |
)
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
|
| 497 |
async def bedrock_complete(
|
| 498 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 499 |
) -> str:
|
| 500 |
-
|
| 501 |
"anthropic.claude-3-haiku-20240307-v1:0",
|
| 502 |
prompt,
|
| 503 |
system_prompt=system_prompt,
|
| 504 |
history_messages=history_messages,
|
| 505 |
**kwargs,
|
| 506 |
)
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
|
| 509 |
async def hf_model_complete(
|
| 510 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 511 |
) -> str:
|
| 512 |
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 513 |
-
|
| 514 |
model_name,
|
| 515 |
prompt,
|
| 516 |
system_prompt=system_prompt,
|
| 517 |
history_messages=history_messages,
|
| 518 |
**kwargs,
|
| 519 |
)
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
|
| 522 |
async def ollama_model_complete(
|
| 523 |
-
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 524 |
) -> str:
|
|
|
|
|
|
|
| 525 |
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 526 |
return await ollama_model_if_cache(
|
| 527 |
model_name,
|
|
|
|
| 29 |
from pydantic import BaseModel, Field
|
| 30 |
from typing import List, Dict, Callable, Any
|
| 31 |
from .base import BaseKVStorage
|
| 32 |
+
from .utils import (
|
| 33 |
+
compute_args_hash,
|
| 34 |
+
wrap_embedding_func_with_attrs,
|
| 35 |
+
locate_json_string_body_from_string,
|
| 36 |
+
)
|
| 37 |
|
| 38 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
|
|
|
|
| 305 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 306 |
) -> str:
|
| 307 |
kwargs.pop("max_tokens", None)
|
| 308 |
+
# kwargs.pop("response_format", None) # allow json
|
| 309 |
host = kwargs.pop("host", None)
|
| 310 |
timeout = kwargs.pop("timeout", None)
|
| 311 |
|
|
|
|
| 349 |
backend_config=TurbomindEngineConfig(
|
| 350 |
tp=tp, model_format=model_format, quant_policy=quant_policy
|
| 351 |
),
|
| 352 |
+
chat_template_config=(
|
| 353 |
+
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
| 354 |
+
),
|
| 355 |
log_level="WARNING",
|
| 356 |
)
|
| 357 |
return lmdeploy_pipe
|
|
|
|
| 462 |
return response
|
| 463 |
|
| 464 |
|
| 465 |
+
class GPTKeywordExtractionFormat(BaseModel):
|
| 466 |
+
high_level_keywords: List[str]
|
| 467 |
+
low_level_keywords: List[str]
|
| 468 |
+
|
| 469 |
+
|
| 470 |
async def gpt_4o_complete(
|
| 471 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 472 |
) -> str:
|
| 473 |
+
if keyword_extraction:
|
| 474 |
+
kwargs["response_format"] = GPTKeywordExtractionFormat
|
| 475 |
return await openai_complete_if_cache(
|
| 476 |
"gpt-4o",
|
| 477 |
prompt,
|
|
|
|
| 482 |
|
| 483 |
|
| 484 |
async def gpt_4o_mini_complete(
|
| 485 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 486 |
) -> str:
|
| 487 |
+
if keyword_extraction:
|
| 488 |
+
kwargs["response_format"] = GPTKeywordExtractionFormat
|
| 489 |
return await openai_complete_if_cache(
|
| 490 |
"gpt-4o-mini",
|
| 491 |
prompt,
|
|
|
|
| 496 |
|
| 497 |
|
| 498 |
async def azure_openai_complete(
|
| 499 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 500 |
) -> str:
|
| 501 |
+
result = await azure_openai_complete_if_cache(
|
| 502 |
"conversation-4o-mini",
|
| 503 |
prompt,
|
| 504 |
system_prompt=system_prompt,
|
| 505 |
history_messages=history_messages,
|
| 506 |
**kwargs,
|
| 507 |
)
|
| 508 |
+
if keyword_extraction: # TODO: use JSON API
|
| 509 |
+
return locate_json_string_body_from_string(result)
|
| 510 |
+
return result
|
| 511 |
|
| 512 |
|
| 513 |
async def bedrock_complete(
|
| 514 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 515 |
) -> str:
|
| 516 |
+
result = await bedrock_complete_if_cache(
|
| 517 |
"anthropic.claude-3-haiku-20240307-v1:0",
|
| 518 |
prompt,
|
| 519 |
system_prompt=system_prompt,
|
| 520 |
history_messages=history_messages,
|
| 521 |
**kwargs,
|
| 522 |
)
|
| 523 |
+
if keyword_extraction: # TODO: use JSON API
|
| 524 |
+
return locate_json_string_body_from_string(result)
|
| 525 |
+
return result
|
| 526 |
|
| 527 |
|
| 528 |
async def hf_model_complete(
|
| 529 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 530 |
) -> str:
|
| 531 |
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 532 |
+
result = await hf_model_if_cache(
|
| 533 |
model_name,
|
| 534 |
prompt,
|
| 535 |
system_prompt=system_prompt,
|
| 536 |
history_messages=history_messages,
|
| 537 |
**kwargs,
|
| 538 |
)
|
| 539 |
+
if keyword_extraction: # TODO: use JSON API
|
| 540 |
+
return locate_json_string_body_from_string(result)
|
| 541 |
+
return result
|
| 542 |
|
| 543 |
|
| 544 |
async def ollama_model_complete(
|
| 545 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 546 |
) -> str:
|
| 547 |
+
if keyword_extraction:
|
| 548 |
+
kwargs["response_format"] = "json"
|
| 549 |
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 550 |
return await ollama_model_if_cache(
|
| 551 |
model_name,
|
lightrag/operate.py
CHANGED
|
@@ -461,12 +461,12 @@ async def kg_query(
|
|
| 461 |
use_model_func = global_config["llm_model_func"]
|
| 462 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 463 |
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
| 464 |
-
result = await use_model_func(kw_prompt)
|
| 465 |
logger.info("kw_prompt result:")
|
| 466 |
print(result)
|
| 467 |
try:
|
| 468 |
-
json_text = locate_json_string_body_from_string(result)
|
| 469 |
-
keywords_data = json.loads(
|
| 470 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 471 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 472 |
|
|
|
|
| 461 |
use_model_func = global_config["llm_model_func"]
|
| 462 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 463 |
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
| 464 |
+
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
| 465 |
logger.info("kw_prompt result:")
|
| 466 |
print(result)
|
| 467 |
try:
|
| 468 |
+
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
|
| 469 |
+
keywords_data = json.loads(result)
|
| 470 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 471 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 472 |
|
lightrag/utils.py
CHANGED
|
@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
| 54 |
maybe_json_str = maybe_json_str.replace("\\n", "")
|
| 55 |
maybe_json_str = maybe_json_str.replace("\n", "")
|
| 56 |
maybe_json_str = maybe_json_str.replace("'", '"')
|
| 57 |
-
json.loads(maybe_json_str)
|
| 58 |
return maybe_json_str
|
| 59 |
except Exception:
|
| 60 |
pass
|
|
|
|
| 54 |
maybe_json_str = maybe_json_str.replace("\\n", "")
|
| 55 |
maybe_json_str = maybe_json_str.replace("\n", "")
|
| 56 |
maybe_json_str = maybe_json_str.replace("'", '"')
|
| 57 |
+
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
| 58 |
return maybe_json_str
|
| 59 |
except Exception:
|
| 60 |
pass
|