Spaces:
Runtime error
Runtime error
from abc import abstractmethod | |
import uuid | |
from text2vec import semantic_search | |
from utils import ( | |
get_relevant_history, | |
load_knowledge_base_qa, | |
load_knowledge_base_UnstructuredFile, | |
get_embedding, | |
extract, | |
) | |
import json | |
from typing import Dict, List | |
import os | |
from googleapiclient.discovery import build | |
import requests | |
from selenium import webdriver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
from bs4 import BeautifulSoup | |
import base64 | |
import re | |
from datetime import datetime, timedelta | |
from typing import Tuple, List, Any, Dict | |
from email.mime.text import MIMEText | |
from email.mime.multipart import MIMEMultipart | |
from google.auth.transport.requests import Request | |
from google.oauth2.credentials import Credentials | |
from google_auth_oauthlib.flow import InstalledAppFlow | |
from googleapiclient.discovery import build | |
from googleapiclient.errors import HttpError | |
from tqdm import tqdm | |
class ToolComponent: | |
def __init__(self): | |
pass | |
def func(self): | |
pass | |
class KnowledgeBaseComponent(ToolComponent): | |
""" | |
Inject knowledge base | |
top_k : Top_k with the highest matching degree | |
type : "QA" or others | |
knowledge_base(json_path) : knowledge_base_path | |
""" | |
def __init__(self, top_k, type, knowledge_base): | |
super().__init__() | |
self.top_k = top_k | |
self.type = type | |
self.knowledge_base = knowledge_base | |
if self.type == "QA": | |
( | |
self.kb_embeddings, | |
self.kb_questions, | |
self.kb_answers, | |
self.kb_chunks, | |
) = load_knowledge_base_qa(self.knowledge_base) | |
else: | |
self.kb_embeddings, self.kb_chunks = load_knowledge_base_UnstructuredFile( | |
self.knowledge_base | |
) | |
def func(self, agent): | |
query = ( | |
agent.long_term_memory[-1]["content"] | |
if len(agent.long_term_memory) > 0 | |
else "" | |
) | |
knowledge = "" | |
query = extract(query, "query") | |
query_embedding = get_embedding(query) | |
hits = semantic_search(query_embedding, self.kb_embeddings, top_k=50) | |
hits = hits[0] | |
temp = [] | |
if self.type == "QA": | |
for hit in hits: | |
matching_idx = hit["corpus_id"] | |
if self.kb_chunks[matching_idx] in temp: | |
pass | |
else: | |
knowledge = ( | |
knowledge | |
+ f"question:{self.kb_questions[matching_idx]},answer:{self.kb_answers[matching_idx]}\n\n" | |
) | |
temp.append(self.kb_answers[matching_idx]) | |
if len(temp) == 1: | |
break | |
print(hits[0]["score"]) | |
score = hits[0]["score"] | |
if score < 0.5: | |
return {"prompt": "No matching knowledge base"} | |
else: | |
return {"prompt": "The relevant content is: " + knowledge + "\n"} | |
else: | |
for hit in hits: | |
matching_idx = hit["corpus_id"] | |
if self.kb_chunks[matching_idx] in temp: | |
pass | |
else: | |
knowledge = knowledge + f"{self.kb_answers[matching_idx]}\n\n" | |
temp.append(self.kb_answers[matching_idx]) | |
if len(temp) == self.top_k: | |
break | |
print(hits[0]["score"]) | |
score = hits[0]["score"] | |
if score < 0.5: | |
return {"prompt": "No matching knowledge base"} | |
else: | |
print(knowledge) | |
return {"prompt": "The relevant content is: " + knowledge + "\n"} | |
class StaticComponent(ToolComponent): | |
"Return static response" | |
def __init__(self, output): | |
super().__init__() | |
self.output = output | |
def func(self, agent): | |
outputdict = {"response": self.output} | |
return outputdict | |
class ExtractComponent(ToolComponent): | |
""" | |
Extract keywords based on the current scene and store them in the environment | |
extract_words(list) : Keywords to be extracted | |
system_prompt & last_prompt : Prompt to extract keywords | |
""" | |
def __init__( | |
self, | |
extract_words, | |
system_prompt, | |
last_prompt=None, | |
): | |
super().__init__() | |
self.extract_words = extract_words | |
self.system_prompt = system_prompt | |
self.default_prompt = ( | |
"Please strictly adhere to the following format for outputting:\n" | |
) | |
for extract_word in extract_words: | |
self.default_prompt += ( | |
f"<{extract_word}> the content you need to extract </{extract_word}>" | |
) | |
self.last_prompt = last_prompt if last_prompt else self.default_prompt | |
def func(self, agent): | |
response = agent.LLM.get_response( | |
agent.long_term_memory, | |
self.system_prompt, | |
self.last_prompt, | |
stream=False, | |
) | |
for extract_word in self.extract_words: | |
key = extract(response, extract_word) | |
key = key if key else response | |
agent.environment.shared_memory[extract_word] = key | |
return {} | |
"""Search sources: chatgpt/search engines/specific search sources/can even be multimodal (if it comes to clothing)""" | |
class WebSearchComponent(ToolComponent): | |
"""search engines""" | |
__ENGINE_NAME__: List = ["google", "bing"] | |
def __init__(self, engine_name: str, api: Dict): | |
""" | |
:param engine_name: The name of the search engine used | |
:param api: Pass in a dictionary, such as {"bing":"key1", "google":"key2", ...}, of course each value can also be a list, or more complicated | |
""" | |
super(WebSearchComponent, self).__init__() | |
"""Determine whether the key and engine_name of the api are legal""" | |
assert engine_name in WebSearchComponent.__ENGINE_NAME__ | |
for api_name in api: | |
assert api_name in WebSearchComponent.__ENGINE_NAME__ | |
self.api = api | |
self.engine_name = engine_name | |
self.search: Dict = {"bing": self._bing_search, "google": self._google_search} | |
def _bing_search(self, query: str, **kwargs): | |
"""Initialize search hyperparameters""" | |
subscription_key = self.api["bing"] | |
search_url = "https://api.bing.microsoft.com/v7.0/search" | |
headers = {"Ocp-Apim-Subscription-Key": subscription_key} | |
params = { | |
"q": query, | |
"textDecorations": True, | |
"textFormat": "HTML", | |
"count": 10, | |
} | |
"""start searching""" | |
response = requests.get(search_url, headers=headers, params=params) | |
response.raise_for_status() | |
results = response.json()["webPages"]["value"] | |
"""execute""" | |
metadata_results = [] | |
for result in results: | |
metadata_result = { | |
"snippet": result["snippet"], | |
"title": result["name"], | |
"link": result["url"], | |
} | |
metadata_results.append(metadata_result) | |
return {"meta data": metadata_results} | |
def _google_search(self, query: str, **kwargs): | |
"""Initialize search hyperparameters""" | |
api_key = self.api[self.engine_name]["api_key"] | |
cse_id = self.api[self.engine_name]["cse_id"] | |
service = build("customsearch", "v1", developerKey=api_key) | |
"""start searching""" | |
results = ( | |
service.cse().list(q=query, cx=cse_id, num=10, **kwargs).execute()["items"] | |
) | |
"""execute""" | |
metadata_results = [] | |
for result in results: | |
metadata_result = { | |
"snippet": result["snippet"], | |
"title": result["title"], | |
"link": result["link"], | |
} | |
metadata_results.append(metadata_result) | |
return {"meta data": metadata_results} | |
def func(self, agent, **kwargs) -> Dict: | |
query = ( | |
agent.long_term_memory[-1]["content"] | |
if len(agent.long_term_memory) > 0 | |
else " " | |
) | |
response = agent.LLM.get_response( | |
None, | |
system_prompt=f"Please analyze the provided conversation and identify keywords that can be used for a search engine query. Format the output as <keywords>extracted keywords</keywords>:\nConversation:\n{query}", | |
stream=False, | |
) | |
response = extract(response, "keywords") | |
query = response if response else query | |
search_results = self.search[self.engine_name](query=query, **kwargs) | |
information = "" | |
for i in search_results["meta data"][:5]: | |
information += i["snippet"] | |
return { | |
"prompt": "You can refer to the following information to reply:\n" | |
+ information | |
} | |
def convert_search_engine_to(self, engine_name): | |
assert engine_name in WebSearchComponent.__ENGINE_NAME__ | |
self.engine_name = engine_name | |
class WebCrawlComponent(ToolComponent): | |
"""Open a single web page for crawling""" | |
def __init__(self): | |
super(WebCrawlComponent, self).__init__() | |
def func(self, agent_dict) -> Dict: | |
url = agent_dict["url"] | |
print(f"crawling {url} ......") | |
content = "" | |
"""Crawling content from url may need to be carried out according to different websites, such as wiki, baidu, zhihu, etc.""" | |
driver = webdriver.Chrome() | |
try: | |
"""open url""" | |
driver.get(url) | |
"""wait 20 second""" | |
wait = WebDriverWait(driver, 20) | |
wait.until(EC.presence_of_element_located((By.TAG_NAME, "body"))) | |
"""crawl code""" | |
page_source = driver.page_source | |
"""parse""" | |
soup = BeautifulSoup(page_source, "html.parser") | |
"""concatenate""" | |
for paragraph in soup.find_all("p"): | |
content = f"{content}\n{paragraph.get_text()}" | |
except Exception as e: | |
print("Error:", e) | |
finally: | |
"""quit""" | |
driver.quit() | |
return {"content": content.strip()} | |
class MailComponent(ToolComponent): | |
__VALID_ACTION__ = ["read", "send"] | |
def __init__( | |
self, cfg_file: str, default_action: str = "read", name: str = "e-mail" | |
): | |
"""'../config/google_mail.json'""" | |
super(MailComponent, self).__init__(name) | |
self.name = name | |
assert ( | |
default_action.lower() in self.__VALID_ACTION__ | |
), f"Action `{default_action}` is not allowed! The valid action is in `{self.__VALID_ACTION__}`" | |
self.action = default_action.lower() | |
self.credential = self._login(cfg_file) | |
def _login(self, cfg_file: str): | |
SCOPES = [ | |
"https://www.googleapis.com/auth/gmail.readonly", | |
"https://www.googleapis.com/auth/gmail.send", | |
] | |
creds = None | |
if os.path.exists("token.json"): | |
print("Login Successfully!") | |
creds = Credentials.from_authorized_user_file("token.json", SCOPES) | |
if not creds or not creds.valid: | |
print("Please authorize in an open browser.") | |
if creds and creds.expired and creds.refresh_token: | |
creds.refresh(Request()) | |
else: | |
flow = InstalledAppFlow.from_client_secrets_file(cfg_file, SCOPES) | |
creds = flow.run_local_server(port=0) | |
# Save the credentials for the next run | |
with open("token.json", "w") as token: | |
token.write(creds.to_json()) | |
return creds | |
def _read(self, mail_dict: dict): | |
credential = self.credential | |
state = mail_dict["state"] if "state" in mail_dict else None | |
time_between = ( | |
mail_dict["time_between"] if "time_between" in mail_dict else None | |
) | |
sender_mail = mail_dict["sender_mail"] if "sender_mail" in mail_dict else None | |
only_both = mail_dict["only_both"] if "only_both" in mail_dict else False | |
order_by_time = ( | |
mail_dict["order_by_time"] if "order_by_time" in mail_dict else "descend" | |
) | |
include_word = ( | |
mail_dict["include_word"] if "include_word" in mail_dict else None | |
) | |
exclude_word = ( | |
mail_dict["exclude_word"] if "exclude_word" in mail_dict else None | |
) | |
MAX_SEARCH_CNT = ( | |
mail_dict["MAX_SEARCH_CNT"] if "MAX_SEARCH_CNT" in mail_dict else 50 | |
) | |
number = mail_dict["number"] if "number" in mail_dict else 10 | |
if state is None: | |
state = "all" | |
if time_between is not None: | |
assert isinstance(time_between, tuple) | |
assert len(time_between) == 2 | |
assert state in ["all", "unread", "read", "sent"] | |
if only_both: | |
assert sender_mail is not None | |
if sender_mail is not None: | |
assert isinstance(sender_mail, str) | |
assert credential | |
assert order_by_time in ["descend", "ascend"] | |
def generate_query(): | |
query = "" | |
if state in ["unread", "read"]: | |
query = f"is:{state}" | |
if state in ["sent"]: | |
query = f"in:{state}" | |
if only_both: | |
query = f"{query} from:{sender_mail} OR to:{sender_mail}" | |
if sender_mail is not None and not only_both: | |
query = f"{query} from:({sender_mail})" | |
if include_word is not None: | |
query = f"{query} {include_word}" | |
if exclude_word is not None: | |
query = f"{query} -{exclude_word}" | |
if time_between is not None: | |
TIME_FORMAT = "%Y/%m/%d" | |
t1, t2 = time_between | |
if t1 == "now": | |
t1 = datetime.now().strftime(TIME_FORMAT) | |
if t2 == "now": | |
t2 = datetime.now().strftime(TIME_FORMAT) | |
if isinstance(t1, str) and isinstance(t2, str): | |
t1 = datetime.strptime(t1, TIME_FORMAT) | |
t2 = datetime.strptime(t2, TIME_FORMAT) | |
elif isinstance(t1, str) and isinstance(t2, int): | |
t1 = datetime.strptime(t1, TIME_FORMAT) | |
t2 = t1 + timedelta(days=t2) | |
elif isinstance(t1, int) and isinstance(t2, str): | |
t2 = datetime.strptime(t2, TIME_FORMAT) | |
t1 = t2 + timedelta(days=t1) | |
else: | |
assert False, "invalid time" | |
if t1 > t2: | |
t1, t2 = t2, t1 | |
query = f"{query} after:{t1.strftime(TIME_FORMAT)} before:{t2.strftime(TIME_FORMAT)}" | |
return query.strip() | |
def sort_by_time(data: List[Dict]): | |
if order_by_time == "descend": | |
reverse = True | |
else: | |
reverse = False | |
sorted_data = sorted( | |
data, | |
key=lambda x: datetime.strptime(x["time"], "%Y-%m-%d %H:%M:%S"), | |
reverse=reverse, | |
) | |
return sorted_data | |
try: | |
service = build("gmail", "v1", credentials=credential) | |
results = ( | |
service.users() | |
.messages() | |
.list(userId="me", labelIds=["INBOX"], q=generate_query()) | |
.execute() | |
) | |
messages = results.get("messages", []) | |
email_data = list() | |
if not messages: | |
print("No eligible emails.") | |
return None | |
else: | |
pbar = tqdm(total=min(MAX_SEARCH_CNT, len(messages))) | |
for cnt, message in enumerate(messages): | |
pbar.update(1) | |
if cnt >= MAX_SEARCH_CNT: | |
break | |
msg = ( | |
service.users() | |
.messages() | |
.get( | |
userId="me", | |
id=message["id"], | |
format="full", | |
metadataHeaders=None, | |
) | |
.execute() | |
) | |
subject = "" | |
for header in msg["payload"]["headers"]: | |
if header["name"] == "Subject": | |
subject = header["value"] | |
break | |
sender = "" | |
for header in msg["payload"]["headers"]: | |
if header["name"] == "From": | |
sender = re.findall( | |
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", header["value"] | |
)[0] | |
break | |
body = "" | |
if "parts" in msg["payload"]: | |
for part in msg["payload"]["parts"]: | |
if part["mimeType"] == "text/plain": | |
data = part["body"]["data"] | |
body = base64.urlsafe_b64decode(data).decode("utf-8") | |
break | |
email_info = { | |
"sender": sender, | |
"time": datetime.fromtimestamp( | |
int(msg["internalDate"]) / 1000 | |
).strftime("%Y-%m-%d %H:%M:%S"), | |
"subject": subject, | |
"body": body, | |
} | |
email_data.append(email_info) | |
pbar.close() | |
email_data = sort_by_time(email_data)[0:number] | |
return {"results": email_data} | |
except Exception as e: | |
print(e) | |
return None | |
def _send(self, mail_dict: dict): | |
recipient_mail = mail_dict["recipient_mail"] | |
subject = mail_dict["subject"] | |
body = mail_dict["body"] | |
credential = self.credential | |
service = build("gmail", "v1", credentials=credential) | |
message = MIMEMultipart() | |
message["to"] = recipient_mail | |
message["subject"] = subject | |
message.attach(MIMEText(body, "plain")) | |
raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode("utf-8") | |
try: | |
message = ( | |
service.users() | |
.messages() | |
.send(userId="me", body={"raw": raw_message}) | |
.execute() | |
) | |
return {"state": True} | |
except HttpError as error: | |
print(error) | |
return {"state": False} | |
def func(self, mail_dict: dict): | |
if "action" in mail_dict: | |
assert mail_dict["action"].lower() in self.__VALID_ACTION__ | |
self.action = mail_dict["action"] | |
functions = {"read": self._read, "send": self._send} | |
return functions[self.action](mail_dict) | |
def convert_action_to(self, action_name: str): | |
assert ( | |
action_name.lower() in self.__VALID_ACTION__ | |
), f"Action `{action_name}` is not allowed! The valid action is in `{self.__VALID_ACTION__}`" | |
self.action = action_name.lower() | |
class WeatherComponet(ToolComponent): | |
def __init__(self, api_key, name="weather", TIME_FORMAT="%Y-%m-%d"): | |
super(WeatherComponet, self).__init__(name) | |
self.name = name | |
self.TIME_FORMAT = TIME_FORMAT | |
self.api_key = api_key | |
def _parse(self, data): | |
dict_data: dict = {} | |
for item in data["data"]: | |
date = item["datetime"] | |
dict_data[date] = {} | |
if "weather" in item: | |
dict_data[date]["description"] = item["weather"]["description"] | |
mapping = { | |
"temp": "temperature", | |
"max_temp": "max_temperature", | |
"min_temp": "min_temperature", | |
"precip": "accumulated_precipitation", | |
} | |
for key in ["temp", "max_temp", "min_temp", "precip"]: | |
if key in item: | |
dict_data[date][mapping[key]] = item[key] | |
return dict_data | |
def _query(self, city_name, country_code, start_date, end_date): | |
"""https://www.weatherbit.io/api/historical-weather-daily""" | |
# print(datetime.strftime(start_date, self.TIME_FORMAT), datetime.strftime(datetime.now(), self.TIME_FORMAT), end_date, datetime.strftime(datetime.now()+timedelta(days=1), self.TIME_FORMAT)) | |
if start_date == datetime.strftime( | |
datetime.now(), self.TIME_FORMAT | |
) and end_date == datetime.strftime( | |
datetime.now() + timedelta(days=1), self.TIME_FORMAT | |
): | |
"""today""" | |
url = f"https://api.weatherbit.io/v2.0/current?city={city_name}&country={country_code}&key={self.api_key}" | |
else: | |
url = f"https://api.weatherbit.io/v2.0/history/daily?&city={city_name}&country={country_code}&start_date={start_date}&end_date={end_date}&key={self.api_key}" | |
response = requests.get(url) | |
data = response.json() | |
return self._parse(data) | |
def func(self, weather_dict: Dict) -> Dict: | |
TIME_FORMAT = self.TIME_FORMAT | |
# Beijing, Shanghai | |
city_name = weather_dict["city_name"] | |
# CN, US | |
country_code = weather_dict["country_code"] | |
# 2020-02-02 | |
start_date = datetime.strftime( | |
datetime.strptime(weather_dict["start_date"], self.TIME_FORMAT), | |
self.TIME_FORMAT, | |
) | |
end_date = weather_dict["end_date"] if "end_date" in weather_dict else None | |
if end_date is None: | |
end_date = datetime.strftime( | |
datetime.strptime(start_date, TIME_FORMAT) + timedelta(days=-1), | |
TIME_FORMAT, | |
) | |
else: | |
end_date = datetime.strftime( | |
datetime.strptime(weather_dict["end_date"], self.TIME_FORMAT), | |
self.TIME_FORMAT, | |
) | |
if datetime.strptime(start_date, TIME_FORMAT) > datetime.strptime( | |
end_date, TIME_FORMAT | |
): | |
start_date, end_date = end_date, start_date | |
assert start_date != end_date | |
return self._query(city_name, country_code, start_date, end_date) | |
class TranslateComponent(ToolComponent): | |
__SUPPORT_LANGUAGE__ = [ | |
"af", | |
"am", | |
"ar", | |
"as", | |
"az", | |
"ba", | |
"bg", | |
"bn", | |
"bo", | |
"bs", | |
"ca", | |
"cs", | |
"cy", | |
"da", | |
"de", | |
"dsb", | |
"dv", | |
"el", | |
"en", | |
"es", | |
"et", | |
"eu", | |
"fa", | |
"fi", | |
"fil", | |
"fj", | |
"fo", | |
"fr", | |
"fr-CA", | |
"ga", | |
"gl", | |
"gom", | |
"gu", | |
"ha", | |
"he", | |
"hi", | |
"hr", | |
"hsb", | |
"ht", | |
"hu", | |
"hy", | |
"id", | |
"ig", | |
"ikt", | |
"is", | |
"it", | |
"iu", | |
"iu-Latn", | |
"ja", | |
"ka", | |
"kk", | |
"km", | |
"kmr", | |
"kn", | |
"ko", | |
"ku", | |
"ky", | |
"ln", | |
"lo", | |
"lt", | |
"lug", | |
"lv", | |
"lzh", | |
"mai", | |
"mg", | |
"mi", | |
"mk", | |
"ml", | |
"mn-Cyrl", | |
"mn-Mong", | |
"mr", | |
"ms", | |
"mt", | |
"mww", | |
"my", | |
"nb", | |
"ne", | |
"nl", | |
"nso", | |
"nya", | |
"or", | |
"otq", | |
"pa", | |
"pl", | |
"prs", | |
"ps", | |
"pt", | |
"pt-PT", | |
"ro", | |
"ru", | |
"run", | |
"rw", | |
"sd", | |
"si", | |
"sk", | |
"sl", | |
"sm", | |
"sn", | |
"so", | |
"sq", | |
"sr-Cyrl", | |
"sr-Latn", | |
"st", | |
"sv", | |
"sw", | |
"ta", | |
"te", | |
"th", | |
"ti", | |
"tk", | |
"tlh-Latn", | |
"tlh-Piqd", | |
"tn", | |
"to", | |
"tr", | |
"tt", | |
"ty", | |
"ug", | |
"uk", | |
"ur", | |
"uz", | |
"vi", | |
"xh", | |
"yo", | |
"yua", | |
"yue", | |
"zh-Hans", | |
"zh-Hant", | |
"zu", | |
] | |
def __init__( | |
self, api_key, location, default_target_language="zh-cn", name="translate" | |
): | |
super(TranslateComponent, self).__init__(name) | |
self.name = name | |
self.api_key = api_key | |
self.location = location | |
self.default_target_language = default_target_language | |
def func(self, translate_dict: Dict) -> Dict: | |
content = translate_dict["content"] | |
target_language = self.default_target_language | |
if "target_language" in translate_dict: | |
target_language = translate_dict["target_language"] | |
assert ( | |
target_language in self.__SUPPORT_LANGUAGE__ | |
), f"language `{target_language}` is not supported." | |
endpoint = "https://api.cognitive.microsofttranslator.com" | |
path = "/translate" | |
constructed_url = endpoint + path | |
params = {"api-version": "3.0", "to": target_language} | |
headers = { | |
"Ocp-Apim-Subscription-Key": self.api_key, | |
"Ocp-Apim-Subscription-Region": self.location, | |
"Content-type": "application/json", | |
"X-ClientTraceId": str(uuid.uuid4()), | |
} | |
body = [{"text": content}] | |
request = requests.post( | |
constructed_url, params=params, headers=headers, json=body | |
) | |
response = request.json() | |
response = json.dumps( | |
response, | |
sort_keys=True, | |
ensure_ascii=False, | |
indent=4, | |
separators=(",", ": "), | |
) | |
response = eval(response) | |
return {"result": response[0]["translations"][0]["text"]} | |
class APIComponent(ToolComponent): | |
def __init__(self): | |
super(APIComponent, self).__init__() | |
def func(self, agent) -> Dict: | |
pass | |
class FunctionComponent(ToolComponent): | |
def __init__( | |
self, | |
functions, | |
function_call="auto", | |
response_type="response", | |
your_function=None, | |
): | |
super().__init__() | |
self.functions = functions | |
self.function_call = function_call | |
self.parameters = {} | |
self.available_functions = {} | |
self.response_type = response_type | |
if your_function: | |
function_name = your_function["name"] | |
function_content = your_function["content"] | |
exec(function_content) | |
self.available_functions[function_name] = eval(function_name) | |
for function in self.functions: | |
self.parameters[function["name"]] = list( | |
function["parameters"]["properties"].keys() | |
) | |
self.available_functions[function["name"]] = eval(function["name"]) | |
def func(self, agent): | |
messages = agent.long_term_memory | |
outputdict = {} | |
query = agent.long_term_memory[-1].content if len(agent.long_term_memory) > 0 else " " | |
relevant_history = get_relevant_history( | |
query, | |
agent.long_term_memory[:-1], | |
agent.chat_embeddings[:-1], | |
) | |
response = agent.LLM.get_response( | |
messages, | |
None, | |
functions=self.functions, | |
stream=False, | |
function_call=self.function_call, | |
relevant_history=relevant_history, | |
) | |
response_message = response | |
if response_message.get("function_call"): | |
function_name = response_message["function_call"]["name"] | |
fuction_to_call = self.available_functions[function_name] | |
function_args = json.loads(response_message["function_call"]["arguments"]) | |
input_args = {} | |
for args_name in self.parameters[function_name]: | |
input_args[args_name] = function_args.get(args_name) | |
function_response = fuction_to_call(**input_args) | |
if self.response_type == "response": | |
outputdict["response"] = function_response | |
elif self.response_type == "prompt": | |
outputdict["prompt"] = function_response | |
return outputdict | |
class CodeComponent(ToolComponent): | |
def __init__(self, file_name, keyword) -> None: | |
super().__init__() | |
self.file_name = file_name | |
self.keyword = keyword | |
self.system_prompt = ( | |
"you need to extract the modified code as completely as possible." | |
) | |
self.last_prompt = ( | |
f"Please strictly adhere to the following format for outputting: \n" | |
) | |
self.last_prompt += ( | |
f"<{self.keyword}> the content you need to extract </{self.keyword}>" | |
) | |
def func(self, agent): | |
response = agent.LLM.get_response( | |
agent.long_term_memory, | |
self.system_prompt, | |
self.last_prompt, | |
stream=False, | |
) | |
code = extract(response, self.keyword) | |
code = code if code else response | |
os.makedirs("output_code", exist_ok=True) | |
file_name = "output_code/" + self.file_name | |
codes = code.split("\n") | |
if codes[0] == "```python": | |
codes.remove(codes[0]) | |
if codes[-1] == "```": | |
codes.remove(codes[-1]) | |
code = "\n".join(codes) | |
with open(file_name, "w", encoding="utf-8") as f: | |
f.write(code) | |
return {} | |