primerz commited on
Commit
6026000
·
verified ·
1 Parent(s): ef222df

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +157 -0
  2. config.py +47 -0
  3. generator.py +156 -0
  4. model.py +166 -0
  5. requirements.txt +14 -0
  6. utils.py +77 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from model import ModelHandler
5
+ from generator import Generator
6
+ from config import Config
7
+
8
+ # 1. Initialize Models Globally
9
+ print("Initializing Application...")
10
+ handler = ModelHandler()
11
+ handler.load_models()
12
+ gen = Generator(handler)
13
+
14
+ # 2. Define GPU-enabled Inference Function
15
+ @spaces.GPU(duration=20)
16
+ def process_img(
17
+ image,
18
+ prompt,
19
+ negative_prompt,
20
+ cfg_scale, # <-- RE-ENABLED
21
+ steps,
22
+ img_strength,
23
+ depth_strength,
24
+ edge_strength,
25
+ seed
26
+ ):
27
+ if image is None:
28
+ raise gr.Error("Please upload an image first.")
29
+
30
+ try:
31
+ print("--- Starting Generation ---")
32
+ result = gen.predict(
33
+ image,
34
+ prompt,
35
+ negative_prompt=negative_prompt,
36
+ guidance_scale=cfg_scale, # <-- RE-ENABLED
37
+ num_inference_steps=steps,
38
+ img2img_strength=img_strength,
39
+ depth_strength=depth_strength,
40
+ lineart_strength=edge_strength,
41
+ seed=seed
42
+ )
43
+ print("--- Generation Complete ---")
44
+ return result
45
+
46
+ except Exception as e:
47
+ print(f"Error during generation: {e}")
48
+ raise gr.Error(f"An error occurred: {str(e)}")
49
+
50
+ # 3. Build Gradio Interface
51
+ with gr.Blocks(title="Face To Style", theme=gr.themes.Soft()) as demo:
52
+ gr.Markdown(
53
+ """
54
+ # 🎮 Face to Style
55
+ Upload any image. If there is a face, we'll keep the identity. If not, we'll stylize the scene!
56
+ """
57
+ )
58
+
59
+ with gr.Row():
60
+ with gr.Column(scale=2):
61
+ input_img = gr.Image(type="pil", label="Input Image")
62
+ prompt = gr.Textbox(
63
+ label="Prompt (Optional)",
64
+ placeholder="Leave empty for auto-captioning...",
65
+ info=f"The trigger words '{Config.STYLE_TRIGGER}' are added automatically."
66
+ )
67
+
68
+ negative_prompt = gr.Textbox(
69
+ label="Negative Prompt (Optional)",
70
+ placeholder="e.g., blurry, text, watermark, bad art...",
71
+ value=Config.DEFAULT_NEGATIVE_PROMPT
72
+ )
73
+
74
+ with gr.Accordion("Advanced Settings", open=False):
75
+ seed = gr.Number(
76
+ label="Seed",
77
+ value=-1,
78
+ info="-1 for random",
79
+ precision=0
80
+ )
81
+
82
+ # --- RE-ENABLED CFG/GUIDANCE SLIDER ---
83
+ cfg_scale = gr.Slider(
84
+ elem_id="cfg_scale",
85
+ minimum=1.0,
86
+ maximum=10.0, # Range for TCD+Style
87
+ step=0.1,
88
+ value=Config.CGF_SCALE, # Default 4.0
89
+ label="Style Strength (Guidance)"
90
+ )
91
+
92
+ steps = gr.Slider(
93
+ elem_id="steps",
94
+ minimum=1,
95
+ maximum=20,
96
+ step=1,
97
+ value=8, # TCD default
98
+ label="Steps Number"
99
+ )
100
+ img_strength = gr.Slider(
101
+ elem_id="img_strength",
102
+ minimum=0.1,
103
+ maximum=1.0,
104
+ step=0.05,
105
+ value=Config.IMG_STRENGTH,
106
+ label="Image Strength (Img2Img)"
107
+ )
108
+ depth_strength = gr.Slider(
109
+ elem_id="depth_strength",
110
+ minimum=0.0,
111
+ maximum=1.0,
112
+ step=0.05,
113
+ value=Config.DEPTH_STRENGTH,
114
+ label="DepthMap Strength"
115
+ )
116
+ edge_strength = gr.Slider(
117
+ elem_id="edge_strength",
118
+ minimum=0.0,
119
+ maximum=1.0,
120
+ step=0.05,
121
+ value=Config.EDGE_STRENGTH,
122
+ label="EdgeMap Strength (LineArt)"
123
+ )
124
+
125
+ run_btn = gr.Button("Generate", variant="primary")
126
+
127
+ with gr.Column(scale=1):
128
+ output_img = gr.Image(label="Styled Result")
129
+
130
+ # Event Handler
131
+ all_inputs = [
132
+ input_img,
133
+ prompt,
134
+ negative_prompt,
135
+ cfg_scale, # <-- RE-ENABLED
136
+ steps,
137
+ img_strength,
138
+ depth_strength,
139
+ edge_strength,
140
+ seed
141
+ ]
142
+
143
+ run_btn.click(
144
+ fn=process_img,
145
+ inputs=all_inputs,
146
+ outputs=[output_img]
147
+ )
148
+
149
+
150
+ # 4. Launch the App
151
+ if __name__ == "__main__":
152
+ demo.queue(max_size=20, api_open=True)
153
+ demo.launch(
154
+ server_name="0.0.0.0",
155
+ server_port=7860,
156
+ show_api=True
157
+ )
config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Config:
4
+ # Hardware
5
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
7
+
8
+ # --- UPDATED: New Base Model & Style LoRA ---
9
+ # Assuming these are in the 'primerz/pixagram' repo or a new one.
10
+ # If they are in a different repo, change REPO_ID.
11
+ REPO_ID = "primerz/pixagram"
12
+ CHECKPOINT_FILENAME = "reality.safetensors"
13
+ LORA_FILENAME = "retroart.safetensors"
14
+ LORA_STRENGTH = 1.25 # TCD works well with 1.0
15
+
16
+ # Trigger Words for the LoRA
17
+ STYLE_TRIGGER = "p1x3l4rt, pixel art"
18
+
19
+ # Default Negative Prompt (Updated for general use)
20
+ DEFAULT_NEGATIVE_PROMPT = "Ugly, artifacts, blurry, disformed, photo-realistic, photo, photography, realistic, low-quality, pixelart, text."
21
+ # --- END UPDATED ---
22
+
23
+ # InstantID Assets
24
+ INSTANTID_REPO = "InstantX/InstantID"
25
+
26
+ # ControlNet Repos
27
+ CN_ZOE_REPO = "diffusers/controlnet-zoE-depth-sdxl-1.0"
28
+ CN_LINEART_REPO = "ShermanG/ControlNet-Standard-Lineart-for-SDXL"
29
+
30
+ # Preprocessor (Annotator) Repo
31
+ ANNOTATOR_REPO = "lllyasviel/Annotators"
32
+
33
+ # Captioning Model
34
+ CAPTIONER_REPO = "Salesforce/blip-image-captioning-base"
35
+
36
+ # InsightFace Model (HF Hub mirror)
37
+ ANTELOPEV2_REPO = "DIAMONIK7777/antelopev2"
38
+ ANTELOPEV2_ROOT = "." # Parent folder
39
+ ANTELOPEV2_NAME = "antelopev2"
40
+
41
+ # Gradio Parameters
42
+ # --- FIX: Style LoRA needs non-zero CFG to activate. ---
43
+ CGF_SCALE = 4.0 # Was 0.0. This activates the prompt trigger.
44
+ STEPS_NUMBER = 4
45
+ IMG_STRENGTH = 0.8
46
+ DEPTH_STRENGTH = 0.8
47
+ EDGE_STRENGTH = 0.8
generator.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import Config
3
+ from utils import get_caption, draw_kps # Removed resize_image_to_1mp
4
+ from PIL import Image
5
+
6
+ class Generator:
7
+ def __init__(self, model_handler):
8
+ self.mh = model_handler
9
+
10
+ def smart_crop_and_resize(self, image):
11
+ """
12
+ Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
13
+ Performs a center crop to match the target ratio, then resizes.
14
+ """
15
+ w, h = image.size
16
+ aspect_ratio = w / h
17
+
18
+ # 1. Determine Target Resolution (Horizon SDXL Buckets)
19
+ if 0.85 <= aspect_ratio <= 1.15:
20
+ target_w, target_h = 1024, 1024
21
+ print(f"Snap to Bucket: Square (1024x1024)")
22
+ elif aspect_ratio < 0.85:
23
+ if aspect_ratio < 0.72:
24
+ target_w, target_h = 832, 1216 # Tall Portrait
25
+ print(f"Snap to Bucket: Tall Portrait (832x1216)")
26
+ else:
27
+ target_w, target_h = 896, 1152 # Standard Portrait
28
+ print(f"Snap to Bucket: Portrait (896x1152)")
29
+ else: # aspect_ratio > 1.15
30
+ if aspect_ratio > 1.35:
31
+ target_w, target_h = 1216, 832 # Wide Landscape
32
+ print(f"Snap to Bucket: Wide Landscape (1216x832)")
33
+ else:
34
+ target_w, target_h = 1152, 896 # Standard Landscape
35
+ print(f"Snap to Bucket: Landscape (1152x896)")
36
+
37
+ # 2. Center Crop to Target Aspect Ratio
38
+ target_ar = target_w / target_h
39
+
40
+ if aspect_ratio > target_ar:
41
+ new_w = int(h * target_ar)
42
+ offset = (w - new_w) // 2
43
+ crop_box = (offset, 0, offset + new_w, h)
44
+ else:
45
+ new_h = int(w / target_ar)
46
+ offset = (h - new_h) // 2
47
+ crop_box = (0, offset, w, offset + new_h)
48
+
49
+ cropped_img = image.crop(crop_box)
50
+
51
+ # 3. Resize to Exact Target Resolution
52
+ final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
53
+ return final_img
54
+
55
+ def prepare_control_images(self, image, width, height):
56
+ """
57
+ Generates conditioning maps, ensuring they are resized
58
+ to the exact target dimensions (width, height).
59
+ """
60
+ print(f"Generating control maps for {width}x{height}...")
61
+ depth_map_raw = self.mh.leres_detector(image)
62
+ lineart_map_raw = self.mh.lineart_anime_detector(image)
63
+ depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
64
+ lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
65
+ return depth_map, lineart_map
66
+
67
+ def predict(
68
+ self,
69
+ input_image,
70
+ user_prompt="",
71
+ negative_prompt="",
72
+ # --- TCD Optimized Defaults ---
73
+ guidance_scale=4.0, # <-- FIX: Set to non-zero default
74
+ num_inference_steps=8,
75
+ img2img_strength=0.9,
76
+ # ----------------------------
77
+ depth_strength=0.3,
78
+ lineart_strength=0.3,
79
+ seed=-1
80
+ ):
81
+ # 1. Pre-process Inputs (Using Smart Crop)
82
+ print("Processing Input...")
83
+ processed_image = self.smart_crop_and_resize(input_image)
84
+ target_width, target_height = processed_image.size
85
+
86
+ # 2. Get Face Info
87
+ face_info = self.mh.get_face_info(processed_image)
88
+
89
+ # 3. Generate Prompt
90
+ if not user_prompt.strip():
91
+ try:
92
+ generated_caption = get_caption(processed_image)
93
+ final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
94
+ except Exception as e:
95
+ print(f"Captioning failed: {e}, using default prompt.")
96
+ final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
97
+ else:
98
+ final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
99
+
100
+ print(f"Prompt: {final_prompt}")
101
+ print(f"Negative Prompt: {negative_prompt}")
102
+
103
+ # 4. Generate Control Maps
104
+ print("Generating Control Maps (Depth, LineArt)...")
105
+ depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
106
+
107
+ # 5. Logic for Face vs No-Face
108
+ if face_info is not None:
109
+ print("Face detected: Applying InstantID with keypoints.")
110
+ face_emb = torch.tensor(
111
+ face_info['embedding'],
112
+ dtype=Config.DTYPE,
113
+ device=Config.DEVICE
114
+ ).unsqueeze(0)
115
+ face_kps = draw_kps(processed_image, face_info['kps'])
116
+ controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
117
+ self.mh.pipeline.set_ip_adapter_scale(0.8)
118
+ else:
119
+ print("No face detected: Disabling InstantID.")
120
+ face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
121
+ face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
122
+ controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
123
+ self.mh.pipeline.set_ip_adapter_scale(0.0)
124
+
125
+ control_guidance_end = [0.3, 0.6, 0.6]
126
+
127
+ if seed == -1 or seed is None:
128
+ seed = torch.Generator().seed()
129
+ generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
130
+ print(f"Using seed: {seed}")
131
+
132
+ # 6. Run Inference
133
+ print("Running pipeline...")
134
+ result = self.mh.pipeline(
135
+ prompt=final_prompt,
136
+ negative_prompt=negative_prompt,
137
+ image=processed_image,
138
+ control_image=[face_kps, depth_map, lineart_map],
139
+ image_embeds=face_emb,
140
+ generator=generator,
141
+
142
+ strength=img2img_strength,
143
+ guidance_scale=guidance_scale, # <-- Will use non-zero value
144
+ num_inference_steps=num_inference_steps,
145
+
146
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
147
+ control_guidance_end=control_guidance_end,
148
+ clip_skip=0,
149
+
150
+ # --- TCD Specific Parameter ---
151
+ eta=0.45, # Gamma/Stochasticity
152
+ # ------------------------------
153
+
154
+ ).images[0]
155
+
156
+ return result
model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from config import Config
6
+
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ TCDScheduler,
10
+ )
11
+ from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
+
13
+ # Import the custom pipeline from your local file
14
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
15
+
16
+ from huggingface_hub import snapshot_download, hf_hub_download
17
+ from insightface.app import FaceAnalysis
18
+ from controlnet_aux import LeresDetector, LineartAnimeDetector
19
+
20
+ class ModelHandler:
21
+ def __init__(self):
22
+ self.pipeline = None
23
+ self.app = None # InsightFace
24
+ self.leres_detector = None
25
+ self.lineart_anime_detector = None
26
+ self.face_analysis_loaded = False
27
+
28
+ def load_face_analysis(self):
29
+ """
30
+ Load face analysis model.
31
+ Downloads from HF Hub to the path insightface expects.
32
+ """
33
+ print("Loading face analysis model...")
34
+
35
+ model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
36
+
37
+ if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
38
+ print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
39
+ try:
40
+ snapshot_download(
41
+ repo_id=Config.ANTELOPEV2_REPO,
42
+ local_dir=model_path, # Download to the correct expected path
43
+ )
44
+ except Exception as e:
45
+ print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
46
+ return False
47
+
48
+ try:
49
+ self.app = FaceAnalysis(
50
+ name=Config.ANTELOPEV2_NAME,
51
+ root=Config.ANTELOPEV2_ROOT,
52
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
53
+ )
54
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
55
+ print(f" [OK] Face analysis model loaded successfully.")
56
+ return True
57
+
58
+ except Exception as e:
59
+ print(f" [WARNING] Face detection system failed to initialize: {e}")
60
+ return False
61
+
62
+ def load_models(self):
63
+ # 1. Load Face Analysis
64
+ self.face_analysis_loaded = self.load_face_analysis()
65
+
66
+ # 2. Load ControlNets
67
+ print("Loading ControlNets (InstantID, Zoe, LineArt)...")
68
+ cn_instantid = ControlNetModel.from_pretrained(
69
+ Config.INSTANTID_REPO,
70
+ subfolder="ControlNetModel",
71
+ torch_dtype=Config.DTYPE
72
+ )
73
+ cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
74
+ cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
75
+
76
+ print("Wrapping ControlNets in MultiControlNetModel...")
77
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
78
+ controlnet = MultiControlNetModel(controlnet_list)
79
+
80
+ # 3. Load SDXL Pipeline (Now from 'reality.safetensors')
81
+ print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
82
+
83
+ checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
84
+ if not os.path.exists(checkpoint_local_path):
85
+ print(f"Downloading checkpoint to {checkpoint_local_path}...")
86
+ hf_hub_download(
87
+ repo_id=Config.REPO_ID,
88
+ filename=Config.CHECKPOINT_FILENAME,
89
+ local_dir="./models",
90
+ local_dir_use_symlinks=False
91
+ )
92
+
93
+ print(f"Loading pipeline from local file: {checkpoint_local_path}")
94
+ self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
95
+ checkpoint_local_path,
96
+ controlnet=controlnet,
97
+ torch_dtype=Config.DTYPE,
98
+ use_safetensors=True
99
+ )
100
+
101
+ self.pipeline.to(Config.DEVICE)
102
+
103
+ try:
104
+ self.pipeline.enable_xformers_memory_efficient_attention()
105
+ print(" [OK] xFormers memory efficient attention enabled.")
106
+ except Exception as e:
107
+ print(f" [WARNING] Failed to enable xFormers: {e}")
108
+
109
+ # 4. Set TCD Scheduler (Sanitized Config)
110
+ print("Configuring TCDScheduler...")
111
+ self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
112
+ print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
113
+
114
+ # 5. Load Adapters
115
+ print("Loading Adapters...")
116
+
117
+ # 5b. Load and Fuse Style LoRA (lucasart)
118
+ print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
119
+ style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
120
+ if not os.path.exists(style_lora_path):
121
+ hf_hub_download(
122
+ repo_id=Config.REPO_ID,
123
+ filename=Config.LORA_FILENAME,
124
+ local_dir="./models",
125
+ local_dir_use_symlinks=False
126
+ )
127
+ self.pipeline.load_lora_weights("./models", weight_name=Config.LORA_FILENAME)
128
+ self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
129
+ print(" [OK] Style LoRA fused.")
130
+
131
+ # 5c. Load IP-Adapter (for InstantID) - *Must be loaded AFTER fusing*
132
+ ip_adapter_filename = "ip-adapter.bin"
133
+ ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
134
+ if not os.path.exists(ip_adapter_local_path):
135
+ hf_hub_download(
136
+ repo_id=Config.INSTANTID_REPO,
137
+ filename=ip_adapter_filename,
138
+ local_dir="./models",
139
+ local_dir_use_symlinks=False
140
+ )
141
+ self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
142
+ print(" [OK] IP-Adapter loaded.")
143
+
144
+ # --- END FIX ---
145
+
146
+ # 7. Load Preprocessors
147
+ print("Loading Preprocessors (LeReS, LineArtAnime)...")
148
+ self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
149
+ self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
150
+
151
+ print("--- All models loaded successfully ---")
152
+
153
+ def get_face_info(self, image):
154
+ """Extracts the largest face, returns insightface result object."""
155
+ if not self.face_analysis_loaded:
156
+ return None
157
+ try:
158
+ cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
159
+ faces = self.app.get(cv2_img)
160
+ if len(faces) == 0:
161
+ return None
162
+ faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
163
+ return faces[0]
164
+ except Exception as e:
165
+ print(f"Face embedding extraction failed: {e}")
166
+ return None
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.27.0
2
+ transformers
3
+ accelerate
4
+ peft
5
+ torch
6
+ opencv-python-headless
7
+ Pillow
8
+ insightface
9
+ onnxruntime
10
+ gradio>=4.0.0
11
+ controlnet_aux
12
+ huggingface_hub
13
+ mediapipe
14
+ timm
utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration
3
+ import torch
4
+ from config import Config
5
+ import cv2
6
+ import numpy as np
7
+ import math
8
+
9
+ # Simple global caching for the captioner
10
+ captioner_processor = None
11
+ captioner_model = None
12
+
13
+ def resize_image_to_1mp(image):
14
+ """Resizes image to approx 1MP (e.g., 1024x1024) preserving aspect ratio."""
15
+ image = image.convert("RGB")
16
+ w, h = image.size
17
+ target_pixels = 1024 * 1024
18
+ aspect_ratio = w / h
19
+
20
+ # Calculate new dimensions
21
+ new_h = int((target_pixels / aspect_ratio) ** 0.5)
22
+ new_w = int(new_h * aspect_ratio)
23
+
24
+ # Ensure divisibility by 48 for efficiency
25
+ new_w = (new_w // 48) * 48
26
+ new_h = (new_h // 48) * 48
27
+
28
+ if new_w == 0 or new_h == 0:
29
+ new_w, new_h = 1024, 1024 # Fallback
30
+
31
+ return image.resize((new_w, new_h), Image.LANCZOS)
32
+
33
+ def get_caption(image):
34
+ """Generates a caption for the image if one isn't provided."""
35
+ global captioner_processor, captioner_model
36
+
37
+ if captioner_model is None:
38
+ print("Loading Captioner (BLIP)...")
39
+ captioner_processor = BlipProcessor.from_pretrained(Config.CAPTIONER_REPO)
40
+ captioner_model = BlipForConditionalGeneration.from_pretrained(Config.CAPTIONER_REPO).to(Config.DEVICE)
41
+
42
+ inputs = captioner_processor(image, return_tensors="pt").to(Config.DEVICE)
43
+ out = captioner_model.generate(**inputs)
44
+ caption = captioner_processor.decode(out[0], skip_special_tokens=True)
45
+ return caption
46
+
47
+ # --- ADDED: Function from your provided file ---
48
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
49
+ stickwidth = 4
50
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
51
+ kps = np.array(kps)
52
+
53
+ w, h = image_pil.size
54
+ out_img = np.zeros([h, w, 3])
55
+
56
+ for i in range(len(limbSeq)):
57
+ index = limbSeq[i]
58
+ color = color_list[index[0]]
59
+
60
+ x = kps[index][:, 0]
61
+ y = kps[index][:, 1]
62
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
63
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
64
+ polygon = cv2.ellipse2Poly(
65
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
66
+ )
67
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
68
+ out_img = (out_img * 0.6).astype(np.uint8)
69
+
70
+ for idx_kp, kp in enumerate(kps):
71
+ color = color_list[idx_kp]
72
+ x, y = kp
73
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
74
+
75
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
76
+ return out_img_pil
77
+ # --- END ADDED ---