davidberenstein1957 commited on
Commit
e34d069
·
1 Parent(s): db46f9e

chore: cleanup unnecessary code for runnign benchmark

Browse files
.env.example DELETED
@@ -1,4 +0,0 @@
1
- FIREWORKS_API_TOKEN=your_fireworks_api_token_here
2
- REPLICATE_API_TOKEN=your_replicate_api_token_here
3
- FAL_KEY=your_fal_key_here
4
- TOGETHER_API_KEY=your_together_api_key_here
 
 
 
 
 
api/__init__.py DELETED
@@ -1,66 +0,0 @@
1
- from typing import Type
2
-
3
- from api.aws import AWSBedrockAPI
4
- from api.baseline import BaselineAPI
5
- from api.fal import FalAPI
6
- from api.fireworks import FireworksAPI
7
- from api.flux import FluxAPI
8
- from api.pruna import PrunaAPI
9
- from api.pruna_dev import PrunaDevAPI
10
- from api.replicate import ReplicateAPI
11
- from api.together import TogetherAPI
12
-
13
- __all__ = [
14
- "create_api",
15
- "FluxAPI",
16
- "BaselineAPI",
17
- "FireworksAPI",
18
- "PrunaAPI",
19
- "ReplicateAPI",
20
- "TogetherAPI",
21
- "FalAPI",
22
- "PrunaDevAPI",
23
- ]
24
-
25
-
26
- def create_api(api_type: str) -> FluxAPI:
27
- """
28
- Factory function to create API instances.
29
-
30
- Args:
31
- api_type (str): The type of API to create. Must be one of:
32
- - "baseline"
33
- - "fireworks"
34
- - "pruna_speed_mode" (where speed_mode is the desired speed mode)
35
- - "replicate"
36
- - "together"
37
- - "fal"
38
- - "aws"
39
-
40
- Returns:
41
- FluxAPI: An instance of the requested API implementation
42
-
43
- Raises:
44
- ValueError: If an invalid API type is provided
45
- """
46
- if api_type == "pruna_dev":
47
- return PrunaDevAPI()
48
- if api_type.startswith("pruna_"):
49
- speed_mode = api_type[6:] # Remove "pruna_" prefix
50
- return PrunaAPI(speed_mode)
51
-
52
- api_map: dict[str, Type[FluxAPI]] = {
53
- "baseline": BaselineAPI,
54
- "fireworks": FireworksAPI,
55
- "replicate": ReplicateAPI,
56
- "together": TogetherAPI,
57
- "fal": FalAPI,
58
- "aws": AWSBedrockAPI,
59
- }
60
-
61
- if api_type not in api_map:
62
- raise ValueError(
63
- f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
64
- )
65
-
66
- return api_map[api_type]()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/aws.py DELETED
@@ -1,69 +0,0 @@
1
- import base64
2
- import json
3
- import os
4
- import time
5
- from pathlib import Path
6
-
7
- import boto3
8
- from dotenv import load_dotenv
9
-
10
- from api.flux import FluxAPI
11
-
12
-
13
- class AWSBedrockAPI(FluxAPI):
14
- def __init__(self):
15
- load_dotenv()
16
- # AWS credentials should be set via environment variables
17
- # AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN
18
- os.environ["AWS_ACCESS_KEY_ID"] = ""
19
- os.environ["AWS_SECRET_ACCESS_KEY"] = ""
20
- os.environ["AWS_SESSION_TOKEN"] = ""
21
- os.environ["AWS_REGION"] = "us-east-1"
22
- self._client = boto3.client("bedrock-runtime")
23
- self._model_id = "amazon.nova-canvas-v1:0"
24
-
25
- @property
26
- def name(self) -> str:
27
- return "aws_nova_canvas"
28
-
29
- def generate_image(self, prompt: str, save_path: Path) -> float:
30
- start_time = time.time()
31
- # Format the request payload
32
- native_request = {
33
- "taskType": "TEXT_IMAGE",
34
- "textToImageParams": {"text": prompt},
35
- "imageGenerationConfig": {
36
- "seed": 0,
37
- "quality": "standard",
38
- "height": 1024,
39
- "width": 1024,
40
- "numberOfImages": 1,
41
- },
42
- }
43
-
44
- try:
45
- # Convert request to JSON and invoke the model
46
- request = json.dumps(native_request)
47
- response = self._client.invoke_model(modelId=self._model_id, body=request)
48
-
49
- # Process the response
50
- model_response = json.loads(response["body"].read())
51
- if not model_response.get("images"):
52
- raise Exception("No images returned from AWS Bedrock API")
53
-
54
- # Save the image
55
- base64_image_data = model_response["images"][0]
56
- self._save_image_from_base64(base64_image_data, save_path)
57
-
58
- except Exception as e:
59
- raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
60
-
61
- end_time = time.time()
62
- return end_time - start_time
63
-
64
- def _save_image_from_base64(self, base64_data: str, save_path: Path):
65
- """Save a base64 encoded image to the specified path."""
66
- save_path.parent.mkdir(parents=True, exist_ok=True)
67
- image_data = base64.b64decode(base64_data)
68
- with open(save_path, "wb") as f:
69
- f.write(image_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/baseline.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- from dotenv import load_dotenv
7
-
8
- from api.flux import FluxAPI
9
-
10
-
11
- class BaselineAPI(FluxAPI):
12
- """
13
- As our baseline, we use the Replicate API with go_fast=False.
14
- """
15
-
16
- def __init__(self):
17
- load_dotenv()
18
- self._api_key = os.getenv("REPLICATE_API_TOKEN")
19
- if not self._api_key:
20
- raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
21
-
22
- @property
23
- def name(self) -> str:
24
- return "baseline"
25
-
26
- def generate_image(self, prompt: str, save_path: Path) -> float:
27
- import replicate
28
-
29
- start_time = time.time()
30
- result = replicate.run(
31
- "black-forest-labs/flux-dev",
32
- input={
33
- "prompt": prompt,
34
- "go_fast": False,
35
- "guidance": 3.5,
36
- "num_outputs": 1,
37
- "aspect_ratio": "1:1",
38
- "output_format": "png",
39
- "num_inference_steps": 28,
40
- "seed": 0,
41
- },
42
- )
43
- end_time = time.time()
44
-
45
- if result and len(result) > 0:
46
- self._save_image_from_result(result[0], save_path)
47
- else:
48
- raise Exception("No result returned from Replicate API")
49
-
50
- return end_time - start_time
51
-
52
- def _save_image_from_result(self, result: Any, save_path: Path):
53
- save_path.parent.mkdir(parents=True, exist_ok=True)
54
- with open(save_path, "wb") as f:
55
- f.write(result.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/fal.py DELETED
@@ -1,48 +0,0 @@
1
- import time
2
- from io import BytesIO
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import fal_client
7
- import requests
8
- from PIL import Image
9
-
10
- from api.flux import FluxAPI
11
-
12
-
13
- class FalAPI(FluxAPI):
14
- @property
15
- def name(self) -> str:
16
- return "fal"
17
-
18
- def generate_image(self, prompt: str, save_path: Path) -> float:
19
- start_time = time.time()
20
- result = fal_client.subscribe(
21
- "fal-ai/flux/dev",
22
- arguments={
23
- "seed": 0,
24
- "prompt": prompt,
25
- "image_size": "square_hd", # 1024x1024 image
26
- "num_images": 1,
27
- "guidance_scale": 3.5,
28
- "num_inference_steps": 28,
29
- "enable_safety_checker": True,
30
- },
31
- )
32
- end_time = time.time()
33
-
34
- url = result["images"][0]["url"]
35
- self._save_image_from_url(url, save_path)
36
-
37
- return end_time - start_time
38
-
39
- def _save_image_from_url(self, url: str, save_path: Path):
40
- response = requests.get(url)
41
- image = Image.open(BytesIO(response.content))
42
- save_path.parent.mkdir(parents=True, exist_ok=True)
43
- image.save(save_path)
44
-
45
- def _save_image_from_result(self, result: Any, save_path: Path):
46
- save_path.parent.mkdir(parents=True, exist_ok=True)
47
- with open(save_path, "wb") as f:
48
- f.write(result.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/fireworks.py DELETED
@@ -1,53 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import requests
7
- from dotenv import load_dotenv
8
-
9
- from api.flux import FluxAPI
10
-
11
-
12
- class FireworksAPI(FluxAPI):
13
- def __init__(self):
14
- load_dotenv()
15
- self._api_key = os.getenv("FIREWORKS_API_TOKEN")
16
- if not self._api_key:
17
- raise ValueError("FIREWORKS_API_TOKEN not found in environment variables")
18
- self._url = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image"
19
-
20
- @property
21
- def name(self) -> str:
22
- return "fireworks_fp8"
23
-
24
- def generate_image(self, prompt: str, save_path: Path) -> float:
25
- start_time = time.time()
26
-
27
- headers = {
28
- "Content-Type": "application/json",
29
- "Accept": "image/jpeg",
30
- "Authorization": f"Bearer {self._api_key}",
31
- }
32
- data = {
33
- "prompt": prompt,
34
- "aspect_ratio": "1:1",
35
- "guidance_scale": 3.5,
36
- "num_inference_steps": 28,
37
- "seed": 0,
38
- }
39
- result = requests.post(self._url, headers=headers, json=data)
40
-
41
- end_time = time.time()
42
-
43
- if result.status_code == 200:
44
- self._save_image_from_result(result, save_path)
45
- else:
46
- raise Exception(f"Error: {result.status_code} {result.text}")
47
-
48
- return end_time - start_time
49
-
50
- def _save_image_from_result(self, result: Any, save_path: Path):
51
- save_path.parent.mkdir(parents=True, exist_ok=True)
52
- with open(save_path, "wb") as f:
53
- f.write(result.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/flux.py DELETED
@@ -1,35 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from pathlib import Path
3
-
4
-
5
- class FluxAPI(ABC):
6
- """
7
- Abstract base class for Flux API implementations.
8
-
9
- This class defines the common interface for all Flux API implementations.
10
- """
11
-
12
- @property
13
- @abstractmethod
14
- def name(self) -> str:
15
- """
16
- The name of the API implementation.
17
-
18
- Returns:
19
- str: The name of the specific API implementation
20
- """
21
- pass
22
-
23
- @abstractmethod
24
- def generate_image(self, prompt: str, save_path: Path) -> float:
25
- """
26
- Generate an image based on the prompt and save it to the specified path.
27
-
28
- Args:
29
- prompt (str): The text prompt to generate the image from
30
- save_path (Path): The path where the generated image should be saved
31
-
32
- Returns:
33
- float: The time taken for the API call in seconds
34
- """
35
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/pruna.py DELETED
@@ -1,53 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import replicate
7
- from dotenv import load_dotenv
8
-
9
- from api.flux import FluxAPI
10
-
11
-
12
- class PrunaAPI(FluxAPI):
13
- def __init__(self, speed_mode: str):
14
- self._speed_mode = speed_mode
15
- self._speed_mode_name = (
16
- speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
17
- )
18
- load_dotenv()
19
- self._api_key = os.getenv("REPLICATE_API_TOKEN")
20
- if not self._api_key:
21
- raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
22
-
23
- @property
24
- def name(self) -> str:
25
- return f"pruna_{self._speed_mode_name}"
26
-
27
- def generate_image(self, prompt: str, save_path: Path) -> float:
28
- start_time = time.time()
29
- result = replicate.run(
30
- "prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410",
31
- input={
32
- "seed": 0,
33
- "prompt": prompt,
34
- "guidance": 3.5,
35
- "num_outputs": 1,
36
- "aspect_ratio": "1:1",
37
- "output_format": "png",
38
- "speed_mode": self._speed_mode,
39
- "num_inference_steps": 28,
40
- },
41
- )
42
- end_time = time.time()
43
-
44
- if result:
45
- self._save_image_from_result(result, save_path)
46
- else:
47
- raise Exception("No result returned from Replicate API")
48
- return end_time - start_time
49
-
50
- def _save_image_from_result(self, result: Any, save_path: Path):
51
- save_path.parent.mkdir(parents=True, exist_ok=True)
52
- with open(save_path, "wb") as f:
53
- f.write(result.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/pruna_dev.py DELETED
@@ -1,49 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import replicate
7
- from dotenv import load_dotenv
8
-
9
- from api.flux import FluxAPI
10
-
11
-
12
- class PrunaDevAPI(FluxAPI):
13
- def __init__(self):
14
- load_dotenv()
15
- self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
- if not self._api_key:
17
- raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
-
19
- @property
20
- def name(self) -> str:
21
- return "pruna_dev"
22
-
23
- def generate_image(self, prompt: str, save_path: Path) -> float:
24
- start_time = time.time()
25
- result = replicate.run(
26
- "prunaai/flux.1-dev:938a4eb31a87d65fb7b23fc300fb5b7ab88a36844bb26e54e1d1dec7acf4eefe",
27
- input={
28
- "seed": 0,
29
- "prompt": prompt,
30
- "guidance": 3.5,
31
- "num_outputs": 1,
32
- "aspect_ratio": "1:1",
33
- "output_format": "png",
34
- "speed_mode": "Juiced 🔥 (default)",
35
- "num_inference_steps": 28,
36
- },
37
- )
38
- end_time = time.time()
39
-
40
- if result:
41
- self._save_image_from_result(result, save_path)
42
- else:
43
- raise Exception("No result returned from Replicate API")
44
- return end_time - start_time
45
-
46
- def _save_image_from_result(self, result: Any, save_path: Path):
47
- save_path.parent.mkdir(parents=True, exist_ok=True)
48
- with open(save_path, "wb") as f:
49
- f.write(result.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/replicate.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import replicate
7
- from dotenv import load_dotenv
8
-
9
- from api.flux import FluxAPI
10
-
11
-
12
- class ReplicateAPI(FluxAPI):
13
- def __init__(self):
14
- load_dotenv()
15
- self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
- if not self._api_key:
17
- raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
-
19
- @property
20
- def name(self) -> str:
21
- return "replicate_go_fast"
22
-
23
- def generate_image(self, prompt: str, save_path: Path) -> float:
24
- start_time = time.time()
25
- result = replicate.run(
26
- "black-forest-labs/flux-dev",
27
- input={
28
- "seed": 0,
29
- "prompt": prompt,
30
- "go_fast": True,
31
- "guidance": 3.5,
32
- "num_outputs": 1,
33
- "aspect_ratio": "1:1",
34
- "output_format": "png",
35
- "num_inference_steps": 28,
36
- },
37
- )
38
- end_time = time.time()
39
- if result and len(result) > 0:
40
- self._save_image_from_result(result[0], save_path)
41
- else:
42
- raise Exception("No result returned from Replicate API")
43
- return end_time - start_time
44
-
45
- def _save_image_from_result(self, result: Any, save_path: Path):
46
- save_path.parent.mkdir(parents=True, exist_ok=True)
47
- with open(save_path, "wb") as f:
48
- f.write(result.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/replicate_wan.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import replicate
7
- from dotenv import load_dotenv
8
-
9
- from api.flux import FluxAPI
10
-
11
-
12
- class ReplicateAPI(FluxAPI):
13
- def __init__(self):
14
- load_dotenv()
15
- self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
- if not self._api_key:
17
- raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
-
19
- @property
20
- def name(self) -> str:
21
- return "replicate_go_fast"
22
-
23
- def generate_image(self, prompt: str, save_path: Path) -> float:
24
- start_time = time.time()
25
- result = replicate.run(
26
- "black-forest-labs/flux-dev",
27
- input={
28
- "seed": 0,
29
- "prompt": prompt,
30
- "go_fast": True,
31
- "guidance": 3.5,
32
- "num_outputs": 1,
33
- "aspect_ratio": "1:1",
34
- "output_format": "png",
35
- "num_inference_steps": 28,
36
- },
37
- )
38
- end_time = time.time()
39
- if result and len(result) > 0:
40
- self._save_image_from_result(result[0], save_path)
41
- else:
42
- raise Exception("No result returned from Replicate API")
43
- return end_time - start_time
44
-
45
- def _save_image_from_result(self, result: Any, save_path: Path):
46
- save_path.parent.mkdir(parents=True, exist_ok=True)
47
- with open(save_path, "wb") as f:
48
- f.write(result.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/together.py DELETED
@@ -1,47 +0,0 @@
1
- import base64
2
- import io
3
- import time
4
- from pathlib import Path
5
- from typing import Any
6
-
7
- from dotenv import load_dotenv
8
- from PIL import Image
9
- from together import Together
10
-
11
- from api.flux import FluxAPI
12
-
13
-
14
- class TogetherAPI(FluxAPI):
15
- def __init__(self):
16
- load_dotenv()
17
- self._client = Together()
18
-
19
- @property
20
- def name(self) -> str:
21
- return "together"
22
-
23
- def generate_image(self, prompt: str, save_path: Path) -> float:
24
- start_time = time.time()
25
- result = self._client.images.generate(
26
- prompt=prompt,
27
- model="black-forest-labs/FLUX.1-dev",
28
- width=1024,
29
- height=1024,
30
- steps=28,
31
- n=1,
32
- seed=0,
33
- response_format="b64_json",
34
- )
35
- end_time = time.time()
36
- if result and hasattr(result, "data") and len(result.data) > 0:
37
- self._save_image_from_result(result, save_path)
38
- else:
39
- raise Exception("No result returned from Together API")
40
- return end_time - start_time
41
-
42
- def _save_image_from_result(self, result: Any, save_path: Path):
43
- save_path.parent.mkdir(parents=True, exist_ok=True)
44
- b64_str = result.data[0].b64_json
45
- image_data = base64.b64decode(b64_str)
46
- image = Image.open(io.BytesIO(image_data))
47
- image.save(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dashboard/app.py → app.py RENAMED
File without changes
benchmark/__init__.py DELETED
@@ -1,45 +0,0 @@
1
- from typing import Type
2
-
3
- from benchmark.draw_bench import DrawBenchPrompts
4
- from benchmark.genai_bench import GenAIBenchPrompts
5
- from benchmark.geneval import GenEvalPrompts
6
- from benchmark.hps import HPSPrompts
7
- from benchmark.parti import PartiPrompts
8
-
9
-
10
- def create_benchmark(
11
- benchmark_type: str,
12
- ) -> Type[
13
- DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
14
- ]:
15
- """
16
- Factory function to create benchmark instances.
17
-
18
- Args:
19
- benchmark_type (str): The type of benchmark to create. Must be one of:
20
- - "draw_bench"
21
- - "genai_bench"
22
- - "geneval"
23
- - "hps"
24
- - "parti"
25
-
26
- Returns:
27
- An instance of the requested benchmark implementation
28
-
29
- Raises:
30
- ValueError: If an invalid benchmark type is provided
31
- """
32
- benchmark_map = {
33
- "draw_bench": DrawBenchPrompts,
34
- "genai_bench": GenAIBenchPrompts,
35
- "geneval": GenEvalPrompts,
36
- "hps": HPSPrompts,
37
- "parti": PartiPrompts,
38
- }
39
-
40
- if benchmark_type not in benchmark_map:
41
- raise ValueError(
42
- f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
43
- )
44
-
45
- return benchmark_map[benchmark_type]()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/draw_bench.py DELETED
@@ -1,25 +0,0 @@
1
- from pathlib import Path
2
- from typing import Iterator, List, Tuple
3
-
4
- from datasets import load_dataset
5
-
6
-
7
- class DrawBenchPrompts:
8
- def __init__(self):
9
- self.dataset = load_dataset("shunk031/DrawBench")["test"]
10
-
11
- def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
- for i, row in enumerate(self.dataset):
13
- yield row["prompts"], Path(f"{i}.png")
14
-
15
- @property
16
- def name(self) -> str:
17
- return "draw_bench"
18
-
19
- @property
20
- def size(self) -> int:
21
- return len(self.dataset)
22
-
23
- @property
24
- def metrics(self) -> List[str]:
25
- return ["image_reward"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/genai_bench.py DELETED
@@ -1,39 +0,0 @@
1
- from pathlib import Path
2
- from typing import Iterator, List, Tuple
3
-
4
- import requests
5
-
6
-
7
- class GenAIBenchPrompts:
8
- def __init__(self):
9
- super().__init__()
10
- self._download_genai_bench_files()
11
- prompts_path = Path("downloads/genai_bench/prompts.txt")
12
- with open(prompts_path, "r") as f:
13
- self.prompts = [line.strip() for line in f if line.strip()]
14
-
15
- def __iter__(self) -> Iterator[Tuple[str, Path]]:
16
- for i, prompt in enumerate(self.prompts):
17
- yield prompt, Path(f"{i}.png")
18
-
19
- def _download_genai_bench_files(self) -> None:
20
- folder_name = Path("downloads/genai_bench")
21
- folder_name.mkdir(parents=True, exist_ok=True)
22
- prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
- prompts_path = folder_name / "prompts.txt"
24
- if not prompts_path.exists():
25
- response = requests.get(prompts_url)
26
- with open(prompts_path, "w") as f:
27
- f.write(response.text)
28
-
29
- @property
30
- def name(self) -> str:
31
- return "genai_bench"
32
-
33
- @property
34
- def size(self) -> int:
35
- return len(self.prompts)
36
-
37
- @property
38
- def metrics(self) -> List[str]:
39
- return ["vqa"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/geneval.py DELETED
@@ -1,44 +0,0 @@
1
- import json
2
- from pathlib import Path
3
- from typing import Any, Dict, Iterator, List, Tuple
4
-
5
- import requests
6
-
7
-
8
- class GenEvalPrompts:
9
- def __init__(self):
10
- super().__init__()
11
- self._download_geneval_file()
12
- metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
13
- self.entries: List[Dict[str, Any]] = []
14
- with open(metadata_path, "r") as f:
15
- for line in f:
16
- if line.strip():
17
- self.entries.append(json.loads(line))
18
-
19
- def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
- for i, entry in enumerate(self.entries):
21
- folder_name = f"{i:05d}"
22
- yield entry, folder_name
23
-
24
- def _download_geneval_file(self) -> None:
25
- folder_name = Path("downloads/geneval")
26
- folder_name.mkdir(parents=True, exist_ok=True)
27
- metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
- metadata_path = folder_name / "evaluation_metadata.jsonl"
29
- if not metadata_path.exists():
30
- response = requests.get(metadata_url)
31
- with open(metadata_path, "w") as f:
32
- f.write(response.text)
33
-
34
- @property
35
- def name(self) -> str:
36
- return "geneval"
37
-
38
- @property
39
- def size(self) -> int:
40
- return len(self.entries)
41
-
42
- @property
43
- def metrics(self) -> List[str]:
44
- raise NotImplementedError("GenEval requires custom evaluation, see README.md")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/hps.py DELETED
@@ -1,56 +0,0 @@
1
- import json
2
- import os
3
- from pathlib import Path
4
- from typing import Dict, Iterator, List, Tuple
5
-
6
- import huggingface_hub
7
-
8
-
9
- class HPSPrompts:
10
- def __init__(self):
11
- super().__init__()
12
- self.hps_prompt_files = [
13
- "anime.json",
14
- "concept-art.json",
15
- "paintings.json",
16
- "photo.json",
17
- ]
18
- self._download_benchmark_prompts()
19
- self.prompts: Dict[str, str] = {}
20
- self._size = 0
21
- for file in self.hps_prompt_files:
22
- category = file.replace(".json", "")
23
- with open(os.path.join("downloads/hps", file), "r") as f:
24
- prompts = json.load(f)
25
- for i, prompt in enumerate(prompts):
26
- if i == 100:
27
- break
28
- filename = f"{category}_{i:03d}.png"
29
- self.prompts[filename] = prompt
30
- self._size += 1
31
-
32
- def __iter__(self) -> Iterator[Tuple[str, Path]]:
33
- for filename, prompt in self.prompts.items():
34
- yield prompt, Path(filename)
35
-
36
- @property
37
- def name(self) -> str:
38
- return "hps"
39
-
40
- @property
41
- def size(self) -> int:
42
- return self._size
43
-
44
- def _download_benchmark_prompts(self) -> None:
45
- folder_name = Path("downloads/hps")
46
- folder_name.mkdir(parents=True, exist_ok=True)
47
- for file in self.hps_prompt_files:
48
- file_name = huggingface_hub.hf_hub_download(
49
- "zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
50
- )
51
- if not os.path.exists(os.path.join(folder_name, file)):
52
- os.symlink(file_name, os.path.join(folder_name, file))
53
-
54
- @property
55
- def metrics(self) -> List[str]:
56
- return ["hps"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/__init__.py DELETED
@@ -1,43 +0,0 @@
1
- from typing import Type
2
-
3
- from benchmark.metrics.arniqa import ARNIQAMetric
4
- from benchmark.metrics.clip import CLIPMetric
5
- from benchmark.metrics.clip_iqa import CLIPIQAMetric
6
- from benchmark.metrics.image_reward import ImageRewardMetric
7
- from benchmark.metrics.sharpness import SharpnessMetric
8
- from benchmark.metrics.vqa import VQAMetric
9
- #from benchmark.metrics.hps import HPSMetric
10
-
11
- def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
- """
13
- Factory function to create metric instances.
14
-
15
- Args:
16
- metric_type (str): The type of metric to create. Must be one of:
17
- - "arniqa"
18
- - "clip"
19
- - "clip_iqa"
20
- - "image_reward"
21
- - "sharpness"
22
- - "vqa"
23
- - "hps"
24
- Returns:
25
- An instance of the requested metric implementation
26
-
27
- Raises:
28
- ValueError: If an invalid metric type is provided
29
- """
30
- metric_map = {
31
- "arniqa": ARNIQAMetric,
32
- "clip": CLIPMetric,
33
- "clip_iqa": CLIPIQAMetric,
34
- "image_reward": ImageRewardMetric,
35
- "sharpness": SharpnessMetric,
36
- "vqa": VQAMetric,
37
- #"hps": HPSMetric,
38
- }
39
-
40
- if metric_type not in metric_map:
41
- raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
42
-
43
- return metric_map[metric_type]()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/arniqa.py DELETED
@@ -1,36 +0,0 @@
1
- from typing import Dict
2
-
3
- import numpy as np
4
- import torch
5
- from PIL import Image
6
- from torchmetrics.image.arniqa import ARNIQA
7
-
8
-
9
- class ARNIQAMetric:
10
- def __init__(self):
11
- self.device = torch.device(
12
- "cuda"
13
- if torch.cuda.is_available()
14
- else "mps"
15
- if torch.backends.mps.is_available()
16
- else "cpu"
17
- )
18
- self.metric = ARNIQA(
19
- regressor_dataset="koniq10k",
20
- reduction="mean",
21
- normalize=True,
22
- autocast=False,
23
- )
24
- self.metric.to(self.device)
25
-
26
- @property
27
- def name(self) -> str:
28
- return "arniqa"
29
-
30
- def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
31
- image_tensor = (
32
- torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
33
- )
34
- image_tensor = image_tensor.unsqueeze(0).to(self.device)
35
- score = self.metric(image_tensor)
36
- return {"arniqa": score.item()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/clip.py DELETED
@@ -1,29 +0,0 @@
1
- from typing import Dict
2
-
3
- import numpy as np
4
- import torch
5
- from PIL import Image
6
- from torchmetrics.multimodal.clip_score import CLIPScore
7
-
8
-
9
- class CLIPMetric:
10
- def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
- self.device = torch.device(
12
- "cuda"
13
- if torch.cuda.is_available()
14
- else "mps"
15
- if torch.backends.mps.is_available()
16
- else "cpu"
17
- )
18
- self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
19
- self.metric.to(self.device)
20
-
21
- @property
22
- def name(self) -> str:
23
- return "clip"
24
-
25
- def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
26
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
27
- image_tensor = image_tensor.to(self.device)
28
- score = self.metric(image_tensor, prompt)
29
- return {"clip": score.item()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/clip_iqa.py DELETED
@@ -1,32 +0,0 @@
1
- from typing import Dict
2
-
3
- import numpy as np
4
- import torch
5
- from PIL import Image
6
- from torchmetrics.multimodal import CLIPImageQualityAssessment
7
-
8
-
9
- class CLIPIQAMetric:
10
- def __init__(self):
11
- self.device = torch.device(
12
- "cuda"
13
- if torch.cuda.is_available()
14
- else "mps"
15
- if torch.backends.mps.is_available()
16
- else "cpu"
17
- )
18
- self.metric = CLIPImageQualityAssessment(
19
- model_name_or_path="clip_iqa", data_range=255.0, prompts=("quality",)
20
- )
21
- self.metric.to(self.device)
22
-
23
- @property
24
- def name(self) -> str:
25
- return "clip_iqa"
26
-
27
- def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
28
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
29
- image_tensor = image_tensor.unsqueeze(0)
30
- image_tensor = image_tensor.to(self.device)
31
- scores = self.metric(image_tensor)
32
- return {"clip_iqa": scores.item()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/hps.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- from typing import Dict
3
-
4
- import huggingface_hub
5
- import torch
6
- from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
- from hpsv2.utils import hps_version_map, root_path
8
- from PIL import Image
9
-
10
-
11
- class HPSMetric:
12
- def __init__(self):
13
- self.hps_version = "v2.1"
14
- self.device = torch.device(
15
- "cuda"
16
- if torch.cuda.is_available()
17
- else "mps"
18
- if torch.backends.mps.is_available()
19
- else "cpu"
20
- )
21
- self.model_dict = {}
22
- self._initialize_model()
23
-
24
- def _initialize_model(self):
25
- if not self.model_dict:
26
- model, preprocess_train, preprocess_val = create_model_and_transforms(
27
- "ViT-H-14",
28
- "laion2B-s32B-b79K",
29
- precision="amp",
30
- device=self.device,
31
- jit=False,
32
- force_quick_gelu=False,
33
- force_custom_text=False,
34
- force_patch_dropout=False,
35
- force_image_size=None,
36
- pretrained_image=False,
37
- image_mean=None,
38
- image_std=None,
39
- light_augmentation=True,
40
- aug_cfg={},
41
- output_dict=True,
42
- with_score_predictor=False,
43
- with_region_predictor=False,
44
- )
45
- self.model_dict["model"] = model
46
- self.model_dict["preprocess_val"] = preprocess_val
47
-
48
- # Load checkpoint
49
- if not os.path.exists(root_path):
50
- os.makedirs(root_path)
51
- cp = huggingface_hub.hf_hub_download(
52
- "xswu/HPSv2", hps_version_map[self.hps_version]
53
- )
54
-
55
- checkpoint = torch.load(cp, map_location=self.device)
56
- model.load_state_dict(checkpoint["state_dict"])
57
- self.tokenizer = get_tokenizer("ViT-H-14")
58
- model = model.to(self.device)
59
- model.eval()
60
-
61
- @property
62
- def name(self) -> str:
63
- return "hps"
64
-
65
- def compute_score(
66
- self,
67
- image: Image.Image,
68
- prompt: str,
69
- ) -> Dict[str, float]:
70
- model = self.model_dict["model"]
71
- preprocess_val = self.model_dict["preprocess_val"]
72
-
73
- with torch.no_grad():
74
- # Process the image
75
- image_tensor = (
76
- preprocess_val(image)
77
- .unsqueeze(0)
78
- .to(device=self.device, non_blocking=True)
79
- )
80
- # Process the prompt
81
- text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
82
- # Calculate the HPS
83
- with torch.cuda.amp.autocast():
84
- outputs = model(image_tensor, text)
85
- image_features, text_features = (
86
- outputs["image_features"],
87
- outputs["text_features"],
88
- )
89
- logits_per_image = image_features @ text_features.T
90
- hps_score = torch.diagonal(logits_per_image).cpu().numpy()
91
-
92
- return {"hps": float(hps_score[0])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/image_reward.py DELETED
@@ -1,35 +0,0 @@
1
- import os
2
- import tempfile
3
- from typing import Dict
4
-
5
- import ImageReward as RM
6
- import torch
7
- from PIL import Image
8
-
9
-
10
- class ImageRewardMetric:
11
- def __init__(self):
12
- self.device = torch.device(
13
- "cuda"
14
- if torch.cuda.is_available()
15
- else "mps"
16
- if torch.backends.mps.is_available()
17
- else "cpu"
18
- )
19
-
20
- self.model = RM.load("ImageReward-v1.0", device=str(self.device))
21
-
22
- @property
23
- def name(self) -> str:
24
- return "image_reward"
25
-
26
- def compute_score(
27
- self,
28
- image: Image.Image,
29
- prompt: str,
30
- ) -> Dict[str, float]:
31
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
32
- image.save(tmp.name)
33
- score = self.model.score(prompt, [tmp.name])
34
- os.unlink(tmp.name)
35
- return {"image_reward": score}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/sharpness.py DELETED
@@ -1,24 +0,0 @@
1
- from typing import Dict
2
-
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
-
7
-
8
- class SharpnessMetric:
9
- def __init__(self):
10
- self.kernel_size = 3
11
-
12
- @property
13
- def name(self) -> str:
14
- return "sharpness"
15
-
16
- def compute_score(
17
- self,
18
- image: Image.Image,
19
- prompt: str,
20
- ) -> Dict[str, float]:
21
- img = np.array(image.convert('L'))
22
- laplacian = cv2.Laplacian(img, cv2.CV_64F, ksize=self.kernel_size)
23
- sharpness = laplacian.var()
24
- return {"sharpness": float(sharpness)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics/vqa.py DELETED
@@ -1,31 +0,0 @@
1
- from pathlib import Path
2
- from typing import Dict
3
-
4
- import t2v_metrics
5
- import torch
6
-
7
-
8
- class VQAMetric:
9
- def __init__(self):
10
- self.device = torch.device(
11
- "cuda"
12
- if torch.cuda.is_available()
13
- else "mps"
14
- if torch.backends.mps.is_available()
15
- else "cpu"
16
- )
17
- self.metric = t2v_metrics.VQAScore(
18
- model="clip-flant5-xxl", device=str(self.device)
19
- )
20
-
21
- @property
22
- def name(self) -> str:
23
- return "vqa_score"
24
-
25
- def compute_score(
26
- self,
27
- image_path: Path,
28
- prompt: str,
29
- ) -> Dict[str, float]:
30
- score = self.metric(images=[str(image_path)], texts=[prompt])
31
- return {"vqa": score[0][0].item()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/parti.py DELETED
@@ -1,28 +0,0 @@
1
- from pathlib import Path
2
- from typing import Iterator, List, Tuple
3
-
4
- from datasets import load_dataset
5
-
6
-
7
- class PartiPrompts:
8
- def __init__(self):
9
- dataset = load_dataset("nateraw/parti-prompts")["train"]
10
- shuffled_dataset = dataset.shuffle(seed=42)
11
- selected_dataset = shuffled_dataset.select(range(800))
12
- self.prompts = [row["Prompt"] for row in selected_dataset]
13
-
14
- def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
- for i, prompt in enumerate(self.prompts):
16
- yield prompt, Path(f"{i}.png")
17
-
18
- @property
19
- def name(self) -> str:
20
- return "parti"
21
-
22
- @property
23
- def size(self) -> int:
24
- return len(self.prompts)
25
-
26
- @property
27
- def metrics(self) -> List[str]:
28
- return ["arniqa", "clip", "clip_iqa", "sharpness"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dashboard/requirements.txt DELETED
@@ -1,3 +0,0 @@
1
-
2
- -e ..
3
-
 
 
 
 
{dashboard/data → data}/text_to_image.jsonl RENAMED
File without changes
evaluate.py DELETED
@@ -1,124 +0,0 @@
1
- import argparse
2
- import json
3
- import warnings
4
- from pathlib import Path
5
- from typing import Dict
6
-
7
- import numpy as np
8
- from PIL import Image
9
- from tqdm import tqdm
10
-
11
- from benchmark import create_benchmark
12
- from benchmark.metrics import create_metric
13
-
14
- warnings.filterwarnings("ignore", category=FutureWarning)
15
-
16
-
17
- def evaluate_benchmark(
18
- benchmark_type: str, api_type: str, images_dir: Path = Path("images")
19
- ) -> Dict:
20
- """
21
- Evaluate a benchmark's images using its specific metrics.
22
-
23
- Args:
24
- benchmark_type (str): Type of benchmark to evaluate
25
- api_type (str): Type of API used to generate images
26
- images_dir (Path): Base directory containing generated images
27
-
28
- Returns:
29
- Dict containing evaluation results
30
- """
31
- benchmark = create_benchmark(benchmark_type)
32
-
33
- benchmark_dir = images_dir / api_type / benchmark_type
34
- metadata_file = benchmark_dir / "metadata.jsonl"
35
-
36
- if not metadata_file.exists():
37
- raise FileNotFoundError(
38
- f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
39
- )
40
-
41
- metadata = []
42
- with open(metadata_file, "r") as f:
43
- for line in f:
44
- metadata.append(json.loads(line))
45
-
46
- metrics = {
47
- metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
48
- }
49
-
50
- results = {
51
- "api": api_type,
52
- "benchmark": benchmark_type,
53
- "metrics": {metric: 0.0 for metric in benchmark.metrics},
54
- "total_images": len(metadata),
55
- }
56
- inference_times = []
57
-
58
- for entry in tqdm(metadata):
59
- image_path = benchmark_dir / entry["filepath"]
60
- if not image_path.exists():
61
- continue
62
-
63
- for metric_type, metric in metrics.items():
64
- try:
65
- if metric_type == "vqa":
66
- score = metric.compute_score(image_path, entry["prompt"])
67
- else:
68
- image = Image.open(image_path)
69
- score = metric.compute_score(image, entry["prompt"])
70
- results["metrics"][metric_type] += score[metric_type]
71
- except Exception as e:
72
- print(f"Error computing {metric_type} for {image_path}: {str(e)}")
73
-
74
- inference_times.append(entry["inference_time"])
75
-
76
- for metric in results["metrics"]:
77
- results["metrics"][metric] /= len(metadata)
78
- results["median_inference_time"] = np.median(inference_times).item()
79
-
80
- return results
81
-
82
-
83
- def main():
84
- parser = argparse.ArgumentParser(
85
- description="Evaluate generated images using benchmark-specific metrics"
86
- )
87
- parser.add_argument("api_type", help="Type of API to evaluate")
88
- parser.add_argument(
89
- "benchmarks", nargs="+", help="List of benchmark types to evaluate"
90
- )
91
-
92
- args = parser.parse_args()
93
-
94
- results_dir = Path("evaluation_results")
95
- results_dir.mkdir(exist_ok=True)
96
-
97
- results_file = results_dir / f"{args.api_type}.jsonl"
98
- existing_results = set()
99
-
100
- if results_file.exists():
101
- with open(results_file, "r") as f:
102
- for line in f:
103
- result = json.loads(line)
104
- existing_results.add(result["benchmark"])
105
-
106
- for benchmark_type in args.benchmarks:
107
- if benchmark_type in existing_results:
108
- print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
109
- continue
110
-
111
- try:
112
- print(f"Evaluating {args.api_type}/{benchmark_type}")
113
- results = evaluate_benchmark(benchmark_type, args.api_type)
114
-
115
- # Append results to file
116
- with open(results_file, "a") as f:
117
- f.write(json.dumps(results) + "\n")
118
-
119
- except Exception as e:
120
- print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
121
-
122
-
123
- if __name__ == "__main__":
124
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation_results/.gitkeep DELETED
File without changes
images/.gitkeep DELETED
File without changes
pyproject.toml DELETED
@@ -1,43 +0,0 @@
1
- [project]
2
- name = "inferbench"
3
- version = "0.1.0"
4
- description = "InferBench - AI inference benchmarking tool"
5
- requires-python = ">=3.12"
6
- dependencies = [
7
- "numpy",
8
- "opencv-python",
9
- "pillow",
10
- "python-dotenv",
11
- "requests",
12
- "tqdm",
13
- "datasets==3.6.0",
14
- "fal-client>=0.5.9",
15
- "hpsv2>=1.2.0",
16
- "huggingface-hub>=0.30.2",
17
- "image-reward>=1.5",
18
- "replicate>=1.0.4",
19
- "t2v-metrics>=1.2",
20
- "together>=1.5.5",
21
- "torch>=2.7.0",
22
- "torchmetrics>=1.7.1",
23
- "clip",
24
- "diffusers<=0.31",
25
- "piq>=0.8.0",
26
- "boto3>=1.39.4",
27
- "gradio>=5.37.0",
28
- "gradio-leaderboard>=0.0.14",
29
- ]
30
-
31
- [build-system]
32
- requires = ["setuptools>=61.0", "wheel"]
33
- build-backend = "setuptools.build_meta"
34
-
35
- [tool.setuptools.packages.find]
36
- include = ["api*", "benchmark*", "dashboard*"]
37
- exclude = ["images*", "evaluation_results*", "*.tests*"]
38
-
39
- [tool.hatch.build.targets.wheel]
40
- packages = ["api", "benchmark", "dashboard"]
41
-
42
- [tool.uv]
43
- dev-dependencies = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample.py DELETED
@@ -1,125 +0,0 @@
1
- import argparse
2
- import json
3
- from pathlib import Path
4
- from typing import List
5
-
6
- from tqdm import tqdm
7
-
8
- from api import create_api
9
- from benchmark import create_benchmark
10
-
11
-
12
- def generate_images(api_type: str, benchmarks: List[str]):
13
- images_dir = Path("images")
14
- api = create_api(api_type)
15
-
16
- api_dir = images_dir / api_type
17
- api_dir.mkdir(parents=True, exist_ok=True)
18
-
19
- for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
- print(f"\nProcessing benchmark: {benchmark_type}")
21
-
22
- benchmark = create_benchmark(benchmark_type)
23
-
24
- if benchmark_type == "geneval":
25
- benchmark_dir = api_dir / benchmark_type
26
- benchmark_dir.mkdir(parents=True, exist_ok=True)
27
-
28
- metadata_file = benchmark_dir / "metadata.jsonl"
29
- existing_metadata = {}
30
- if metadata_file.exists():
31
- with open(metadata_file, "r") as f:
32
- for line in f:
33
- entry = json.loads(line)
34
- existing_metadata[entry["filepath"]] = entry
35
-
36
- with open(metadata_file, "a") as f:
37
- for metadata, folder_name in tqdm(
38
- benchmark,
39
- desc=f"Generating images for {benchmark_type}",
40
- leave=False,
41
- ):
42
- sample_path = benchmark_dir / folder_name
43
- samples_path = sample_path / "samples"
44
- samples_path.mkdir(parents=True, exist_ok=True)
45
- image_path = samples_path / "0000.png"
46
-
47
- if image_path.exists():
48
- continue
49
-
50
- try:
51
- inference_time = api.generate_image(
52
- metadata["prompt"], image_path
53
- )
54
-
55
- metadata_entry = {
56
- "filepath": str(image_path),
57
- "prompt": metadata["prompt"],
58
- "inference_time": inference_time,
59
- }
60
-
61
- f.write(json.dumps(metadata_entry) + "\n")
62
-
63
- except Exception as e:
64
- print(
65
- f"\nError generating image for prompt: {metadata['prompt']}"
66
- )
67
- print(f"Error: {str(e)}")
68
- continue
69
- else:
70
- benchmark_dir = api_dir / benchmark_type
71
- benchmark_dir.mkdir(parents=True, exist_ok=True)
72
-
73
- metadata_file = benchmark_dir / "metadata.jsonl"
74
- existing_metadata = {}
75
- if metadata_file.exists():
76
- with open(metadata_file, "r") as f:
77
- for line in f:
78
- entry = json.loads(line)
79
- existing_metadata[entry["filepath"]] = entry
80
-
81
- with open(metadata_file, "a") as f:
82
- for prompt, image_path in tqdm(
83
- benchmark,
84
- desc=f"Generating images for {benchmark_type}",
85
- leave=False,
86
- ):
87
- if image_path in existing_metadata:
88
- continue
89
-
90
- full_image_path = benchmark_dir / image_path
91
-
92
- if full_image_path.exists():
93
- continue
94
-
95
- try:
96
- inference_time = api.generate_image(prompt, full_image_path)
97
-
98
- metadata_entry = {
99
- "filepath": str(image_path),
100
- "prompt": prompt,
101
- "inference_time": inference_time,
102
- }
103
-
104
- f.write(json.dumps(metadata_entry) + "\n")
105
-
106
- except Exception as e:
107
- print(f"\nError generating image for prompt: {prompt}")
108
- print(f"Error: {str(e)}")
109
- continue
110
-
111
-
112
- def main():
113
- parser = argparse.ArgumentParser(
114
- description="Generate images for specified benchmarks using a given API"
115
- )
116
- parser.add_argument("api_type", help="Type of API to use for image generation")
117
- parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
118
-
119
- args = parser.parse_args()
120
-
121
- generate_images(args.api_type, args.benchmarks)
122
-
123
-
124
- if __name__ == "__main__":
125
- main()