itda-multimodal-segmentation / db_multimodal_create.py
leedoming's picture
Create db_multimodal_create.py
466ea14 verified
import chromadb
import logging
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import requests
import io
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import os
from io import BytesIO
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
# 로깅 설정
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('fashion_db_creation.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_models():
try:
logger.info("Loading models...")
# CLIP 모델
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# 세그멘테이션 모델
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 이미지 전처리를 위한 transforms 추가
from torchvision import transforms
resize_transform = transforms.Compose([
transforms.Resize((224, 224)), # CLIP 입력 크기에 맞춤
transforms.ToTensor(),
])
return model, preprocess_val, segmenter, device, resize_transform
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def process_segmentation(image, segmenter):
"""Segmentation processing"""
try:
output = segmenter(image)
if not output:
logger.warning("No segments found in image")
return None
segment_sizes = [np.sum(seg['mask']) for seg in output]
if not segment_sizes:
return None
largest_idx = np.argmax(segment_sizes)
mask = output[largest_idx]['mask']
if not isinstance(mask, np.ndarray):
mask = np.array(mask)
if len(mask.shape) > 2:
mask = mask[:, :, 0]
mask = mask.astype(float)
logger.info(f"Successfully created mask with shape {mask.shape}")
return mask
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
return None
def load_image_from_url(url, max_retries=3):
for attempt in range(max_retries):
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
img = Image.open(BytesIO(response.content)).convert('RGB')
return img
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1)
else:
logger.error(f"Failed to load image from {url} after {max_retries} attempts")
return None
def extract_features(image, mask, model, preprocess_val, device):
"""Advanced feature extraction with mask-based attention"""
try:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
mask_3channel = np.repeat(mask, 3, axis=2)
# 1. 원본 이미지에서 특징 추출
image_tensor_original = preprocess_val(image).unsqueeze(0).to(device)
# 2. 마스크된 이미지(흰색 배경) 특징 추출
masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255
image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8))
image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device)
# 3. 의류 부분만 크롭한 버전 특징 추출
bbox = get_bbox_from_mask(mask) # 마스크로부터 경계상자 추출
cropped_img = crop_and_resize(img_array * mask_3channel, bbox)
image_cropped = Image.fromarray(cropped_img.astype(np.uint8))
image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device)
with torch.no_grad():
# 세 가지 버전의 특징 추출
features_original = model.encode_image(image_tensor_original)
features_masked = model.encode_image(image_tensor_masked)
features_cropped = model.encode_image(image_tensor_cropped)
# 가중치를 사용한 특징 결합
combined_features = (
0.2 * features_original +
0.3 * features_masked +
0.5 * features_cropped
)
# 정규화
combined_features /= combined_features.norm(dim=-1, keepdim=True)
return combined_features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None
def get_bbox_from_mask(mask):
"""마스크로부터 경계상자 좌표 추출"""
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
# 여유 공간 추가
padding = 10
rmin = max(rmin - padding, 0)
rmax = min(rmax + padding, mask.shape[0])
cmin = max(cmin - padding, 0)
cmax = min(cmax + padding, mask.shape[1])
return rmin, rmax, cmin, cmax
def crop_and_resize(image, bbox):
"""경계상자로 이미지 크롭 및 리사이즈"""
rmin, rmax, cmin, cmax = bbox
cropped = image[rmin:rmax, cmin:cmax]
# PIL을 사용하여 정사각형으로 리사이즈
size = max(cropped.shape[:2])
square_img = np.full((size, size, 3), 255, dtype=np.uint8)
start_h = (size - cropped.shape[0]) // 2
start_w = (size - cropped.shape[1]) // 2
square_img[start_h:start_h+cropped.shape[0],
start_w:start_w+cropped.shape[1]] = cropped
return square_img
def process_item(item, model, preprocess_val, segmenter, device, resize_transform):
"""Process single item from JSON data"""
try:
# 이미지 URL 추출
if '이미지 링크' in item:
image_url = item['이미지 링크']
elif '이미지 URL' in item:
image_url = item['이미지 URL']
else:
logger.warning(f"No image URL found in item")
return None
# 메타데이터 생성
metadata = create_metadata(item)
# 이미지 다운로드
image = load_image_from_url(image_url)
if image is None:
logger.warning(f"Failed to load image from {image_url}")
return None
# 세그멘테이션 수행
mask = process_segmentation(image, segmenter)
if mask is None:
logger.warning(f"Failed to create segmentation mask for {image_url}")
return None
# 새로운 특징 추출 방식 적용
try:
features = extract_features(image, mask, model, preprocess_val, device)
if features is None:
raise ValueError("Feature extraction failed")
# 디버깅용 이미지 저장 (선택사항)
# save_debug_images(image, mask, image_url)
except Exception as e:
logger.error(f"Feature extraction failed for {image_url}: {str(e)}")
return None
return {
'id': metadata['product_id'],
'embedding': features.tolist(),
'metadata': metadata,
'image_uri': image_url
}
except Exception as e:
logger.error(f"Error processing item: {str(e)}")
return None
# 디버깅용 이미지 저장 함수 (선택사항)
def save_debug_images(image, mask, url):
try:
debug_dir = "debug_images"
os.makedirs(debug_dir, exist_ok=True)
# URL에서 파일명 추출
filename = url.split('/')[-1].split('?')[0]
# 원본, 마스크, 처리된 이미지 저장
image.save(f"{debug_dir}/original_{filename}")
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
mask_img.save(f"{debug_dir}/mask_{filename}")
except Exception as e:
logger.warning(f"Failed to save debug images: {str(e)}")
def create_metadata(item):
"""Create standardized metadata from different JSON formats"""
metadata = {}
# 상품 ID 처리 개선
if '상품 ID' in item: # 무신사 형식
metadata['product_id'] = item['상품 ID']
else:
# 11번가/G마켓의 경우 상품명과 URL로 유니크한 ID 생성
unique_string = f"{item.get('상품명', '')}{item.get('이미지 URL', '')}"
metadata['product_id'] = str(hash(unique_string))
# 나머지 메타데이터 처리
metadata['brand'] = item.get('브랜드명', 'unknown')
metadata['name'] = item.get('제품명') or item.get('상품명', 'unknown')
metadata['price'] = (item.get('정가') or item.get('가격') or
item.get('판매가', 'unknown'))
metadata['discount'] = item.get('할인율', 'unknown')
if '카테고리' in item:
if isinstance(item['카테고리'], list):
metadata['category'] = '/'.join(item['카테고리'])
else:
metadata['category'] = item['카테고리']
else:
# 11번가/G마켓의 경우 상품명에서 카테고리 추출 시도
name = metadata['name'].lower()
categories = ['원피스', '셔츠', '블라우스', '니트', '가디건',
'스커트', '팬츠', '셋업', '아우터', '자켓']
found_categories = [cat for cat in categories if cat in name]
metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown'
metadata['image_url'] = (item.get('이미지 링크') or
item.get('이미지 URL', 'unknown'))
# 쇼핑몰 출처 추가
if '이미지 링크' in item:
metadata['source'] = 'musinsa'
elif 'cdn.011st.com' in metadata['image_url']:
metadata['source'] = '11st'
elif 'gmarket' in metadata['image_url']:
metadata['source'] = 'gmarket'
else:
metadata['source'] = 'unknown'
return metadata
def create_multimodal_fashion_db(json_files):
try:
logger.info("Starting multimodal fashion database creation")
# 모델 로드
model, preprocess_val, segmenter, device, resize_transform = load_models()
# ChromaDB 설정
client = chromadb.PersistentClient(path="./fashion_multimodal_db")
# Multimodal collection 생성
embedding_function = OpenCLIPEmbeddingFunction()
data_loader = ImageLoader()
try:
client.delete_collection("fashion_multimodal")
logger.info("Deleted existing collection")
except:
logger.info("No existing collection to delete")
collection = client.create_collection(
name="fashion_multimodal",
embedding_function=embedding_function,
data_loader=data_loader,
metadata={"description": "Fashion multimodal collection with advanced feature extraction"}
)
# 처리 결과 통계
stats = {
'total_processed': 0,
'successful': 0,
'failed': 0,
'feature_extraction_failed': 0
}
# JSON 파일들 처리
for json_file in json_files:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
logger.info(f"Processing {len(data)} items from {json_file}")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for item in data:
future = executor.submit(
process_item,
item, model, preprocess_val, segmenter, device, resize_transform
)
futures.append(future)
processed_items = []
for future in tqdm(futures, desc=f"Processing {json_file}"):
stats['total_processed'] += 1
result = future.result()
if result is not None:
processed_items.append(result)
stats['successful'] += 1
else:
stats['failed'] += 1
# 배치로 데이터베이스에 추가
if processed_items:
try:
collection.add(
ids=[item['id'] for item in processed_items],
embeddings=[item['embedding'] for item in processed_items],
metadatas=[item['metadata'] for item in processed_items],
uris=[item['image_uri'] for item in processed_items]
)
except Exception as e:
logger.error(f"Failed to add batch to collection: {str(e)}")
stats['failed'] += len(processed_items)
stats['successful'] -= len(processed_items)
# 최종 통계 출력
logger.info("Processing completed:")
logger.info(f"Total processed: {stats['total_processed']}")
logger.info(f"Successful: {stats['successful']}")
logger.info(f"Failed: {stats['failed']}")
return stats['successful'] > 0
except Exception as e:
logger.error(f"Database creation error: {str(e)}")
return False
if __name__ == "__main__":
json_files = [
'./musinsa_ranking_images_category_0920.json',
'./11st/11st_bagaccessory_20241017_172846.json',
'./11st/11st_best_abroad_bagaccessory_20241017_173300.json',
'./11st/11st_best_abroad_fashion_20241017_173144.json',
'./11st/11st_best_abroad_luxury_20241017_173343.json',
'./11st/11st_best_men_20241017_172534.json',
'./11st/11st_best_women_20241017_172127.json',
'./gmarket/gmarket_best_accessory_20241015_155921.json',
'./gmarket/gmarket_best_bag_20241015_155811.json',
'./gmarket/gmarket_best_brand_20241015_155530.json',
'./gmarket/gmarket_best_casual_20241015_155421.json',
'./gmarket/gmarket_best_men_20241015_155025.json',
'./gmarket/gmarket_best_shoe_20241015_155613.json',
'./gmarket/gmarket_best_women_20241015_154206.json'
]
success = create_multimodal_fashion_db(json_files)
if success:
print("Successfully created multimodal fashion database!")
else:
print("Failed to create database. Check the logs for details.")