Technologic101 commited on
Commit
fe041cb
·
1 Parent(s): 361875d

task: set up chainlit langgraph

Browse files
Files changed (5) hide show
  1. .gitignore +13 -5
  2. src/agents/designer.py +16 -20
  3. src/agents/workflow.py +59 -39
  4. src/app.py +15 -8
  5. src/chains/design_rag.py +58 -67
.gitignore CHANGED
@@ -1,17 +1,25 @@
 
 
 
 
1
  .env
 
 
 
 
 
 
 
 
2
  .env.local
3
  .env.development.local
4
  .env.test.local
5
  .env.production.local
6
 
7
- .venv
8
  .files
9
 
10
  .ipynb_checkpoints
11
 
12
- .chainlit
13
  .chainlit/cache
14
 
15
- /data
16
-
17
- __pycache__
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
  .env
6
+ .venv/
7
+ venv/
8
+
9
+ # Project specific
10
+ designs/
11
+ .chainlit/
12
+ dist/
13
+
14
  .env.local
15
  .env.development.local
16
  .env.test.local
17
  .env.production.local
18
 
 
19
  .files
20
 
21
  .ipynb_checkpoints
22
 
 
23
  .chainlit/cache
24
 
25
+ /data
 
 
src/agents/designer.py CHANGED
@@ -1,19 +1,18 @@
1
- from typing import List, Dict
2
- from langchain_core.messages import HumanMessage, AIMessage
3
  from chains.design_rag import DesignRAG
4
  from .workflow import create_graph, AgentState
5
  from pathlib import Path
6
 
7
  class DesignerAgent:
8
- def __init__(self, rag: DesignRAG):
9
  self.rag = rag
10
  self.workflow = create_graph(rag)
11
- self.state: AgentState = {
12
- "messages": [],
13
- "html_content": self._load_default_html(),
14
- "style_requirements": {},
15
- "css_output": None
16
- }
17
 
18
  def _load_default_html(self) -> str:
19
  """Load default CSS Zen Garden HTML"""
@@ -28,14 +27,11 @@ class DesignerAgent:
28
 
29
  async def process(self, message: str) -> str:
30
  """Process a message through the workflow"""
31
- # Add message to state
32
- self.state["messages"].append(HumanMessage(content=message))
33
-
34
- # Run workflow
35
- next_state = await self.workflow.invoke(self.state)
36
-
37
- # Update state
38
- self.state = next_state
39
-
40
- # Return last response
41
- return self.state["messages"][-1].content
 
1
+ from langchain_core.messages import HumanMessage
 
2
  from chains.design_rag import DesignRAG
3
  from .workflow import create_graph, AgentState
4
  from pathlib import Path
5
 
6
  class DesignerAgent:
7
+ def __init__(self, rag: DesignRAG) -> None:
8
  self.rag = rag
9
  self.workflow = create_graph(rag)
10
+ self.state: AgentState = AgentState(
11
+ messages=[],
12
+ html_content=self._load_default_html(),
13
+ style_requirements={},
14
+ css_output=None
15
+ )
16
 
17
  def _load_default_html(self) -> str:
18
  """Load default CSS Zen Garden HTML"""
 
27
 
28
  async def process(self, message: str) -> str:
29
  """Process a message through the workflow"""
30
+ try:
31
+ self.state["messages"].append(HumanMessage(content=message))
32
+ next_state = await self.workflow.ainvoke(self.state)
33
+ self.state = next_state
34
+ return self.state["messages"][-1].content
35
+ except Exception as e:
36
+ print(f"Error in process: {str(e)}")
37
+ return "I encountered an error processing your message."
 
 
 
src/agents/workflow.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Annotated, TypedDict
2
  from langchain_core.messages import HumanMessage, AIMessage
3
  from langgraph.graph import StateGraph, END
4
  from langchain_core.messages import BaseMessage
@@ -6,6 +6,7 @@ from chains.design_rag import DesignRAG
6
  from langchain.prompts import ChatPromptTemplate
7
  import chainlit as cl
8
  import json
 
9
 
10
  # Define state types
11
  class AgentState(TypedDict):
@@ -14,35 +15,47 @@ class AgentState(TypedDict):
14
  style_requirements: Dict
15
  css_output: str | None
16
 
17
- # Node functions
18
- async def conversation_node(state: AgentState, rag: DesignRAG):
 
 
 
 
 
 
 
 
 
19
  """Handle conversation and requirement gathering"""
20
- # Get last message
21
  last_message = state["messages"][-1]
 
 
22
 
23
  if not isinstance(last_message, HumanMessage):
24
- return {"messages": state["messages"]}
25
 
26
- # Check for style requirements readiness
27
  if "!generate" in last_message.content.lower():
28
- # Extract style requirements from conversation
29
- requirements = await extract_requirements(state["messages"], rag)
30
- return {
31
- "messages": state["messages"],
32
- "style_requirements": requirements,
33
- "next": "generate_css"
34
- }
 
35
 
36
- # Normal conversation - get context and respond
37
- response = await rag.query(last_message.content)
38
- state["messages"].append(AIMessage(content=response))
39
 
40
- return {"messages": state["messages"]}
 
 
 
 
 
41
 
42
- async def extract_requirements(
43
- messages: List[BaseMessage],
44
- rag: DesignRAG
45
- ) -> Dict:
46
  """Extract style requirements from conversation history"""
47
  # Combine messages into context
48
  context = "\n".join([
@@ -65,8 +78,8 @@ async def extract_requirements(
65
  }}
66
  """
67
 
68
- # Get requirements through RAG system
69
- response = await rag.llm.ainvoke(prompt)
70
  return json.loads(response.content)
71
 
72
  async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
@@ -107,37 +120,44 @@ async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
107
  )
108
  )
109
 
110
- # Store generated CSS
111
- state["css_output"] = response.content
112
-
113
- # Add completion message
114
- state["messages"].append(AIMessage(content="""I've generated the CSS based on your requirements.
115
  Here's what I created:
116
 
117
  ```css
118
- {css}
119
  ```
120
 
121
  Would you like me to explain any part of the design or make any adjustments?
122
- """.format(css=response.content)))
 
 
123
 
124
- return state
 
 
 
 
 
125
 
126
  def create_graph(rag: DesignRAG) -> StateGraph:
127
  """Create the workflow graph"""
128
- # Create graph
129
  workflow = StateGraph(AgentState)
130
 
131
- # Add nodes
132
- workflow.add_node("conversation", lambda s: conversation_node(s, rag))
133
  workflow.add_node("generate_css", lambda s: generate_css_node(s, rag))
134
 
135
  # Add edges
136
- workflow.add_edge("conversation", "conversation")
137
- workflow.add_edge("conversation", "generate_css")
138
- workflow.add_edge("generate_css", END)
 
 
 
 
 
 
 
139
 
140
- # Set entry point
141
  workflow.set_entry_point("conversation")
142
-
143
  return workflow.compile()
 
1
+ from typing import Dict, List, Annotated, TypedDict, Callable, Any
2
  from langchain_core.messages import HumanMessage, AIMessage
3
  from langgraph.graph import StateGraph, END
4
  from langchain_core.messages import BaseMessage
 
6
  from langchain.prompts import ChatPromptTemplate
7
  import chainlit as cl
8
  import json
9
+ from langchain_openai import ChatOpenAI
10
 
11
  # Define state types
12
  class AgentState(TypedDict):
 
15
  style_requirements: Dict
16
  css_output: str | None
17
 
18
+ # At the top level
19
+ llm = ChatOpenAI(temperature=0.2)
20
+
21
+ def should_end(state: AgentState) -> str:
22
+ """Determine if the conversation should end"""
23
+ if state["messages"] and isinstance(state["messages"][-1], HumanMessage):
24
+ if "!done" in state["messages"][-1].content.lower():
25
+ return "end"
26
+ return "conversation"
27
+
28
+ async def conversation_node(state: AgentState) -> AgentState:
29
  """Handle conversation and requirement gathering"""
 
30
  last_message = state["messages"][-1]
31
+ print("last message received")
32
+ print(last_message)
33
 
34
  if not isinstance(last_message, HumanMessage):
35
+ return state
36
 
 
37
  if "!generate" in last_message.content.lower():
38
+ requirements = await extract_requirements(state["messages"])
39
+ await cl.Message("I'll generate CSS based on your requirements...").send()
40
+ return AgentState(
41
+ messages=state["messages"],
42
+ html_content=state["html_content"],
43
+ style_requirements=requirements,
44
+ css_output=state["css_output"]
45
+ )
46
 
47
+ # Normal conversation - just acknowledge and guide
48
+ response = "I understand. Tell me more about the design you're looking for, or type !generate when you're ready to create the CSS."
49
+ await cl.Message(content=response).send()
50
 
51
+ return AgentState(
52
+ messages=[*state["messages"], AIMessage(content=response)],
53
+ html_content=state["html_content"],
54
+ style_requirements=state.get("style_requirements", {}),
55
+ css_output=state.get("css_output")
56
+ )
57
 
58
+ async def extract_requirements(messages: List[BaseMessage]) -> Dict:
 
 
 
59
  """Extract style requirements from conversation history"""
60
  # Combine messages into context
61
  context = "\n".join([
 
78
  }}
79
  """
80
 
81
+ # Get requirements through LLM
82
+ response = await llm.ainvoke(prompt)
83
  return json.loads(response.content)
84
 
85
  async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
 
120
  )
121
  )
122
 
123
+ css_message = f"""I've generated the CSS based on your requirements.
 
 
 
 
124
  Here's what I created:
125
 
126
  ```css
127
+ {response.content}
128
  ```
129
 
130
  Would you like me to explain any part of the design or make any adjustments?
131
+ Type !done when you're satisfied with the result."""
132
+
133
+ await cl.Message(content=css_message).send()
134
 
135
+ return AgentState(
136
+ messages=[*state["messages"], AIMessage(content=css_message)],
137
+ html_content=state["html_content"],
138
+ style_requirements=state["style_requirements"],
139
+ css_output=response.content
140
+ )
141
 
142
  def create_graph(rag: DesignRAG) -> StateGraph:
143
  """Create the workflow graph"""
 
144
  workflow = StateGraph(AgentState)
145
 
146
+ # Add nodes directly
147
+ workflow.add_node("conversation", conversation_node)
148
  workflow.add_node("generate_css", lambda s: generate_css_node(s, rag))
149
 
150
  # Add edges
151
+ workflow.add_conditional_edges(
152
+ "conversation",
153
+ should_end, # Use our existing should_end function
154
+ {
155
+ "conversation": "conversation",
156
+ "generate_css": "generate_css",
157
+ "end": END
158
+ }
159
+ )
160
+ workflow.add_edge("generate_css", "conversation")
161
 
 
162
  workflow.set_entry_point("conversation")
 
163
  return workflow.compile()
src/app.py CHANGED
@@ -1,23 +1,30 @@
1
  import chainlit as cl
2
- from langchain_core.messages import HumanMessage, AIMessage
3
  from chains.design_rag import DesignRAG
4
  from agents.designer import DesignerAgent
5
 
 
 
 
 
6
  @cl.on_chat_start
7
- async def start():
8
  """Initialize the chat session"""
9
- # Initialize RAG system
10
- design_rag = DesignRAG()
11
- # Initialize designer agent
12
- designer = DesignerAgent(rag=design_rag)
13
-
14
  # Store in user session
15
  cl.user_session.set("designer", designer)
 
 
 
 
 
16
 
17
  @cl.on_message
18
- async def main(message: cl.Message):
19
  """Handle incoming messages"""
20
  designer = cl.user_session.get("designer")
 
 
 
 
21
 
22
  # Process message through designer agent
23
  response = await designer.process(message.content)
 
1
  import chainlit as cl
 
2
  from chains.design_rag import DesignRAG
3
  from agents.designer import DesignerAgent
4
 
5
+ # Initialize these once at module level
6
+ design_rag = DesignRAG()
7
+ designer = DesignerAgent(rag=design_rag)
8
+
9
  @cl.on_chat_start
10
+ async def start() -> None:
11
  """Initialize the chat session"""
 
 
 
 
 
12
  # Store in user session
13
  cl.user_session.set("designer", designer)
14
+
15
+ # Send welcome message
16
+ await cl.Message(
17
+ content="Welcome! I'm here to help you imagine a unique design. What style are you looking for?"
18
+ ).send()
19
 
20
  @cl.on_message
21
+ async def main(message: cl.Message) -> None:
22
  """Handle incoming messages"""
23
  designer = cl.user_session.get("designer")
24
+ if designer is None:
25
+ # Reinitialize if missing
26
+ designer = DesignerAgent(rag=design_rag)
27
+ cl.user_session.set("designer", designer)
28
 
29
  # Process message through designer agent
30
  response = await designer.process(message.content)
src/chains/design_rag.py CHANGED
@@ -5,7 +5,8 @@ from langchain_community.vectorstores import FAISS
5
  from langchain.prompts import ChatPromptTemplate
6
  from pathlib import Path
7
  import json
8
- from typing import Dict
 
9
 
10
  class DesignRAG:
11
  def __init__(self):
@@ -18,78 +19,68 @@ class DesignRAG:
18
  # Create retriever
19
  self.retriever = self.vector_store.as_retriever(
20
  search_type="similarity",
21
- search_kwargs={"k": 3}
22
  )
23
 
24
  # Create LLM
25
- self.llm = ChatOpenAI(temperature=0.7)
26
-
27
- # Create the RAG chain
28
- self.chain = self._create_chain()
29
 
30
- def _create_vector_store(self):
31
  """Create FAISS vector store from design metadata"""
32
- designs_dir = Path("designs")
33
- documents = []
34
-
35
- # Load all metadata files
36
- for design_dir in designs_dir.glob("**/metadata.json"):
37
- with open(design_dir, "r") as f:
38
- metadata = json.load(f)
39
-
40
- # Create document text from metadata
41
- text = f"""
42
- Design {metadata['id']}:
43
- Description: {metadata.get('description', '')}
44
- Categories: {', '.join(metadata.get('categories', []))}
45
- Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
46
- """
47
 
48
- # Load associated CSS
49
- css_path = design_dir.parent / "style.css"
50
- if css_path.exists():
51
- with open(css_path, "r") as f:
52
- css = f.read()
53
- text += f"\nCSS:\n{css}"
54
 
55
- documents.append({
56
- "page_content": text,
57
- "metadata": {
58
- "id": metadata["id"],
59
- "url": metadata.get("url", ""),
60
- }
61
- })
62
-
63
- # Create and return vector store
64
- return FAISS.from_documents(documents, self.embeddings)
65
-
66
- def _create_chain(self):
67
- """Create the RAG processing chain"""
68
- # Define prompt template
69
- template = """You are a design assistant helping to find and adapt CSS Zen Garden designs.
70
- Use the following similar designs to inform your response:
71
-
72
- {context}
73
-
74
- Based on these examples, help the user with their request:
75
- {question}
76
- """
77
-
78
- prompt = ChatPromptTemplate.from_template(template)
79
-
80
- # Create and return chain
81
- chain = (
82
- {"context": self.retriever, "question": RunnablePassthrough()}
83
- | prompt
84
- | self.llm
85
- | StrOutputParser()
86
- )
87
-
88
- return chain
89
-
90
- async def query(self, question: str) -> str:
91
- """Process a query through the RAG system"""
92
- return await self.chain.ainvoke(question)
 
 
 
 
 
 
 
 
 
 
93
 
94
  async def query_similar_designs(self, requirements: Dict) -> str:
95
  """Find similar designs based on requirements"""
@@ -103,7 +94,7 @@ class DesignRAG:
103
  """
104
 
105
  # Get similar documents
106
- docs = self.retriever.get_relevant_documents(query)
107
 
108
  # Format examples
109
  examples = []
 
5
  from langchain.prompts import ChatPromptTemplate
6
  from pathlib import Path
7
  import json
8
+ from typing import Dict, List, Optional
9
+ from langchain_core.documents import Document
10
 
11
  class DesignRAG:
12
  def __init__(self):
 
19
  # Create retriever
20
  self.retriever = self.vector_store.as_retriever(
21
  search_type="similarity",
22
+ search_kwargs={"k": 5}
23
  )
24
 
25
  # Create LLM
26
+ self.llm = ChatOpenAI(temperature=0.2)
 
 
 
27
 
28
+ def _create_vector_store(self) -> FAISS:
29
  """Create FAISS vector store from design metadata"""
30
+ try:
31
+ # Update path to look in data/designs
32
+ designs_dir = Path("data/designs")
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ documents = []
 
 
 
 
 
35
 
36
+ # Load all metadata files
37
+ for design_dir in designs_dir.glob("**/metadata.json"):
38
+ try:
39
+ with open(design_dir, "r") as f:
40
+ metadata = json.load(f)
41
+
42
+ # Create document text from metadata with safe gets
43
+ text = f"""
44
+ Design {metadata.get('id', 'unknown')}:
45
+ Description: {metadata.get('description', 'No description available')}
46
+ Categories: {', '.join(metadata.get('categories', []))}
47
+ Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
48
+ """
49
+
50
+ # Load associated CSS
51
+ css_path = design_dir.parent / "style.css"
52
+ if css_path.exists():
53
+ with open(css_path, "r") as f:
54
+ css = f.read()
55
+ text += f"\nCSS:\n{css}"
56
+
57
+ # Create Document object with minimal metadata
58
+ documents.append(
59
+ Document(
60
+ page_content=text.strip(),
61
+ metadata={
62
+ "id": metadata.get('id', 'unknown'),
63
+ "path": str(design_dir.parent)
64
+ }
65
+ )
66
+ )
67
+ except Exception as e:
68
+ print(f"Error processing design {design_dir}: {e}")
69
+ continue
70
+
71
+ if not documents:
72
+ print("Warning: No valid design documents found")
73
+ # Create empty vector store with a placeholder document
74
+ return FAISS.from_documents(
75
+ [Document(page_content="No designs available", metadata={"id": "placeholder"})],
76
+ self.embeddings
77
+ )
78
+
79
+ # Create and return vector store
80
+ return FAISS.from_documents(documents, self.embeddings)
81
+ except Exception as e:
82
+ print(f"Error creating vector store: {str(e)}")
83
+ raise
84
 
85
  async def query_similar_designs(self, requirements: Dict) -> str:
86
  """Find similar designs based on requirements"""
 
94
  """
95
 
96
  # Get similar documents
97
+ docs = await self.retriever.get_relevant_documents(query)
98
 
99
  # Format examples
100
  examples = []