ai-luoshaoye / memory.py
BarryWang's picture
Upload memory.py
c478992
raw
history blame
No virus
1.76 kB
# -*- coding: utf-8 -*-
# @Time : 2023/3/21 13:04
# @Author : BarryWang
# @FileName: memory.py
# @Github : https://github.com/BarryWangQwQ
# import time
from txtai.embeddings import Embeddings
class Dialogue:
user_content: str
assistant_content: str
def __init__(self, user_content: str, assistant_content: str):
self.user_content = user_content
self.assistant_content = assistant_content
def raw(self):
return [
{'role': 'user', 'content': self.user_content},
{'role': 'assistant', 'content': self.assistant_content}
]
class MemoryBlocks:
def __init__(self):
self.embeddings = Embeddings(
{
"path": "sentence-transformers/distiluse-base-multilingual-cased-v2",
'content': True
}
)
print('已加载模拟记忆区块')
def upsert(self, dialogue_list):
self.embeddings.upsert(
(
uid, {'text': dialogue.user_content, 'raw': dialogue.raw()}, None
) for uid, dialogue in enumerate(dialogue_list)
)
def search(self, question: str) -> list:
neighborhoods = []
results = self.embeddings.search(
"SELECT text, score, raw FROM txtai WHERE similar('{0}') limit 5".format(question)
)
for r in results:
neighborhoods += eval(r['raw'])
return neighborhoods
# data = [
# Dialogue('你的生日是哪一天?', '我的生日是1997年8月12日'),
# Dialogue('你叫什么名字?', '我叫洛少爷'),
# Dialogue('你擅长什么?', '我擅长音乐,偶尔也会配音'),
# ]
# if __name__ == "__main__":
# upsert(data)
# print(search("你是谁?"))