Spaces:
Runtime error
Runtime error
Commit
·
fe041cb
1
Parent(s):
361875d
task: set up chainlit langgraph
Browse files- .gitignore +13 -5
- src/agents/designer.py +16 -20
- src/agents/workflow.py +59 -39
- src/app.py +15 -8
- src/chains/design_rag.py +58 -67
.gitignore
CHANGED
@@ -1,17 +1,25 @@
|
|
|
|
|
|
|
|
|
|
1 |
.env
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
.env.local
|
3 |
.env.development.local
|
4 |
.env.test.local
|
5 |
.env.production.local
|
6 |
|
7 |
-
.venv
|
8 |
.files
|
9 |
|
10 |
.ipynb_checkpoints
|
11 |
|
12 |
-
.chainlit
|
13 |
.chainlit/cache
|
14 |
|
15 |
-
/data
|
16 |
-
|
17 |
-
__pycache__
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
.env
|
6 |
+
.venv/
|
7 |
+
venv/
|
8 |
+
|
9 |
+
# Project specific
|
10 |
+
designs/
|
11 |
+
.chainlit/
|
12 |
+
dist/
|
13 |
+
|
14 |
.env.local
|
15 |
.env.development.local
|
16 |
.env.test.local
|
17 |
.env.production.local
|
18 |
|
|
|
19 |
.files
|
20 |
|
21 |
.ipynb_checkpoints
|
22 |
|
|
|
23 |
.chainlit/cache
|
24 |
|
25 |
+
/data
|
|
|
|
src/agents/designer.py
CHANGED
@@ -1,19 +1,18 @@
|
|
1 |
-
from
|
2 |
-
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from chains.design_rag import DesignRAG
|
4 |
from .workflow import create_graph, AgentState
|
5 |
from pathlib import Path
|
6 |
|
7 |
class DesignerAgent:
|
8 |
-
def __init__(self, rag: DesignRAG):
|
9 |
self.rag = rag
|
10 |
self.workflow = create_graph(rag)
|
11 |
-
self.state: AgentState =
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
|
18 |
def _load_default_html(self) -> str:
|
19 |
"""Load default CSS Zen Garden HTML"""
|
@@ -28,14 +27,11 @@ class DesignerAgent:
|
|
28 |
|
29 |
async def process(self, message: str) -> str:
|
30 |
"""Process a message through the workflow"""
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
# Return last response
|
41 |
-
return self.state["messages"][-1].content
|
|
|
1 |
+
from langchain_core.messages import HumanMessage
|
|
|
2 |
from chains.design_rag import DesignRAG
|
3 |
from .workflow import create_graph, AgentState
|
4 |
from pathlib import Path
|
5 |
|
6 |
class DesignerAgent:
|
7 |
+
def __init__(self, rag: DesignRAG) -> None:
|
8 |
self.rag = rag
|
9 |
self.workflow = create_graph(rag)
|
10 |
+
self.state: AgentState = AgentState(
|
11 |
+
messages=[],
|
12 |
+
html_content=self._load_default_html(),
|
13 |
+
style_requirements={},
|
14 |
+
css_output=None
|
15 |
+
)
|
16 |
|
17 |
def _load_default_html(self) -> str:
|
18 |
"""Load default CSS Zen Garden HTML"""
|
|
|
27 |
|
28 |
async def process(self, message: str) -> str:
|
29 |
"""Process a message through the workflow"""
|
30 |
+
try:
|
31 |
+
self.state["messages"].append(HumanMessage(content=message))
|
32 |
+
next_state = await self.workflow.ainvoke(self.state)
|
33 |
+
self.state = next_state
|
34 |
+
return self.state["messages"][-1].content
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error in process: {str(e)}")
|
37 |
+
return "I encountered an error processing your message."
|
|
|
|
|
|
src/agents/workflow.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Dict, List, Annotated, TypedDict
|
2 |
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from langgraph.graph import StateGraph, END
|
4 |
from langchain_core.messages import BaseMessage
|
@@ -6,6 +6,7 @@ from chains.design_rag import DesignRAG
|
|
6 |
from langchain.prompts import ChatPromptTemplate
|
7 |
import chainlit as cl
|
8 |
import json
|
|
|
9 |
|
10 |
# Define state types
|
11 |
class AgentState(TypedDict):
|
@@ -14,35 +15,47 @@ class AgentState(TypedDict):
|
|
14 |
style_requirements: Dict
|
15 |
css_output: str | None
|
16 |
|
17 |
-
#
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""Handle conversation and requirement gathering"""
|
20 |
-
# Get last message
|
21 |
last_message = state["messages"][-1]
|
|
|
|
|
22 |
|
23 |
if not isinstance(last_message, HumanMessage):
|
24 |
-
return
|
25 |
|
26 |
-
# Check for style requirements readiness
|
27 |
if "!generate" in last_message.content.lower():
|
28 |
-
|
29 |
-
|
30 |
-
return
|
31 |
-
|
32 |
-
"
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
-
# Normal conversation -
|
37 |
-
response =
|
38 |
-
|
39 |
|
40 |
-
return
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
async def extract_requirements(
|
43 |
-
messages: List[BaseMessage],
|
44 |
-
rag: DesignRAG
|
45 |
-
) -> Dict:
|
46 |
"""Extract style requirements from conversation history"""
|
47 |
# Combine messages into context
|
48 |
context = "\n".join([
|
@@ -65,8 +78,8 @@ async def extract_requirements(
|
|
65 |
}}
|
66 |
"""
|
67 |
|
68 |
-
# Get requirements through
|
69 |
-
response = await
|
70 |
return json.loads(response.content)
|
71 |
|
72 |
async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
|
@@ -107,37 +120,44 @@ async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
|
|
107 |
)
|
108 |
)
|
109 |
|
110 |
-
|
111 |
-
state["css_output"] = response.content
|
112 |
-
|
113 |
-
# Add completion message
|
114 |
-
state["messages"].append(AIMessage(content="""I've generated the CSS based on your requirements.
|
115 |
Here's what I created:
|
116 |
|
117 |
```css
|
118 |
-
{
|
119 |
```
|
120 |
|
121 |
Would you like me to explain any part of the design or make any adjustments?
|
122 |
-
"""
|
|
|
|
|
123 |
|
124 |
-
return
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
def create_graph(rag: DesignRAG) -> StateGraph:
|
127 |
"""Create the workflow graph"""
|
128 |
-
# Create graph
|
129 |
workflow = StateGraph(AgentState)
|
130 |
|
131 |
-
# Add nodes
|
132 |
-
workflow.add_node("conversation",
|
133 |
workflow.add_node("generate_css", lambda s: generate_css_node(s, rag))
|
134 |
|
135 |
# Add edges
|
136 |
-
workflow.
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
# Set entry point
|
141 |
workflow.set_entry_point("conversation")
|
142 |
-
|
143 |
return workflow.compile()
|
|
|
1 |
+
from typing import Dict, List, Annotated, TypedDict, Callable, Any
|
2 |
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from langgraph.graph import StateGraph, END
|
4 |
from langchain_core.messages import BaseMessage
|
|
|
6 |
from langchain.prompts import ChatPromptTemplate
|
7 |
import chainlit as cl
|
8 |
import json
|
9 |
+
from langchain_openai import ChatOpenAI
|
10 |
|
11 |
# Define state types
|
12 |
class AgentState(TypedDict):
|
|
|
15 |
style_requirements: Dict
|
16 |
css_output: str | None
|
17 |
|
18 |
+
# At the top level
|
19 |
+
llm = ChatOpenAI(temperature=0.2)
|
20 |
+
|
21 |
+
def should_end(state: AgentState) -> str:
|
22 |
+
"""Determine if the conversation should end"""
|
23 |
+
if state["messages"] and isinstance(state["messages"][-1], HumanMessage):
|
24 |
+
if "!done" in state["messages"][-1].content.lower():
|
25 |
+
return "end"
|
26 |
+
return "conversation"
|
27 |
+
|
28 |
+
async def conversation_node(state: AgentState) -> AgentState:
|
29 |
"""Handle conversation and requirement gathering"""
|
|
|
30 |
last_message = state["messages"][-1]
|
31 |
+
print("last message received")
|
32 |
+
print(last_message)
|
33 |
|
34 |
if not isinstance(last_message, HumanMessage):
|
35 |
+
return state
|
36 |
|
|
|
37 |
if "!generate" in last_message.content.lower():
|
38 |
+
requirements = await extract_requirements(state["messages"])
|
39 |
+
await cl.Message("I'll generate CSS based on your requirements...").send()
|
40 |
+
return AgentState(
|
41 |
+
messages=state["messages"],
|
42 |
+
html_content=state["html_content"],
|
43 |
+
style_requirements=requirements,
|
44 |
+
css_output=state["css_output"]
|
45 |
+
)
|
46 |
|
47 |
+
# Normal conversation - just acknowledge and guide
|
48 |
+
response = "I understand. Tell me more about the design you're looking for, or type !generate when you're ready to create the CSS."
|
49 |
+
await cl.Message(content=response).send()
|
50 |
|
51 |
+
return AgentState(
|
52 |
+
messages=[*state["messages"], AIMessage(content=response)],
|
53 |
+
html_content=state["html_content"],
|
54 |
+
style_requirements=state.get("style_requirements", {}),
|
55 |
+
css_output=state.get("css_output")
|
56 |
+
)
|
57 |
|
58 |
+
async def extract_requirements(messages: List[BaseMessage]) -> Dict:
|
|
|
|
|
|
|
59 |
"""Extract style requirements from conversation history"""
|
60 |
# Combine messages into context
|
61 |
context = "\n".join([
|
|
|
78 |
}}
|
79 |
"""
|
80 |
|
81 |
+
# Get requirements through LLM
|
82 |
+
response = await llm.ainvoke(prompt)
|
83 |
return json.loads(response.content)
|
84 |
|
85 |
async def generate_css_node(state: AgentState, rag: DesignRAG) -> AgentState:
|
|
|
120 |
)
|
121 |
)
|
122 |
|
123 |
+
css_message = f"""I've generated the CSS based on your requirements.
|
|
|
|
|
|
|
|
|
124 |
Here's what I created:
|
125 |
|
126 |
```css
|
127 |
+
{response.content}
|
128 |
```
|
129 |
|
130 |
Would you like me to explain any part of the design or make any adjustments?
|
131 |
+
Type !done when you're satisfied with the result."""
|
132 |
+
|
133 |
+
await cl.Message(content=css_message).send()
|
134 |
|
135 |
+
return AgentState(
|
136 |
+
messages=[*state["messages"], AIMessage(content=css_message)],
|
137 |
+
html_content=state["html_content"],
|
138 |
+
style_requirements=state["style_requirements"],
|
139 |
+
css_output=response.content
|
140 |
+
)
|
141 |
|
142 |
def create_graph(rag: DesignRAG) -> StateGraph:
|
143 |
"""Create the workflow graph"""
|
|
|
144 |
workflow = StateGraph(AgentState)
|
145 |
|
146 |
+
# Add nodes directly
|
147 |
+
workflow.add_node("conversation", conversation_node)
|
148 |
workflow.add_node("generate_css", lambda s: generate_css_node(s, rag))
|
149 |
|
150 |
# Add edges
|
151 |
+
workflow.add_conditional_edges(
|
152 |
+
"conversation",
|
153 |
+
should_end, # Use our existing should_end function
|
154 |
+
{
|
155 |
+
"conversation": "conversation",
|
156 |
+
"generate_css": "generate_css",
|
157 |
+
"end": END
|
158 |
+
}
|
159 |
+
)
|
160 |
+
workflow.add_edge("generate_css", "conversation")
|
161 |
|
|
|
162 |
workflow.set_entry_point("conversation")
|
|
|
163 |
return workflow.compile()
|
src/app.py
CHANGED
@@ -1,23 +1,30 @@
|
|
1 |
import chainlit as cl
|
2 |
-
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from chains.design_rag import DesignRAG
|
4 |
from agents.designer import DesignerAgent
|
5 |
|
|
|
|
|
|
|
|
|
6 |
@cl.on_chat_start
|
7 |
-
async def start():
|
8 |
"""Initialize the chat session"""
|
9 |
-
# Initialize RAG system
|
10 |
-
design_rag = DesignRAG()
|
11 |
-
# Initialize designer agent
|
12 |
-
designer = DesignerAgent(rag=design_rag)
|
13 |
-
|
14 |
# Store in user session
|
15 |
cl.user_session.set("designer", designer)
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
@cl.on_message
|
18 |
-
async def main(message: cl.Message):
|
19 |
"""Handle incoming messages"""
|
20 |
designer = cl.user_session.get("designer")
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Process message through designer agent
|
23 |
response = await designer.process(message.content)
|
|
|
1 |
import chainlit as cl
|
|
|
2 |
from chains.design_rag import DesignRAG
|
3 |
from agents.designer import DesignerAgent
|
4 |
|
5 |
+
# Initialize these once at module level
|
6 |
+
design_rag = DesignRAG()
|
7 |
+
designer = DesignerAgent(rag=design_rag)
|
8 |
+
|
9 |
@cl.on_chat_start
|
10 |
+
async def start() -> None:
|
11 |
"""Initialize the chat session"""
|
|
|
|
|
|
|
|
|
|
|
12 |
# Store in user session
|
13 |
cl.user_session.set("designer", designer)
|
14 |
+
|
15 |
+
# Send welcome message
|
16 |
+
await cl.Message(
|
17 |
+
content="Welcome! I'm here to help you imagine a unique design. What style are you looking for?"
|
18 |
+
).send()
|
19 |
|
20 |
@cl.on_message
|
21 |
+
async def main(message: cl.Message) -> None:
|
22 |
"""Handle incoming messages"""
|
23 |
designer = cl.user_session.get("designer")
|
24 |
+
if designer is None:
|
25 |
+
# Reinitialize if missing
|
26 |
+
designer = DesignerAgent(rag=design_rag)
|
27 |
+
cl.user_session.set("designer", designer)
|
28 |
|
29 |
# Process message through designer agent
|
30 |
response = await designer.process(message.content)
|
src/chains/design_rag.py
CHANGED
@@ -5,7 +5,8 @@ from langchain_community.vectorstores import FAISS
|
|
5 |
from langchain.prompts import ChatPromptTemplate
|
6 |
from pathlib import Path
|
7 |
import json
|
8 |
-
from typing import Dict
|
|
|
9 |
|
10 |
class DesignRAG:
|
11 |
def __init__(self):
|
@@ -18,78 +19,68 @@ class DesignRAG:
|
|
18 |
# Create retriever
|
19 |
self.retriever = self.vector_store.as_retriever(
|
20 |
search_type="similarity",
|
21 |
-
search_kwargs={"k":
|
22 |
)
|
23 |
|
24 |
# Create LLM
|
25 |
-
self.llm = ChatOpenAI(temperature=0.
|
26 |
-
|
27 |
-
# Create the RAG chain
|
28 |
-
self.chain = self._create_chain()
|
29 |
|
30 |
-
def _create_vector_store(self):
|
31 |
"""Create FAISS vector store from design metadata"""
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
# Load all metadata files
|
36 |
-
for design_dir in designs_dir.glob("**/metadata.json"):
|
37 |
-
with open(design_dir, "r") as f:
|
38 |
-
metadata = json.load(f)
|
39 |
-
|
40 |
-
# Create document text from metadata
|
41 |
-
text = f"""
|
42 |
-
Design {metadata['id']}:
|
43 |
-
Description: {metadata.get('description', '')}
|
44 |
-
Categories: {', '.join(metadata.get('categories', []))}
|
45 |
-
Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
|
46 |
-
"""
|
47 |
|
48 |
-
|
49 |
-
css_path = design_dir.parent / "style.css"
|
50 |
-
if css_path.exists():
|
51 |
-
with open(css_path, "r") as f:
|
52 |
-
css = f.read()
|
53 |
-
text += f"\nCSS:\n{css}"
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
"
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
async def query_similar_designs(self, requirements: Dict) -> str:
|
95 |
"""Find similar designs based on requirements"""
|
@@ -103,7 +94,7 @@ class DesignRAG:
|
|
103 |
"""
|
104 |
|
105 |
# Get similar documents
|
106 |
-
docs = self.retriever.get_relevant_documents(query)
|
107 |
|
108 |
# Format examples
|
109 |
examples = []
|
|
|
5 |
from langchain.prompts import ChatPromptTemplate
|
6 |
from pathlib import Path
|
7 |
import json
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
from langchain_core.documents import Document
|
10 |
|
11 |
class DesignRAG:
|
12 |
def __init__(self):
|
|
|
19 |
# Create retriever
|
20 |
self.retriever = self.vector_store.as_retriever(
|
21 |
search_type="similarity",
|
22 |
+
search_kwargs={"k": 5}
|
23 |
)
|
24 |
|
25 |
# Create LLM
|
26 |
+
self.llm = ChatOpenAI(temperature=0.2)
|
|
|
|
|
|
|
27 |
|
28 |
+
def _create_vector_store(self) -> FAISS:
|
29 |
"""Create FAISS vector store from design metadata"""
|
30 |
+
try:
|
31 |
+
# Update path to look in data/designs
|
32 |
+
designs_dir = Path("data/designs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
documents = []
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Load all metadata files
|
37 |
+
for design_dir in designs_dir.glob("**/metadata.json"):
|
38 |
+
try:
|
39 |
+
with open(design_dir, "r") as f:
|
40 |
+
metadata = json.load(f)
|
41 |
+
|
42 |
+
# Create document text from metadata with safe gets
|
43 |
+
text = f"""
|
44 |
+
Design {metadata.get('id', 'unknown')}:
|
45 |
+
Description: {metadata.get('description', 'No description available')}
|
46 |
+
Categories: {', '.join(metadata.get('categories', []))}
|
47 |
+
Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
|
48 |
+
"""
|
49 |
+
|
50 |
+
# Load associated CSS
|
51 |
+
css_path = design_dir.parent / "style.css"
|
52 |
+
if css_path.exists():
|
53 |
+
with open(css_path, "r") as f:
|
54 |
+
css = f.read()
|
55 |
+
text += f"\nCSS:\n{css}"
|
56 |
+
|
57 |
+
# Create Document object with minimal metadata
|
58 |
+
documents.append(
|
59 |
+
Document(
|
60 |
+
page_content=text.strip(),
|
61 |
+
metadata={
|
62 |
+
"id": metadata.get('id', 'unknown'),
|
63 |
+
"path": str(design_dir.parent)
|
64 |
+
}
|
65 |
+
)
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error processing design {design_dir}: {e}")
|
69 |
+
continue
|
70 |
+
|
71 |
+
if not documents:
|
72 |
+
print("Warning: No valid design documents found")
|
73 |
+
# Create empty vector store with a placeholder document
|
74 |
+
return FAISS.from_documents(
|
75 |
+
[Document(page_content="No designs available", metadata={"id": "placeholder"})],
|
76 |
+
self.embeddings
|
77 |
+
)
|
78 |
+
|
79 |
+
# Create and return vector store
|
80 |
+
return FAISS.from_documents(documents, self.embeddings)
|
81 |
+
except Exception as e:
|
82 |
+
print(f"Error creating vector store: {str(e)}")
|
83 |
+
raise
|
84 |
|
85 |
async def query_similar_designs(self, requirements: Dict) -> str:
|
86 |
"""Find similar designs based on requirements"""
|
|
|
94 |
"""
|
95 |
|
96 |
# Get similar documents
|
97 |
+
docs = await self.retriever.get_relevant_documents(query)
|
98 |
|
99 |
# Format examples
|
100 |
examples = []
|