OminiControlRotation / evaluation_subject_driven.py
nvan15's picture
Batch upload part 1
b03742a verified
import openai
import base64
from pathlib import Path
import random
import os
evaluation_prompts = {
"identity": """
Compare the original subject image with the generated image.
Rate on a scale of 1-5 how well the essential identifying features
are preserved (logos, brand marks, distinctive patterns).
Score: [1-5]
Reasoning: [explanation]
""",
"material": """
Evaluate the material quality and surface characteristics.
Rate on a scale of 1-5 how accurately materials are represented
(textures, reflections, surface properties).
Score: [1-5]
Reasoning: [explanation]
""",
"color": """
Assess color fidelity in regions NOT specified for modification.
Rate on a scale of 1-5 how consistent colors remain.
Score: [1-5]
Reasoning: [explanation]
""",
"appearance": """
Evaluate the overall realism and coherence of the generated image.
Rate on a scale of 1-5 how realistic and natural it appears.
Score: [1-5]
Reasoning: [explanation]
""",
"modification": """
Given the text prompt: "{prompt}"
Rate on a scale of 1-5 how well the specified changes are executed.
Score: [1-5]
Reasoning: [explanation]
"""
}
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def evaluate_subject_driven_generation(
original_image_path,
generated_image_path,
text_prompt,
client
):
"""
Evaluate a subject-driven generation using GPT-4o vision
"""
# Encode images
original_img = encode_image(original_image_path)
generated_img = encode_image(generated_image_path)
results = {}
# 1. Identity Preservation
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "Original subject image:"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
{"type": "text", "text": "Generated image:"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
{"type": "text", "text": evaluation_prompts["identity"]}
]
}],
max_tokens=300
)
results['identity'] = parse_score(response.choices[0].message.content)
# 2. Material Quality
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "Evaluate this generated image:"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
{"type": "text", "text": evaluation_prompts["material"]}
]
}],
max_tokens=300
)
results['material'] = parse_score(response.choices[0].message.content)
# 3. Color Fidelity
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "Original:"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
{"type": "text", "text": "Generated:"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
{"type": "text", "text": evaluation_prompts["color"]}
]
}],
max_tokens=300
)
results['color'] = parse_score(response.choices[0].message.content)
# 4. Natural Appearance
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
{"type": "text", "text": evaluation_prompts["appearance"]}
]
}],
max_tokens=300
)
results['appearance'] = parse_score(response.choices[0].message.content)
# 5. Modification Accuracy
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": f"Text prompt: {text_prompt}"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
{"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)}
]
}],
max_tokens=300
)
results['modification'] = parse_score(response.choices[0].message.content)
return results
def parse_score(response_text):
"""Extract score from GPT-4o response"""
# Look for "Score: X" pattern
import re
match = re.search(r'Score:\s*(\d+)', response_text)
if match:
return int(match.group(1))
return None
subject_names = [
"backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
"candle", "cat", "cat2", "clock", "colorful_sneaker",
"dog", "dog2", "dog3", "dog5", "dog6",
"dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
"monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
"robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie"
]
def get_prompt(subject_id, prompt_id):
# subject in 0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29
if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]:
subject_name = subject_names[subject_id]
prompts = [
f"a {subject_name} in the jungle",
f"a {subject_name} in the snow",
f"a {subject_name} on the beach",
f"a {subject_name} on a cobblestone street",
f"a {subject_name} on top of pink fabric",
f"a {subject_name} on top of a wooden floor",
f"a {subject_name} with a city in the background",
f"a {subject_name} with a mountain in the background",
f"a {subject_name} with a blue house in the background",
f"a {subject_name} on top of a purple rug in a forest",
f"a {subject_name} with a wheat field in the background",
f"a {subject_name} with a tree and autumn leaves in the background",
f"a {subject_name} with the Eiffel Tower in the background",
f"a {subject_name} floating on top of water",
f"a {subject_name} floating in an ocean of milk",
f"a {subject_name} on top of green grass with sunflowers around it",
f"a {subject_name} on top of a mirror",
f"a {subject_name} on top of the sidewalk in a crowded street",
f"a {subject_name} on top of a dirt road",
f"a {subject_name} on top of a white rug",
f"a red {subject_name}",
f"a purple {subject_name}",
f"a shiny {subject_name}",
f"a wet {subject_name}",
f"a cube shaped {subject_name}"
]
else:
prompts = [
f"a {subject_name} in the jungle",
f"a {subject_name} in the snow",
f"a {subject_name} on the beach",
f"a {subject_name} on a cobblestone street",
f"a {subject_name} on top of pink fabric",
f"a {subject_name} on top of a wooden floor",
f"a {subject_name} with a city in the background",
f"a {subject_name} with a mountain in the background",
f"a {subject_name} with a blue house in the background",
f"a {subject_name} on top of a purple rug in a forest",
f"a {subject_name} wearing a red hat",
f"a {subject_name} wearing a santa hat",
f"a {subject_name} wearing a rainbow scarf",
f"a {subject_name} wearing a black top hat and a monocle",
f"a {subject_name} in a chef outfit",
f"a {subject_name} in a firefighter outfit",
f"a {subject_name} in a police outfit",
f"a {subject_name} wearing pink glasses",
f"a {subject_name} wearing a yellow shirt",
f"a {subject_name} in a purple wizard outfit",
f"a red {subject_name}",
f"a purple {subject_name}",
f"a shiny {subject_name}",
f"a wet {subject_name}",
f"a cube shaped {subject_name}"
]
return prompts[prompt_id]
def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv):
"""
Evaluate 750 image pairs with 5 seeds each
"""
import pandas as pd
results_list = []
# Iterate through DreamBooth dataset
for subject_id in range(30): # 30 subjects
subject_name = subject_names[subject_id]
for prompt_id in range(25): # 25 prompts per subject
original = f"{dataset_path}/{subject_name}"
# get a random file in this folder
original_files = list(Path(original).glob("*.png"))
if len(original_files) == 0:
raise ValueError(f"No original images found in {original}")
original = str(original_files[0])
for seed in range(5): # 5 different seeds
# take random file in the folder
prompt = get_prompt(subject_id, prompt_id)
# generated image path
generated_folder = f"{dataset_path}/{subject_name}/generated/"
os.makedirs(generated_folder, exist_ok=True)
generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png"
generate_fn(
prompt=prompt,
subject_image_path=original,
output_image_path=generated,
seed=seed
)
scores = evaluate_subject_driven_generation(
original, generated, prompt, client
)
results_list.append({
'subject_id': subject_id,
'subject_name': subject_name,
'prompt_id': prompt_id,
'seed': seed,
'prompt': prompt,
**scores
})
# Save results
df = pd.DataFrame(results_list)
df.to_csv(output_csv, index=False)
# Calculate statistics
print(df.groupby('subject_id').mean())
print(f"\nOverall averages:")
print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean())
def evaluate_omini_control():
import torch
from diffusers.pipelines import FluxPipeline
from PIL import Image
from omini.pipeline.flux_omini import Condition, generate, seed_everything
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.load_lora_weights(
"Yuanshi/OminiControl",
weight_name=f"omini/subject_512.safetensors",
adapter_name="subject",
)
def generate_fn(image_path, prompt, seed, output_path):
seed_everything(seed)
image = Image.open(image_path).convert("RGB").resize((512, 512))
condition = Condition.from_image(
image,
"subject", position_delta=(0, 32)
)
result_img = generate(
pipe,
prompt=prompt,
conditions=[condition],
).images[0]
result_img.save(output_path)
return generate_fn
if __name__ == "__main__":
openai.api_key = os.getenv("OPENAI_API_KEY")
# client = openai.Client()
# generate_fn = evaluate_omini_control()
# dataset_path = "data/dreambooth"
# output_csv = "evaluation_subject_driven_omini_control.csv"
# batch_evaluate_dreambooth(
# client,
# generate_fn,
# dataset_path,
# output_csv
# )
result = evaluate_subject_driven_generation(
"data/dreambooth/backpack/00.jpg",
"data/dreambooth/backpack/01.jpg",
"a backpack in the jungle",
openai.Client()
)
print(result)