vision-web-app / app.py
David Ko
Gradio UI에 벡터 DB 저장 및 검색 기능 통합
c28eadf
import gradio as gr
import torch
from PIL import Image
import numpy as np
import os
import requests
import json
import base64
from io import BytesIO
import uuid
# Model initialization
print("Loading models... This may take a moment.")
# YOLOv8 model
yolo_model = None
try:
from ultralytics import YOLO
yolo_model = YOLO("yolov8n.pt") # Using the nano model for faster inference
print("YOLOv8 model loaded successfully")
except Exception as e:
print("Error loading YOLOv8 model:", e)
yolo_model = None
# DETR model (DEtection TRansformer)
detr_processor = None
detr_model = None
try:
from transformers import DetrImageProcessor, DetrForObjectDetection
# Load the DETR image processor
# DetrImageProcessor: Handles preprocessing of images for DETR model
# - Resizes images to appropriate dimensions
# - Normalizes pixel values
# - Converts images to tensors
# - Handles batch processing
detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
# Load the DETR object detection model
# DetrForObjectDetection: The actual object detection model
# - Uses ResNet-50 as backbone
# - Transformer-based architecture for object detection
# - Predicts bounding boxes and object classes
# - Pre-trained on COCO dataset by Facebook AI Research
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
print("DETR model loaded successfully")
except Exception as e:
print("Error loading DETR model:", e)
detr_processor = None
detr_model = None
# ViT model
vit_processor = None
vit_model = None
try:
from transformers import ViTImageProcessor, ViTForImageClassification
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
print("ViT model loaded successfully")
except Exception as e:
print("Error loading ViT model:", e)
vit_processor = None
vit_model = None
# Get device information
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 벡터 DB에 객체 저장 함수
def save_objects_to_vector_db(image, detection_results, model_type='yolo'):
if image is None or detection_results is None:
return "이미지나 객체 인식 결과가 없습니다."
try:
# 이미지를 base64로 인코딩
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
# 모델 타입에 따라 다른 API 엔드포인트 호출
if model_type in ['yolo', 'detr']:
# 객체 정보 추출
objects = []
for obj in detection_results['objects']:
objects.append({
"class": obj['class'],
"confidence": obj['confidence'],
"bbox": obj['bbox']
})
# API 요청 데이터 구성
data = {
"image": img_str,
"objects": objects,
"image_id": str(uuid.uuid4())
}
# API 호출
response = requests.post(
"http://localhost:7860/api/add-detected-objects",
json=data
)
if response.status_code == 200:
result = response.json()
if 'error' in result:
return f"오류 발생: {result['error']}"
return f"벡터 DB에 {len(objects)}개 객체 저장 완료! ID: {result.get('ids', '알 수 없음')}"
elif model_type == 'vit':
# ViT 분류 결과 저장
data = {
"image": img_str,
"metadata": {
"model": "vit",
"classifications": detection_results.get('classifications', [])
}
}
# API 호출
response = requests.post(
"http://localhost:7860/api/add-image",
json=data
)
if response.status_code == 200:
result = response.json()
if 'error' in result:
return f"오류 발생: {result['error']}"
return f"벡터 DB에 이미지 및 분류 결과 저장 완료! ID: {result.get('id', '알 수 없음')}"
else:
return "지원하지 않는 모델 타입입니다."
if response.status_code != 200:
return f"API 오류: {response.status_code}"
except Exception as e:
return f"오류 발생: {str(e)}"
# 벡터 DB에서 유사 객체 검색 함수
def search_similar_objects(image=None, class_name=None):
try:
data = {}
if image is not None:
# 이미지를 base64로 인코딩
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
data["image"] = img_str
data["n_results"] = 5
elif class_name is not None and class_name.strip():
data["class_name"] = class_name.strip()
data["n_results"] = 5
else:
return "이미지나 클래스 이름 중 하나는 제공해야 합니다.", []
# API 호출
response = requests.post(
"http://localhost:7860/api/search-similar-objects",
json=data
)
if response.status_code == 200:
results = response.json()
if isinstance(results, dict) and 'error' in results:
return f"오류 발생: {results['error']}", []
# 결과 포맷팅
formatted_results = []
for i, result in enumerate(results):
similarity = (1 - result.get('distance', 0)) * 100
img_data = result.get('image', '')
# 이미지 데이터를 PIL 이미지로 변환
if img_data:
try:
img_bytes = base64.b64decode(img_data)
img = Image.open(BytesIO(img_bytes))
except Exception:
img = None
else:
img = None
# 메타데이터 추출
metadata = result.get('metadata', {})
class_name = metadata.get('class', 'N/A')
confidence = metadata.get('confidence', 0) * 100 if metadata.get('confidence') else 'N/A'
formatted_results.append({
'image': img,
'info': f"결과 #{i+1} | 유사도: {similarity:.2f}% | 클래스: {class_name} | 신뢰도: {confidence if isinstance(confidence, str) else f'{confidence:.2f}%'} | ID: {result.get('id', 'N/A')}"
})
return f"{len(formatted_results)}개의 유사 객체를 찾았습니다.", formatted_results
else:
return f"API 오류: {response.status_code}", []
except Exception as e:
return f"오류 발생: {str(e)}", []
# Define model inference functions
def process_yolo(image):
if yolo_model is None:
return None, "YOLOv8 model not loaded", None
# Measure inference time
import time
start_time = time.time()
# Convert to numpy if it's a PIL image
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Run inference
results = yolo_model(image_np)
# Process results
result_image = results[0].plot()
result_image = Image.fromarray(result_image)
# Get detection information
boxes = results[0].boxes
class_names = results[0].names
# Format detection results
detections = []
detection_objects = {'objects': []}
for box in boxes:
class_id = int(box.cls[0].item())
class_name = class_names[class_id]
confidence = round(box.conf[0].item(), 2)
bbox = box.xyxy[0].tolist()
bbox = [round(x) for x in bbox]
detections.append("{}: {} at {}".format(class_name, confidence, bbox))
# 벡터 DB 저장용 객체 정보 추가
detection_objects['objects'].append({
'class': class_name,
'confidence': confidence,
'bbox': bbox
})
# Calculate inference time
inference_time = time.time() - start_time
# Add inference time and device info to detection text
device_info = "GPU" if torch.cuda.is_available() else "CPU"
performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}"
detection_text = "\n".join(detections) if detections else "No objects detected"
detection_text += performance_info
return result_image, detection_text, detection_objects
return result_image, detection_text
def process_detr(image):
if detr_model is None or detr_processor is None:
return None, "DETR model not loaded"
# Measure inference time
import time
start_time = time.time()
# Prepare image for the model
inputs = detr_processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = detr_model(**inputs)
# Convert outputs to image with bounding boxes
# Create tensor with original image dimensions (height, width)
# image.size[::-1] reverses the (width, height) to (height, width) as required by DETR
target_sizes = torch.tensor([image.size[::-1]])
# Process raw model outputs into usable detection results
# - Maps predictions back to original image size
# - Filters detections using confidence threshold (0.9)
# - Returns a dictionary with 'scores', 'labels', and 'boxes' keys
# - [0] extracts results for the first (and only) image in the batch
results = detr_processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
# Create a copy of the image to draw on
result_image = image.copy()
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import io
# Create figure and axes
fig, ax = plt.subplots(1)
ax.imshow(result_image)
# Format detection results
detections = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i) for i in box.tolist()]
class_name = detr_model.config.id2label[label.item()]
confidence = round(score.item(), 2)
# Draw rectangle
rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
# Add label
plt.text(box[0], box[1], "{}: {}".format(class_name, confidence),
bbox=dict(facecolor='white', alpha=0.8))
detections.append("{}: {} at {}".format(class_name, confidence, box))
# Save figure to image
buf = io.BytesIO()
plt.tight_layout()
plt.axis('off')
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
result_image = Image.open(buf)
plt.close(fig)
# Calculate inference time
inference_time = time.time() - start_time
# Add inference time and device info to detection text
device_info = "GPU" if torch.cuda.is_available() else "CPU"
performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}"
detection_text = "\n".join(detections) if detections else "No objects detected"
detection_text += performance_info
return result_image, detection_text
def process_vit(image):
if vit_model is None or vit_processor is None:
return "ViT model not loaded"
# Measure inference time
import time
start_time = time.time()
# Prepare image for the model
inputs = vit_processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = vit_model(**inputs)
# Extract raw logits (unnormalized scores) from model output
# Hugging Face models return logits directly, not probabilities
logits = outputs.logits
# Get the predicted class
# argmax(-1) finds the index with highest score across the last dimension (class dimension)
# item() converts the tensor value to a Python scalar
predicted_class_idx = logits.argmax(-1).item()
# Map the class index to human-readable label using the model's configuration
prediction = vit_model.config.id2label[predicted_class_idx]
# Get top 5 predictions
# Apply softmax to convert raw logits to probabilities
# softmax normalizes the exponentials of logits so they sum to 1.0
# dim=-1 applies softmax along the class dimension
# Shape before softmax: [1, num_classes] (batch_size=1, num_classes=1000)
# [0] extracts the first (and only) item from the batch dimension
# Shape after [0]: [num_classes] (a 1D tensor with 1000 class probabilities)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get the values and indices of the 5 highest probabilities
top5_prob, top5_indices = torch.topk(probs, 5)
results = []
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
class_name = vit_model.config.id2label[idx.item()]
results.append("{}. {}: {:.3f}".format(i+1, class_name, prob.item()))
# Calculate inference time
inference_time = time.time() - start_time
# Add inference time and device info to results
device_info = "GPU" if torch.cuda.is_available() else "CPU"
performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}"
result_text = "\n".join(results)
result_text += performance_info
return result_text
# Define Gradio interface
with gr.Blocks(title="Object Detection Demo") as demo:
gr.Markdown("""
# Multi-Model Object Detection Demo
This demo showcases three different object detection and image classification models:
- **YOLOv8**: Fast and accurate object detection
- **DETR**: DEtection TRansformer for object detection
- **ViT**: Vision Transformer for image classification
Upload an image to see how each model performs!
""")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Row():
yolo_button = gr.Button("Detect with YOLOv8")
detr_button = gr.Button("Detect with DETR")
vit_button = gr.Button("Classify with ViT")
with gr.Row():
with gr.Column():
yolo_output = gr.Image(type="pil", label="YOLOv8 Detection")
yolo_text = gr.Textbox(label="YOLOv8 Results")
with gr.Column():
detr_output = gr.Image(type="pil", label="DETR Detection")
detr_text = gr.Textbox(label="DETR Results")
with gr.Column():
vit_text = gr.Textbox(label="ViT Classification Results")
# 벡터 DB 저장 버튼 및 결과 표시
with gr.Row():
with gr.Column():
gr.Markdown("### 벡터 DB 저장")
save_yolo_button = gr.Button("YOLOv8 인식 결과 저장", variant="primary")
save_detr_button = gr.Button("DETR 인식 결과 저장", variant="primary")
save_vit_button = gr.Button("ViT 분류 결과 저장", variant="primary")
save_result = gr.Textbox(label="벡터 DB 저장 결과")
with gr.Column():
gr.Markdown("### 벡터 DB 검색")
search_class = gr.Textbox(label="클래스 이름으로 검색")
search_button = gr.Button("검색", variant="secondary")
search_result_text = gr.Textbox(label="검색 결과 정보")
search_result_gallery = gr.Gallery(label="검색 결과", columns=5, height=300)
# 객체 인식 결과 저장용 상태 변수
yolo_detection_state = gr.State(None)
detr_detection_state = gr.State(None)
vit_classification_state = gr.State(None)
# Set up event handlers
yolo_button.click(
fn=process_yolo,
inputs=input_image,
outputs=[yolo_output, yolo_text, yolo_detection_state]
)
# DETR 결과 처리 함수 수정 - 상태 저장 추가
def process_detr_with_state(image):
result_image, result_text = process_detr(image)
# 객체 인식 결과 추출
detection_results = {"objects": []}
# 결과 텍스트에서 객체 정보 추출
lines = result_text.split('\n')
for line in lines:
if ': ' in line and ' at ' in line:
try:
class_conf, location = line.split(' at ')
class_name, confidence = class_conf.split(': ')
confidence = float(confidence)
# 바운딩 박스 정보 추출
bbox_str = location.strip('[]').split(', ')
bbox = [int(coord) for coord in bbox_str]
detection_results["objects"].append({
"class": class_name,
"confidence": confidence,
"bbox": bbox
})
except Exception:
pass
return result_image, result_text, detection_results
# ViT 결과 처리 함수 수정 - 상태 저장 추가
def process_vit_with_state(image):
result_text = process_vit(image)
# 분류 결과 추출
classifications = []
# 결과 텍스트에서 분류 정보 추출
lines = result_text.split('\n')
for line in lines:
if '. ' in line and ': ' in line:
try:
rank_class, confidence = line.split(': ')
_, class_name = rank_class.split('. ')
confidence = float(confidence)
classifications.append({
"class": class_name,
"confidence": confidence
})
except Exception:
pass
return result_text, {"classifications": classifications}
detr_button.click(
fn=process_detr_with_state,
inputs=input_image,
outputs=[detr_output, detr_text, detr_detection_state]
)
vit_button.click(
fn=process_vit_with_state,
inputs=input_image,
outputs=[vit_text, vit_classification_state]
)
# 벡터 DB 저장 버튼 이벤트 핸들러
save_yolo_button.click(
fn=lambda img, det: save_objects_to_vector_db(img, det, 'yolo'),
inputs=[input_image, yolo_detection_state],
outputs=save_result
)
save_detr_button.click(
fn=lambda img, det: save_objects_to_vector_db(img, det, 'detr'),
inputs=[input_image, detr_detection_state],
outputs=save_result
)
save_vit_button.click(
fn=lambda img, det: save_objects_to_vector_db(img, det, 'vit'),
inputs=[input_image, vit_classification_state],
outputs=save_result
)
# 검색 버튼 이벤트 핸들러
def format_search_results(result_text, results):
images = []
captions = []
for result in results:
if result.get('image'):
images.append(result['image'])
captions.append(result['info'])
return result_text, [(img, cap) for img, cap in zip(images, captions)]
search_button.click(
fn=lambda class_name: search_similar_objects(class_name=class_name),
inputs=search_class,
outputs=[search_result_text, search_result_gallery]
)
# Launch the app
if __name__ == "__main__":
demo.launch()