Shroominic commited on
Commit
000213d
·
1 Parent(s): d7fa612

🗂️ move actual prompts out of chains

Browse files
codeinterpreterapi/chains/modifications_check.py CHANGED
@@ -1,51 +1,11 @@
1
  import json
2
- from json import JSONDecodeError
3
  from typing import List
4
 
5
  from langchain.base_language import BaseLanguageModel
6
  from langchain.chat_models.openai import ChatOpenAI
7
- from langchain.prompts.chat import (
8
- ChatPromptTemplate,
9
- HumanMessagePromptTemplate,
10
- )
11
- from langchain.schema import (
12
- AIMessage,
13
- OutputParserException,
14
- SystemMessage,
15
- )
16
-
17
-
18
- prompt = ChatPromptTemplate(
19
- input_variables=["code"],
20
- messages=[
21
- SystemMessage(
22
- content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
23
- "With changes it means creating new files or modifying exsisting ones.\n"
24
- "Answer with a function call `determine_modifications` and list them inside.\n"
25
- "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
26
- ),
27
- HumanMessagePromptTemplate.from_template("{code}"),
28
- ],
29
- )
30
-
31
- functions = [
32
- {
33
- "name": "determine_modifications",
34
- "description": "Based on code of the user determine if the code makes any changes to the file system. \n"
35
- "With changes it means creating new files or modifying exsisting ones.\n",
36
- "parameters": {
37
- "type": "object",
38
- "properties": {
39
- "modifications": {
40
- "type": "array",
41
- "items": {"type": "string"},
42
- "description": "The filenames that are modified by the code.",
43
- },
44
- },
45
- "required": ["modifications"],
46
- },
47
- }
48
- ]
49
 
50
 
51
  async def get_file_modifications(
@@ -55,8 +15,8 @@ async def get_file_modifications(
55
  ) -> List[str] | None:
56
  if retry < 1:
57
  return None
58
- messages = prompt.format_prompt(code=code).to_messages()
59
- message = await llm.apredict_messages(messages, functions=functions)
60
 
61
  if not isinstance(message, AIMessage):
62
  raise OutputParserException("Expected an AIMessage")
@@ -71,7 +31,7 @@ async def get_file_modifications(
71
 
72
 
73
  async def test():
74
- llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
75
 
76
  code = """
77
  import matplotlib.pyplot as plt
@@ -87,17 +47,12 @@ async def test():
87
  plt.show()
88
  """
89
 
90
- code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
91
-
92
- modifications = await get_file_modifications(code2, llm)
93
-
94
- print(modifications)
95
 
96
 
97
  if __name__ == "__main__":
98
  import asyncio
99
  from dotenv import load_dotenv
100
-
101
  load_dotenv()
102
 
103
  asyncio.run(test())
 
1
  import json
 
2
  from typing import List
3
 
4
  from langchain.base_language import BaseLanguageModel
5
  from langchain.chat_models.openai import ChatOpenAI
6
+ from langchain.schema import AIMessage, OutputParserException
7
+
8
+ from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  async def get_file_modifications(
 
15
  ) -> List[str] | None:
16
  if retry < 1:
17
  return None
18
+ messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
19
+ message = await llm.apredict_messages(messages, functions=[determine_modifications_function])
20
 
21
  if not isinstance(message, AIMessage):
22
  raise OutputParserException("Expected an AIMessage")
 
31
 
32
 
33
  async def test():
34
+ llm = ChatOpenAI(model="gpt-3.5") # type: ignore
35
 
36
  code = """
37
  import matplotlib.pyplot as plt
 
47
  plt.show()
48
  """
49
 
50
+ print(await get_file_modifications(code, llm))
 
 
 
 
51
 
52
 
53
  if __name__ == "__main__":
54
  import asyncio
55
  from dotenv import load_dotenv
 
56
  load_dotenv()
57
 
58
  asyncio.run(test())
codeinterpreterapi/chains/remove_download_link.py CHANGED
@@ -1,39 +1,15 @@
1
  from langchain.base_language import BaseLanguageModel
2
  from langchain.chat_models.openai import ChatOpenAI
3
- from langchain.prompts.chat import (
4
- ChatPromptTemplate,
5
- HumanMessagePromptTemplate,
6
- )
7
- from langchain.schema import (
8
- AIMessage,
9
- OutputParserException,
10
- SystemMessage,
11
- HumanMessage,
12
- )
13
-
14
-
15
- prompt = ChatPromptTemplate(
16
- input_variables=["input_response"],
17
- messages=[
18
- SystemMessage(
19
- content="The user will send you a response and you need to remove the download link from it.\n"
20
- "Reformat the remaining message so no whitespace or half sentences are still there.\n"
21
- "If the response does not contain a download link, return the response as is.\n"
22
- ),
23
- HumanMessage(
24
- content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."
25
- ),
26
- AIMessage(content="The dataset has been successfully converted to CSV format."),
27
- HumanMessagePromptTemplate.from_template("{input_response}"),
28
- ],
29
- )
30
 
31
 
32
  async def remove_download_link(
33
  input_response: str,
34
  llm: BaseLanguageModel,
35
  ) -> str:
36
- messages = prompt.format_prompt(input_response=input_response).to_messages()
37
  message = await llm.apredict_messages(messages)
38
 
39
  if not isinstance(message, AIMessage):
@@ -47,15 +23,12 @@ async def test():
47
 
48
  example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."
49
 
50
- modifications = await remove_download_link(example, llm)
51
-
52
- print(modifications)
53
 
54
 
55
  if __name__ == "__main__":
56
  import asyncio
57
- import dotenv
58
-
59
- dotenv.load_dotenv()
60
 
61
  asyncio.run(test())
 
1
  from langchain.base_language import BaseLanguageModel
2
  from langchain.chat_models.openai import ChatOpenAI
3
+ from langchain.schema import AIMessage, OutputParserException
4
+
5
+ from codeinterpreterapi.prompts import remove_dl_link_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  async def remove_download_link(
9
  input_response: str,
10
  llm: BaseLanguageModel,
11
  ) -> str:
12
+ messages = remove_dl_link_prompt.format_prompt(input_response=input_response).to_messages()
13
  message = await llm.apredict_messages(messages)
14
 
15
  if not isinstance(message, AIMessage):
 
23
 
24
  example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."
25
 
26
+ print(await remove_download_link(example, llm))
 
 
27
 
28
 
29
  if __name__ == "__main__":
30
  import asyncio
31
+ from dotenv import load_dotenv
32
+ load_dotenv()
 
33
 
34
  asyncio.run(test())
codeinterpreterapi/prompts/__init__.py CHANGED
@@ -1 +1,3 @@
1
  from .system_message import system_message as code_interpreter_system_message
 
 
 
1
  from .system_message import system_message as code_interpreter_system_message
2
+ from .modifications_check import determine_modifications_function, determine_modifications_prompt
3
+ from .remove_dl_link import remove_dl_link_prompt
codeinterpreterapi/prompts/modifications_check.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain.schema import SystemMessage
3
+ from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
4
+
5
+
6
+ determine_modifications_prompt = ChatPromptTemplate(
7
+ input_variables=["code"],
8
+ messages=[
9
+ SystemMessage(
10
+ content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
11
+ "With changes it means creating new files or modifying exsisting ones.\n"
12
+ "Answer with a function call `determine_modifications` and list them inside.\n"
13
+ "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
14
+ ),
15
+ HumanMessagePromptTemplate.from_template("{code}"),
16
+ ],
17
+ )
18
+
19
+
20
+ determine_modifications_function = {
21
+ "name": "determine_modifications",
22
+ "description": "Based on code of the user determine if the code makes any changes to the file system. \n"
23
+ "With changes it means creating new files or modifying exsisting ones.\n",
24
+ "parameters": {
25
+ "type": "object",
26
+ "properties": {
27
+ "modifications": {
28
+ "type": "array",
29
+ "items": {"type": "string"},
30
+ "description": "The filenames that are modified by the code.",
31
+ },
32
+ },
33
+ "required": ["modifications"],
34
+ },
35
+ }
codeinterpreterapi/prompts/remove_dl_link.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts.chat import (
2
+ ChatPromptTemplate,
3
+ HumanMessagePromptTemplate,
4
+ )
5
+ from langchain.schema import (
6
+ AIMessage,
7
+ SystemMessage,
8
+ HumanMessage,
9
+ )
10
+
11
+
12
+ remove_dl_link_prompt = ChatPromptTemplate(
13
+ input_variables=["input_response"],
14
+ messages=[
15
+ SystemMessage(
16
+ content="The user will send you a response and you need to remove the download link from it.\n"
17
+ "Reformat the remaining message so no whitespace or half sentences are still there.\n"
18
+ "If the response does not contain a download link, return the response as is.\n"
19
+ ),
20
+ HumanMessage(
21
+ content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."
22
+ ),
23
+ AIMessage(content="The dataset has been successfully converted to CSV format."),
24
+ HumanMessagePromptTemplate.from_template("{input_response}"),
25
+ ],
26
+ )