Shroominic commited on
Commit
e0cbbfe
·
1 Parent(s): c8866ac

🧹 black cleanup

Browse files
codeinterpreterapi/callbacks.py CHANGED
@@ -11,7 +11,7 @@ class CodeCallbackHandler(AsyncIteratorCallbackHandler):
11
  def __init__(self, session: "CodeInterpreterSession"):
12
  self.session = session
13
  super().__init__()
14
-
15
  async def on_agent_action(
16
  self,
17
  action: AgentAction,
@@ -26,4 +26,4 @@ class CodeCallbackHandler(AsyncIteratorCallbackHandler):
26
  f"⚙️ Running code: ```python\n{action.tool_input['code']}\n```" # type: ignore
27
  )
28
  else:
29
- raise ValueError(f"Unknown action: {action.tool}")
 
11
  def __init__(self, session: "CodeInterpreterSession"):
12
  self.session = session
13
  super().__init__()
14
+
15
  async def on_agent_action(
16
  self,
17
  action: AgentAction,
 
26
  f"⚙️ Running code: ```python\n{action.tool_input['code']}\n```" # type: ignore
27
  )
28
  else:
29
+ raise ValueError(f"Unknown action: {action.tool}")
codeinterpreterapi/chains/functions_agent.py CHANGED
@@ -105,9 +105,11 @@ def _format_intermediate_steps(
105
  messages.extend(_convert_agent_action_to_messages(agent_action, observation))
106
 
107
  return messages
108
-
109
 
110
- async def _parse_ai_message(message: BaseMessage, llm: BaseLanguageModel) -> Union[AgentAction, AgentFinish]:
 
 
 
111
  """Parse an AI message."""
112
  if not isinstance(message, AIMessage):
113
  raise TypeError(f"Expected an AI message got {type(message)}")
@@ -117,7 +119,7 @@ async def _parse_ai_message(message: BaseMessage, llm: BaseLanguageModel) -> Uni
117
  if function_call:
118
  function_call = message.additional_kwargs["function_call"]
119
  function_name = function_call["name"]
120
- try:
121
  _tool_input = json.loads(function_call["arguments"])
122
  except JSONDecodeError:
123
  if function_name == "python":
@@ -199,8 +201,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
199
  def functions(self) -> List[dict]:
200
  return [dict(format_tool_to_openai_function(t)) for t in self.tools]
201
 
202
- def plan(self): raise NotImplementedError
203
-
 
204
  async def aplan(
205
  self,
206
  intermediate_steps: List[Tuple[AgentAction, str]],
@@ -229,7 +232,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
229
  )
230
  agent_decision = await _parse_ai_message(predicted_message, self.llm)
231
  return agent_decision
232
-
233
  @classmethod
234
  def create_prompt(
235
  cls,
 
105
  messages.extend(_convert_agent_action_to_messages(agent_action, observation))
106
 
107
  return messages
 
108
 
109
+
110
+ async def _parse_ai_message(
111
+ message: BaseMessage, llm: BaseLanguageModel
112
+ ) -> Union[AgentAction, AgentFinish]:
113
  """Parse an AI message."""
114
  if not isinstance(message, AIMessage):
115
  raise TypeError(f"Expected an AI message got {type(message)}")
 
119
  if function_call:
120
  function_call = message.additional_kwargs["function_call"]
121
  function_name = function_call["name"]
122
+ try:
123
  _tool_input = json.loads(function_call["arguments"])
124
  except JSONDecodeError:
125
  if function_name == "python":
 
201
  def functions(self) -> List[dict]:
202
  return [dict(format_tool_to_openai_function(t)) for t in self.tools]
203
 
204
+ def plan(self):
205
+ raise NotImplementedError
206
+
207
  async def aplan(
208
  self,
209
  intermediate_steps: List[Tuple[AgentAction, str]],
 
232
  )
233
  agent_decision = await _parse_ai_message(predicted_message, self.llm)
234
  return agent_decision
235
+
236
  @classmethod
237
  def create_prompt(
238
  cls,
codeinterpreterapi/chains/modifications_check.py CHANGED
@@ -18,28 +18,27 @@ from langchain.schema import (
18
  prompt = ChatPromptTemplate(
19
  input_variables=["code"],
20
  messages=[
21
- SystemMessage(content=
22
- "The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
23
  "With changes it means creating new files or modifying exsisting ones.\n"
24
  "Answer with a function call `determine_modifications` and list them inside.\n"
25
  "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
26
  ),
27
- HumanMessagePromptTemplate.from_template("{code}")
28
- ]
29
  )
30
 
31
  functions = [
32
  {
33
  "name": "determine_modifications",
34
- "description":
35
- "Based on code of the user determine if the code makes any changes to the file system. \n"
36
- "With changes it means creating new files or modifying exsisting ones.\n",
37
  "parameters": {
38
  "type": "object",
39
  "properties": {
40
  "modifications": {
41
  "type": "array",
42
- "items": { "type": "string" },
43
  "description": "The filenames that are modified by the code.",
44
  },
45
  },
@@ -50,30 +49,30 @@ functions = [
50
 
51
 
52
  async def get_file_modifications(
53
- code: str,
54
  llm: BaseLanguageModel,
55
  retry: int = 2,
56
  ) -> List[str] | None:
57
- if retry < 1:
58
  return None
59
  messages = prompt.format_prompt(code=code).to_messages()
60
  message = await llm.apredict_messages(messages, functions=functions)
61
-
62
  if not isinstance(message, AIMessage):
63
  raise OutputParserException("Expected an AIMessage")
64
-
65
  function_call = message.additional_kwargs.get("function_call", None)
66
-
67
  if function_call is None:
68
- return await get_file_modifications(code, llm, retry=retry-1)
69
- else:
70
  function_call = json.loads(function_call["arguments"])
71
  return function_call["modifications"]
72
-
73
 
74
  async def test():
75
  llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
76
-
77
  code = """
78
  import matplotlib.pyplot as plt
79
 
@@ -89,15 +88,16 @@ async def test():
89
  """
90
 
91
  code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
92
-
93
  modifications = await get_file_modifications(code2, llm)
94
-
95
  print(modifications)
96
 
97
 
98
  if __name__ == "__main__":
99
  import asyncio
100
  from dotenv import load_dotenv
 
101
  load_dotenv()
102
-
103
  asyncio.run(test())
 
18
  prompt = ChatPromptTemplate(
19
  input_variables=["code"],
20
  messages=[
21
+ SystemMessage(
22
+ content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
23
  "With changes it means creating new files or modifying exsisting ones.\n"
24
  "Answer with a function call `determine_modifications` and list them inside.\n"
25
  "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
26
  ),
27
+ HumanMessagePromptTemplate.from_template("{code}"),
28
+ ],
29
  )
30
 
31
  functions = [
32
  {
33
  "name": "determine_modifications",
34
+ "description": "Based on code of the user determine if the code makes any changes to the file system. \n"
35
+ "With changes it means creating new files or modifying exsisting ones.\n",
 
36
  "parameters": {
37
  "type": "object",
38
  "properties": {
39
  "modifications": {
40
  "type": "array",
41
+ "items": {"type": "string"},
42
  "description": "The filenames that are modified by the code.",
43
  },
44
  },
 
49
 
50
 
51
  async def get_file_modifications(
52
+ code: str,
53
  llm: BaseLanguageModel,
54
  retry: int = 2,
55
  ) -> List[str] | None:
56
+ if retry < 1:
57
  return None
58
  messages = prompt.format_prompt(code=code).to_messages()
59
  message = await llm.apredict_messages(messages, functions=functions)
60
+
61
  if not isinstance(message, AIMessage):
62
  raise OutputParserException("Expected an AIMessage")
63
+
64
  function_call = message.additional_kwargs.get("function_call", None)
65
+
66
  if function_call is None:
67
+ return await get_file_modifications(code, llm, retry=retry - 1)
68
+ else:
69
  function_call = json.loads(function_call["arguments"])
70
  return function_call["modifications"]
71
+
72
 
73
  async def test():
74
  llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
75
+
76
  code = """
77
  import matplotlib.pyplot as plt
78
 
 
88
  """
89
 
90
  code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
91
+
92
  modifications = await get_file_modifications(code2, llm)
93
+
94
  print(modifications)
95
 
96
 
97
  if __name__ == "__main__":
98
  import asyncio
99
  from dotenv import load_dotenv
100
+
101
  load_dotenv()
102
+
103
  asyncio.run(test())
codeinterpreterapi/chains/remove_download_link.py CHANGED
@@ -8,51 +8,54 @@ from langchain.schema import (
8
  AIMessage,
9
  OutputParserException,
10
  SystemMessage,
11
- HumanMessage
12
  )
13
 
14
 
15
  prompt = ChatPromptTemplate(
16
  input_variables=["input_response"],
17
  messages=[
18
- SystemMessage(content=
19
- "The user will send you a response and you need to remove the download link from it.\n"
20
  "Reformat the remaining message so no whitespace or half sentences are still there.\n"
21
  "If the response does not contain a download link, return the response as is.\n"
22
  ),
23
- HumanMessage(content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."),
 
 
24
  AIMessage(content="The dataset has been successfully converted to CSV format."),
25
- HumanMessagePromptTemplate.from_template("{input_response}")
26
- ]
27
  )
28
 
29
 
30
  async def remove_download_link(
31
- input_response: str,
32
  llm: BaseLanguageModel,
33
  ) -> str:
34
  messages = prompt.format_prompt(input_response=input_response).to_messages()
35
  message = await llm.apredict_messages(messages)
36
-
37
  if not isinstance(message, AIMessage):
38
  raise OutputParserException("Expected an AIMessage")
39
-
40
  return message.content
41
-
42
 
43
  async def test():
44
  llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
45
-
46
  example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."
47
-
48
  modifications = await remove_download_link(example, llm)
49
-
50
  print(modifications)
51
 
52
 
53
  if __name__ == "__main__":
54
  import asyncio
55
  import dotenv
 
56
  dotenv.load_dotenv()
57
-
58
  asyncio.run(test())
 
8
  AIMessage,
9
  OutputParserException,
10
  SystemMessage,
11
+ HumanMessage,
12
  )
13
 
14
 
15
  prompt = ChatPromptTemplate(
16
  input_variables=["input_response"],
17
  messages=[
18
+ SystemMessage(
19
+ content="The user will send you a response and you need to remove the download link from it.\n"
20
  "Reformat the remaining message so no whitespace or half sentences are still there.\n"
21
  "If the response does not contain a download link, return the response as is.\n"
22
  ),
23
+ HumanMessage(
24
+ content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."
25
+ ),
26
  AIMessage(content="The dataset has been successfully converted to CSV format."),
27
+ HumanMessagePromptTemplate.from_template("{input_response}"),
28
+ ],
29
  )
30
 
31
 
32
  async def remove_download_link(
33
+ input_response: str,
34
  llm: BaseLanguageModel,
35
  ) -> str:
36
  messages = prompt.format_prompt(input_response=input_response).to_messages()
37
  message = await llm.apredict_messages(messages)
38
+
39
  if not isinstance(message, AIMessage):
40
  raise OutputParserException("Expected an AIMessage")
41
+
42
  return message.content
43
+
44
 
45
  async def test():
46
  llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
47
+
48
  example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."
49
+
50
  modifications = await remove_download_link(example, llm)
51
+
52
  print(modifications)
53
 
54
 
55
  if __name__ == "__main__":
56
  import asyncio
57
  import dotenv
58
+
59
  dotenv.load_dotenv()
60
+
61
  asyncio.run(test())
codeinterpreterapi/config.py CHANGED
@@ -9,10 +9,11 @@ class CodeInterpreterAPISettings(BaseSettings):
9
  """
10
  CodeInterpreter API Config
11
  """
 
12
  VERBOSE: bool = False
13
-
14
  CODEBOX_API_KEY: str | None = None
15
  OPENAI_API_KEY: str | None = None
16
-
17
 
18
  settings = CodeInterpreterAPISettings()
 
9
  """
10
  CodeInterpreter API Config
11
  """
12
+
13
  VERBOSE: bool = False
14
+
15
  CODEBOX_API_KEY: str | None = None
16
  OPENAI_API_KEY: str | None = None
17
+
18
 
19
  settings = CodeInterpreterAPISettings()
codeinterpreterapi/prompts/__init__.py CHANGED
@@ -1 +1 @@
1
- from .system_message import system_message as code_interpreter_system_message
 
1
+ from .system_message import system_message as code_interpreter_system_message
codeinterpreterapi/prompts/system_message.py CHANGED
@@ -1,7 +1,8 @@
1
  from langchain.schema import SystemMessage
2
 
3
 
4
- system_message = SystemMessage(content="""
 
5
  Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics.
6
  As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
7
  Assistant is constantly learning and improving, and its capabilities are constantly evolving.
@@ -13,4 +14,5 @@ The human also maybe thinks this code interpreter is for writing code but it is
13
  Tell the human if they use the code interpreter incorrectly.
14
  Already installed packages are: (numpy pandas matplotlib seaborn scikit-learn yfinance scipy statsmodels sympy bokeh plotly dash networkx).
15
  If you encounter an error, try again and fix the code.
16
- """)
 
 
1
  from langchain.schema import SystemMessage
2
 
3
 
4
+ system_message = SystemMessage(
5
+ content="""
6
  Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics.
7
  As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
8
  Assistant is constantly learning and improving, and its capabilities are constantly evolving.
 
14
  Tell the human if they use the code interpreter incorrectly.
15
  Already installed packages are: (numpy pandas matplotlib seaborn scikit-learn yfinance scipy statsmodels sympy bokeh plotly dash networkx).
16
  If you encounter an error, try again and fix the code.
17
+ """
18
+ )
codeinterpreterapi/schema/file.py CHANGED
@@ -5,53 +5,58 @@ from pydantic import BaseModel
5
  class File(BaseModel):
6
  name: str
7
  content: bytes
8
-
9
  @classmethod
10
  def from_path(cls, path: str):
11
  with open(path, "rb") as f:
12
  path = path.split("/")[-1]
13
  return cls(name=path, content=f.read())
14
-
15
- @classmethod
16
  async def afrom_path(cls, path: str):
17
- await asyncio.to_thread(cls.from_path, path)
18
-
19
  @classmethod
20
  def from_url(cls, url: str):
21
  import requests # type: ignore
 
22
  r = requests.get(url)
23
  return cls(name=url.split("/")[-1], content=r.content)
24
-
25
  @classmethod
26
  async def afrom_url(cls, url: str):
27
  import aiohttp
 
28
  async with aiohttp.ClientSession() as session:
29
  async with session.get(url) as r:
30
  return cls(name=url.split("/")[-1], content=await r.read())
31
-
32
  def save(self, path: str):
33
  with open(path, "wb") as f:
34
  f.write(self.content)
35
-
36
  async def asave(self, path: str):
37
  await asyncio.to_thread(self.save, path)
38
-
39
  def show_image(self):
40
  try:
41
  from PIL import Image # type: ignore
42
  except ImportError:
43
- print("Please install it with `pip install codeinterpreterapi[image_support]` to display images.")
 
 
44
  exit(1)
45
-
46
  from io import BytesIO
 
47
  img_io = BytesIO(self.content)
48
  img = Image.open(img_io)
49
-
50
  # Display the image
51
  img.show()
52
-
53
  def __str__(self):
54
  return self.name
55
-
56
  def __repr__(self):
57
- return f"File(name={self.name})"
 
5
  class File(BaseModel):
6
  name: str
7
  content: bytes
8
+
9
  @classmethod
10
  def from_path(cls, path: str):
11
  with open(path, "rb") as f:
12
  path = path.split("/")[-1]
13
  return cls(name=path, content=f.read())
14
+
15
+ @classmethod
16
  async def afrom_path(cls, path: str):
17
+ await asyncio.to_thread(cls.from_path, path)
18
+
19
  @classmethod
20
  def from_url(cls, url: str):
21
  import requests # type: ignore
22
+
23
  r = requests.get(url)
24
  return cls(name=url.split("/")[-1], content=r.content)
25
+
26
  @classmethod
27
  async def afrom_url(cls, url: str):
28
  import aiohttp
29
+
30
  async with aiohttp.ClientSession() as session:
31
  async with session.get(url) as r:
32
  return cls(name=url.split("/")[-1], content=await r.read())
33
+
34
  def save(self, path: str):
35
  with open(path, "wb") as f:
36
  f.write(self.content)
37
+
38
  async def asave(self, path: str):
39
  await asyncio.to_thread(self.save, path)
40
+
41
  def show_image(self):
42
  try:
43
  from PIL import Image # type: ignore
44
  except ImportError:
45
+ print(
46
+ "Please install it with `pip install codeinterpreterapi[image_support]` to display images."
47
+ )
48
  exit(1)
49
+
50
  from io import BytesIO
51
+
52
  img_io = BytesIO(self.content)
53
  img = Image.open(img_io)
54
+
55
  # Display the image
56
  img.show()
57
+
58
  def __str__(self):
59
  return self.name
60
+
61
  def __repr__(self):
62
+ return f"File(name={self.name})"
codeinterpreterapi/schema/input.py CHANGED
@@ -1,9 +1,9 @@
1
  from pydantic import BaseModel
2
 
3
 
4
- class CodeInput(BaseModel):
5
  code: str
6
-
7
 
8
- class FileInput(BaseModel):
 
9
  filename: str
 
1
  from pydantic import BaseModel
2
 
3
 
4
+ class CodeInput(BaseModel):
5
  code: str
 
6
 
7
+
8
+ class FileInput(BaseModel):
9
  filename: str
codeinterpreterapi/schema/response.py CHANGED
@@ -4,21 +4,21 @@ from .file import File
4
 
5
  class UserRequest(HumanMessage):
6
  files: list[File] = []
7
-
8
  def __str__(self):
9
  return self.content
10
-
11
  def __repr__(self):
12
  return f"UserRequest(content={self.content}, files={self.files})"
13
 
14
 
15
  class CodeInterpreterResponse(AIMessage):
16
  files: list[File] = []
17
- # final_code: str = "" TODO: implement
18
  # final_output: str = "" TODO: implement
19
-
20
  def __str__(self):
21
  return self.content
22
-
23
  def __repr__(self):
24
- return f"CodeInterpreterResponse(content={self.content}, files={self.files})"
 
4
 
5
  class UserRequest(HumanMessage):
6
  files: list[File] = []
7
+
8
  def __str__(self):
9
  return self.content
10
+
11
  def __repr__(self):
12
  return f"UserRequest(content={self.content}, files={self.files})"
13
 
14
 
15
  class CodeInterpreterResponse(AIMessage):
16
  files: list[File] = []
17
+ # final_code: str = "" TODO: implement
18
  # final_output: str = "" TODO: implement
19
+
20
  def __str__(self):
21
  return self.content
22
+
23
  def __repr__(self):
24
+ return f"CodeInterpreterResponse(content={self.content}, files={self.files})"
codeinterpreterapi/session.py CHANGED
@@ -9,7 +9,7 @@ from langchain.prompts.chat import MessagesPlaceholder
9
  from langchain.agents import AgentExecutor, BaseSingleActionAgent
10
  from langchain.memory import ConversationBufferMemory
11
 
12
- from codeinterpreterapi.schema import CodeInterpreterResponse, CodeInput, File, UserRequest # type: ignore
13
  from codeinterpreterapi.config import settings
14
  from codeinterpreterapi.chains.functions_agent import OpenAIFunctionsAgent
15
  from codeinterpreterapi.prompts import code_interpreter_system_message
@@ -18,8 +18,7 @@ from codeinterpreterapi.chains.modifications_check import get_file_modifications
18
  from codeinterpreterapi.chains.remove_download_link import remove_download_link
19
 
20
 
21
- class CodeInterpreterSession():
22
-
23
  def __init__(self, model=None, openai_api_key=None) -> None:
24
  self.codebox = CodeBox()
25
  self.tools: list[StructuredTool] = self._tools()
@@ -27,46 +26,46 @@ class CodeInterpreterSession():
27
  self.agent_executor: AgentExecutor = self._agent_executor()
28
  self.input_files: list[File] = []
29
  self.output_files: list[File] = []
30
-
31
  async def _init(self) -> None:
32
  await self.codebox.astart()
33
-
34
  async def _close(self) -> None:
35
  await self.codebox.astop()
36
-
37
  def _tools(self) -> list[StructuredTool]:
38
  return [
39
  StructuredTool(
40
- name = "python",
41
- description =
42
- # TODO: variables as context to the agent
43
- # TODO: current files as context to the agent
44
- "Input a string of code to a python interpreter (jupyter kernel). "
45
- "Variables are preserved between runs. ",
46
- func = self.codebox.run,
47
- coroutine = self.arun_handler,
48
- args_schema = CodeInput,
49
  ),
50
  ]
51
-
52
  def _llm(self, model: str | None, openai_api_key: str | None) -> BaseChatModel:
53
  if model is None:
54
  model = "gpt-4"
55
-
56
  if openai_api_key is None:
57
  if settings.OPENAI_API_KEY is None:
58
  raise ValueError("OpenAI API key missing.")
59
  else:
60
  openai_api_key = settings.OPENAI_API_KEY
61
-
62
  return ChatOpenAI(
63
  temperature=0.03,
64
- model=model,
65
  openai_api_key=openai_api_key,
66
  max_retries=3,
67
- request_timeout=60*3,
68
  ) # type: ignore
69
-
70
  def _agent(self) -> BaseSingleActionAgent:
71
  return OpenAIFunctionsAgent.from_llm_and_tools(
72
  llm=self.llm,
@@ -74,7 +73,7 @@ class CodeInterpreterSession():
74
  system_message=code_interpreter_system_message,
75
  extra_prompt_messages=[MessagesPlaceholder(variable_name="memory")],
76
  )
77
-
78
  def _agent_executor(self) -> AgentExecutor:
79
  return AgentExecutor.from_agent_and_tools(
80
  agent=self._agent(),
@@ -84,23 +83,23 @@ class CodeInterpreterSession():
84
  verbose=settings.VERBOSE,
85
  memory=ConversationBufferMemory(memory_key="memory", return_messages=True),
86
  )
87
-
88
  async def show_code(self, code: str) -> None:
89
- """ Callback function to show code to the user. """
90
  if settings.VERBOSE:
91
  print(code)
92
-
93
- def run_handler(self, code: str):
94
  raise NotImplementedError("Use arun_handler for now.")
95
-
96
  async def arun_handler(self, code: str):
97
- """ Run code in container and send the output to the user """
98
  # TODO: upload files
99
  output: CodeBoxOutput = await self.codebox.arun(code)
100
-
101
  if not isinstance(output.content, str):
102
  raise TypeError("Expected output.content to be a string.")
103
-
104
  if output.type == "image/png":
105
  filename = f"image-{uuid.uuid4()}.png"
106
  file_buffer = BytesIO(base64.b64decode(output.content))
@@ -108,58 +107,64 @@ class CodeInterpreterSession():
108
  self.output_files.append(File(name=filename, content=file_buffer.read()))
109
  return f"Image {filename} got send to the user."
110
 
111
- elif output.type == "error":
112
  if "ModuleNotFoundError" in output.content:
113
- if package := re.search(r"ModuleNotFoundError: No module named '(.*)'", output.content):
 
 
114
  await self.codebox.ainstall(package.group(1))
115
  return f"{package.group(1)} was missing but got installed now. Please try again."
116
  # TODO: preanalyze error to optimize next code generation
117
  print("Error:", output.content)
118
-
119
- elif (modifications := await get_file_modifications(code, self.llm)):
120
  for filename in modifications:
121
- if filename in [file.name for file in self.input_files]:
122
  continue
123
  fileb = await self.codebox.adownload(filename)
124
- if not fileb.content:
125
  continue
126
  file_buffer = BytesIO(fileb.content)
127
  file_buffer.name = filename
128
- self.output_files.append(File(name=filename, content=file_buffer.read()))
129
-
 
 
130
  return output.content
131
-
132
  async def input_handler(self, request: UserRequest):
133
- if not request.files:
134
  return
135
  if not request.content:
136
- request.content = "I uploaded, just text me back and confirm that you got the file(s)."
 
 
137
  request.content += "\n**The user uploaded the following files: **\n"
138
  for file in request.files:
139
  self.input_files.append(file)
140
  request.content += f"[Attachment: {file.name}]\n"
141
  await self.codebox.aupload(file.name, file.content)
142
  request.content += "**File(s) are now available in the cwd. **\n"
143
-
144
  async def output_handler(self, final_response: str) -> CodeInterpreterResponse:
145
- """ Embed images in the response """
146
  for file in self.output_files:
147
  if str(file.name) in final_response:
148
  # rm ![Any](file.name) from the response
149
  final_response = re.sub(rf"\n\n!\[.*\]\(.*\)", "", final_response)
150
-
151
  if self.output_files and re.search(rf"\n\[.*\]\(.*\)", final_response):
152
  final_response = await remove_download_link(final_response, self.llm)
153
-
154
  return CodeInterpreterResponse(content=final_response, files=self.output_files)
155
-
156
  async def generate_response(
157
- self,
158
- user_msg: str,
159
  files: list[File] = [],
160
  detailed_error: bool = False,
161
  ) -> CodeInterpreterResponse:
162
- """ Generate a Code Interpreter response based on the user's input."""
163
  user_request = UserRequest(content=user_msg, files=files)
164
  try:
165
  await self.input_handler(user_request)
@@ -168,20 +173,21 @@ class CodeInterpreterSession():
168
  except Exception as e:
169
  if settings.VERBOSE:
170
  import traceback
 
171
  traceback.print_exc()
172
  if detailed_error:
173
- return CodeInterpreterResponse(content=
174
- f"Error in CodeInterpreterSession: {e.__class__.__name__} - {e}"
175
  )
176
  else:
177
- return CodeInterpreterResponse(content=
178
- "Sorry, something went while generating your response."
179
  "Please try again or restart the session."
180
  )
181
 
182
  async def __aenter__(self) -> "CodeInterpreterSession":
183
  await self._init()
184
  return self
185
-
186
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
187
  await self._close()
 
9
  from langchain.agents import AgentExecutor, BaseSingleActionAgent
10
  from langchain.memory import ConversationBufferMemory
11
 
12
+ from codeinterpreterapi.schema import CodeInterpreterResponse, CodeInput, File, UserRequest # type: ignore
13
  from codeinterpreterapi.config import settings
14
  from codeinterpreterapi.chains.functions_agent import OpenAIFunctionsAgent
15
  from codeinterpreterapi.prompts import code_interpreter_system_message
 
18
  from codeinterpreterapi.chains.remove_download_link import remove_download_link
19
 
20
 
21
+ class CodeInterpreterSession:
 
22
  def __init__(self, model=None, openai_api_key=None) -> None:
23
  self.codebox = CodeBox()
24
  self.tools: list[StructuredTool] = self._tools()
 
26
  self.agent_executor: AgentExecutor = self._agent_executor()
27
  self.input_files: list[File] = []
28
  self.output_files: list[File] = []
29
+
30
  async def _init(self) -> None:
31
  await self.codebox.astart()
32
+
33
  async def _close(self) -> None:
34
  await self.codebox.astop()
35
+
36
  def _tools(self) -> list[StructuredTool]:
37
  return [
38
  StructuredTool(
39
+ name="python",
40
+ description=
41
+ # TODO: variables as context to the agent
42
+ # TODO: current files as context to the agent
43
+ "Input a string of code to a python interpreter (jupyter kernel). "
44
+ "Variables are preserved between runs. ",
45
+ func=self.codebox.run,
46
+ coroutine=self.arun_handler,
47
+ args_schema=CodeInput,
48
  ),
49
  ]
50
+
51
  def _llm(self, model: str | None, openai_api_key: str | None) -> BaseChatModel:
52
  if model is None:
53
  model = "gpt-4"
54
+
55
  if openai_api_key is None:
56
  if settings.OPENAI_API_KEY is None:
57
  raise ValueError("OpenAI API key missing.")
58
  else:
59
  openai_api_key = settings.OPENAI_API_KEY
60
+
61
  return ChatOpenAI(
62
  temperature=0.03,
63
+ model=model,
64
  openai_api_key=openai_api_key,
65
  max_retries=3,
66
+ request_timeout=60 * 3,
67
  ) # type: ignore
68
+
69
  def _agent(self) -> BaseSingleActionAgent:
70
  return OpenAIFunctionsAgent.from_llm_and_tools(
71
  llm=self.llm,
 
73
  system_message=code_interpreter_system_message,
74
  extra_prompt_messages=[MessagesPlaceholder(variable_name="memory")],
75
  )
76
+
77
  def _agent_executor(self) -> AgentExecutor:
78
  return AgentExecutor.from_agent_and_tools(
79
  agent=self._agent(),
 
83
  verbose=settings.VERBOSE,
84
  memory=ConversationBufferMemory(memory_key="memory", return_messages=True),
85
  )
86
+
87
  async def show_code(self, code: str) -> None:
88
+ """Callback function to show code to the user."""
89
  if settings.VERBOSE:
90
  print(code)
91
+
92
+ def run_handler(self, code: str):
93
  raise NotImplementedError("Use arun_handler for now.")
94
+
95
  async def arun_handler(self, code: str):
96
+ """Run code in container and send the output to the user"""
97
  # TODO: upload files
98
  output: CodeBoxOutput = await self.codebox.arun(code)
99
+
100
  if not isinstance(output.content, str):
101
  raise TypeError("Expected output.content to be a string.")
102
+
103
  if output.type == "image/png":
104
  filename = f"image-{uuid.uuid4()}.png"
105
  file_buffer = BytesIO(base64.b64decode(output.content))
 
107
  self.output_files.append(File(name=filename, content=file_buffer.read()))
108
  return f"Image {filename} got send to the user."
109
 
110
+ elif output.type == "error":
111
  if "ModuleNotFoundError" in output.content:
112
+ if package := re.search(
113
+ r"ModuleNotFoundError: No module named '(.*)'", output.content
114
+ ):
115
  await self.codebox.ainstall(package.group(1))
116
  return f"{package.group(1)} was missing but got installed now. Please try again."
117
  # TODO: preanalyze error to optimize next code generation
118
  print("Error:", output.content)
119
+
120
+ elif modifications := await get_file_modifications(code, self.llm):
121
  for filename in modifications:
122
+ if filename in [file.name for file in self.input_files]:
123
  continue
124
  fileb = await self.codebox.adownload(filename)
125
+ if not fileb.content:
126
  continue
127
  file_buffer = BytesIO(fileb.content)
128
  file_buffer.name = filename
129
+ self.output_files.append(
130
+ File(name=filename, content=file_buffer.read())
131
+ )
132
+
133
  return output.content
134
+
135
  async def input_handler(self, request: UserRequest):
136
+ if not request.files:
137
  return
138
  if not request.content:
139
+ request.content = (
140
+ "I uploaded, just text me back and confirm that you got the file(s)."
141
+ )
142
  request.content += "\n**The user uploaded the following files: **\n"
143
  for file in request.files:
144
  self.input_files.append(file)
145
  request.content += f"[Attachment: {file.name}]\n"
146
  await self.codebox.aupload(file.name, file.content)
147
  request.content += "**File(s) are now available in the cwd. **\n"
148
+
149
  async def output_handler(self, final_response: str) -> CodeInterpreterResponse:
150
+ """Embed images in the response"""
151
  for file in self.output_files:
152
  if str(file.name) in final_response:
153
  # rm ![Any](file.name) from the response
154
  final_response = re.sub(rf"\n\n!\[.*\]\(.*\)", "", final_response)
155
+
156
  if self.output_files and re.search(rf"\n\[.*\]\(.*\)", final_response):
157
  final_response = await remove_download_link(final_response, self.llm)
158
+
159
  return CodeInterpreterResponse(content=final_response, files=self.output_files)
160
+
161
  async def generate_response(
162
+ self,
163
+ user_msg: str,
164
  files: list[File] = [],
165
  detailed_error: bool = False,
166
  ) -> CodeInterpreterResponse:
167
+ """Generate a Code Interpreter response based on the user's input."""
168
  user_request = UserRequest(content=user_msg, files=files)
169
  try:
170
  await self.input_handler(user_request)
 
173
  except Exception as e:
174
  if settings.VERBOSE:
175
  import traceback
176
+
177
  traceback.print_exc()
178
  if detailed_error:
179
+ return CodeInterpreterResponse(
180
+ content=f"Error in CodeInterpreterSession: {e.__class__.__name__} - {e}"
181
  )
182
  else:
183
+ return CodeInterpreterResponse(
184
+ content="Sorry, something went while generating your response."
185
  "Please try again or restart the session."
186
  )
187
 
188
  async def __aenter__(self) -> "CodeInterpreterSession":
189
  await self._init()
190
  return self
191
+
192
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
193
  await self._close()
examples/convert_file.py CHANGED
@@ -8,13 +8,14 @@ async def main():
8
  files = [
9
  File.from_path("examples/iris.csv"),
10
  ]
11
-
12
  output = await session.generate_response(user_request, files=files)
13
  file = output.files[0]
14
-
15
- file.save("examples/iris.xlsx")
16
 
17
 
18
  if __name__ == "__main__":
19
  import asyncio
 
20
  asyncio.run(main())
 
8
  files = [
9
  File.from_path("examples/iris.csv"),
10
  ]
11
+
12
  output = await session.generate_response(user_request, files=files)
13
  file = output.files[0]
14
+
15
+ file.save("examples/iris.xlsx")
16
 
17
 
18
  if __name__ == "__main__":
19
  import asyncio
20
+
21
  asyncio.run(main())
examples/plot_sin_wave.py CHANGED
@@ -4,12 +4,14 @@ from codeinterpreterapi import CodeInterpreterSession
4
  async def main():
5
  async with CodeInterpreterSession() as session:
6
  user_request = "Plot a sin wave and show it to me."
7
-
8
  output = await session.generate_response(user_request)
9
-
10
  file = output.files[0]
11
  file.show_image()
12
 
 
13
  if __name__ == "__main__":
14
  import asyncio
 
15
  asyncio.run(main())
 
4
  async def main():
5
  async with CodeInterpreterSession() as session:
6
  user_request = "Plot a sin wave and show it to me."
7
+
8
  output = await session.generate_response(user_request)
9
+
10
  file = output.files[0]
11
  file.show_image()
12
 
13
+
14
  if __name__ == "__main__":
15
  import asyncio
16
+
17
  asyncio.run(main())
examples/show_bitcoin_chart.py CHANGED
@@ -6,9 +6,9 @@ async def main():
6
  async with CodeInterpreterSession() as session:
7
  currentdate = datetime.now().strftime("%Y-%m-%d")
8
  user_request = f"Plot the bitcoin chart of 2023 YTD (today is {currentdate})"
9
-
10
  output = await session.generate_response(user_request)
11
-
12
  file = output.files[0]
13
  file.show_image()
14
 
@@ -21,4 +21,5 @@ async def main():
21
 
22
  if __name__ == "__main__":
23
  import asyncio
 
24
  asyncio.run(main())
 
6
  async with CodeInterpreterSession() as session:
7
  currentdate = datetime.now().strftime("%Y-%m-%d")
8
  user_request = f"Plot the bitcoin chart of 2023 YTD (today is {currentdate})"
9
+
10
  output = await session.generate_response(user_request)
11
+
12
  file = output.files[0]
13
  file.show_image()
14
 
 
21
 
22
  if __name__ == "__main__":
23
  import asyncio
24
+
25
  asyncio.run(main())