ugly-holiday-card-generator / test_endpoint.py
Mikiko Bazeley
Refactored and removed controlnet
57eccf2
import os
import requests
from dotenv import load_dotenv
from PIL import Image
from io import BytesIO
# Correct the path to the .env file to reflect its location
dotenv_path = os.path.join(os.path.dirname(__file__), 'env', '.env')
#print("dotenv_path: ", dotenv_path)
# Load environment variables from the .env file
load_dotenv(dotenv_path, override=True)
# Get the API key from the .env file
api_key = os.getenv("FIREWORKS_API_KEY")
if not api_key:
raise ValueError("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
# User input for the prompt
prompt = input("Enter a prompt for image generation: ")
# Validate the prompt input
if not prompt.strip():
raise ValueError("Prompt cannot be empty!")
# Set the model endpoint for either flux-1-dev or flux-1-schnell
# For dev: "flux-1-dev" (30 steps)
# For schnell: "flux-1-schnell" (4 steps)
#model_path = "flux-1-schnell-fp8"
model_path = "flux-1-dev-fp8" # Uncomment if you want to switch to the dev model
# API URL for the model
url = f"https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model_path}/text_to_image"
# Headers for the API request
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "image/jpeg"
}
# Data payload to send with the request
data = {
"prompt": prompt, # Use the user-provided prompt
"aspect_ratio": "16:9",
"guidance_scale": 3.5,
"num_inference_steps": 30 if model_path == "flux-1-dev" else 4, # 30 steps for dev, 4 for schnell
"seed": 0
}
# Make the POST request to the API
response = requests.post(url, headers=headers, json=data)
# Check the status of the response
if response.status_code == 200:
# If the request is successful, convert the response to an image
img_data = response.content
img = Image.open(BytesIO(img_data))
# Save the image
img.save("output_image.jpg")
print("Image saved successfully as output_image.jpg.")
else:
# If there's an error, print the status code and response text
print(f"Error: {response.status_code}, {response.text}")