Spaces:
Running
Running
Commit
·
e34d069
1
Parent(s):
db46f9e
chore: cleanup unnecessary code for runnign benchmark
Browse files- .env.example +0 -4
- api/__init__.py +0 -66
- api/aws.py +0 -69
- api/baseline.py +0 -55
- api/fal.py +0 -48
- api/fireworks.py +0 -53
- api/flux.py +0 -35
- api/pruna.py +0 -53
- api/pruna_dev.py +0 -49
- api/replicate.py +0 -48
- api/replicate_wan.py +0 -48
- api/together.py +0 -47
- dashboard/app.py → app.py +0 -0
- benchmark/__init__.py +0 -45
- benchmark/draw_bench.py +0 -25
- benchmark/genai_bench.py +0 -39
- benchmark/geneval.py +0 -44
- benchmark/hps.py +0 -56
- benchmark/metrics/__init__.py +0 -43
- benchmark/metrics/arniqa.py +0 -36
- benchmark/metrics/clip.py +0 -29
- benchmark/metrics/clip_iqa.py +0 -32
- benchmark/metrics/hps.py +0 -92
- benchmark/metrics/image_reward.py +0 -35
- benchmark/metrics/sharpness.py +0 -24
- benchmark/metrics/vqa.py +0 -31
- benchmark/parti.py +0 -28
- dashboard/requirements.txt +0 -3
- {dashboard/data → data}/text_to_image.jsonl +0 -0
- evaluate.py +0 -124
- evaluation_results/.gitkeep +0 -0
- images/.gitkeep +0 -0
- pyproject.toml +0 -43
- sample.py +0 -125
.env.example
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
FIREWORKS_API_TOKEN=your_fireworks_api_token_here
|
| 2 |
-
REPLICATE_API_TOKEN=your_replicate_api_token_here
|
| 3 |
-
FAL_KEY=your_fal_key_here
|
| 4 |
-
TOGETHER_API_KEY=your_together_api_key_here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/__init__.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
from typing import Type
|
| 2 |
-
|
| 3 |
-
from api.aws import AWSBedrockAPI
|
| 4 |
-
from api.baseline import BaselineAPI
|
| 5 |
-
from api.fal import FalAPI
|
| 6 |
-
from api.fireworks import FireworksAPI
|
| 7 |
-
from api.flux import FluxAPI
|
| 8 |
-
from api.pruna import PrunaAPI
|
| 9 |
-
from api.pruna_dev import PrunaDevAPI
|
| 10 |
-
from api.replicate import ReplicateAPI
|
| 11 |
-
from api.together import TogetherAPI
|
| 12 |
-
|
| 13 |
-
__all__ = [
|
| 14 |
-
"create_api",
|
| 15 |
-
"FluxAPI",
|
| 16 |
-
"BaselineAPI",
|
| 17 |
-
"FireworksAPI",
|
| 18 |
-
"PrunaAPI",
|
| 19 |
-
"ReplicateAPI",
|
| 20 |
-
"TogetherAPI",
|
| 21 |
-
"FalAPI",
|
| 22 |
-
"PrunaDevAPI",
|
| 23 |
-
]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def create_api(api_type: str) -> FluxAPI:
|
| 27 |
-
"""
|
| 28 |
-
Factory function to create API instances.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
api_type (str): The type of API to create. Must be one of:
|
| 32 |
-
- "baseline"
|
| 33 |
-
- "fireworks"
|
| 34 |
-
- "pruna_speed_mode" (where speed_mode is the desired speed mode)
|
| 35 |
-
- "replicate"
|
| 36 |
-
- "together"
|
| 37 |
-
- "fal"
|
| 38 |
-
- "aws"
|
| 39 |
-
|
| 40 |
-
Returns:
|
| 41 |
-
FluxAPI: An instance of the requested API implementation
|
| 42 |
-
|
| 43 |
-
Raises:
|
| 44 |
-
ValueError: If an invalid API type is provided
|
| 45 |
-
"""
|
| 46 |
-
if api_type == "pruna_dev":
|
| 47 |
-
return PrunaDevAPI()
|
| 48 |
-
if api_type.startswith("pruna_"):
|
| 49 |
-
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
| 50 |
-
return PrunaAPI(speed_mode)
|
| 51 |
-
|
| 52 |
-
api_map: dict[str, Type[FluxAPI]] = {
|
| 53 |
-
"baseline": BaselineAPI,
|
| 54 |
-
"fireworks": FireworksAPI,
|
| 55 |
-
"replicate": ReplicateAPI,
|
| 56 |
-
"together": TogetherAPI,
|
| 57 |
-
"fal": FalAPI,
|
| 58 |
-
"aws": AWSBedrockAPI,
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
if api_type not in api_map:
|
| 62 |
-
raise ValueError(
|
| 63 |
-
f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
return api_map[api_type]()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/aws.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
| 1 |
-
import base64
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import time
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import boto3
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
-
|
| 10 |
-
from api.flux import FluxAPI
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class AWSBedrockAPI(FluxAPI):
|
| 14 |
-
def __init__(self):
|
| 15 |
-
load_dotenv()
|
| 16 |
-
# AWS credentials should be set via environment variables
|
| 17 |
-
# AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN
|
| 18 |
-
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
| 19 |
-
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
| 20 |
-
os.environ["AWS_SESSION_TOKEN"] = ""
|
| 21 |
-
os.environ["AWS_REGION"] = "us-east-1"
|
| 22 |
-
self._client = boto3.client("bedrock-runtime")
|
| 23 |
-
self._model_id = "amazon.nova-canvas-v1:0"
|
| 24 |
-
|
| 25 |
-
@property
|
| 26 |
-
def name(self) -> str:
|
| 27 |
-
return "aws_nova_canvas"
|
| 28 |
-
|
| 29 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 30 |
-
start_time = time.time()
|
| 31 |
-
# Format the request payload
|
| 32 |
-
native_request = {
|
| 33 |
-
"taskType": "TEXT_IMAGE",
|
| 34 |
-
"textToImageParams": {"text": prompt},
|
| 35 |
-
"imageGenerationConfig": {
|
| 36 |
-
"seed": 0,
|
| 37 |
-
"quality": "standard",
|
| 38 |
-
"height": 1024,
|
| 39 |
-
"width": 1024,
|
| 40 |
-
"numberOfImages": 1,
|
| 41 |
-
},
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
try:
|
| 45 |
-
# Convert request to JSON and invoke the model
|
| 46 |
-
request = json.dumps(native_request)
|
| 47 |
-
response = self._client.invoke_model(modelId=self._model_id, body=request)
|
| 48 |
-
|
| 49 |
-
# Process the response
|
| 50 |
-
model_response = json.loads(response["body"].read())
|
| 51 |
-
if not model_response.get("images"):
|
| 52 |
-
raise Exception("No images returned from AWS Bedrock API")
|
| 53 |
-
|
| 54 |
-
# Save the image
|
| 55 |
-
base64_image_data = model_response["images"][0]
|
| 56 |
-
self._save_image_from_base64(base64_image_data, save_path)
|
| 57 |
-
|
| 58 |
-
except Exception as e:
|
| 59 |
-
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
| 60 |
-
|
| 61 |
-
end_time = time.time()
|
| 62 |
-
return end_time - start_time
|
| 63 |
-
|
| 64 |
-
def _save_image_from_base64(self, base64_data: str, save_path: Path):
|
| 65 |
-
"""Save a base64 encoded image to the specified path."""
|
| 66 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 67 |
-
image_data = base64.b64decode(base64_data)
|
| 68 |
-
with open(save_path, "wb") as f:
|
| 69 |
-
f.write(image_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/baseline.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
-
|
| 8 |
-
from api.flux import FluxAPI
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class BaselineAPI(FluxAPI):
|
| 12 |
-
"""
|
| 13 |
-
As our baseline, we use the Replicate API with go_fast=False.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def __init__(self):
|
| 17 |
-
load_dotenv()
|
| 18 |
-
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
| 19 |
-
if not self._api_key:
|
| 20 |
-
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
| 21 |
-
|
| 22 |
-
@property
|
| 23 |
-
def name(self) -> str:
|
| 24 |
-
return "baseline"
|
| 25 |
-
|
| 26 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 27 |
-
import replicate
|
| 28 |
-
|
| 29 |
-
start_time = time.time()
|
| 30 |
-
result = replicate.run(
|
| 31 |
-
"black-forest-labs/flux-dev",
|
| 32 |
-
input={
|
| 33 |
-
"prompt": prompt,
|
| 34 |
-
"go_fast": False,
|
| 35 |
-
"guidance": 3.5,
|
| 36 |
-
"num_outputs": 1,
|
| 37 |
-
"aspect_ratio": "1:1",
|
| 38 |
-
"output_format": "png",
|
| 39 |
-
"num_inference_steps": 28,
|
| 40 |
-
"seed": 0,
|
| 41 |
-
},
|
| 42 |
-
)
|
| 43 |
-
end_time = time.time()
|
| 44 |
-
|
| 45 |
-
if result and len(result) > 0:
|
| 46 |
-
self._save_image_from_result(result[0], save_path)
|
| 47 |
-
else:
|
| 48 |
-
raise Exception("No result returned from Replicate API")
|
| 49 |
-
|
| 50 |
-
return end_time - start_time
|
| 51 |
-
|
| 52 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 53 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 54 |
-
with open(save_path, "wb") as f:
|
| 55 |
-
f.write(result.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/fal.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import time
|
| 2 |
-
from io import BytesIO
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import fal_client
|
| 7 |
-
import requests
|
| 8 |
-
from PIL import Image
|
| 9 |
-
|
| 10 |
-
from api.flux import FluxAPI
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class FalAPI(FluxAPI):
|
| 14 |
-
@property
|
| 15 |
-
def name(self) -> str:
|
| 16 |
-
return "fal"
|
| 17 |
-
|
| 18 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 19 |
-
start_time = time.time()
|
| 20 |
-
result = fal_client.subscribe(
|
| 21 |
-
"fal-ai/flux/dev",
|
| 22 |
-
arguments={
|
| 23 |
-
"seed": 0,
|
| 24 |
-
"prompt": prompt,
|
| 25 |
-
"image_size": "square_hd", # 1024x1024 image
|
| 26 |
-
"num_images": 1,
|
| 27 |
-
"guidance_scale": 3.5,
|
| 28 |
-
"num_inference_steps": 28,
|
| 29 |
-
"enable_safety_checker": True,
|
| 30 |
-
},
|
| 31 |
-
)
|
| 32 |
-
end_time = time.time()
|
| 33 |
-
|
| 34 |
-
url = result["images"][0]["url"]
|
| 35 |
-
self._save_image_from_url(url, save_path)
|
| 36 |
-
|
| 37 |
-
return end_time - start_time
|
| 38 |
-
|
| 39 |
-
def _save_image_from_url(self, url: str, save_path: Path):
|
| 40 |
-
response = requests.get(url)
|
| 41 |
-
image = Image.open(BytesIO(response.content))
|
| 42 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
-
image.save(save_path)
|
| 44 |
-
|
| 45 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 46 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
-
with open(save_path, "wb") as f:
|
| 48 |
-
f.write(result.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/fireworks.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import requests
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from api.flux import FluxAPI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class FireworksAPI(FluxAPI):
|
| 13 |
-
def __init__(self):
|
| 14 |
-
load_dotenv()
|
| 15 |
-
self._api_key = os.getenv("FIREWORKS_API_TOKEN")
|
| 16 |
-
if not self._api_key:
|
| 17 |
-
raise ValueError("FIREWORKS_API_TOKEN not found in environment variables")
|
| 18 |
-
self._url = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image"
|
| 19 |
-
|
| 20 |
-
@property
|
| 21 |
-
def name(self) -> str:
|
| 22 |
-
return "fireworks_fp8"
|
| 23 |
-
|
| 24 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 25 |
-
start_time = time.time()
|
| 26 |
-
|
| 27 |
-
headers = {
|
| 28 |
-
"Content-Type": "application/json",
|
| 29 |
-
"Accept": "image/jpeg",
|
| 30 |
-
"Authorization": f"Bearer {self._api_key}",
|
| 31 |
-
}
|
| 32 |
-
data = {
|
| 33 |
-
"prompt": prompt,
|
| 34 |
-
"aspect_ratio": "1:1",
|
| 35 |
-
"guidance_scale": 3.5,
|
| 36 |
-
"num_inference_steps": 28,
|
| 37 |
-
"seed": 0,
|
| 38 |
-
}
|
| 39 |
-
result = requests.post(self._url, headers=headers, json=data)
|
| 40 |
-
|
| 41 |
-
end_time = time.time()
|
| 42 |
-
|
| 43 |
-
if result.status_code == 200:
|
| 44 |
-
self._save_image_from_result(result, save_path)
|
| 45 |
-
else:
|
| 46 |
-
raise Exception(f"Error: {result.status_code} {result.text}")
|
| 47 |
-
|
| 48 |
-
return end_time - start_time
|
| 49 |
-
|
| 50 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 51 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 52 |
-
with open(save_path, "wb") as f:
|
| 53 |
-
f.write(result.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/flux.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class FluxAPI(ABC):
|
| 6 |
-
"""
|
| 7 |
-
Abstract base class for Flux API implementations.
|
| 8 |
-
|
| 9 |
-
This class defines the common interface for all Flux API implementations.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
@property
|
| 13 |
-
@abstractmethod
|
| 14 |
-
def name(self) -> str:
|
| 15 |
-
"""
|
| 16 |
-
The name of the API implementation.
|
| 17 |
-
|
| 18 |
-
Returns:
|
| 19 |
-
str: The name of the specific API implementation
|
| 20 |
-
"""
|
| 21 |
-
pass
|
| 22 |
-
|
| 23 |
-
@abstractmethod
|
| 24 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 25 |
-
"""
|
| 26 |
-
Generate an image based on the prompt and save it to the specified path.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
prompt (str): The text prompt to generate the image from
|
| 30 |
-
save_path (Path): The path where the generated image should be saved
|
| 31 |
-
|
| 32 |
-
Returns:
|
| 33 |
-
float: The time taken for the API call in seconds
|
| 34 |
-
"""
|
| 35 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/pruna.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import replicate
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from api.flux import FluxAPI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class PrunaAPI(FluxAPI):
|
| 13 |
-
def __init__(self, speed_mode: str):
|
| 14 |
-
self._speed_mode = speed_mode
|
| 15 |
-
self._speed_mode_name = (
|
| 16 |
-
speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
|
| 17 |
-
)
|
| 18 |
-
load_dotenv()
|
| 19 |
-
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
| 20 |
-
if not self._api_key:
|
| 21 |
-
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def name(self) -> str:
|
| 25 |
-
return f"pruna_{self._speed_mode_name}"
|
| 26 |
-
|
| 27 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 28 |
-
start_time = time.time()
|
| 29 |
-
result = replicate.run(
|
| 30 |
-
"prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410",
|
| 31 |
-
input={
|
| 32 |
-
"seed": 0,
|
| 33 |
-
"prompt": prompt,
|
| 34 |
-
"guidance": 3.5,
|
| 35 |
-
"num_outputs": 1,
|
| 36 |
-
"aspect_ratio": "1:1",
|
| 37 |
-
"output_format": "png",
|
| 38 |
-
"speed_mode": self._speed_mode,
|
| 39 |
-
"num_inference_steps": 28,
|
| 40 |
-
},
|
| 41 |
-
)
|
| 42 |
-
end_time = time.time()
|
| 43 |
-
|
| 44 |
-
if result:
|
| 45 |
-
self._save_image_from_result(result, save_path)
|
| 46 |
-
else:
|
| 47 |
-
raise Exception("No result returned from Replicate API")
|
| 48 |
-
return end_time - start_time
|
| 49 |
-
|
| 50 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 51 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 52 |
-
with open(save_path, "wb") as f:
|
| 53 |
-
f.write(result.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/pruna_dev.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import replicate
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from api.flux import FluxAPI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class PrunaDevAPI(FluxAPI):
|
| 13 |
-
def __init__(self):
|
| 14 |
-
load_dotenv()
|
| 15 |
-
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
| 16 |
-
if not self._api_key:
|
| 17 |
-
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def name(self) -> str:
|
| 21 |
-
return "pruna_dev"
|
| 22 |
-
|
| 23 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 24 |
-
start_time = time.time()
|
| 25 |
-
result = replicate.run(
|
| 26 |
-
"prunaai/flux.1-dev:938a4eb31a87d65fb7b23fc300fb5b7ab88a36844bb26e54e1d1dec7acf4eefe",
|
| 27 |
-
input={
|
| 28 |
-
"seed": 0,
|
| 29 |
-
"prompt": prompt,
|
| 30 |
-
"guidance": 3.5,
|
| 31 |
-
"num_outputs": 1,
|
| 32 |
-
"aspect_ratio": "1:1",
|
| 33 |
-
"output_format": "png",
|
| 34 |
-
"speed_mode": "Juiced 🔥 (default)",
|
| 35 |
-
"num_inference_steps": 28,
|
| 36 |
-
},
|
| 37 |
-
)
|
| 38 |
-
end_time = time.time()
|
| 39 |
-
|
| 40 |
-
if result:
|
| 41 |
-
self._save_image_from_result(result, save_path)
|
| 42 |
-
else:
|
| 43 |
-
raise Exception("No result returned from Replicate API")
|
| 44 |
-
return end_time - start_time
|
| 45 |
-
|
| 46 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 47 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 48 |
-
with open(save_path, "wb") as f:
|
| 49 |
-
f.write(result.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/replicate.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import replicate
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from api.flux import FluxAPI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class ReplicateAPI(FluxAPI):
|
| 13 |
-
def __init__(self):
|
| 14 |
-
load_dotenv()
|
| 15 |
-
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
| 16 |
-
if not self._api_key:
|
| 17 |
-
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def name(self) -> str:
|
| 21 |
-
return "replicate_go_fast"
|
| 22 |
-
|
| 23 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 24 |
-
start_time = time.time()
|
| 25 |
-
result = replicate.run(
|
| 26 |
-
"black-forest-labs/flux-dev",
|
| 27 |
-
input={
|
| 28 |
-
"seed": 0,
|
| 29 |
-
"prompt": prompt,
|
| 30 |
-
"go_fast": True,
|
| 31 |
-
"guidance": 3.5,
|
| 32 |
-
"num_outputs": 1,
|
| 33 |
-
"aspect_ratio": "1:1",
|
| 34 |
-
"output_format": "png",
|
| 35 |
-
"num_inference_steps": 28,
|
| 36 |
-
},
|
| 37 |
-
)
|
| 38 |
-
end_time = time.time()
|
| 39 |
-
if result and len(result) > 0:
|
| 40 |
-
self._save_image_from_result(result[0], save_path)
|
| 41 |
-
else:
|
| 42 |
-
raise Exception("No result returned from Replicate API")
|
| 43 |
-
return end_time - start_time
|
| 44 |
-
|
| 45 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 46 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
-
with open(save_path, "wb") as f:
|
| 48 |
-
f.write(result.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/replicate_wan.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import replicate
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from api.flux import FluxAPI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class ReplicateAPI(FluxAPI):
|
| 13 |
-
def __init__(self):
|
| 14 |
-
load_dotenv()
|
| 15 |
-
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
| 16 |
-
if not self._api_key:
|
| 17 |
-
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def name(self) -> str:
|
| 21 |
-
return "replicate_go_fast"
|
| 22 |
-
|
| 23 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 24 |
-
start_time = time.time()
|
| 25 |
-
result = replicate.run(
|
| 26 |
-
"black-forest-labs/flux-dev",
|
| 27 |
-
input={
|
| 28 |
-
"seed": 0,
|
| 29 |
-
"prompt": prompt,
|
| 30 |
-
"go_fast": True,
|
| 31 |
-
"guidance": 3.5,
|
| 32 |
-
"num_outputs": 1,
|
| 33 |
-
"aspect_ratio": "1:1",
|
| 34 |
-
"output_format": "png",
|
| 35 |
-
"num_inference_steps": 28,
|
| 36 |
-
},
|
| 37 |
-
)
|
| 38 |
-
end_time = time.time()
|
| 39 |
-
if result and len(result) > 0:
|
| 40 |
-
self._save_image_from_result(result[0], save_path)
|
| 41 |
-
else:
|
| 42 |
-
raise Exception("No result returned from Replicate API")
|
| 43 |
-
return end_time - start_time
|
| 44 |
-
|
| 45 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 46 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
-
with open(save_path, "wb") as f:
|
| 48 |
-
f.write(result.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/together.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
import base64
|
| 2 |
-
import io
|
| 3 |
-
import time
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
from PIL import Image
|
| 9 |
-
from together import Together
|
| 10 |
-
|
| 11 |
-
from api.flux import FluxAPI
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class TogetherAPI(FluxAPI):
|
| 15 |
-
def __init__(self):
|
| 16 |
-
load_dotenv()
|
| 17 |
-
self._client = Together()
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def name(self) -> str:
|
| 21 |
-
return "together"
|
| 22 |
-
|
| 23 |
-
def generate_image(self, prompt: str, save_path: Path) -> float:
|
| 24 |
-
start_time = time.time()
|
| 25 |
-
result = self._client.images.generate(
|
| 26 |
-
prompt=prompt,
|
| 27 |
-
model="black-forest-labs/FLUX.1-dev",
|
| 28 |
-
width=1024,
|
| 29 |
-
height=1024,
|
| 30 |
-
steps=28,
|
| 31 |
-
n=1,
|
| 32 |
-
seed=0,
|
| 33 |
-
response_format="b64_json",
|
| 34 |
-
)
|
| 35 |
-
end_time = time.time()
|
| 36 |
-
if result and hasattr(result, "data") and len(result.data) > 0:
|
| 37 |
-
self._save_image_from_result(result, save_path)
|
| 38 |
-
else:
|
| 39 |
-
raise Exception("No result returned from Together API")
|
| 40 |
-
return end_time - start_time
|
| 41 |
-
|
| 42 |
-
def _save_image_from_result(self, result: Any, save_path: Path):
|
| 43 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 44 |
-
b64_str = result.data[0].b64_json
|
| 45 |
-
image_data = base64.b64decode(b64_str)
|
| 46 |
-
image = Image.open(io.BytesIO(image_data))
|
| 47 |
-
image.save(save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard/app.py → app.py
RENAMED
|
File without changes
|
benchmark/__init__.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
from typing import Type
|
| 2 |
-
|
| 3 |
-
from benchmark.draw_bench import DrawBenchPrompts
|
| 4 |
-
from benchmark.genai_bench import GenAIBenchPrompts
|
| 5 |
-
from benchmark.geneval import GenEvalPrompts
|
| 6 |
-
from benchmark.hps import HPSPrompts
|
| 7 |
-
from benchmark.parti import PartiPrompts
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def create_benchmark(
|
| 11 |
-
benchmark_type: str,
|
| 12 |
-
) -> Type[
|
| 13 |
-
DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
|
| 14 |
-
]:
|
| 15 |
-
"""
|
| 16 |
-
Factory function to create benchmark instances.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
benchmark_type (str): The type of benchmark to create. Must be one of:
|
| 20 |
-
- "draw_bench"
|
| 21 |
-
- "genai_bench"
|
| 22 |
-
- "geneval"
|
| 23 |
-
- "hps"
|
| 24 |
-
- "parti"
|
| 25 |
-
|
| 26 |
-
Returns:
|
| 27 |
-
An instance of the requested benchmark implementation
|
| 28 |
-
|
| 29 |
-
Raises:
|
| 30 |
-
ValueError: If an invalid benchmark type is provided
|
| 31 |
-
"""
|
| 32 |
-
benchmark_map = {
|
| 33 |
-
"draw_bench": DrawBenchPrompts,
|
| 34 |
-
"genai_bench": GenAIBenchPrompts,
|
| 35 |
-
"geneval": GenEvalPrompts,
|
| 36 |
-
"hps": HPSPrompts,
|
| 37 |
-
"parti": PartiPrompts,
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
if benchmark_type not in benchmark_map:
|
| 41 |
-
raise ValueError(
|
| 42 |
-
f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
return benchmark_map[benchmark_type]()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/draw_bench.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Iterator, List, Tuple
|
| 3 |
-
|
| 4 |
-
from datasets import load_dataset
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class DrawBenchPrompts:
|
| 8 |
-
def __init__(self):
|
| 9 |
-
self.dataset = load_dataset("shunk031/DrawBench")["test"]
|
| 10 |
-
|
| 11 |
-
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
| 12 |
-
for i, row in enumerate(self.dataset):
|
| 13 |
-
yield row["prompts"], Path(f"{i}.png")
|
| 14 |
-
|
| 15 |
-
@property
|
| 16 |
-
def name(self) -> str:
|
| 17 |
-
return "draw_bench"
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def size(self) -> int:
|
| 21 |
-
return len(self.dataset)
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def metrics(self) -> List[str]:
|
| 25 |
-
return ["image_reward"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/genai_bench.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Iterator, List, Tuple
|
| 3 |
-
|
| 4 |
-
import requests
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class GenAIBenchPrompts:
|
| 8 |
-
def __init__(self):
|
| 9 |
-
super().__init__()
|
| 10 |
-
self._download_genai_bench_files()
|
| 11 |
-
prompts_path = Path("downloads/genai_bench/prompts.txt")
|
| 12 |
-
with open(prompts_path, "r") as f:
|
| 13 |
-
self.prompts = [line.strip() for line in f if line.strip()]
|
| 14 |
-
|
| 15 |
-
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
| 16 |
-
for i, prompt in enumerate(self.prompts):
|
| 17 |
-
yield prompt, Path(f"{i}.png")
|
| 18 |
-
|
| 19 |
-
def _download_genai_bench_files(self) -> None:
|
| 20 |
-
folder_name = Path("downloads/genai_bench")
|
| 21 |
-
folder_name.mkdir(parents=True, exist_ok=True)
|
| 22 |
-
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
| 23 |
-
prompts_path = folder_name / "prompts.txt"
|
| 24 |
-
if not prompts_path.exists():
|
| 25 |
-
response = requests.get(prompts_url)
|
| 26 |
-
with open(prompts_path, "w") as f:
|
| 27 |
-
f.write(response.text)
|
| 28 |
-
|
| 29 |
-
@property
|
| 30 |
-
def name(self) -> str:
|
| 31 |
-
return "genai_bench"
|
| 32 |
-
|
| 33 |
-
@property
|
| 34 |
-
def size(self) -> int:
|
| 35 |
-
return len(self.prompts)
|
| 36 |
-
|
| 37 |
-
@property
|
| 38 |
-
def metrics(self) -> List[str]:
|
| 39 |
-
return ["vqa"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/geneval.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
from typing import Any, Dict, Iterator, List, Tuple
|
| 4 |
-
|
| 5 |
-
import requests
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class GenEvalPrompts:
|
| 9 |
-
def __init__(self):
|
| 10 |
-
super().__init__()
|
| 11 |
-
self._download_geneval_file()
|
| 12 |
-
metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
|
| 13 |
-
self.entries: List[Dict[str, Any]] = []
|
| 14 |
-
with open(metadata_path, "r") as f:
|
| 15 |
-
for line in f:
|
| 16 |
-
if line.strip():
|
| 17 |
-
self.entries.append(json.loads(line))
|
| 18 |
-
|
| 19 |
-
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
| 20 |
-
for i, entry in enumerate(self.entries):
|
| 21 |
-
folder_name = f"{i:05d}"
|
| 22 |
-
yield entry, folder_name
|
| 23 |
-
|
| 24 |
-
def _download_geneval_file(self) -> None:
|
| 25 |
-
folder_name = Path("downloads/geneval")
|
| 26 |
-
folder_name.mkdir(parents=True, exist_ok=True)
|
| 27 |
-
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
| 28 |
-
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
| 29 |
-
if not metadata_path.exists():
|
| 30 |
-
response = requests.get(metadata_url)
|
| 31 |
-
with open(metadata_path, "w") as f:
|
| 32 |
-
f.write(response.text)
|
| 33 |
-
|
| 34 |
-
@property
|
| 35 |
-
def name(self) -> str:
|
| 36 |
-
return "geneval"
|
| 37 |
-
|
| 38 |
-
@property
|
| 39 |
-
def size(self) -> int:
|
| 40 |
-
return len(self.entries)
|
| 41 |
-
|
| 42 |
-
@property
|
| 43 |
-
def metrics(self) -> List[str]:
|
| 44 |
-
raise NotImplementedError("GenEval requires custom evaluation, see README.md")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/hps.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Dict, Iterator, List, Tuple
|
| 5 |
-
|
| 6 |
-
import huggingface_hub
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class HPSPrompts:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
super().__init__()
|
| 12 |
-
self.hps_prompt_files = [
|
| 13 |
-
"anime.json",
|
| 14 |
-
"concept-art.json",
|
| 15 |
-
"paintings.json",
|
| 16 |
-
"photo.json",
|
| 17 |
-
]
|
| 18 |
-
self._download_benchmark_prompts()
|
| 19 |
-
self.prompts: Dict[str, str] = {}
|
| 20 |
-
self._size = 0
|
| 21 |
-
for file in self.hps_prompt_files:
|
| 22 |
-
category = file.replace(".json", "")
|
| 23 |
-
with open(os.path.join("downloads/hps", file), "r") as f:
|
| 24 |
-
prompts = json.load(f)
|
| 25 |
-
for i, prompt in enumerate(prompts):
|
| 26 |
-
if i == 100:
|
| 27 |
-
break
|
| 28 |
-
filename = f"{category}_{i:03d}.png"
|
| 29 |
-
self.prompts[filename] = prompt
|
| 30 |
-
self._size += 1
|
| 31 |
-
|
| 32 |
-
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
| 33 |
-
for filename, prompt in self.prompts.items():
|
| 34 |
-
yield prompt, Path(filename)
|
| 35 |
-
|
| 36 |
-
@property
|
| 37 |
-
def name(self) -> str:
|
| 38 |
-
return "hps"
|
| 39 |
-
|
| 40 |
-
@property
|
| 41 |
-
def size(self) -> int:
|
| 42 |
-
return self._size
|
| 43 |
-
|
| 44 |
-
def _download_benchmark_prompts(self) -> None:
|
| 45 |
-
folder_name = Path("downloads/hps")
|
| 46 |
-
folder_name.mkdir(parents=True, exist_ok=True)
|
| 47 |
-
for file in self.hps_prompt_files:
|
| 48 |
-
file_name = huggingface_hub.hf_hub_download(
|
| 49 |
-
"zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
|
| 50 |
-
)
|
| 51 |
-
if not os.path.exists(os.path.join(folder_name, file)):
|
| 52 |
-
os.symlink(file_name, os.path.join(folder_name, file))
|
| 53 |
-
|
| 54 |
-
@property
|
| 55 |
-
def metrics(self) -> List[str]:
|
| 56 |
-
return ["hps"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/__init__.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
from typing import Type
|
| 2 |
-
|
| 3 |
-
from benchmark.metrics.arniqa import ARNIQAMetric
|
| 4 |
-
from benchmark.metrics.clip import CLIPMetric
|
| 5 |
-
from benchmark.metrics.clip_iqa import CLIPIQAMetric
|
| 6 |
-
from benchmark.metrics.image_reward import ImageRewardMetric
|
| 7 |
-
from benchmark.metrics.sharpness import SharpnessMetric
|
| 8 |
-
from benchmark.metrics.vqa import VQAMetric
|
| 9 |
-
#from benchmark.metrics.hps import HPSMetric
|
| 10 |
-
|
| 11 |
-
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
| 12 |
-
"""
|
| 13 |
-
Factory function to create metric instances.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
metric_type (str): The type of metric to create. Must be one of:
|
| 17 |
-
- "arniqa"
|
| 18 |
-
- "clip"
|
| 19 |
-
- "clip_iqa"
|
| 20 |
-
- "image_reward"
|
| 21 |
-
- "sharpness"
|
| 22 |
-
- "vqa"
|
| 23 |
-
- "hps"
|
| 24 |
-
Returns:
|
| 25 |
-
An instance of the requested metric implementation
|
| 26 |
-
|
| 27 |
-
Raises:
|
| 28 |
-
ValueError: If an invalid metric type is provided
|
| 29 |
-
"""
|
| 30 |
-
metric_map = {
|
| 31 |
-
"arniqa": ARNIQAMetric,
|
| 32 |
-
"clip": CLIPMetric,
|
| 33 |
-
"clip_iqa": CLIPIQAMetric,
|
| 34 |
-
"image_reward": ImageRewardMetric,
|
| 35 |
-
"sharpness": SharpnessMetric,
|
| 36 |
-
"vqa": VQAMetric,
|
| 37 |
-
#"hps": HPSMetric,
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
if metric_type not in metric_map:
|
| 41 |
-
raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
|
| 42 |
-
|
| 43 |
-
return metric_map[metric_type]()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/arniqa.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from PIL import Image
|
| 6 |
-
from torchmetrics.image.arniqa import ARNIQA
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class ARNIQAMetric:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self.device = torch.device(
|
| 12 |
-
"cuda"
|
| 13 |
-
if torch.cuda.is_available()
|
| 14 |
-
else "mps"
|
| 15 |
-
if torch.backends.mps.is_available()
|
| 16 |
-
else "cpu"
|
| 17 |
-
)
|
| 18 |
-
self.metric = ARNIQA(
|
| 19 |
-
regressor_dataset="koniq10k",
|
| 20 |
-
reduction="mean",
|
| 21 |
-
normalize=True,
|
| 22 |
-
autocast=False,
|
| 23 |
-
)
|
| 24 |
-
self.metric.to(self.device)
|
| 25 |
-
|
| 26 |
-
@property
|
| 27 |
-
def name(self) -> str:
|
| 28 |
-
return "arniqa"
|
| 29 |
-
|
| 30 |
-
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
| 31 |
-
image_tensor = (
|
| 32 |
-
torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 33 |
-
)
|
| 34 |
-
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
| 35 |
-
score = self.metric(image_tensor)
|
| 36 |
-
return {"arniqa": score.item()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/clip.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from PIL import Image
|
| 6 |
-
from torchmetrics.multimodal.clip_score import CLIPScore
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class CLIPMetric:
|
| 10 |
-
def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
|
| 11 |
-
self.device = torch.device(
|
| 12 |
-
"cuda"
|
| 13 |
-
if torch.cuda.is_available()
|
| 14 |
-
else "mps"
|
| 15 |
-
if torch.backends.mps.is_available()
|
| 16 |
-
else "cpu"
|
| 17 |
-
)
|
| 18 |
-
self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
|
| 19 |
-
self.metric.to(self.device)
|
| 20 |
-
|
| 21 |
-
@property
|
| 22 |
-
def name(self) -> str:
|
| 23 |
-
return "clip"
|
| 24 |
-
|
| 25 |
-
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
| 26 |
-
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
| 27 |
-
image_tensor = image_tensor.to(self.device)
|
| 28 |
-
score = self.metric(image_tensor, prompt)
|
| 29 |
-
return {"clip": score.item()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/clip_iqa.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from PIL import Image
|
| 6 |
-
from torchmetrics.multimodal import CLIPImageQualityAssessment
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class CLIPIQAMetric:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self.device = torch.device(
|
| 12 |
-
"cuda"
|
| 13 |
-
if torch.cuda.is_available()
|
| 14 |
-
else "mps"
|
| 15 |
-
if torch.backends.mps.is_available()
|
| 16 |
-
else "cpu"
|
| 17 |
-
)
|
| 18 |
-
self.metric = CLIPImageQualityAssessment(
|
| 19 |
-
model_name_or_path="clip_iqa", data_range=255.0, prompts=("quality",)
|
| 20 |
-
)
|
| 21 |
-
self.metric.to(self.device)
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def name(self) -> str:
|
| 25 |
-
return "clip_iqa"
|
| 26 |
-
|
| 27 |
-
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
| 28 |
-
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
| 29 |
-
image_tensor = image_tensor.unsqueeze(0)
|
| 30 |
-
image_tensor = image_tensor.to(self.device)
|
| 31 |
-
scores = self.metric(image_tensor)
|
| 32 |
-
return {"clip_iqa": scores.item()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/hps.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from typing import Dict
|
| 3 |
-
|
| 4 |
-
import huggingface_hub
|
| 5 |
-
import torch
|
| 6 |
-
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 7 |
-
from hpsv2.utils import hps_version_map, root_path
|
| 8 |
-
from PIL import Image
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class HPSMetric:
|
| 12 |
-
def __init__(self):
|
| 13 |
-
self.hps_version = "v2.1"
|
| 14 |
-
self.device = torch.device(
|
| 15 |
-
"cuda"
|
| 16 |
-
if torch.cuda.is_available()
|
| 17 |
-
else "mps"
|
| 18 |
-
if torch.backends.mps.is_available()
|
| 19 |
-
else "cpu"
|
| 20 |
-
)
|
| 21 |
-
self.model_dict = {}
|
| 22 |
-
self._initialize_model()
|
| 23 |
-
|
| 24 |
-
def _initialize_model(self):
|
| 25 |
-
if not self.model_dict:
|
| 26 |
-
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
| 27 |
-
"ViT-H-14",
|
| 28 |
-
"laion2B-s32B-b79K",
|
| 29 |
-
precision="amp",
|
| 30 |
-
device=self.device,
|
| 31 |
-
jit=False,
|
| 32 |
-
force_quick_gelu=False,
|
| 33 |
-
force_custom_text=False,
|
| 34 |
-
force_patch_dropout=False,
|
| 35 |
-
force_image_size=None,
|
| 36 |
-
pretrained_image=False,
|
| 37 |
-
image_mean=None,
|
| 38 |
-
image_std=None,
|
| 39 |
-
light_augmentation=True,
|
| 40 |
-
aug_cfg={},
|
| 41 |
-
output_dict=True,
|
| 42 |
-
with_score_predictor=False,
|
| 43 |
-
with_region_predictor=False,
|
| 44 |
-
)
|
| 45 |
-
self.model_dict["model"] = model
|
| 46 |
-
self.model_dict["preprocess_val"] = preprocess_val
|
| 47 |
-
|
| 48 |
-
# Load checkpoint
|
| 49 |
-
if not os.path.exists(root_path):
|
| 50 |
-
os.makedirs(root_path)
|
| 51 |
-
cp = huggingface_hub.hf_hub_download(
|
| 52 |
-
"xswu/HPSv2", hps_version_map[self.hps_version]
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
checkpoint = torch.load(cp, map_location=self.device)
|
| 56 |
-
model.load_state_dict(checkpoint["state_dict"])
|
| 57 |
-
self.tokenizer = get_tokenizer("ViT-H-14")
|
| 58 |
-
model = model.to(self.device)
|
| 59 |
-
model.eval()
|
| 60 |
-
|
| 61 |
-
@property
|
| 62 |
-
def name(self) -> str:
|
| 63 |
-
return "hps"
|
| 64 |
-
|
| 65 |
-
def compute_score(
|
| 66 |
-
self,
|
| 67 |
-
image: Image.Image,
|
| 68 |
-
prompt: str,
|
| 69 |
-
) -> Dict[str, float]:
|
| 70 |
-
model = self.model_dict["model"]
|
| 71 |
-
preprocess_val = self.model_dict["preprocess_val"]
|
| 72 |
-
|
| 73 |
-
with torch.no_grad():
|
| 74 |
-
# Process the image
|
| 75 |
-
image_tensor = (
|
| 76 |
-
preprocess_val(image)
|
| 77 |
-
.unsqueeze(0)
|
| 78 |
-
.to(device=self.device, non_blocking=True)
|
| 79 |
-
)
|
| 80 |
-
# Process the prompt
|
| 81 |
-
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 82 |
-
# Calculate the HPS
|
| 83 |
-
with torch.cuda.amp.autocast():
|
| 84 |
-
outputs = model(image_tensor, text)
|
| 85 |
-
image_features, text_features = (
|
| 86 |
-
outputs["image_features"],
|
| 87 |
-
outputs["text_features"],
|
| 88 |
-
)
|
| 89 |
-
logits_per_image = image_features @ text_features.T
|
| 90 |
-
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 91 |
-
|
| 92 |
-
return {"hps": float(hps_score[0])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/image_reward.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import tempfile
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
import ImageReward as RM
|
| 6 |
-
import torch
|
| 7 |
-
from PIL import Image
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class ImageRewardMetric:
|
| 11 |
-
def __init__(self):
|
| 12 |
-
self.device = torch.device(
|
| 13 |
-
"cuda"
|
| 14 |
-
if torch.cuda.is_available()
|
| 15 |
-
else "mps"
|
| 16 |
-
if torch.backends.mps.is_available()
|
| 17 |
-
else "cpu"
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
self.model = RM.load("ImageReward-v1.0", device=str(self.device))
|
| 21 |
-
|
| 22 |
-
@property
|
| 23 |
-
def name(self) -> str:
|
| 24 |
-
return "image_reward"
|
| 25 |
-
|
| 26 |
-
def compute_score(
|
| 27 |
-
self,
|
| 28 |
-
image: Image.Image,
|
| 29 |
-
prompt: str,
|
| 30 |
-
) -> Dict[str, float]:
|
| 31 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 32 |
-
image.save(tmp.name)
|
| 33 |
-
score = self.model.score(prompt, [tmp.name])
|
| 34 |
-
os.unlink(tmp.name)
|
| 35 |
-
return {"image_reward": score}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/sharpness.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
-
|
| 3 |
-
import cv2
|
| 4 |
-
import numpy as np
|
| 5 |
-
from PIL import Image
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class SharpnessMetric:
|
| 9 |
-
def __init__(self):
|
| 10 |
-
self.kernel_size = 3
|
| 11 |
-
|
| 12 |
-
@property
|
| 13 |
-
def name(self) -> str:
|
| 14 |
-
return "sharpness"
|
| 15 |
-
|
| 16 |
-
def compute_score(
|
| 17 |
-
self,
|
| 18 |
-
image: Image.Image,
|
| 19 |
-
prompt: str,
|
| 20 |
-
) -> Dict[str, float]:
|
| 21 |
-
img = np.array(image.convert('L'))
|
| 22 |
-
laplacian = cv2.Laplacian(img, cv2.CV_64F, ksize=self.kernel_size)
|
| 23 |
-
sharpness = laplacian.var()
|
| 24 |
-
return {"sharpness": float(sharpness)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics/vqa.py
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Dict
|
| 3 |
-
|
| 4 |
-
import t2v_metrics
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class VQAMetric:
|
| 9 |
-
def __init__(self):
|
| 10 |
-
self.device = torch.device(
|
| 11 |
-
"cuda"
|
| 12 |
-
if torch.cuda.is_available()
|
| 13 |
-
else "mps"
|
| 14 |
-
if torch.backends.mps.is_available()
|
| 15 |
-
else "cpu"
|
| 16 |
-
)
|
| 17 |
-
self.metric = t2v_metrics.VQAScore(
|
| 18 |
-
model="clip-flant5-xxl", device=str(self.device)
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
@property
|
| 22 |
-
def name(self) -> str:
|
| 23 |
-
return "vqa_score"
|
| 24 |
-
|
| 25 |
-
def compute_score(
|
| 26 |
-
self,
|
| 27 |
-
image_path: Path,
|
| 28 |
-
prompt: str,
|
| 29 |
-
) -> Dict[str, float]:
|
| 30 |
-
score = self.metric(images=[str(image_path)], texts=[prompt])
|
| 31 |
-
return {"vqa": score[0][0].item()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/parti.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Iterator, List, Tuple
|
| 3 |
-
|
| 4 |
-
from datasets import load_dataset
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class PartiPrompts:
|
| 8 |
-
def __init__(self):
|
| 9 |
-
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
| 10 |
-
shuffled_dataset = dataset.shuffle(seed=42)
|
| 11 |
-
selected_dataset = shuffled_dataset.select(range(800))
|
| 12 |
-
self.prompts = [row["Prompt"] for row in selected_dataset]
|
| 13 |
-
|
| 14 |
-
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
| 15 |
-
for i, prompt in enumerate(self.prompts):
|
| 16 |
-
yield prompt, Path(f"{i}.png")
|
| 17 |
-
|
| 18 |
-
@property
|
| 19 |
-
def name(self) -> str:
|
| 20 |
-
return "parti"
|
| 21 |
-
|
| 22 |
-
@property
|
| 23 |
-
def size(self) -> int:
|
| 24 |
-
return len(self.prompts)
|
| 25 |
-
|
| 26 |
-
@property
|
| 27 |
-
def metrics(self) -> List[str]:
|
| 28 |
-
return ["arniqa", "clip", "clip_iqa", "sharpness"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard/requirements.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
-e ..
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
{dashboard/data → data}/text_to_image.jsonl
RENAMED
|
File without changes
|
evaluate.py
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
import warnings
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Dict
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
from PIL import Image
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
|
| 11 |
-
from benchmark import create_benchmark
|
| 12 |
-
from benchmark.metrics import create_metric
|
| 13 |
-
|
| 14 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def evaluate_benchmark(
|
| 18 |
-
benchmark_type: str, api_type: str, images_dir: Path = Path("images")
|
| 19 |
-
) -> Dict:
|
| 20 |
-
"""
|
| 21 |
-
Evaluate a benchmark's images using its specific metrics.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
benchmark_type (str): Type of benchmark to evaluate
|
| 25 |
-
api_type (str): Type of API used to generate images
|
| 26 |
-
images_dir (Path): Base directory containing generated images
|
| 27 |
-
|
| 28 |
-
Returns:
|
| 29 |
-
Dict containing evaluation results
|
| 30 |
-
"""
|
| 31 |
-
benchmark = create_benchmark(benchmark_type)
|
| 32 |
-
|
| 33 |
-
benchmark_dir = images_dir / api_type / benchmark_type
|
| 34 |
-
metadata_file = benchmark_dir / "metadata.jsonl"
|
| 35 |
-
|
| 36 |
-
if not metadata_file.exists():
|
| 37 |
-
raise FileNotFoundError(
|
| 38 |
-
f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
metadata = []
|
| 42 |
-
with open(metadata_file, "r") as f:
|
| 43 |
-
for line in f:
|
| 44 |
-
metadata.append(json.loads(line))
|
| 45 |
-
|
| 46 |
-
metrics = {
|
| 47 |
-
metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
results = {
|
| 51 |
-
"api": api_type,
|
| 52 |
-
"benchmark": benchmark_type,
|
| 53 |
-
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
| 54 |
-
"total_images": len(metadata),
|
| 55 |
-
}
|
| 56 |
-
inference_times = []
|
| 57 |
-
|
| 58 |
-
for entry in tqdm(metadata):
|
| 59 |
-
image_path = benchmark_dir / entry["filepath"]
|
| 60 |
-
if not image_path.exists():
|
| 61 |
-
continue
|
| 62 |
-
|
| 63 |
-
for metric_type, metric in metrics.items():
|
| 64 |
-
try:
|
| 65 |
-
if metric_type == "vqa":
|
| 66 |
-
score = metric.compute_score(image_path, entry["prompt"])
|
| 67 |
-
else:
|
| 68 |
-
image = Image.open(image_path)
|
| 69 |
-
score = metric.compute_score(image, entry["prompt"])
|
| 70 |
-
results["metrics"][metric_type] += score[metric_type]
|
| 71 |
-
except Exception as e:
|
| 72 |
-
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
| 73 |
-
|
| 74 |
-
inference_times.append(entry["inference_time"])
|
| 75 |
-
|
| 76 |
-
for metric in results["metrics"]:
|
| 77 |
-
results["metrics"][metric] /= len(metadata)
|
| 78 |
-
results["median_inference_time"] = np.median(inference_times).item()
|
| 79 |
-
|
| 80 |
-
return results
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def main():
|
| 84 |
-
parser = argparse.ArgumentParser(
|
| 85 |
-
description="Evaluate generated images using benchmark-specific metrics"
|
| 86 |
-
)
|
| 87 |
-
parser.add_argument("api_type", help="Type of API to evaluate")
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"benchmarks", nargs="+", help="List of benchmark types to evaluate"
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
args = parser.parse_args()
|
| 93 |
-
|
| 94 |
-
results_dir = Path("evaluation_results")
|
| 95 |
-
results_dir.mkdir(exist_ok=True)
|
| 96 |
-
|
| 97 |
-
results_file = results_dir / f"{args.api_type}.jsonl"
|
| 98 |
-
existing_results = set()
|
| 99 |
-
|
| 100 |
-
if results_file.exists():
|
| 101 |
-
with open(results_file, "r") as f:
|
| 102 |
-
for line in f:
|
| 103 |
-
result = json.loads(line)
|
| 104 |
-
existing_results.add(result["benchmark"])
|
| 105 |
-
|
| 106 |
-
for benchmark_type in args.benchmarks:
|
| 107 |
-
if benchmark_type in existing_results:
|
| 108 |
-
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
| 109 |
-
continue
|
| 110 |
-
|
| 111 |
-
try:
|
| 112 |
-
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
| 113 |
-
results = evaluate_benchmark(benchmark_type, args.api_type)
|
| 114 |
-
|
| 115 |
-
# Append results to file
|
| 116 |
-
with open(results_file, "a") as f:
|
| 117 |
-
f.write(json.dumps(results) + "\n")
|
| 118 |
-
|
| 119 |
-
except Exception as e:
|
| 120 |
-
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
if __name__ == "__main__":
|
| 124 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation_results/.gitkeep
DELETED
|
File without changes
|
images/.gitkeep
DELETED
|
File without changes
|
pyproject.toml
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
[project]
|
| 2 |
-
name = "inferbench"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
description = "InferBench - AI inference benchmarking tool"
|
| 5 |
-
requires-python = ">=3.12"
|
| 6 |
-
dependencies = [
|
| 7 |
-
"numpy",
|
| 8 |
-
"opencv-python",
|
| 9 |
-
"pillow",
|
| 10 |
-
"python-dotenv",
|
| 11 |
-
"requests",
|
| 12 |
-
"tqdm",
|
| 13 |
-
"datasets==3.6.0",
|
| 14 |
-
"fal-client>=0.5.9",
|
| 15 |
-
"hpsv2>=1.2.0",
|
| 16 |
-
"huggingface-hub>=0.30.2",
|
| 17 |
-
"image-reward>=1.5",
|
| 18 |
-
"replicate>=1.0.4",
|
| 19 |
-
"t2v-metrics>=1.2",
|
| 20 |
-
"together>=1.5.5",
|
| 21 |
-
"torch>=2.7.0",
|
| 22 |
-
"torchmetrics>=1.7.1",
|
| 23 |
-
"clip",
|
| 24 |
-
"diffusers<=0.31",
|
| 25 |
-
"piq>=0.8.0",
|
| 26 |
-
"boto3>=1.39.4",
|
| 27 |
-
"gradio>=5.37.0",
|
| 28 |
-
"gradio-leaderboard>=0.0.14",
|
| 29 |
-
]
|
| 30 |
-
|
| 31 |
-
[build-system]
|
| 32 |
-
requires = ["setuptools>=61.0", "wheel"]
|
| 33 |
-
build-backend = "setuptools.build_meta"
|
| 34 |
-
|
| 35 |
-
[tool.setuptools.packages.find]
|
| 36 |
-
include = ["api*", "benchmark*", "dashboard*"]
|
| 37 |
-
exclude = ["images*", "evaluation_results*", "*.tests*"]
|
| 38 |
-
|
| 39 |
-
[tool.hatch.build.targets.wheel]
|
| 40 |
-
packages = ["api", "benchmark", "dashboard"]
|
| 41 |
-
|
| 42 |
-
[tool.uv]
|
| 43 |
-
dev-dependencies = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import List
|
| 5 |
-
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
|
| 8 |
-
from api import create_api
|
| 9 |
-
from benchmark import create_benchmark
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def generate_images(api_type: str, benchmarks: List[str]):
|
| 13 |
-
images_dir = Path("images")
|
| 14 |
-
api = create_api(api_type)
|
| 15 |
-
|
| 16 |
-
api_dir = images_dir / api_type
|
| 17 |
-
api_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
-
|
| 19 |
-
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
|
| 20 |
-
print(f"\nProcessing benchmark: {benchmark_type}")
|
| 21 |
-
|
| 22 |
-
benchmark = create_benchmark(benchmark_type)
|
| 23 |
-
|
| 24 |
-
if benchmark_type == "geneval":
|
| 25 |
-
benchmark_dir = api_dir / benchmark_type
|
| 26 |
-
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
| 27 |
-
|
| 28 |
-
metadata_file = benchmark_dir / "metadata.jsonl"
|
| 29 |
-
existing_metadata = {}
|
| 30 |
-
if metadata_file.exists():
|
| 31 |
-
with open(metadata_file, "r") as f:
|
| 32 |
-
for line in f:
|
| 33 |
-
entry = json.loads(line)
|
| 34 |
-
existing_metadata[entry["filepath"]] = entry
|
| 35 |
-
|
| 36 |
-
with open(metadata_file, "a") as f:
|
| 37 |
-
for metadata, folder_name in tqdm(
|
| 38 |
-
benchmark,
|
| 39 |
-
desc=f"Generating images for {benchmark_type}",
|
| 40 |
-
leave=False,
|
| 41 |
-
):
|
| 42 |
-
sample_path = benchmark_dir / folder_name
|
| 43 |
-
samples_path = sample_path / "samples"
|
| 44 |
-
samples_path.mkdir(parents=True, exist_ok=True)
|
| 45 |
-
image_path = samples_path / "0000.png"
|
| 46 |
-
|
| 47 |
-
if image_path.exists():
|
| 48 |
-
continue
|
| 49 |
-
|
| 50 |
-
try:
|
| 51 |
-
inference_time = api.generate_image(
|
| 52 |
-
metadata["prompt"], image_path
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
metadata_entry = {
|
| 56 |
-
"filepath": str(image_path),
|
| 57 |
-
"prompt": metadata["prompt"],
|
| 58 |
-
"inference_time": inference_time,
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
f.write(json.dumps(metadata_entry) + "\n")
|
| 62 |
-
|
| 63 |
-
except Exception as e:
|
| 64 |
-
print(
|
| 65 |
-
f"\nError generating image for prompt: {metadata['prompt']}"
|
| 66 |
-
)
|
| 67 |
-
print(f"Error: {str(e)}")
|
| 68 |
-
continue
|
| 69 |
-
else:
|
| 70 |
-
benchmark_dir = api_dir / benchmark_type
|
| 71 |
-
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
-
|
| 73 |
-
metadata_file = benchmark_dir / "metadata.jsonl"
|
| 74 |
-
existing_metadata = {}
|
| 75 |
-
if metadata_file.exists():
|
| 76 |
-
with open(metadata_file, "r") as f:
|
| 77 |
-
for line in f:
|
| 78 |
-
entry = json.loads(line)
|
| 79 |
-
existing_metadata[entry["filepath"]] = entry
|
| 80 |
-
|
| 81 |
-
with open(metadata_file, "a") as f:
|
| 82 |
-
for prompt, image_path in tqdm(
|
| 83 |
-
benchmark,
|
| 84 |
-
desc=f"Generating images for {benchmark_type}",
|
| 85 |
-
leave=False,
|
| 86 |
-
):
|
| 87 |
-
if image_path in existing_metadata:
|
| 88 |
-
continue
|
| 89 |
-
|
| 90 |
-
full_image_path = benchmark_dir / image_path
|
| 91 |
-
|
| 92 |
-
if full_image_path.exists():
|
| 93 |
-
continue
|
| 94 |
-
|
| 95 |
-
try:
|
| 96 |
-
inference_time = api.generate_image(prompt, full_image_path)
|
| 97 |
-
|
| 98 |
-
metadata_entry = {
|
| 99 |
-
"filepath": str(image_path),
|
| 100 |
-
"prompt": prompt,
|
| 101 |
-
"inference_time": inference_time,
|
| 102 |
-
}
|
| 103 |
-
|
| 104 |
-
f.write(json.dumps(metadata_entry) + "\n")
|
| 105 |
-
|
| 106 |
-
except Exception as e:
|
| 107 |
-
print(f"\nError generating image for prompt: {prompt}")
|
| 108 |
-
print(f"Error: {str(e)}")
|
| 109 |
-
continue
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def main():
|
| 113 |
-
parser = argparse.ArgumentParser(
|
| 114 |
-
description="Generate images for specified benchmarks using a given API"
|
| 115 |
-
)
|
| 116 |
-
parser.add_argument("api_type", help="Type of API to use for image generation")
|
| 117 |
-
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
|
| 118 |
-
|
| 119 |
-
args = parser.parse_args()
|
| 120 |
-
|
| 121 |
-
generate_images(args.api_type, args.benchmarks)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
if __name__ == "__main__":
|
| 125 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|