|
from random import randint |
|
from typing import Iterator |
|
|
|
from injector import Injector |
|
|
|
from taskweaver.logging import LoggingModule |
|
from taskweaver.memory import Attachment, Post |
|
from taskweaver.memory.attachment import AttachmentType |
|
from taskweaver.role import PostTranslator |
|
|
|
response_str1 = ( |
|
'{"response": [{"type": "thought", "content": "This is the thought"}, {"type": "python", ' |
|
'"content": "print(\'This is the code\')"}, {"type": "text", "content": "This ' |
|
'is the text"}, {"type": "sample", "content": "print(\'This is the ' |
|
'sample code\')"}, {"type": "execution_status", "content": "SUCCESS"}, ' |
|
'{"type": "execution_result", "content": "This is the execution result"}, ' |
|
'{"type": "send_to", "content": "Planner"}, {"type": "message", "content": ' |
|
'"This is the message"}]}' |
|
) |
|
|
|
role_name = "ProgramApe" |
|
executor_name = "CodeExecutor" |
|
|
|
app_injector = Injector( |
|
[LoggingModule], |
|
) |
|
translator = app_injector.create_object(PostTranslator) |
|
|
|
|
|
def test_parse_llm_stream(): |
|
def response_str() -> Iterator[str]: |
|
words = response_str1.split(" ") |
|
|
|
pos = 0 |
|
|
|
while True: |
|
n = randint(1, 10) |
|
part = " ".join(words[pos : pos + n]) + " " |
|
yield part |
|
pos += n |
|
if pos >= len(words): |
|
break |
|
|
|
attachments = translator.parse_llm_output_stream(response_str()) |
|
attachment_list = list(attachments) |
|
assert len(attachment_list) == 8 |
|
|
|
|
|
def test_parse_llm(): |
|
def early_stop(type: AttachmentType, text: str) -> bool: |
|
if type in [AttachmentType.python, AttachmentType.sample, AttachmentType.text]: |
|
return True |
|
return False |
|
|
|
response = translator.raw_text_to_post( |
|
llm_output=response_str1, |
|
send_from="CodeInterpreter", |
|
event_handler=lambda t, v: print(f"{t}: {v}"), |
|
early_stop=early_stop, |
|
) |
|
|
|
assert response.message is None |
|
assert response.send_to is None |
|
assert response.send_from == "CodeInterpreter" |
|
assert len(response.attachment_list) == 2 |
|
assert response.attachment_list[0].type == AttachmentType.thought |
|
assert response.attachment_list[0].content == "This is the thought" |
|
|
|
assert response.attachment_list[1].type == AttachmentType.python |
|
assert response.attachment_list[1].content == "print('This is the code')" |
|
|
|
response = translator.raw_text_to_post( |
|
llm_output=response_str1, |
|
send_from="CodeInterpreter", |
|
event_handler=lambda t, v: print(f"{t}: {v}"), |
|
) |
|
assert len(response.attachment_list) == 6 |
|
assert response.attachment_list[4].type == AttachmentType.execution_status |
|
assert response.attachment_list[4].content == "SUCCESS" |
|
assert response.attachment_list[5].type == AttachmentType.execution_result |
|
assert response.attachment_list[5].content == "This is the execution result" |
|
|
|
|
|
def test_post_to_raw_text(): |
|
post = Post.create(message="This is the message", send_from="CodeInterpreter", send_to="Planner") |
|
|
|
prompt = translator.post_to_raw_text(post=post, if_format_message=True, if_format_send_to=True) |
|
assert prompt == ( |
|
'{"response": [{"type": "send_to", "content": "Planner"}, {"type": "message", ' |
|
'"content": "This is the message"}]}' |
|
) |
|
|
|
prompt = translator.post_to_raw_text(post=post, if_format_message=False, if_format_send_to=False) |
|
assert prompt == '{"response": []}' |
|
|
|
post.add_attachment(Attachment.create(type="thought", content="This is the thought")) |
|
post.add_attachment(Attachment.create(type="python", content="print('This is the code')")) |
|
post.add_attachment(Attachment.create(type="text", content="This is the text")) |
|
post.add_attachment(Attachment.create(type="sample", content="print('This is the sample code')")) |
|
post.add_attachment(Attachment.create(type="execution_status", content="SUCCESS")) |
|
post.add_attachment(Attachment.create(type="execution_result", content="This is the execution result")) |
|
|
|
prompt = translator.post_to_raw_text(post=post, if_format_message=True, if_format_send_to=True) |
|
assert prompt == response_str1 |
|
|