File size: 6,287 Bytes
9e3bd6c
d3a3bf1
69a2277
f29389b
433cace
 
acc4043
 
 
 
 
69a2277
f29389b
acc4043
f29389b
 
69a2277
acc4043
69a2277
f29389b
acc4043
433cace
 
 
8cafce9
 
9362fe6
8cafce9
acc4043
433cace
 
acc4043
433cace
 
 
 
 
 
9362fe6
 
433cace
 
 
 
 
 
 
9362fe6
f29389b
 
69a2277
9362fe6
69a2277
054edad
 
69a2277
 
 
 
 
433cace
69a2277
433cace
9362fe6
 
 
 
 
 
 
433cace
69a2277
 
 
f29389b
433cace
69a2277
 
 
 
acc4043
 
 
9362fe6
acc4043
 
9362fe6
 
 
 
 
 
acc4043
 
69a2277
 
 
 
054edad
 
 
69a2277
 
 
 
f29389b
9362fe6
f29389b
054edad
 
9362fe6
054edad
9362fe6
 
 
054edad
69a2277
054edad
 
 
 
9362fe6
054edad
 
 
 
 
 
9362fe6
054edad
 
 
 
 
 
 
9362fe6
054edad
 
8d8a928
 
9362fe6
8d8a928
acc4043
 
 
 
 
8d8a928
 
acc4043
 
8d8a928
 
 
acc4043
8d8a928
 
acc4043
 
8d8a928
9362fe6
8d8a928
acc4043
 
045059f
9362fe6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import json
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image
import torch
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    CLIPTokenizer
)

# ----------------------------
# πŸ” Load API Keys & Setup
# ----------------------------
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"

# ----------------------------
# πŸ“Έ Load BLIP Captioning Model
# ----------------------------
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

# ----------------------------
# 🧠 Load CLIP Tokenizer (for token check)
# ----------------------------
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# ----------------------------
# πŸ“Έ Generate Caption from Product Image
# ----------------------------
def generate_blip_caption(image: Image.Image) -> str:
    try:
        inputs = processor(images=image, return_tensors="pt").to(device)
        out = blip_model.generate(**inputs, max_length=50)
        caption = processor.decode(out[0], skip_special_tokens=True)
        # Clean duplicate tokens
        caption = " ".join(dict.fromkeys(caption.split()))
        print(f"πŸ–ΌοΈ BLIP Caption: {caption}")
        return caption
    except Exception as e:
        print("❌ BLIP Captioning Error:", e)
        return "a product image"

# ----------------------------
# 🧠 GPT Scene Planning with Caption + Visual Style
# ----------------------------
SCENE_SYSTEM_INSTRUCTIONS = """
You are a scene planning assistant for an AI image generation system.
Your job is to take a caption from a product image, a visual style hint, and a user prompt, then return a structured JSON with:
- scene (environment, setting)
- subject (main_actor)
- objects (main_product or items)
- layout (foreground/background elements and their placement)
- rules (validation rules to ensure visual correctness)
Respond ONLY in raw JSON format. Do NOT include explanations.
"""

def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
    try:
        caption = generate_blip_caption(image)
        visual_hint = caption if "shoe" in caption or "product" in caption else "low-top product photo on white background"
        
        merged_prompt = (
            f"Image Caption: {caption}\n"
            f"Image Visual Style: {visual_hint}\n"
            f"User Prompt: {prompt}"
        )

        response = client.chat.completions.create(
            model="gpt-4o-mini-2024-07-18",
            messages=[
                {"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
                {"role": "user", "content": merged_prompt}
            ],
            temperature=0.3,
            max_tokens=500
        )
        content = response.choices[0].message.content
        print("🧠 Scene Plan (Raw):", content)

        # Logging
        os.makedirs("logs", exist_ok=True)
        with open("logs/scene_plans.jsonl", "a") as f:
            f.write(json.dumps({
                "caption": caption,
                "visual_hint": visual_hint,
                "prompt": prompt,
                "scene_plan": content
            }) + "\n")

        return json.loads(content)

    except Exception as e:
        print("❌ extract_scene_plan() Error:", e)
        return {
            "scene": {"environment": "studio", "setting": "plain white background"},
            "subject": {"main_actor": "a product"},
            "objects": {"main_product": "product"},
            "layout": {},
            "rules": {}
        }

# ----------------------------
# ✨ Enriched Prompt Generation (GPT, 77-token safe)
# ----------------------------
ENRICHED_PROMPT_INSTRUCTIONS = """
You are a prompt engineer for an AI image generation model.
Given a structured scene plan and a user prompt, generate a single natural-language enriched prompt that:
1. Describes the subject, product, setting, and layout clearly
2. Uses natural, photo-realistic language
3. Stays strictly under 77 tokens (CLIP token limit)
Return ONLY the enriched prompt string. No explanations.
"""

def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
    prompts = []
    for _ in range(n):
        try:
            user_input = f"Scene Plan:\n{json.dumps(scene_plan)}\n\nUser Prompt:\n{base_prompt}"
            response = client.chat.completions.create(
                model="gpt-4o-mini-2024-07-18",
                messages=[
                    {"role": "system", "content": ENRICHED_PROMPT_INSTRUCTIONS},
                    {"role": "user", "content": user_input}
                ],
                temperature=0.4,
                max_tokens=100
            )
            enriched = response.choices[0].message.content.strip()
            token_count = len(tokenizer(enriched)["input_ids"])
            print(f"πŸ“ Enriched Prompt ({token_count} tokens): {enriched}")
            prompts.append(enriched)
        except Exception as e:
            print("⚠️ Prompt fallback:", e)
            prompts.append(base_prompt)
    return prompts

# ----------------------------
# ❌ Negative Prompt Generator
# ----------------------------
NEGATIVE_SYSTEM_PROMPT = """
You are a prompt engineer. Given a structured scene plan, generate a short negative prompt
to suppress unwanted visual elements such as: distortion, blurriness, poor anatomy,
logo errors, background noise, or low realism.
Return a single comma-separated list. No intro text.
"""

def generate_negative_prompt_from_scene(scene_plan: dict) -> str:
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini-2024-07-18",
            messages=[
                {"role": "system", "content": NEGATIVE_SYSTEM_PROMPT},
                {"role": "user", "content": json.dumps(scene_plan)}
            ],
            temperature=0.2,
            max_tokens=100
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print("❌ Negative Prompt Error:", e)
        return "blurry, distorted, low quality, deformed, watermark"