Virtual-Store / main_code_script.py
burman-ai's picture
Upload 2 files
e7bcb12 verified
# Install necessary libraries (in your requirements.txt)
# pillow opencv-python transformers mediapipe diffusers accelerate transformers
# Example install command: pip install pillow opencv-python transformers mediapipe diffusers accelerate transformers
from PIL import Image
import cv2
import mediapipe as mp
import numpy as np
from transformers import pipeline
from diffusers import StableDiffusionInpaintPipeline
import torch
# --- 1. Pose Estimation (using Mediapipe) ---
def estimate_pose(image_path):
"""Detects the pose of a person in an image using Mediapipe.
Args:
image_path: Path to the input image.
Returns:
A list of landmarks (x, y, visibility)
or None if no pose is detected.
"""
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose
with mp_pose.Pose(
static_image_mode=True,
model_complexity=2,
enable_segmentation=True,
min_detection_confidence=0.5) as pose:
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
# Example: Draw the pose landmarks on the image (for visualization)
annotated_image = image.copy()
mp_drawing.draw_landmarks(
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
#cv2.imwrite("pose_annotated.jpg", annotated_image) # Save annotated image
#return results.pose_landmarks.landmark
# Return the landmarks
return results, image # Return the entire result
else:
return None, None # or raise an exception
# --- 2. Clothing Segmentation (Example - using a placeholder function) ---
def segment_clothing(image, results): #Added result
"""Segments the clothing region in the image.
This is a simplified example. In reality, you would use a pre-trained
segmentation model.
"""
# 1. Create a mask where the person is present.
segmentation_mask = results.segmentation_mask
threshold = 0.5 # Adjust this threshold as needed.
# Threshold the segmentation mask to create a binary mask.
binary_mask = (segmentation_mask > threshold).astype(np.uint8) * 255
# Convert binary mask to a PIL Image
mask_img = Image.fromarray(binary_mask).convert("L")
return mask_img
# --- 3. Image Inpainting (Replacing Clothing - using Stable Diffusion Inpainting) ---
def inpaint_clothing(image, mask_img, garment_image_path, device="cuda" if torch.cuda.is_available() else "cpu"): # Changed input
"""
Replaces the clothing region in the image with the uploaded garment image,
using Stable Diffusion Inpainting.
"""
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16
)
pipe = pipe.to(device)
# Resize the image and mask to the same size. Important for inpainting.
image = image.resize((512, 512))
mask_img = mask_img.resize((512, 512))
# Load the garment image
garment_image = Image.open(garment_image_path).convert("RGB")
garment_image = garment_image.resize((512,512)) # Resize if necessary
# Inpaint using the garment image as a guide (This part might need further refinement)
# A simple approach is to use the garment image in the prompt.
# More advanced techniques might involve using the garment image as
# a style reference or directly manipulating the latent space.
prompt = f"A photo of a person wearing the uploaded garment"
image = pipe(prompt=prompt, image=image, mask_image=mask_img).images[0]
return image
# --- 4. Main Function (Putting it all together) ---
def change_clothing(image_path, garment_image_path): # Changed input
"""
Main function to change the clothing in an image.
"""
# 1. Load the image
image = Image.open(image_path).convert("RGB")
# 2. Estimate the pose
results, cv2_image = estimate_pose(image_path)
if results is None:
print("No pose detected.")
return None
# 3. Segment the clothing
mask_img = segment_clothing(image, results)
# 4. Inpaint the clothing
modified_image = inpaint_clothing(image, mask_img, garment_image_path) # Changed input
return modified_image
# --- Example Usage ---
if __name__ == "__main__":
input_image_path = "person.jpg" # Replace with your image
garment_image_path = "garment.jpg" # Replace with your garment image
modified_image = change_clothing(input_image_path, garment_image_path)
if modified_image:
modified_image.save("modified_image.jpg")
print("Clothing changed and saved to modified_image.jpg")
else:
print("Failed to change clothing.")