Sagent / writer /writer.py
damin
重写字符串匹配
c3ef9a9
import os
import re
from bs4 import BeautifulSoup
import getpass
from quart import Quart, request
from prompts import *
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from typing import Annotated
from langgraph.graph.message import add_messages
from typing import List, Tuple, Annotated, TypedDict, Dict
import operator
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
import aiohttp
import json
from typing import Literal
from langgraph.graph import StateGraph
import asyncio
from sugar import timeit
from config import load_config_to_env
from graph_base import Plan, GraphState
load_config_to_env()
config = {"recursion_limit": 500}
model = ChatOpenAI(
model="gpt-4o",
temperature=0
)
SMODIN_HEADERS = {
"Host": "api.smodin.io",
"X-Api-Key": os.environ.get("SMODIN_API_KEY"), "accept": "application/json",
"content-type": "application/json"
}
absract_planner = absract_plan_prompt | model.with_structured_output(Plan)
# pre_planner = pre_plan_prompt | model.with_structured_output(Plan)
designer_model = designer_prompt | model
no_conclusion_model = designer_prompt | model
async def abstract_plan_step(state: GraphState):
abstract = await absract_planner.ainvoke({"messages": [("user", state["input"])]})
return {"plan": abstract.steps}
async def deliver(state: GraphState):
i = state["time"]
messages = state["plan"]
return {"messages": [messages[i]]}
async def designer(state: GraphState):
messages = []
messages.append(state["messages"][-1])
response = await designer_model.ainvoke({"messages": [("user", f"Previous history{messages[-2]},Outline of this section{messages[-1]}")]})
response = response.content.replace('<h2>Introduction: ', '')
Introduction_title_pattern = r'<h2>Introduction:.*?</h2>|<h2>Introduction</h2>'
conclusion_pattern = re.compile(r'^\s*<h2>Conclusion</h2>', re.IGNORECASE)
response = re.sub(Introduction_title_pattern, '', response)
if not conclusion_pattern.match(response):
pattern1 = r'\b(<p>Overall|<p>Ultimately).*?</p>'
pattern2 = re.compile(r'<h3>Conclusion.*?</p>', re.IGNORECASE | re.DOTALL)
response = re.sub(pattern1, '', response).strip()
response = re.sub(pattern2, '', response).strip()
soup = BeautifulSoup(response, 'html.parser')
paragraphs = soup.find_all('p')
for p in paragraphs:
original_text = p.get_text()
payload = {"language": "en",
"text": original_text,
"strength": 4,
}
async with aiohttp.ClientSession() as session:
async with session.post(os.environ.get("SMODIN_BASE_URL"), headers=SMODIN_HEADERS, json=payload) as res:
if res.status == 201:
result = await res.json()
rewritten_text = result["rewrites"][0]["rewrite"]
# '''AI去除总结'''
# # rewritten_text = await no_conclusion_model.ainvoke(
# # {"messages": [("user", rewritten_text["rewrites"][0]["rewrite"])]})
p.string = rewritten_text
# # if not conclusion_pattern.match(rewritten_text["rewrites"][0]["rewrite"]):
# # rewritten_text = re.sub(pattern, '', rewritten_text["rewrites"][0]["rewrite"], flags=re.IGNORECASE | re.DOTALL)
else:
print("Error:", res.status, await res.text())
return {"messages": []}
# def route(state: GraphState) -> Literal["__end__", "action"]:
# # If there is no function call, then we finish
# if state["time"] < len(state["plan"]):
# return "action"
# else:
# return "__end__"
def should_end(state: GraphState) -> Literal["__end__", "deliver"]:
# If there is no function call, then we finish
if state["time"] < len(state["plan"]):
return "deliver"
else:
return "__end__"
# Define a new graph
workflow = StateGraph(GraphState)
workflow.add_node("abstract_planner", abstract_plan_step)
# workflow.add_node("planner", plan_step)
workflow.add_node("action", designer)
workflow.add_node("deliver", deliver)
workflow.set_entry_point("abstract_planner")
workflow.add_conditional_edges("action", should_end)
workflow.add_edge("abstract_planner", "deliver")
# workflow.add_edge("planner", "deliver")
workflow.add_edge("deliver", "action")
app = workflow.compile()
async def write(inputs):
messages = ''
pattern = r"标题:(.*?)(?=,)"
messages += re.search(pattern, inputs["input"])
async for event in app.astream(inputs, config=config):
for k, v in event.items():
if k == "action":
# messages = messages.replace('\n\n', '\n')
messages += v["messages"][0]
# messages += '\n\n'
# print(v["messages"])
return messages
if __name__ == "__main__":
def get_user_input():
# title = input("请输入标题: ")
title = "iPhone"
# core_keywords = input("请输入核心关键词: ")
core_keywords = "Apple"
# related_keywords = input("请输入相关关键词: ")
related_keywords = "Jobs"
inputs = {
"time": 0,
"input": f"标题:{title},核心关键词:{core_keywords},相关关键词:{related_keywords}"
}
return inputs
result = asyncio.run(write(get_user_input()))
print(result)