msgxai-hg-api / test_endpoint.py
msgxai's picture
chore: fix metadata propertoes for HF Inference Endpoint
557227d
#!/usr/bin/env python3
# This script demonstrates how to test your Hugging Face Inference Endpoint
# Replace the API_TOKEN and API_URL with your actual values
import requests
import json
import base64
from PIL import Image
import io
import argparse
import os
def test_inference_endpoint(api_token, api_url, prompt, negative_prompt=None,
seed=None, inference_steps=30, guidance_scale=7,
width=1024, height=768, output_dir="generated_images"):
"""
Test a Hugging Face Inference Endpoint for image generation.
Args:
api_token (str): Your Hugging Face API token
api_url (str): The URL of your inference endpoint
prompt (str): The text prompt for image generation
negative_prompt (str, optional): Negative prompt to guide generation
seed (int, optional): Random seed for reproducibility
inference_steps (int): Number of inference steps
guidance_scale (float): Guidance scale for generation
width (int): Image width
height (int): Image height
output_dir (str): Directory to save generated images
"""
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Headers for the request
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json"
}
# Build parameters dictionary with provided values
parameters = {
"width": width,
"height": height,
"inference_steps": inference_steps,
"guidance_scale": guidance_scale
}
# Add optional parameters if provided
if negative_prompt:
parameters["negative_prompt"] = negative_prompt
if seed:
parameters["seed"] = seed
# Request payload
payload = {
"inputs": prompt,
"parameters": parameters
}
print(f"Sending request to {api_url}...")
print(f"Prompt: '{prompt}'")
try:
# Send the request
response = requests.post(api_url, headers=headers, json=payload)
# Check for errors
if response.status_code != 200:
print(f"Error: {response.status_code} - {response.text}")
return
# Parse the response
result = response.json()
# Check for error in the response
if isinstance(result, dict) and "error" in result:
print(f"API Error: {result['error']}")
return
# Extract the generated image and seed
if isinstance(result, list) and len(result) > 0:
item = result[0]
if "generated_image" in item:
# Convert the base64-encoded image to a PIL Image
image_bytes = base64.b64decode(item["generated_image"])
image = Image.open(io.BytesIO(image_bytes))
# Create a filename based on the prompt and seed
used_seed = item.get("seed", "unknown_seed")
filename = f"{output_dir}/generated_{used_seed}.png"
# Save the image
image.save(filename)
print(f"Image saved to {filename}")
print(f"Seed: {used_seed}")
return image
else:
print("Response doesn't contain 'generated_image' field")
else:
print("Unexpected response format:", result)
except Exception as e:
print(f"Error: {str(e)}")
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Test Hugging Face Inference Endpoints for image generation")
parser.add_argument("--token", required=True, help="Your Hugging Face API token")
parser.add_argument("--url", required=True, help="URL of your inference endpoint")
parser.add_argument("--prompt", required=True, help="Text prompt for image generation")
parser.add_argument("--negative_prompt", help="Negative prompt")
parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
parser.add_argument("--steps", type=int, default=30, help="Number of inference steps")
parser.add_argument("--guidance", type=float, default=7, help="Guidance scale")
parser.add_argument("--width", type=int, default=1024, help="Image width")
parser.add_argument("--height", type=int, default=768, help="Image height")
parser.add_argument("--output_dir", default="generated_images", help="Directory to save generated images")
args = parser.parse_args()
# Call the test function with provided arguments
test_inference_endpoint(
api_token=args.token,
api_url=args.url,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
seed=args.seed,
inference_steps=args.steps,
guidance_scale=args.guidance,
width=args.width,
height=args.height,
output_dir=args.output_dir
)