team-ai / code_generate.py
peichao.dong
optimate ui
2459b19
raw
history blame
10 kB
import re
from typing import List, Union
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.agents import tool, Tool, LLMSingleActionAgent, AgentExecutor, AgentOutputParser
from langchain.schema import AgentAction, AgentFinish
from langchain.agents import initialize_agent
from langchain.memory import ConversationBufferMemory
from langchain.prompts import StringPromptTemplate
from promopts import code_generate_agent_template
from promopts import API_LAYER_PROMPT, DOMAIN_LAYER_PROMPT, PERSISTENT_LAYER_PROMPT
domainLayerChain = LLMChain(llm = ChatOpenAI(temperature=0.1), prompt=DOMAIN_LAYER_PROMPT)
persistentChain = LLMChain(llm = ChatOpenAI(temperature=0.1), prompt=PERSISTENT_LAYER_PROMPT)
apiChain = LLMChain(llm = ChatOpenAI(temperature=0.1), prompt=API_LAYER_PROMPT)
@tool("Generate Domain Layer Code", return_direct=True)
def domainLayerCodeGenerator(input: str) -> str:
'''useful for when you need to generate domain layer code'''
response = domainLayerChain.run(input)
return response
@tool("Generate Persistent Layer Code", return_direct=True)
def persistentLayerCodeGenerator(input: str) -> str:
'''useful for when you need to generate persistent layer code'''
response = persistentChain.run(input)
return response
@tool("Generate API Layer Code", return_direct=True)
def apiLayerCodeGenerator(input: str) -> str:
'''useful for when you need to generate API layer code'''
response = apiChain.run(input)
return response
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]
def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
return self.template.format(**kwargs)
class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# Check if agent should finish
if "Final Answer:" in llm_output:
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": llm_output.split(
"Final Answer:")[-1].strip()},
log=llm_output,
)
# Parse out the action and action input
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
raise ValueError(f"Could not parse LLM output: `{llm_output}`")
action = match.group(1).strip()
action_input = match.group(2)
# Return the action and action input
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
# chatllm=ChatOpenAI(temperature=0)
# code_genenrate_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# code_generate_agent = initialize_agent(tools, chatllm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, memory=memory, verbose=True)
# agent = initialize_agent(
# tools=tools, llm=llm_chain, template=AGENT_PROMPT, stop=["\nObservation:"], agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
tools = [domainLayerCodeGenerator, persistentLayerCodeGenerator, apiLayerCodeGenerator]
def code_agent_executor() -> AgentExecutor:
output_parser = CustomOutputParser()
AGENT_PROMPT = CustomPromptTemplate(
template=code_generate_agent_template,
tools=tools,
# This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically
# This includes the `intermediate_steps` variable because that is needed
input_variables=["input", "intermediate_steps"]
)
code_llm_chain = LLMChain(llm=ChatOpenAI(temperature=0.7), prompt=AGENT_PROMPT)
tool_names = [tool.name for tool in tools]
code_agent = LLMSingleActionAgent(
llm_chain=code_llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names,
)
code_agent_executor = AgentExecutor.from_agent_and_tools(
agent=code_agent, tools=tools, verbose=True)
return code_agent_executor
# if __name__ == "__main__":
# response = domainLayerChain.run("""FeatureConfig用于配置某个Feature中控制前端展示效果的配置项
# FeatureConfig主要属性包括:featureKey(feature标识)、data(配置数据)、saData(埋点数据)、status(状态)、标题、描述、创建时间、更新时间
# FeatureConfig中status为枚举值,取值范围为(DRAFT、PUBLISHED、DISABLED)
# FeatureConfig新增后status为DRAFT、执行发布操作后变为PUBLISHED、执行撤销操作后变为DISABLED
# 状态为DRAFT的FeatureConfig可以执行编辑、发布、撤销操作
# 发布后FeatureConfig变为PUBLISHED状态,可以执行撤销操作
# 撤销后FeatureConfig变为DISABLED状态,不可以执行编辑、发布、撤销操作
# """)
# print(response)
# response = persistentChain.run("""
# Entity:
# ```
# public class FeatureConfig {
# private FeatureConfigId id;
# private FeatureConfigDescription description;
# public enum FeatureConfigStatus {
# DRAFT, PUBLISHED, DISABLED;
# }
# public record FeatureConfigId(String id) {}
# public record FeatureKey(String key) {}
# public record FeatureConfigData(String data) {}
# public record FeatureConfigSaData(String saData) {}
# @Builder
# public record FeatureConfigDescription(FeatureKey featureKey, FeatureConfigData data, FeatureConfigSaData saData, String title, String description,
# FeatureConfigStatus status, LocalDateTime createTime, LocalDateTime updateTime) {}
# public void update(FeatureConfigDescription description) {
# this.title = description.title();
# this.description = description.description();
# this.updateTime = LocalDateTime.now();
# }
# public void publish() {
# this.status = FeatureConfigStatus.PUBLISHED;
# this.updateTime = LocalDateTime.now();
# }
# public void disable() {
# this.status = FeatureConfigStatus.DISABLED;
# this.updateTime = LocalDateTime.now();
# }
# }
# ```
# Association:
# ```
# public interface FeatureConfigs {
# Flux<FeatureConfig> findAllByFeatureKey(String featureKey);
# Mono<FeatureConfig> findById(FeatureConfigId id);
# Mono<FeatureConfig> save(FeatureConfig featureConfig);
# }
# ```
# """)
# print(response)
# response = apiChain.run("""
# Entity:
# ```
# public class FeatureConfig {
# private FeatureConfigId id;
# private FeatureConfigDescription description;
# public enum FeatureConfigStatus {
# DRAFT, PUBLISHED, DISABLED;
# }
# public record FeatureConfigId(String id) {}
# public record FeatureKey(String key) {}
# public record FeatureConfigData(String data) {}
# public record FeatureConfigSaData(String saData) {}
# @Builder
# public record FeatureConfigDescription(FeatureKey featureKey, FeatureConfigData data, FeatureConfigSaData saData, String title, String description,
# FeatureConfigStatus status, LocalDateTime createTime, LocalDateTime updateTime) {}
# public void update(FeatureConfigDescription description) {
# this.title = description.title();
# this.description = description.description();
# this.updateTime = LocalDateTime.now();
# }
# public void publish() {
# this.status = FeatureConfigStatus.PUBLISHED;
# this.updateTime = LocalDateTime.now();
# }
# public void disable() {
# this.status = FeatureConfigStatus.DISABLED;
# this.updateTime = LocalDateTime.now();
# }
# }
# ```
# Association:
# ```
# public interface FeatureConfigs {
# Flux<FeatureConfig> findAllByFeatureKey(String featureKey);
# Mono<FeatureConfig> findById(FeatureConfigId id);
# Mono<FeatureConfig> save(FeatureConfig featureConfig);
# Mono<Void> update(FeatureConfigId id, FeatureConfigDescription description);
# Mono<Void> publish(FeatureConfigId id);
# Mono<Void> disable(FeatureConfigId id);
# }
# ```
# """)
# print(response)
# if __name__ == "code_generate":
# response = code_agent_executor.run("""
# 根据如下需求generate domain layer code:
# ---
# FeatureConfig用于配置某个Feature中控制前端展示效果的配置项
# FeatureConfig主要属性包括:featureKey(feature标识)、data(配置数据)、saData(埋点数据)、status(状态)、标题、描述、创建时间、更新时间
# FeatureConfig中status为枚举值,取值范围为(DRAFT、PUBLISHED、DISABLED)
# FeatureConfig新增后status为DRAFT、执行发布操作后变为PUBLISHED、执行撤销操作后变为DISABLED
# 状态为DRAFT的FeatureConfig可以执行编辑、发布、撤销操作
# 发布后FeatureConfig变为PUBLISHED状态,可以执行撤销操作
# 撤销后FeatureConfig变为DISABLED状态,不可以执行编辑、发布、撤销操作
# ---
# """)
# print(response)