|
|
|
r"""_summary_ |
|
-*- coding: utf-8 -*- |
|
|
|
Module : prompt.reader |
|
|
|
File Name : reader.py |
|
|
|
Description : Read prompt template |
|
|
|
Creation Date : 2024-07-16 |
|
|
|
Author : Frank Kang(frankkang@zju.edu.cn) |
|
""" |
|
import xml.etree.ElementTree as ET |
|
from xml.etree.ElementTree import Element |
|
|
|
|
|
class Query(object): |
|
def __init__(self, query_node: Element) -> None: |
|
super(Query, self).__init__() |
|
self.rank = int(query_node.get('rank')) |
|
self.title = query_node.find('title').text |
|
self.text = query_node.find('text').text |
|
|
|
@staticmethod |
|
def Get_Title(query_node: Element) -> str: |
|
return query_node.find('title').text |
|
|
|
|
|
class AssistantCreateQuery(Query): |
|
TITLE = 'System Message' |
|
|
|
def __init__(self, query_node: Element) -> None: |
|
super(AssistantCreateQuery, self).__init__(query_node) |
|
|
|
def __call__(self, name, *args, tools=[{"type": "code_interpreter"}], model="gpt-4-1106-preview", **kwds) -> dict: |
|
"""Get parameters used for client.beta.assistants.create |
|
|
|
Returns: |
|
dict: parameters used for client.beta.assistants.create |
|
""" |
|
return {'name': name, 'instructions': self.text.format(*args, **kwds), 'tools': tools, 'model': model} |
|
|
|
|
|
class MessageQuery(Query): |
|
TITLE = 'User Message' |
|
|
|
def __init__(self, query_node: Element) -> None: |
|
super(MessageQuery, self).__init__(query_node) |
|
|
|
def __call__(self, *args, **kwds) -> dict: |
|
"""Using like str.format |
|
|
|
Returns: |
|
dict: _description_ |
|
""" |
|
return {'role': 'user', 'content': self.text.format(*args, **kwds)} |
|
|
|
|
|
class Prompt(object): |
|
def __init__(self, path) -> None: |
|
"""Init Prompy by xml file |
|
|
|
Args: |
|
path (_type_): _description_ |
|
""" |
|
super(Prompt, self).__init__() |
|
self.path = path |
|
tree = ET.parse(path) |
|
body = tree.getroot() |
|
self.queries = {} |
|
for query in body.findall('query'): |
|
self.__read_query__(query) |
|
|
|
def __read_query__(self, query_node: Element): |
|
title = Query.Get_Title(query_node) |
|
query: Query |
|
if title == AssistantCreateQuery.TITLE: |
|
query = AssistantCreateQuery(query_node) |
|
elif title == MessageQuery.TITLE: |
|
query = MessageQuery(query_node) |
|
else: |
|
raise TypeError('Title not supported!') |
|
|
|
if query.rank not in self.queries: |
|
self.queries[query.rank] = [query] |
|
else: |
|
self.queries[query.rank].append(query) |
|
|
|
|
|
|