article_generation / appbf.py
inksiyu's picture
Update appbf.py
f78b788 verified
from flask import Flask, render_template, request, jsonify, Response
import requests
import json
import logging
import re
from flask import request, jsonify, Response, stream_with_context
import subprocess
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', encoding='utf-8')
logger = logging.getLogger(__name__)
app = Flask(__name__)
def generate_content(messages, model, api_key, url):
payload = json.dumps({
"model": model,
"messages": messages,
"stream": True
})
headers = {
'Accept': 'application/json',
'Authorization': api_key,
'Content-Type': 'application/json'
}
decoded_payload = json.loads(payload)
logger.info(f"Request payload: {json.dumps(decoded_payload, ensure_ascii=False)}")
response = requests.post(url, headers=headers, data=payload, stream=True)
content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if "data: " in decoded_line:
data_str = decoded_line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
if "choices" in data:
delta = data["choices"][0]["delta"]
if "content" in delta:
content_str = delta["content"]
content += content_str
yield content_str
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
elif "error" in decoded_line:
error_data = json.loads(decoded_line)
logger.error(f"Error: {error_data}")
break
def write_to_file(file_path, outline, generated_content):
with open(file_path, 'w', encoding='utf-8') as file:
for part, content in zip(outline, generated_content):
file.write(part + "\n")
if content:
file.write(content.strip() + "\n\n")
@app.route('/')
def index():
return render_template('index.html')
@app.route('/generate_outline', methods=['POST'])
def generate_outline():
data = request.get_json()
outline_model = data['outline_model']
url = (data['proxy_url'] or "https://api.openai.com/v1") + "/chat/completions"
api_key = data['api_key']
title = data['title']
doc_type = data['doc_type']
notice = data['notice']
system_prompt = ""
doc_type_name = ""
if doc_type == "1":
system_prompt = "You are ChatGPT, a large language model by OpenAI, and you are good at writing academic papers."
doc_type_name = "论文"
elif doc_type == "2":
system_prompt = "You are ChatGPT, a large language model by OpenAI, and you are good at writing reports."
doc_type_name = "报告"
elif doc_type == "3":
system_prompt = "You are ChatGPT, a large language model by OpenAI. From now on, you need to play the role of a writer. You are good at writing articles."
doc_type_name = "文章"
outline_prompt = f'我想写一个题目是"{title}"的{doc_type_name},下面是具体要求:{notice}\n请你帮我列出一个详细具体的大纲,大纲需要列出各个标题,只能用Markdown中的#和##来区分标题的层级(例如#大标题1\n##小标题1\n##小标题2\n#大标题2\n##小标题3\n##小标题4\n##小标题5\n以此类推,标题数量不固定\n注意:有且仅有大纲中绝对禁止出现"###"的组合,大纲中只能通过#和##进行标题层次的划分,之后的输出待定)。你的输出仅可包含示例中的内容,无需任何其他的解释说明'
retry_prompt = '前文已经要求过你不能输出###了,你误解了我的意思,请你直接重新生成新的outline,直接输出outline即可,无需道歉和其他任何解释'
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": outline_prompt}
]
outline = "".join(list(generate_content(messages, outline_model, api_key, url)))
if re.search(r'###', outline):
temp_outline = outline
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": outline_prompt},
{"role": "assistant", "content": temp_outline},
{"role": "user", "content": retry_prompt}
]
outline = "".join(list(generate_content(messages, outline_model, api_key, url)))
return jsonify({"outline": outline})
is_paused = False
@app.route('/generate_key_words', methods=['POST'])
def generate_key_words():
data = request.get_json()
key_word_model = data['key_word_model']
api_key = data['api_key']
url = (data['proxy_url'] or "https://api.openai.com/v1") + "/chat/completions"
outline = data['outline']
result = subprocess.run(['python', 'key_words.py', key_word_model, api_key, url, outline], capture_output=True, text=True)
search_key_word = result.stdout.strip()
return jsonify({"search_key_word": search_key_word})
@app.route('/search', methods=['POST'])
def search():
data = request.get_json()
selected_key_words = data['selected_key_words']
bing_api_key = data['bing_api_key']
proxy_url = data['proxy_url']
api_key =data['api_key']
print(f"Received selected key words: {', '.join(selected_key_words)}")
print(f"Received Bing API key: {bing_api_key}")
cmd = ['python', 'search.py', ','.join(selected_key_words), bing_api_key, api_key,proxy_url]
print(f"Running command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8')
search_result = result.stdout.strip()
error_output = result.stderr.strip()
print("Output from search.py:")
print(search_result)
if error_output:
print("Error output from search.py:")
print(error_output)
return jsonify({"search_result": search_result})
@app.route('/qa', methods=['POST'])
def qa():
data = request.get_json()
search_result = data['search_result']
api_key = data['api_key']
api_url_base = data['proxy_url']
import qa
qa_result = qa.main(search_result, api_key, api_url_base)
return jsonify({"qa_result": qa_result})
@app.route('/pause', methods=['POST'])
def pause():
global is_paused
is_paused = True
return jsonify({"message": "Generation paused"})
@app.route('/resume', methods=['POST'])
def resume():
global is_paused
is_paused = False
return jsonify({"message": "Generation resumed"})
@app.route('/generate_content', methods=['POST'])
def generate_article_content():
global is_paused
is_paused = False
data = request.get_json()
expand_outline = data['expand_outline']
optimize_token = data['optimize_token']
outline = data['outline']
content_model = data['content_model']
api_key = data['api_key']
url = (data['proxy_url'] or "https://api.openai.com/v1") + "/chat/completions"
doc_type = data['doc_type']
top_k = int(data['top_k'])
similarity_threshold = float(data['similarity_threshold'])
system_prompt = ""
if doc_type == "1":
system_prompt = "You are ChatGPT, a large language model by OpenAI, and you are good at writing academic papers."
elif doc_type == "2":
system_prompt = "You are ChatGPT, a large language model by OpenAI, and you are good at writing reports."
elif doc_type == "3":
system_prompt = "You are ChatGPT, a large language model by OpenAI. From now on, you need to play the role of a writer. You are good at writing articles."
outline_parts = [part for part in outline.split("\n") if part.strip()]
def generate():
messages = [
{"role": "system", "content": system_prompt},
{"role": "assistant", "content": outline}
]
for part in outline_parts:
if is_paused:
break
if part.startswith("##"):
search_queries = [part[2:].strip()]
from getcsv import search_dataset
search_results = search_dataset(search_queries, top_k=top_k, similarity_threshold=similarity_threshold)
# 提取匹配到的QA
matched_qa = ""
for result in search_results:
question = result['question']
answer = result['answer']
matched_qa += f"Question: {question}\nAnswer: {answer}\n\n"
if expand_outline.lower() == 'n':
article_prompt = f'参考之前的整体大纲,以及已经生成的内容,为"{part}"部分扩写内容。在扩写中,严禁出现下一级小标题,要求语言通俗易懂,不要过于死板,善于使用各种修辞手法。在每次输出的开头注明小标题是哪个(要用##放到小标题前面)(不要输出其他小标题和其他小标题的内容),结尾必须换2行\n将以下内容作为你本次生成的知识库:\n{matched_qa}'
else:
article_prompt = f'参考之前的整体大纲,以及已经生成的内容,为"{part}"部分扩写内容。你需要在##小标题的基础上再扩充###小小标题,要求语言通俗易懂,不要过于死板,善于使用各种修辞手法。在每次输出的开头注明小标题是哪个(要用##放到小标题前面),结尾必须换2行\n将以下内容作为你本次生成的知识库:\n{matched_qa}'
if optimize_token.lower() == 'y':
messages = messages[:2]
messages.append({"role": "user", "content": article_prompt})
article_content = ""
for chunk in generate_content(messages, content_model, api_key, url):
# if is_paused:
# break
article_content += chunk
yield chunk
if optimize_token.lower() == 'n':
messages.append({"role": "assistant", "content": article_content})
else:
yield "\n\n"
return Response(stream_with_context(generate()), mimetype='text/plain')
@app.route('/write_to_file', methods=['POST'])
def write_to_file():
data = request.get_json()
title = data['title']
content = data['content']
file_path = f"{title}.txt"
with open(file_path, 'w', encoding='utf-8') as file:
file.write(content)
return jsonify({"message": "Content written to file successfully."})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)