b10902118 commited on
Commit
292daea
·
1 Parent(s): 65da904

support JSON output for ollama and openai

Browse files
Files changed (3) hide show
  1. lightrag/llm.py +38 -14
  2. lightrag/operate.py +3 -3
  3. lightrag/utils.py +1 -1
lightrag/llm.py CHANGED
@@ -29,7 +29,11 @@ import torch
29
  from pydantic import BaseModel, Field
30
  from typing import List, Dict, Callable, Any
31
  from .base import BaseKVStorage
32
- from .utils import compute_args_hash, wrap_embedding_func_with_attrs
 
 
 
 
33
 
34
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
 
@@ -301,7 +305,7 @@ async def ollama_model_if_cache(
301
  model, prompt, system_prompt=None, history_messages=[], **kwargs
302
  ) -> str:
303
  kwargs.pop("max_tokens", None)
304
- kwargs.pop("response_format", None)
305
  host = kwargs.pop("host", None)
306
  timeout = kwargs.pop("timeout", None)
307
 
@@ -345,9 +349,9 @@ def initialize_lmdeploy_pipeline(
345
  backend_config=TurbomindEngineConfig(
346
  tp=tp, model_format=model_format, quant_policy=quant_policy
347
  ),
348
- chat_template_config=ChatTemplateConfig(model_name=chat_template)
349
- if chat_template
350
- else None,
351
  log_level="WARNING",
352
  )
353
  return lmdeploy_pipe
@@ -458,9 +462,16 @@ async def lmdeploy_model_if_cache(
458
  return response
459
 
460
 
 
 
 
 
 
461
  async def gpt_4o_complete(
462
- prompt, system_prompt=None, history_messages=[], **kwargs
463
  ) -> str:
 
 
464
  return await openai_complete_if_cache(
465
  "gpt-4o",
466
  prompt,
@@ -471,8 +482,10 @@ async def gpt_4o_complete(
471
 
472
 
473
  async def gpt_4o_mini_complete(
474
- prompt, system_prompt=None, history_messages=[], **kwargs
475
  ) -> str:
 
 
476
  return await openai_complete_if_cache(
477
  "gpt-4o-mini",
478
  prompt,
@@ -483,45 +496,56 @@ async def gpt_4o_mini_complete(
483
 
484
 
485
  async def azure_openai_complete(
486
- prompt, system_prompt=None, history_messages=[], **kwargs
487
  ) -> str:
488
- return await azure_openai_complete_if_cache(
489
  "conversation-4o-mini",
490
  prompt,
491
  system_prompt=system_prompt,
492
  history_messages=history_messages,
493
  **kwargs,
494
  )
 
 
 
495
 
496
 
497
  async def bedrock_complete(
498
- prompt, system_prompt=None, history_messages=[], **kwargs
499
  ) -> str:
500
- return await bedrock_complete_if_cache(
501
  "anthropic.claude-3-haiku-20240307-v1:0",
502
  prompt,
503
  system_prompt=system_prompt,
504
  history_messages=history_messages,
505
  **kwargs,
506
  )
 
 
 
507
 
508
 
509
  async def hf_model_complete(
510
- prompt, system_prompt=None, history_messages=[], **kwargs
511
  ) -> str:
512
  model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
513
- return await hf_model_if_cache(
514
  model_name,
515
  prompt,
516
  system_prompt=system_prompt,
517
  history_messages=history_messages,
518
  **kwargs,
519
  )
 
 
 
520
 
521
 
522
  async def ollama_model_complete(
523
- prompt, system_prompt=None, history_messages=[], **kwargs
524
  ) -> str:
 
 
525
  model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
526
  return await ollama_model_if_cache(
527
  model_name,
 
29
  from pydantic import BaseModel, Field
30
  from typing import List, Dict, Callable, Any
31
  from .base import BaseKVStorage
32
+ from .utils import (
33
+ compute_args_hash,
34
+ wrap_embedding_func_with_attrs,
35
+ locate_json_string_body_from_string,
36
+ )
37
 
38
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
 
 
305
  model, prompt, system_prompt=None, history_messages=[], **kwargs
306
  ) -> str:
307
  kwargs.pop("max_tokens", None)
308
+ # kwargs.pop("response_format", None) # allow json
309
  host = kwargs.pop("host", None)
310
  timeout = kwargs.pop("timeout", None)
311
 
 
349
  backend_config=TurbomindEngineConfig(
350
  tp=tp, model_format=model_format, quant_policy=quant_policy
351
  ),
352
+ chat_template_config=(
353
+ ChatTemplateConfig(model_name=chat_template) if chat_template else None
354
+ ),
355
  log_level="WARNING",
356
  )
357
  return lmdeploy_pipe
 
462
  return response
463
 
464
 
465
+ class GPTKeywordExtractionFormat(BaseModel):
466
+ high_level_keywords: List[str]
467
+ low_level_keywords: List[str]
468
+
469
+
470
  async def gpt_4o_complete(
471
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
472
  ) -> str:
473
+ if keyword_extraction:
474
+ kwargs["response_format"] = GPTKeywordExtractionFormat
475
  return await openai_complete_if_cache(
476
  "gpt-4o",
477
  prompt,
 
482
 
483
 
484
  async def gpt_4o_mini_complete(
485
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
486
  ) -> str:
487
+ if keyword_extraction:
488
+ kwargs["response_format"] = GPTKeywordExtractionFormat
489
  return await openai_complete_if_cache(
490
  "gpt-4o-mini",
491
  prompt,
 
496
 
497
 
498
  async def azure_openai_complete(
499
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
500
  ) -> str:
501
+ result = await azure_openai_complete_if_cache(
502
  "conversation-4o-mini",
503
  prompt,
504
  system_prompt=system_prompt,
505
  history_messages=history_messages,
506
  **kwargs,
507
  )
508
+ if keyword_extraction: # TODO: use JSON API
509
+ return locate_json_string_body_from_string(result)
510
+ return result
511
 
512
 
513
  async def bedrock_complete(
514
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
515
  ) -> str:
516
+ result = await bedrock_complete_if_cache(
517
  "anthropic.claude-3-haiku-20240307-v1:0",
518
  prompt,
519
  system_prompt=system_prompt,
520
  history_messages=history_messages,
521
  **kwargs,
522
  )
523
+ if keyword_extraction: # TODO: use JSON API
524
+ return locate_json_string_body_from_string(result)
525
+ return result
526
 
527
 
528
  async def hf_model_complete(
529
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
530
  ) -> str:
531
  model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
532
+ result = await hf_model_if_cache(
533
  model_name,
534
  prompt,
535
  system_prompt=system_prompt,
536
  history_messages=history_messages,
537
  **kwargs,
538
  )
539
+ if keyword_extraction: # TODO: use JSON API
540
+ return locate_json_string_body_from_string(result)
541
+ return result
542
 
543
 
544
  async def ollama_model_complete(
545
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
546
  ) -> str:
547
+ if keyword_extraction:
548
+ kwargs["response_format"] = "json"
549
  model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
550
  return await ollama_model_if_cache(
551
  model_name,
lightrag/operate.py CHANGED
@@ -461,12 +461,12 @@ async def kg_query(
461
  use_model_func = global_config["llm_model_func"]
462
  kw_prompt_temp = PROMPTS["keywords_extraction"]
463
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
464
- result = await use_model_func(kw_prompt)
465
  logger.info("kw_prompt result:")
466
  print(result)
467
  try:
468
- json_text = locate_json_string_body_from_string(result)
469
- keywords_data = json.loads(json_text)
470
  hl_keywords = keywords_data.get("high_level_keywords", [])
471
  ll_keywords = keywords_data.get("low_level_keywords", [])
472
 
 
461
  use_model_func = global_config["llm_model_func"]
462
  kw_prompt_temp = PROMPTS["keywords_extraction"]
463
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
464
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
465
  logger.info("kw_prompt result:")
466
  print(result)
467
  try:
468
+ # json_text = locate_json_string_body_from_string(result) # handled in use_model_func
469
+ keywords_data = json.loads(result)
470
  hl_keywords = keywords_data.get("high_level_keywords", [])
471
  ll_keywords = keywords_data.get("low_level_keywords", [])
472
 
lightrag/utils.py CHANGED
@@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
54
  maybe_json_str = maybe_json_str.replace("\\n", "")
55
  maybe_json_str = maybe_json_str.replace("\n", "")
56
  maybe_json_str = maybe_json_str.replace("'", '"')
57
- json.loads(maybe_json_str)
58
  return maybe_json_str
59
  except Exception:
60
  pass
 
54
  maybe_json_str = maybe_json_str.replace("\\n", "")
55
  maybe_json_str = maybe_json_str.replace("\n", "")
56
  maybe_json_str = maybe_json_str.replace("'", '"')
57
+ # json.loads(maybe_json_str) # don't check here, cannot validate schema after all
58
  return maybe_json_str
59
  except Exception:
60
  pass