Technologic101 commited on
Commit
ff20ed1
·
1 Parent(s): 8515119

task: refactors as LangGraph

Browse files
Files changed (4) hide show
  1. src/graph.ipynb +0 -317
  2. src/graph.py +38 -5
  3. src/nodes/analyzer.py +1 -0
  4. src/nodes/designer.py +122 -0
src/graph.ipynb DELETED
@@ -1,317 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import os\n",
10
- "import getpass\n",
11
- "\n",
12
- "\n",
13
- "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
14
- ]
15
- },
16
- {
17
- "cell_type": "markdown",
18
- "metadata": {},
19
- "source": [
20
- "Add tools later"
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": 2,
26
- "metadata": {},
27
- "outputs": [
28
- {
29
- "name": "stdout",
30
- "output_type": "stream",
31
- "text": [
32
- "Loaded 82 design documents\n",
33
- "Testing RAG retriever with requirements:\n",
34
- "\n",
35
- "Retrieved Designs:\n",
36
- "----------------------------------------\n",
37
- "Generated query: \"vintage classic easy to use grandmother love design\"\n",
38
- "Design 180:\n",
39
- "Description: This design employs a vintage newspaper aesthetic with a classic serif typography that evokes an old-world charm, utilizing sepia-toned paper backgrounds to enhance its nostalgic feel. The layout is text-heavy with a deliberate obfuscation, reflecting a layered collage effect. Its balanced placement keeps the focus central, inviting closer inspection and interaction.\n",
40
- "Categories: Vintage, Nostalgic, Typography, Collage, Editorial\n",
41
- "Visual Characteristics: Sepia tone, Serif typography, Textured background, Layered elements, Central focus\n",
42
- "URL: https://csszengarden.com/180\n",
43
- "\n",
44
- "Design 182:\n",
45
- "Description: The design creatively utilizes a retro theme with vinyl records as the prominent visual element to evoke a sense of nostalgia and classic style, complemented by a muted green color palette that brings harmony and balance. Handwritten and vintage-style typography enhance the retro aesthetic, while background illustrations and decorative elements like stars add whimsy and depth to the composition.\n",
46
- "Categories: Retro, Nostalgic, Music-themed, Decorative, Vintage\n",
47
- "Visual Characteristics: Vinyl Records, Muted Green Palette, Handwritten Typography, Background Illustrations, Decorative Elements\n",
48
- "URL: https://csszengarden.com/182\n",
49
- "\n",
50
- "Design 194:\n",
51
- "Description: This design exudes a minimalist elegance with a muted, earthy color palette and a clean layout, embodying a sense of calm and sophistication. The subtle use of textures and classic serif typography enhances the refined aesthetic, while the centered alignment and generous spacing contribute to a relaxed readability. The incorporation of a delicate floral illustration adds a touch of organic charm, making the design feel both timeless and inviting.\n",
52
- "Categories: Minimalism, Elegant, Organic, Sophisticated, Classic\n",
53
- "Visual Characteristics: Muted Color Palette, Serif Typography, Centered Layout, Generous Spacing, Floral Illustration\n",
54
- "URL: https://csszengarden.com/194\n",
55
- "\n",
56
- "Design 212:\n",
57
- "Description: The design features a retro aesthetic using a muted color palette of browns and creams, creating a nostalgic and vintage feel. The asymmetrical layout and bold typography contribute to the visual hierarchy, guiding the viewer through the content effortlessly. Illustrations with a mid-century modern style add character, merging traditional design elements with contemporary functionality.\n",
58
- "Categories: Retro, Typography, Illustration, Vintage Style, Educational\n",
59
- "Visual Characteristics: Muted Color Palette, Asymmetrical Layout, Bold Typography, Retro Illustrations, Functional Design\n",
60
- "URL: https://csszengarden.com/212\n"
61
- ]
62
- }
63
- ],
64
- "source": [
65
- "#from tools.design_retriever import DesignRetrieverTool\n",
66
- "from chains.design_rag import DesignRAG\n",
67
- "\n",
68
- "# Initialize DesignRAG and create the tool\n",
69
- "design_rag = DesignRAG()\n",
70
- "#design_retriever = DesignRetrieverTool(rag=design_rag)\n",
71
- "\n",
72
- "test_requirements = {\n",
73
- " \"I want a design that is vintage and classic, something easy to use that a grandmother would love\"\n",
74
- " }\n",
75
- "\n",
76
- "# Test the retriever\n",
77
- "async def test_rag():\n",
78
- " print(\"Testing RAG retriever with requirements:\")\n",
79
- " print(\"\\nRetrieved Designs:\")\n",
80
- " print(\"----------------------------------------\")\n",
81
- " \n",
82
- " results = await design_rag.query_similar_designs(test_requirements, 2)\n",
83
- " print(results)\n",
84
- "\n",
85
- "# Run the test\n",
86
- "await test_rag()\n"
87
- ]
88
- },
89
- {
90
- "cell_type": "markdown",
91
- "metadata": {},
92
- "source": [
93
- "Pick a model good for chat and tools"
94
- ]
95
- },
96
- {
97
- "cell_type": "code",
98
- "execution_count": null,
99
- "metadata": {},
100
- "outputs": [
101
- {
102
- "data": {
103
- "text/plain": [
104
- "RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x1245518d0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x124548e50>, root_client=<openai.OpenAI object at 0x1108f9310>, root_async_client=<openai.AsyncOpenAI object at 0x115d92090>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********'), streaming=True), kwargs={'tools': [{'type': 'function', 'function': {'name': 'design_retriever', 'description': 'Retrieves similar designs based on style requirements', 'parameters': {'properties': {'requirements': {'type': 'object'}, 'num_examples': {'default': 3, 'type': 'integer'}}, 'required': ['requirements'], 'type': 'object'}}}]}, config={}, config_factories=[])"
105
- ]
106
- },
107
- "execution_count": 11,
108
- "metadata": {},
109
- "output_type": "execute_result"
110
- }
111
- ],
112
- "source": [
113
- "from langchain_openai import ChatOpenAI\n",
114
- "\n",
115
- "model = ChatOpenAI(\n",
116
- " model=\"gpt-4o\", \n",
117
- " temperature=0,\n",
118
- " streaming=True\n",
119
- ")\n",
120
- "\n",
121
- "model.bind_tools(tool_belt)"
122
- ]
123
- },
124
- {
125
- "cell_type": "markdown",
126
- "metadata": {},
127
- "source": [
128
- "Initialize state\n"
129
- ]
130
- },
131
- {
132
- "cell_type": "code",
133
- "execution_count": 12,
134
- "metadata": {},
135
- "outputs": [],
136
- "source": [
137
- "from typing import TypedDict, Annotated\n",
138
- "from langgraph.graph.message import add_messages\n",
139
- "\n",
140
- "class AgentState(TypedDict):\n",
141
- " messages: Annotated[list, add_messages]"
142
- ]
143
- },
144
- {
145
- "cell_type": "markdown",
146
- "metadata": {},
147
- "source": [
148
- "Set up the nodes and graph\n"
149
- ]
150
- },
151
- {
152
- "cell_type": "code",
153
- "execution_count": 16,
154
- "metadata": {},
155
- "outputs": [],
156
- "source": [
157
- "from langgraph.prebuilt import ToolNode\n",
158
- "from langgraph.graph import StateGraph, END\n",
159
- "from langchain_core.messages import HumanMessage, SystemMessage\n",
160
- "\n",
161
- "system_message = SystemMessage(content=\"\"\"You are a helpful design assistant that can retrieve and analyze design examples. \n",
162
- "When a user describes their design preferences or requirements, use the design_retriever tool to find relevant examples.\n",
163
- "\n",
164
- "Always use the design_retriever tool when:\n",
165
- "- A user describes specific design requirements\n",
166
- "- A user asks to see similar designs\n",
167
- "- You need to find design inspiration based on user preferences\n",
168
- "\n",
169
- "Format the requirements as a dictionary with these keys:\n",
170
- "- style_description: Brief description of desired visual style\n",
171
- "- key_elements: List of important visual elements\n",
172
- "- color_scheme: Description of colors\n",
173
- "- layout_preferences: Layout requirements\n",
174
- "- mood: Desired emotional impact\n",
175
- "\"\"\")\n",
176
- "\n",
177
- "def call_model(state):\n",
178
- " messages = [system_message] + state[\"messages\"]\n",
179
- " response = model.invoke(messages)\n",
180
- " return {\"messages\" : [response]}\n",
181
- "\n",
182
- "tool_node = ToolNode(tool_belt)\n",
183
- "\n",
184
- "uncompiled_graph = StateGraph(AgentState)\n",
185
- "\n",
186
- "uncompiled_graph.add_node(\"agent\", call_model)\n",
187
- "uncompiled_graph.add_node(\"action\", tool_node)\n",
188
- "uncompiled_graph.set_entry_point(\"agent\")\n",
189
- "\n",
190
- "\n",
191
- "def should_continue(state):\n",
192
- " last_message = state[\"messages\"][-1]\n",
193
- "\n",
194
- " if last_message.tool_calls:\n",
195
- " return \"action\"\n",
196
- "\n",
197
- " return END\n",
198
- "\n",
199
- "uncompiled_graph.add_conditional_edges(\n",
200
- " \"agent\",\n",
201
- " should_continue\n",
202
- ")\n",
203
- "uncompiled_graph.add_edge(\"action\", \"agent\")\n",
204
- "\n",
205
- "graph = uncompiled_graph.compile()\n",
206
- "\n",
207
- "#formatted chain\n",
208
- "\n",
209
- "def convert_inputs(input_object):\n",
210
- " return {\"messages\" : [HumanMessage(content=input_object[\"question\"])]}\n",
211
- "\n",
212
- "def parse_output(input_state):\n",
213
- " return input_state[\"messages\"][-1].content\n",
214
- "\n",
215
- "graph_chain = convert_inputs | graph | parse_output\n",
216
- "\n"
217
- ]
218
- },
219
- {
220
- "cell_type": "markdown",
221
- "metadata": {},
222
- "source": [
223
- "Try it out!"
224
- ]
225
- },
226
- {
227
- "cell_type": "code",
228
- "execution_count": null,
229
- "metadata": {},
230
- "outputs": [
231
- {
232
- "name": "stdout",
233
- "output_type": "stream",
234
- "text": [
235
- "Receiving update from node: 'agent'\n",
236
- "[AIMessage(content=\"Hello! I'm here and ready to help you with any design needs or questions you might have. How can I assist you today?\", additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_f9f4fb6dbf'}, id='run-4edce0b5-fdec-4d5d-a4a6-92430faca51a-0')]\n",
237
- "\n",
238
- "\n",
239
- "\n"
240
- ]
241
- }
242
- ],
243
- "source": [
244
- "from langchain_core.messages import HumanMessage\n",
245
- "\n",
246
- "async for chunk in graph.astream({\"messages\" : [HumanMessage(content=\"Hello, how are you?\")]}, stream_mode=\"updates\"):\n",
247
- " for node, values in chunk.items():\n",
248
- " print(f\"Receiving update from node: '{node}'\")\n",
249
- " print(values[\"messages\"])\n",
250
- " print(\"\\n\\n\")"
251
- ]
252
- },
253
- {
254
- "cell_type": "markdown",
255
- "metadata": {},
256
- "source": [
257
- "Let's see if the RAG tool works."
258
- ]
259
- },
260
- {
261
- "cell_type": "code",
262
- "execution_count": 18,
263
- "metadata": {},
264
- "outputs": [
265
- {
266
- "name": "stdout",
267
- "output_type": "stream",
268
- "text": [
269
- "Receiving update from node: 'agent'\n",
270
- "[AIMessage(content=\"To find a design that matches your description, I'll use the design_retriever tool. Here are the requirements based on your description:\\n\\n- style_description: Monochromatic with subtle accents\\n- key_elements: Grid-based layout, clear hierarchy\\n- color_scheme: Monochromatic with subtle accent colors\\n- layout_preferences: Grid-based\\n- mood: Professional and sophisticated\\n\\nLet's find some examples for you.\", additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_f9f4fb6dbf'}, id='run-8fa2e4af-671c-4c75-82fd-a7b3d6237e54-0')]\n",
271
- "\n",
272
- "\n",
273
- "\n"
274
- ]
275
- }
276
- ],
277
- "source": [
278
- "# Create a test message\n",
279
- "from langchain_core.messages import HumanMessage\n",
280
- "\n",
281
- "test_message = HumanMessage(\n",
282
- " content=\"\"\"I want to see a design matching this description: \n",
283
- " I want it to use a monochromatic color scheme with subtle accent colors. \n",
284
- " The layout should be grid-based with clear hierarchy. \n",
285
- " The overall mood should be professional and sophisticated.\"\"\"\n",
286
- ")\n",
287
- "\n",
288
- "async for chunk in graph.astream({\"messages\" : [test_message]}, stream_mode=\"updates\"):\n",
289
- " for node, values in chunk.items():\n",
290
- " print(f\"Receiving update from node: '{node}'\")\n",
291
- " print(values[\"messages\"])\n",
292
- " print(\"\\n\\n\")"
293
- ]
294
- }
295
- ],
296
- "metadata": {
297
- "kernelspec": {
298
- "display_name": ".venv",
299
- "language": "python",
300
- "name": "python3"
301
- },
302
- "language_info": {
303
- "codemirror_mode": {
304
- "name": "ipython",
305
- "version": 3
306
- },
307
- "file_extension": ".py",
308
- "mimetype": "text/x-python",
309
- "name": "python",
310
- "nbconvert_exporter": "python",
311
- "pygments_lexer": "ipython3",
312
- "version": "3.11.11"
313
- }
314
- },
315
- "nbformat": 4,
316
- "nbformat_minor": 2
317
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/graph.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Annotated
2
-
3
  from typing_extensions import TypedDict
4
-
5
  from langgraph.graph import StateGraph, START, END
6
  from langgraph.graph.message import add_messages
7
-
 
 
8
 
9
  class State(TypedDict):
10
  # Messages have the type "list". The `add_messages` function
@@ -12,7 +12,40 @@ class State(TypedDict):
12
  # (in this case, it appends messages to the list, rather than overwriting them)
13
  messages: Annotated[list, add_messages]
14
 
15
-
16
- graph_builder = StateGraph(State)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
 
1
  from typing import Annotated
 
2
  from typing_extensions import TypedDict
 
3
  from langgraph.graph import StateGraph, START, END
4
  from langgraph.graph.message import add_messages
5
+ from langgraph.prebuilt import ToolInvoker
6
+ from nodes.designer import DesignerNode
7
+ from langchain.tools.render import format_tool_to_openai_function
8
 
9
  class State(TypedDict):
10
  # Messages have the type "list". The `add_messages` function
 
12
  # (in this case, it appends messages to the list, rather than overwriting them)
13
  messages: Annotated[list, add_messages]
14
 
15
+ def create_graph():
16
+ # Initialize nodes
17
+ designer = DesignerNode()
18
+
19
+ # Create graph
20
+ graph = StateGraph(State)
21
+
22
+ # Add designer node
23
+ graph.add_node("designer", designer)
24
+
25
+ # Create tool invoker node with designer's tools
26
+ tools = designer.get_available_tools()
27
+ tool_executor = ToolInvoker(tools=tools)
28
+ graph.add_node("tools", tool_executor)
29
+
30
+ # Add edges
31
+ graph.add_edge(START, "designer")
32
+
33
+ # Add conditional edges based on tool calls
34
+ graph.add_conditional_edges(
35
+ "designer",
36
+ lambda state: "tools" if state["messages"][-1].get("tool_calls") else END,
37
+ {
38
+ "tools": "tools",
39
+ END: END
40
+ }
41
+ )
42
+
43
+ # After tool execution, return to designer
44
+ graph.add_edge("tools", "designer")
45
+
46
+ return graph.compile()
47
+
48
+ # Create the graph
49
+ graph = create_graph()
50
 
51
 
src/nodes/analyzer.py CHANGED
@@ -0,0 +1 @@
 
 
1
+
src/nodes/designer.py CHANGED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from anthropic import AsyncAnthropic
3
+ import json
4
+ from langchain_core.tools import tool
5
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
6
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from nodes.design_rag import DesignRAG
8
+
9
+ class DesignerNode:
10
+ """Main conversation node for discussing design requirements and retrieving examples"""
11
+
12
+ def __init__(self):
13
+ self.client = AsyncAnthropic()
14
+ self.rag = DesignRAG()
15
+
16
+ # Define the conversation prompt
17
+ self.prompt = ChatPromptTemplate.from_messages([
18
+ ("system", """You are an expert design assistant helping users find design inspiration.
19
+ Your goal is to understand their design needs and requirements through conversation.
20
+
21
+ Guidelines:
22
+ 1. Focus on understanding visual design requirements, not implementation
23
+ 2. Ask clarifying questions about style, mood, and visual elements
24
+ 3. When the user asks to see examples, use the retrieve_design_examples tool
25
+ 4. Track both must-have requirements and nice-to-have preferences
26
+ 5. When showing examples, explain how they match the requirements
27
+
28
+ Available tools:
29
+ - retrieve_design_examples: Find relevant design examples based on conversation
30
+
31
+ When the user asks to see examples, ALWAYS use the retrieve_design_examples tool.
32
+ Format tool calls using the exact function name and parameters.
33
+ """),
34
+ MessagesPlaceholder(variable_name="chat_history"),
35
+ ("human", "{input}"),
36
+ ])
37
+
38
+ @tool()
39
+ async def retrieve_design_examples(self, conversation: List[str], num_examples: int = 1) -> str:
40
+ """
41
+ Find and retrieve relevant design examples based on the conversation history.
42
+
43
+ Args:
44
+ conversation: List of conversation messages
45
+ num_examples: Number of examples to retrieve (default: 1)
46
+
47
+ Returns:
48
+ String containing design examples and their details
49
+ """
50
+ return await self.rag.query_similar_designs(conversation, num_examples)
51
+
52
+ def get_available_tools(self):
53
+ """Return list of available tools"""
54
+ return [self.retrieve_design_examples]
55
+
56
+ async def __call__(self, state: Dict) -> Dict:
57
+ """Process messages and manage design discussion"""
58
+ messages = state.get("messages", [])
59
+
60
+ # Convert messages to chat history format
61
+ chat_history = []
62
+ for msg in messages[:-1]: # Exclude the last message which is the current input
63
+ if isinstance(msg, dict):
64
+ role = msg.get("role", "user")
65
+ content = msg.get("content", "")
66
+ chat_history.append(
67
+ HumanMessage(content=content) if role == "user"
68
+ else AIMessage(content=content)
69
+ )
70
+ elif isinstance(msg, BaseMessage):
71
+ chat_history.append(msg)
72
+
73
+ # Get the current input message
74
+ current_input = messages[-1].get("content") if isinstance(messages[-1], dict) else messages[-1].content
75
+
76
+ # Get response from Claude
77
+ response = await self.client.messages.create(
78
+ model="claude-3-haiku-20240307",
79
+ max_tokens=500,
80
+ messages=[{
81
+ "role": "user",
82
+ "content": self.prompt.format(
83
+ chat_history=chat_history,
84
+ input=current_input
85
+ )
86
+ }]
87
+ )
88
+
89
+ response_text = response.content[0].text
90
+
91
+ # Check if response indicates need for examples
92
+ should_retrieve = (
93
+ "retrieve_design_examples" in response_text or
94
+ any(phrase in current_input.lower()
95
+ for phrase in ["show example", "find design", "get example"])
96
+ )
97
+
98
+ if should_retrieve:
99
+ # Create tool call message
100
+ state["messages"].append({
101
+ "role": "assistant",
102
+ "content": response_text,
103
+ "tool_calls": [{
104
+ "type": "function",
105
+ "function": {
106
+ "name": "retrieve_design_examples",
107
+ "arguments": json.dumps({
108
+ "conversation": [msg.get("content", msg) if isinstance(msg, dict) else msg
109
+ for msg in messages],
110
+ "num_examples": 1
111
+ })
112
+ }
113
+ }]
114
+ })
115
+ else:
116
+ # Regular response without tool calls
117
+ state["messages"].append({
118
+ "role": "assistant",
119
+ "content": response_text
120
+ })
121
+
122
+ return state