Kevin Hu commited on
Commit
0c61e3b
·
1 Parent(s): 7240dd7

less text, better extraction (#1869)

Browse files

### What problem does this PR solve?

#1861

### Type of change

- [x] Refactoring

Files changed (1) hide show
  1. graphrag/index.py +6 -5
graphrag/index.py CHANGED
@@ -75,10 +75,11 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
75
  llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
76
  ext = GraphExtractor(llm_bdl)
77
  left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
78
- left_token_count = max(llm_bdl.max_length * 0.8, left_token_count)
79
 
80
  assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
81
 
 
82
  texts, graphs = [], []
83
  cnt = 0
84
  threads = []
@@ -86,15 +87,15 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
86
  for i in range(len(chunks)):
87
  tkn_cnt = num_tokens_from_string(chunks[i])
88
  if cnt+tkn_cnt >= left_token_count and texts:
89
- for b in range(0, len(texts), 16):
90
- threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback))
91
  texts = []
92
  cnt = 0
93
  texts.append(chunks[i])
94
  cnt += tkn_cnt
95
  if texts:
96
- for b in range(0, len(texts), 16):
97
- threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback))
98
 
99
  callback(0.5, "Extracting entities.")
100
  graphs = []
 
75
  llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
76
  ext = GraphExtractor(llm_bdl)
77
  left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
78
+ left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
79
 
80
  assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
81
 
82
+ BATCH_SIZE=1
83
  texts, graphs = [], []
84
  cnt = 0
85
  threads = []
 
87
  for i in range(len(chunks)):
88
  tkn_cnt = num_tokens_from_string(chunks[i])
89
  if cnt+tkn_cnt >= left_token_count and texts:
90
+ for b in range(0, len(texts), BATCH_SIZE):
91
+ threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
92
  texts = []
93
  cnt = 0
94
  texts.append(chunks[i])
95
  cnt += tkn_cnt
96
  if texts:
97
+ for b in range(0, len(texts), BATCH_SIZE):
98
+ threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
99
 
100
  callback(0.5, "Extracting entities.")
101
  graphs = []