|
|
import openai |
|
|
import base64 |
|
|
from pathlib import Path |
|
|
import random |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
evaluation_prompts = { |
|
|
"identity": """ |
|
|
Compare the original subject image with the generated image. |
|
|
Rate on a scale of 1-5 how well the essential identifying features |
|
|
are preserved (logos, brand marks, distinctive patterns). |
|
|
Score: [1-5] |
|
|
Reasoning: [explanation] |
|
|
""", |
|
|
|
|
|
"material": """ |
|
|
Evaluate the material quality and surface characteristics. |
|
|
Rate on a scale of 1-5 how accurately materials are represented |
|
|
(textures, reflections, surface properties). |
|
|
Score: [1-5] |
|
|
Reasoning: [explanation] |
|
|
""", |
|
|
|
|
|
"color": """ |
|
|
Assess color fidelity in regions NOT specified for modification. |
|
|
Rate on a scale of 1-5 how consistent colors remain. |
|
|
Score: [1-5] |
|
|
Reasoning: [explanation] |
|
|
""", |
|
|
|
|
|
"appearance": """ |
|
|
Evaluate the overall realism and coherence of the generated image. |
|
|
Rate on a scale of 1-5 how realistic and natural it appears. |
|
|
Score: [1-5] |
|
|
Reasoning: [explanation] |
|
|
""", |
|
|
|
|
|
"modification": """ |
|
|
Given the text prompt: "{prompt}" |
|
|
Rate on a scale of 1-5 how well the specified changes are executed. |
|
|
Score: [1-5] |
|
|
Reasoning: [explanation] |
|
|
""" |
|
|
} |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
def evaluate_subject_driven_generation( |
|
|
original_image_path, |
|
|
generated_image_path, |
|
|
text_prompt, |
|
|
client |
|
|
): |
|
|
""" |
|
|
Evaluate a subject-driven generation using GPT-4o vision |
|
|
""" |
|
|
|
|
|
|
|
|
original_img = encode_image(original_image_path) |
|
|
generated_img = encode_image(generated_image_path) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "Original subject image:"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
|
|
{"type": "text", "text": "Generated image:"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
|
|
{"type": "text", "text": evaluation_prompts["identity"]} |
|
|
] |
|
|
}], |
|
|
max_tokens=300 |
|
|
) |
|
|
results['identity'] = parse_score(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "Evaluate this generated image:"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
|
|
{"type": "text", "text": evaluation_prompts["material"]} |
|
|
] |
|
|
}], |
|
|
max_tokens=300 |
|
|
) |
|
|
results['material'] = parse_score(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "Original:"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
|
|
{"type": "text", "text": "Generated:"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
|
|
{"type": "text", "text": evaluation_prompts["color"]} |
|
|
] |
|
|
}], |
|
|
max_tokens=300 |
|
|
) |
|
|
results['color'] = parse_score(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
|
|
{"type": "text", "text": evaluation_prompts["appearance"]} |
|
|
] |
|
|
}], |
|
|
max_tokens=300 |
|
|
) |
|
|
results['appearance'] = parse_score(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": f"Text prompt: {text_prompt}"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
|
|
{"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)} |
|
|
] |
|
|
}], |
|
|
max_tokens=300 |
|
|
) |
|
|
results['modification'] = parse_score(response.choices[0].message.content) |
|
|
|
|
|
return results |
|
|
|
|
|
def parse_score(response_text): |
|
|
"""Extract score from GPT-4o response""" |
|
|
|
|
|
import re |
|
|
match = re.search(r'Score:\s*(\d+)', response_text) |
|
|
if match: |
|
|
return int(match.group(1)) |
|
|
return None |
|
|
|
|
|
subject_names = [ |
|
|
"backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can", |
|
|
"candle", "cat", "cat2", "clock", "colorful_sneaker", |
|
|
"dog", "dog2", "dog3", "dog5", "dog6", |
|
|
"dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie", |
|
|
"monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", |
|
|
"robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie" |
|
|
] |
|
|
|
|
|
|
|
|
def get_prompt(subject_id, prompt_id): |
|
|
|
|
|
|
|
|
if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]: |
|
|
subject_name = subject_names[subject_id] |
|
|
prompts = [ |
|
|
f"a {subject_name} in the jungle", |
|
|
f"a {subject_name} in the snow", |
|
|
f"a {subject_name} on the beach", |
|
|
f"a {subject_name} on a cobblestone street", |
|
|
f"a {subject_name} on top of pink fabric", |
|
|
f"a {subject_name} on top of a wooden floor", |
|
|
f"a {subject_name} with a city in the background", |
|
|
f"a {subject_name} with a mountain in the background", |
|
|
f"a {subject_name} with a blue house in the background", |
|
|
f"a {subject_name} on top of a purple rug in a forest", |
|
|
f"a {subject_name} with a wheat field in the background", |
|
|
f"a {subject_name} with a tree and autumn leaves in the background", |
|
|
f"a {subject_name} with the Eiffel Tower in the background", |
|
|
f"a {subject_name} floating on top of water", |
|
|
f"a {subject_name} floating in an ocean of milk", |
|
|
f"a {subject_name} on top of green grass with sunflowers around it", |
|
|
f"a {subject_name} on top of a mirror", |
|
|
f"a {subject_name} on top of the sidewalk in a crowded street", |
|
|
f"a {subject_name} on top of a dirt road", |
|
|
f"a {subject_name} on top of a white rug", |
|
|
f"a red {subject_name}", |
|
|
f"a purple {subject_name}", |
|
|
f"a shiny {subject_name}", |
|
|
f"a wet {subject_name}", |
|
|
f"a cube shaped {subject_name}" |
|
|
] |
|
|
|
|
|
else: |
|
|
prompts = [ |
|
|
f"a {subject_name} in the jungle", |
|
|
f"a {subject_name} in the snow", |
|
|
f"a {subject_name} on the beach", |
|
|
f"a {subject_name} on a cobblestone street", |
|
|
f"a {subject_name} on top of pink fabric", |
|
|
f"a {subject_name} on top of a wooden floor", |
|
|
f"a {subject_name} with a city in the background", |
|
|
f"a {subject_name} with a mountain in the background", |
|
|
f"a {subject_name} with a blue house in the background", |
|
|
f"a {subject_name} on top of a purple rug in a forest", |
|
|
f"a {subject_name} wearing a red hat", |
|
|
f"a {subject_name} wearing a santa hat", |
|
|
f"a {subject_name} wearing a rainbow scarf", |
|
|
f"a {subject_name} wearing a black top hat and a monocle", |
|
|
f"a {subject_name} in a chef outfit", |
|
|
f"a {subject_name} in a firefighter outfit", |
|
|
f"a {subject_name} in a police outfit", |
|
|
f"a {subject_name} wearing pink glasses", |
|
|
f"a {subject_name} wearing a yellow shirt", |
|
|
f"a {subject_name} in a purple wizard outfit", |
|
|
f"a red {subject_name}", |
|
|
f"a purple {subject_name}", |
|
|
f"a shiny {subject_name}", |
|
|
f"a wet {subject_name}", |
|
|
f"a cube shaped {subject_name}" |
|
|
] |
|
|
|
|
|
return prompts[prompt_id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv): |
|
|
""" |
|
|
Evaluate 750 image pairs with 5 seeds each |
|
|
""" |
|
|
import pandas as pd |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
|
|
|
for subject_id in range(30): |
|
|
subject_name = subject_names[subject_id] |
|
|
for prompt_id in range(25): |
|
|
original = f"{dataset_path}/{subject_name}" |
|
|
|
|
|
original_files = list(Path(original).glob("*.png")) |
|
|
if len(original_files) == 0: |
|
|
raise ValueError(f"No original images found in {original}") |
|
|
|
|
|
original = str(original_files[0]) |
|
|
|
|
|
|
|
|
for seed in range(5): |
|
|
|
|
|
prompt = get_prompt(subject_id, prompt_id) |
|
|
|
|
|
|
|
|
generated_folder = f"{dataset_path}/{subject_name}/generated/" |
|
|
os.makedirs(generated_folder, exist_ok=True) |
|
|
generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png" |
|
|
|
|
|
generate_fn( |
|
|
prompt=prompt, |
|
|
subject_image_path=original, |
|
|
output_image_path=generated, |
|
|
seed=seed |
|
|
) |
|
|
|
|
|
scores = evaluate_subject_driven_generation( |
|
|
original, generated, prompt, client |
|
|
) |
|
|
|
|
|
results_list.append({ |
|
|
'subject_id': subject_id, |
|
|
'subject_name': subject_name, |
|
|
'prompt_id': prompt_id, |
|
|
'seed': seed, |
|
|
'prompt': prompt, |
|
|
|
|
|
**scores |
|
|
}) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(results_list) |
|
|
df.to_csv(output_csv, index=False) |
|
|
|
|
|
|
|
|
print(df.groupby('subject_id').mean()) |
|
|
print(f"\nOverall averages:") |
|
|
print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean()) |
|
|
|
|
|
|
|
|
def evaluate_omini_control(): |
|
|
|
|
|
import torch |
|
|
from diffusers.pipelines import FluxPipeline |
|
|
from PIL import Image |
|
|
|
|
|
from omini.pipeline.flux_omini import Condition, generate, seed_everything |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
pipe = pipe.to("cuda") |
|
|
pipe.load_lora_weights( |
|
|
"Yuanshi/OminiControl", |
|
|
weight_name=f"omini/subject_512.safetensors", |
|
|
adapter_name="subject", |
|
|
) |
|
|
|
|
|
def generate_fn(image_path, prompt, seed, output_path): |
|
|
seed_everything(seed) |
|
|
|
|
|
image = Image.open(image_path).convert("RGB").resize((512, 512)) |
|
|
condition = Condition.from_image( |
|
|
image, |
|
|
"subject", position_delta=(0, 32) |
|
|
) |
|
|
|
|
|
result_img = generate( |
|
|
pipe, |
|
|
prompt=prompt, |
|
|
conditions=[condition], |
|
|
).images[0] |
|
|
|
|
|
result_img.save(output_path) |
|
|
|
|
|
return generate_fn |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = evaluate_subject_driven_generation( |
|
|
"data/dreambooth/backpack/00.jpg", |
|
|
"data/dreambooth/backpack/01.jpg", |
|
|
"a backpack in the jungle", |
|
|
openai.Client() |
|
|
) |
|
|
|
|
|
print(result) |