File size: 5,942 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import io
import itertools
import json
from json import JSONDecodeError
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
import ijson
from injector import inject
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Attachment, Post
from taskweaver.memory.attachment import AttachmentType
class PostTranslator:
"""
PostTranslator is used to parse the output of the LLM or convert it to a Post object.
The core function is post_to_raw_text and raw_text_to_post.
"""
@inject
def __init__(
self,
logger: TelemetryLogger,
):
self.logger = logger
def raw_text_to_post(
self,
llm_output: str,
send_from: Literal["User", "Planner", "CodeInterpreter"],
event_handler: Callable[[str, str], None],
early_stop: Optional[Callable[[Union[AttachmentType, Literal["message", "send_to"]], str], bool]] = None,
validation_func: Optional[Callable[[Post], None]] = None,
) -> Post:
"""
Convert the raw text output of LLM to a Post object.
:param llm_output_stream:
:param send_from:
:param event_handler:
:param early_stop:
:return: Post
"""
# llm_output_list = [token for token in llm_output_stream] # collect all the llm output via iterator
# llm_output = "".join(llm_output_list)
post = Post.create(message=None, send_from=send_from, send_to=None)
self.logger.info(f"LLM output: {llm_output}")
for d in self.parse_llm_output_stream([llm_output]):
type_str = d["type"]
type: Optional[AttachmentType] = None
value = d["content"]
if type_str == "message":
post.message = value
elif type_str == "send_to":
assert value in [
"User",
"Planner",
"CodeInterpreter",
], f"Invalid send_to value: {value}"
post.send_to = value # type: ignore
else:
type = AttachmentType(type_str)
post.add_attachment(Attachment.create(type=type, content=value))
event_handler(type_str, value)
parsed_type = (
type
if type is not None
else "message"
if type_str == "message"
else "send_to"
if type_str == "send_to"
else None
)
assert parsed_type is not None, f"Invalid type: {type_str}"
if early_stop is not None and early_stop(parsed_type, value):
break
if post.send_to is not None:
event_handler(post.send_from + "->" + post.send_to, post.message)
if validation_func is not None:
validation_func(post)
return post
def post_to_raw_text(
self,
post: Post,
content_formatter: Callable[[Attachment], str] = lambda x: x.content,
if_format_message: bool = True,
if_format_send_to: bool = True,
ignored_types: Optional[List[AttachmentType]] = None,
) -> str:
"""
Convert a Post object to raw text in the format of LLM output.
:param post:
:param content_formatter:
:param if_format_message:
:param if_format_send_to:
:param ignored_types:
:return: str
"""
structured_llm: List[Dict[str, str]] = []
for attachment in post.attachment_list:
attachments_dict = {}
if ignored_types is not None and attachment.type in ignored_types:
continue
attachments_dict["type"] = attachment.type.value
attachments_dict["content"] = content_formatter(attachment)
structured_llm.append(attachments_dict)
if if_format_send_to:
structured_llm.append({"type": "send_to", "content": post.send_to})
if if_format_message:
structured_llm.append({"type": "message", "content": post.message})
structured_llm_text = json.dumps({"response": structured_llm})
return structured_llm_text
def parse_llm_output(self, llm_output: str) -> List[Dict[str, str]]:
try:
structured_llm_output: Any = json.loads(llm_output)["response"]
assert isinstance(
structured_llm_output,
list,
), "LLM output should be a list object"
return structured_llm_output # type: ignore
except (JSONDecodeError, AssertionError) as e:
self.logger.error(
f"Failed to parse LLM output due to {str(e)}. LLM output:\n {llm_output}",
)
raise e
def parse_llm_output_stream(
self,
llm_output: Iterator[str],
) -> Iterator[Dict[str, str]]:
json_data_stream = io.StringIO("".join(itertools.chain(llm_output)))
parser = ijson.parse(json_data_stream)
element = {}
try:
for prefix, event, value in parser:
if prefix == "response.item" and event == "map_key" and value == "type":
element["type"] = None
elif prefix == "response.item.type" and event == "string":
element["type"] = value
elif prefix == "response.item" and event == "map_key" and value == "content":
element["content"] = None
elif prefix == "response.item.content" and event == "string":
element["content"] = value
if len(element) == 2 and None not in element.values():
yield element
element = {}
except ijson.JSONError as e:
self.logger.warning(
f"Failed to parse LLM output stream due to JSONError: {str(e)}",
)
|