pvanand commited on
Commit
5e2ed5c
·
verified ·
1 Parent(s): 6ef6156

Update document_generator.py

Browse files
Files changed (1) hide show
  1. document_generator.py +33 -29
document_generator.py CHANGED
@@ -40,7 +40,7 @@ FORMAT YOUR OUTPUT AS MARKDOWN ENCLOSED IN <response></response> tags
40
  DOCUMENT_SECTION_PROMPT_USER = """<prompt>Output the content for the section "{section_or_subsection_title}" formatted as markdown. Follow this instruction: {content_instruction}</prompt>"""
41
 
42
  # File: app.py
43
- import os
44
  import json
45
  import re
46
  import time
@@ -52,16 +52,17 @@ import functools
52
  from fastapi import APIRouter, HTTPException
53
  from pydantic import BaseModel
54
  from fastapi_cache.decorator import cache
 
55
 
56
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
57
  logger = logging.getLogger(__name__)
58
 
59
  def log_execution(func: Callable) -> Callable:
60
  @functools.wraps(func)
61
- def wrapper(*args: Any, **kwargs: Any) -> Any:
62
  logger.info(f"Executing {func.__name__}")
63
  try:
64
- result = func(*args, **kwargs)
65
  logger.info(f"{func.__name__} completed successfully")
66
  return result
67
  except Exception as e:
@@ -77,7 +78,7 @@ class AIClient:
77
  )
78
 
79
  @log_execution
80
- def generate_response(
81
  self,
82
  messages: List[Dict[str, str]],
83
  model: str = "openai/gpt-4o-mini",
@@ -85,12 +86,14 @@ class AIClient:
85
  ) -> Optional[str]:
86
  if not messages:
87
  return None
88
- response = self.client.chat.completions.create(
 
 
89
  model=model,
90
  messages=messages,
91
  max_tokens=max_tokens,
92
  stream=False
93
- )
94
  return response.choices[0].message.content
95
 
96
  class DocumentGenerator:
@@ -120,14 +123,14 @@ class DocumentGenerator:
120
  return content.lstrip()
121
 
122
  @log_execution
123
- def generate_document_outline(self, query: str, max_retries: int = 3) -> Optional[Dict]:
124
  messages = [
125
  {"role": "system", "content": DOCUMENT_OUTLINE_PROMPT_SYSTEM},
126
  {"role": "user", "content": DOCUMENT_OUTLINE_PROMPT_USER.format(query=query)}
127
  ]
128
 
129
  for attempt in range(max_retries):
130
- outline_response = self.ai_client.generate_response(messages, model="openai/gpt-4o")
131
  outline_json_text = self.extract_between_tags(outline_response, "output")
132
 
133
  try:
@@ -142,7 +145,7 @@ class DocumentGenerator:
142
  return None
143
 
144
  @log_execution
145
- def generate_content(self, title: str, content_instruction: str, section_number: str) -> str:
146
  self.content_messages.append({
147
  "role": "user",
148
  "content": DOCUMENT_SECTION_PROMPT_USER.format(
@@ -150,7 +153,7 @@ class DocumentGenerator:
150
  content_instruction=content_instruction
151
  )
152
  })
153
- section_response = self.ai_client.generate_response(self.content_messages)
154
  content = self.extract_between_tags(section_response, "response")
155
  content = self.remove_duplicate_title(content, title, section_number)
156
  self.content_messages.append({
@@ -160,7 +163,7 @@ class DocumentGenerator:
160
  return content
161
 
162
  @log_execution
163
- def generate_full_document(self, document_outline: Dict, query: str) -> Dict:
164
  self.document_outline = document_outline
165
 
166
  overall_objective = query
@@ -181,16 +184,21 @@ class DocumentGenerator:
181
  section_number = section.get("SectionNumber", "")
182
  content_instruction = section.get("Content", "")
183
  logger.info(f"Generating content for section: {section_title}")
184
- section["Content"] = self.generate_content(section_title, content_instruction, section_number)
 
185
 
186
  for subsection in section.get("Subsections", []):
187
  subsection_title = subsection.get("Title", "")
188
  subsection_number = subsection.get("SectionNumber", "")
189
  subsection_content_instruction = subsection.get("Content", "")
190
  logger.info(f"Generating content for subsection: {subsection_title}")
191
- subsection["Content"] = self.generate_content(subsection_title, subsection_content_instruction, subsection_number)
 
192
 
193
- return self.document_outline
 
 
 
194
 
195
  class MarkdownConverter:
196
  @staticmethod
@@ -258,9 +266,6 @@ class MarkdownDocumentRequest(BaseModel):
258
  json_document: Dict
259
  query: str
260
 
261
- class MarkdownDocumentResponse(BaseModel):
262
- markdown_document: str
263
-
264
  @cache(expire=600*24*7)
265
  @router.post("/generate-document/json", response_model=JsonDocumentResponse)
266
  async def generate_document_outline_endpoint(request: DocumentRequest):
@@ -269,7 +274,7 @@ async def generate_document_outline_endpoint(request: DocumentRequest):
269
 
270
  try:
271
  # Generate the document outline
272
- json_document = document_generator.generate_document_outline(request.query)
273
 
274
  if json_document is None:
275
  raise HTTPException(status_code=500, detail="Failed to generate a valid document outline")
@@ -278,21 +283,20 @@ async def generate_document_outline_endpoint(request: DocumentRequest):
278
  except Exception as e:
279
  raise HTTPException(status_code=500, detail=str(e))
280
 
281
- @router.post("/generate-document/markdown", response_model=MarkdownDocumentResponse)
282
  async def generate_markdown_document_endpoint(request: MarkdownDocumentRequest):
283
  ai_client = AIClient()
284
  document_generator = DocumentGenerator(ai_client)
285
 
286
- try:
287
- # Generate the full document content
288
- full_document = document_generator.generate_full_document(request.json_document, request.query)
289
-
290
- # Convert to Markdown
291
- markdown_document = MarkdownConverter.convert_to_markdown(full_document["Document"])
292
-
293
- return MarkdownDocumentResponse(markdown_document=markdown_document)
294
- except Exception as e:
295
- raise HTTPException(status_code=500, detail=str(e))
296
 
297
  @router.post("/generate-document-test", response_model=MarkdownDocumentResponse)
298
  async def test_generate_document_endpoint(request: DocumentRequest):
 
40
  DOCUMENT_SECTION_PROMPT_USER = """<prompt>Output the content for the section "{section_or_subsection_title}" formatted as markdown. Follow this instruction: {content_instruction}</prompt>"""
41
 
42
  # File: app.py
43
+ import os
44
  import json
45
  import re
46
  import time
 
52
  from fastapi import APIRouter, HTTPException
53
  from pydantic import BaseModel
54
  from fastapi_cache.decorator import cache
55
+ from starlette.responses import StreamingResponse
56
 
57
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
58
  logger = logging.getLogger(__name__)
59
 
60
  def log_execution(func: Callable) -> Callable:
61
  @functools.wraps(func)
62
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
63
  logger.info(f"Executing {func.__name__}")
64
  try:
65
+ result = await func(*args, **kwargs)
66
  logger.info(f"{func.__name__} completed successfully")
67
  return result
68
  except Exception as e:
 
78
  )
79
 
80
  @log_execution
81
+ async def generate_response(
82
  self,
83
  messages: List[Dict[str, str]],
84
  model: str = "openai/gpt-4o-mini",
 
86
  ) -> Optional[str]:
87
  if not messages:
88
  return None
89
+ loop = asyncio.get_event_loop()
90
+ response = await loop.run_in_executor(None, functools.partial(
91
+ self.client.chat.completions.create,
92
  model=model,
93
  messages=messages,
94
  max_tokens=max_tokens,
95
  stream=False
96
+ ))
97
  return response.choices[0].message.content
98
 
99
  class DocumentGenerator:
 
123
  return content.lstrip()
124
 
125
  @log_execution
126
+ async def generate_document_outline(self, query: str, max_retries: int = 3) -> Optional[Dict]:
127
  messages = [
128
  {"role": "system", "content": DOCUMENT_OUTLINE_PROMPT_SYSTEM},
129
  {"role": "user", "content": DOCUMENT_OUTLINE_PROMPT_USER.format(query=query)}
130
  ]
131
 
132
  for attempt in range(max_retries):
133
+ outline_response = await self.ai_client.generate_response(messages, model="openai/gpt-4o")
134
  outline_json_text = self.extract_between_tags(outline_response, "output")
135
 
136
  try:
 
145
  return None
146
 
147
  @log_execution
148
+ async def generate_content(self, title: str, content_instruction: str, section_number: str) -> str:
149
  self.content_messages.append({
150
  "role": "user",
151
  "content": DOCUMENT_SECTION_PROMPT_USER.format(
 
153
  content_instruction=content_instruction
154
  )
155
  })
156
+ section_response = await self.ai_client.generate_response(self.content_messages)
157
  content = self.extract_between_tags(section_response, "response")
158
  content = self.remove_duplicate_title(content, title, section_number)
159
  self.content_messages.append({
 
163
  return content
164
 
165
  @log_execution
166
+ async def generate_full_document(self, document_outline: Dict, query: str):
167
  self.document_outline = document_outline
168
 
169
  overall_objective = query
 
184
  section_number = section.get("SectionNumber", "")
185
  content_instruction = section.get("Content", "")
186
  logger.info(f"Generating content for section: {section_title}")
187
+ section["Content"] = await self.generate_content(section_title, content_instruction, section_number)
188
+ yield json.dumps({"type": "document_section", "content": section}) + "\n"
189
 
190
  for subsection in section.get("Subsections", []):
191
  subsection_title = subsection.get("Title", "")
192
  subsection_number = subsection.get("SectionNumber", "")
193
  subsection_content_instruction = subsection.get("Content", "")
194
  logger.info(f"Generating content for subsection: {subsection_title}")
195
+ subsection["Content"] = await self.generate_content(subsection_title, subsection_content_instruction, subsection_number)
196
+ yield json.dumps({"type": "document_subsection", "content": subsection}) + "\n"
197
 
198
+ # Generate the complete markdown document
199
+ full_document = self.document_outline
200
+ markdown_document = MarkdownConverter.convert_to_markdown(full_document["Document"])
201
+ yield json.dumps({"type": "complete_document", "content": markdown_document}) + "\n"
202
 
203
  class MarkdownConverter:
204
  @staticmethod
 
266
  json_document: Dict
267
  query: str
268
 
 
 
 
269
  @cache(expire=600*24*7)
270
  @router.post("/generate-document/json", response_model=JsonDocumentResponse)
271
  async def generate_document_outline_endpoint(request: DocumentRequest):
 
274
 
275
  try:
276
  # Generate the document outline
277
+ json_document = await document_generator.generate_document_outline(request.query)
278
 
279
  if json_document is None:
280
  raise HTTPException(status_code=500, detail="Failed to generate a valid document outline")
 
283
  except Exception as e:
284
  raise HTTPException(status_code=500, detail=str(e))
285
 
286
+ @router.post("/generate-document/markdown")
287
  async def generate_markdown_document_endpoint(request: MarkdownDocumentRequest):
288
  ai_client = AIClient()
289
  document_generator = DocumentGenerator(ai_client)
290
 
291
+ async def event_stream():
292
+ try:
293
+ # Generate the full document content and stream it
294
+ async for section in document_generator.generate_full_document(request.json_document, request.query):
295
+ yield section
296
+ except Exception as e:
297
+ yield json.dumps({"type": "error", "message": str(e)}) + "\n"
298
+
299
+ return StreamingResponse(event_stream(), media_type="application/json")
 
300
 
301
  @router.post("/generate-document-test", response_model=MarkdownDocumentResponse)
302
  async def test_generate_document_endpoint(request: DocumentRequest):