yasserrmd's picture
Update app.py
e823d14 verified
raw
history blame
3.43 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import re
from groq import Groq
# Initialize FastAPI app
app = FastAPI()
# Serve static files for assets
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize Hugging Face Inference Client
#client = InferenceClient()
client = Groq()
# Pydantic model for API input
class InfographicRequest(BaseModel):
description: str
# Load prompt template from environment variable
SYSTEM_INSTRUCT = os.getenv("SYSTEM_INSTRUCTOR")
PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE")
async def extract_code_blocks(markdown_text):
"""
Extracts code blocks from the given Markdown text.
Args:
markdown_text (str): The Markdown content as a string.
Returns:
list: A list of code blocks extracted from the Markdown.
"""
# Regex to match code blocks (fenced with triple backticks)
code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL)
# Find all code blocks
code_blocks = code_block_pattern.findall(markdown_text)
return code_blocks
def generate_infographic(request: InfographicRequest):
description = request.description
prompt = PROMPT_TEMPLATE.format(description=description)
generated_completion = client.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[
{"role": "system", "content": SYSTEM_INSTRUCT},
{"role": "user", "content": prompt}
],
temperature=0.5,
max_tokens=5000,
top_p=1,
stream=False,
stop=None
)
generated_text = generated_completion.choices[0].message.content
print(generated_text)
return generated_text
# Route to serve the HTML template
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
return HTMLResponse(open("static/infographic_gen.html").read())
# Route to handle infographic generation
@app.post("/generate")
async def generate_infographic(request: InfographicRequest):
generated_text= generate_infographic(request)
code_blocks=await extract_code_blocks(generated_text)
if code_blocks:
return JSONResponse(content={"html": code_blocks[0]})
else:
return JSONResponse(content={"error": "No generation"},status_code=500)
# try:
# messages = [{"role": "user", "content": prompt}]
# stream = client.chat.completions.create(
# model="Qwen/Qwen2.5-Coder-32B-Instruct",
# messages=messages,
# temperature=0.4,
# max_tokens=6000,
# top_p=0.7,
# stream=True,
# )
# generated_text = ""
# for chunk in stream:
# generated_text += chunk.choices[0].delta.content
# print(generated_text)
#code_blocks= await extract_code_blocks(generated_text)
# code_blocks= await generate_infographic(description)
# if code_blocks:
# return JSONResponse(content={"html": code_blocks[0]})
# else:
# return JSONResponse(content={"error": "No generation"},status_code=500)
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)