ImageCat / app.py
amitkumarjaiswal's picture
Create app.py
c6a4a94 verified
import gradio as gr
import os
from pathlib import Path
from PIL import Image
import torch
import clip
import yaml
import pandas as pd
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from pprint import pprint as print
categories = {}
# Configuration loading and validation
def load_config(path):
try:
with open(path) as file:
config = yaml.full_load(file)
# Validate necessary sections are present
necessary_keys = ['categories', 'config']
for key in necessary_keys:
if key not in config:
raise ValueError(f'Missing necessary config section: {key}')
return config
except FileNotFoundError:
print("Error: config.yml file not found.")
raise
except ValueError as e:
print(str(e))
raise
config = load_config('config.yml')
categories = config['categories']
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize models and processor
processor = AutoProcessor.from_pretrained(config['config']['models']['blip']['model_name'])
blip_model = Blip2ForConditionalGeneration.from_pretrained(config['config']['models']['blip']['model_name'], torch_dtype=torch.float16)
blip_model.to(device)
model, preprocess = clip.load(config['config']['models']['clip']['model_name'], device=device)
current_index = 0
# Load categories from a YAML configuration
# Precompute category embeddings
for category_name, category_details in categories.items():
print(f"Precomputing embeddings for category: {category_name}; {category_details}")
embeddings_tensor = model.encode_text(clip.tokenize(category_details['description']).to(device))
category_details['embeddings'] = embeddings_tensor.detach().cpu().numpy()
def load_image(path):
try:
image = Image.open(path)
image_input = preprocess(image).unsqueeze(0).to(device)
return image, image_input
except Exception as e:
print(f"Error loading image {path}: {e}")
return None, None
def predict_category(image_input, caption_input=None):
if image_input is None:
return None, None
with torch.no_grad():
image_features = model.encode_image(image_input)
if caption_input is not None:
caption_input = clip.tokenize(caption_input).to(device)
text_features = model.encode_text(caption_input)
image_features = torch.cat([image_features, text_features])
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
best_category = None
best_similarity = -1
for category_name, category_details in categories.items():
similarity = (image_features * category_details['embeddings']).sum()
if similarity > best_similarity:
best_similarity = similarity
best_category = category_name
return best_category, image_features
image_dir = Path(config['config']['paths']['images'])
image_files = [f for f in image_dir.glob('*') if f.suffix.lower() in ['.png', '.jpg', '.jpeg']]
images_df = pd.DataFrame(columns=['image_path', 'image_embedding', 'predicted_category', 'generated_text'])
for image_path in image_files:
img, image_input = load_image(image_path)
if img is not None:
blip_input = processor(img, return_tensors="pt").to(device, torch.float16)
# Ensure generation settings are compatible
predicted_ids = blip_model.generate(**blip_input, max_new_tokens=10)
generated_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
predicted_category, image_features = predict_category(image_input, generated_text)
generated_text = generated_text.replace(" ", "_") + image_path.suffix
new_row = {
'image_path': str(image_path),
'image_embedding': image_features if image_features is not None else None,
'predicted_category': predicted_category,
'generated_text': generated_text
}
# Using direct indexing to add to the DataFrame
index = len(images_df)
images_df.loc[index] = new_row
print(images_df.head())
# Gradio interface setup and launch
def next_image_and_prediction(user_choice):
global current_index
images_df.loc[current_index, 'predicted_category'] = user_choice
current_index = (current_index + 1) % len(images_df)
if current_index < len(images_df):
next_img_path = images_df.loc[current_index, 'image_path']
predicted_category = images_df.loc[current_index, 'predicted_category']
predicted_filename = images_df.loc[current_index, 'generated_text']
print(f"Next image: {next_img_path}, Predicted category: {predicted_category}")
return next_img_path, predicted_category, predicted_filename
else:
return None, "No more images"
def move_images_to_category_folder():
for index, row in images_df.iterrows():
image_path = Path(row['image_path'])
category_name = row['predicted_category']
if category_name in categories:
category_path = Path(categories[category_name]['path'])
category_dir = Path(config['config']['paths']['output']) / category_path
category_dir.mkdir(parents=True, exist_ok=True)
new_image_path = category_dir / row['generated_text']
image_path.rename(new_image_path)
print(f"Moved {image_path} to {new_image_path}")
else:
print(f"Category {category_name} not found in categories.")
with gr.Blocks() as blocks:
image_block = gr.Image(label="Image", type="filepath", height=300, width=300)
filename = gr.Textbox(label="Filename", type="text")
next_button = gr.Button("Next Image")
category_dropdown = gr.Dropdown(label="Category", choices=list(categories.keys()), type="value")
submit_button = gr.Button("Submit")
submit_button.click(fn=move_images_to_category_folder, inputs=[], outputs=[])
next_button.click(fn=next_image_and_prediction, inputs=category_dropdown, outputs=[image_block, category_dropdown, filename])
if not images_df.empty:
img_path, predicted_category = images_df.loc[0, ['image_path', 'predicted_category']]
image_block.value = img_path
category_dropdown.value = predicted_category
blocks.launch()