auto-draft / utils /knowledge.py
shaocongma
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
94dc00e
raw history blame
No virus
1.96 kB
import tiktoken
from random import shuffle
# `tokenizer`: used to count how many tokens
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
def tiktoken_len(text):
# evaluate how many tokens for the given text
tokens = tokenizer.encode(text, disallowed_special=())
return len(tokens)
class Knowledge:
def __init__(self, db):
self.db = db
self.contents = []
def collect_knowledge(self, keywords_dict: dict, max_query: int):
"""
keywords_dict:
{"machine learning": 5, "language model": 2};
"""
db = self.db
if max_query > 0:
for kw in keywords_dict:
docs = db.similarity_search_with_score(kw, k=max_query)
for i in range(max_query):
content = {"content": docs[i][0].page_content.replace('\n', ' '),
"score": docs[i][1]} # todo: add more meta information; clean the page_content
self.contents.append(content)
# sort contents by score / shuffle
shuffle(self.contents)
def to_prompts(self, max_tokens=2048):
if len(self.contents) == 0:
return ""
prompts = []
tokens = 0
for idx, content in enumerate(self.contents):
prompt = "Reference {}: {}\n".format(idx, content["content"])
tokens += tiktoken_len(prompt)
if tokens >= max_tokens:
break
else:
prompts.append(prompt)
return "".join(prompts)
def to_json(self):
if len(self.contents) == 0:
return {}
output = {}
for idx, content in enumerate(self.contents):
output[str(idx)] = {
"content": content["content"],
"score": str(content["score"])
}
print(output)
return output