import os import requests from tqdm import tqdm from datasets import load_dataset import numpy as np from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input from tensorflow.keras.preprocessing import image from sklearn.neighbors import NearestNeighbors import joblib from PIL import UnidentifiedImageError, Image import gradio as gr from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences # Load the dataset dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k") # Filter out NSFW content and null models dataset_filtered = dataset['train'].filter( lambda x: not x['nsfw'] and x['Model'] is not None and x['Model'].strip() != '' ) # Take a subset of the filtered dataset subset_size = 2700 dataset_subset = dataset_filtered.shuffle(seed=42).select(range(subset_size)) # Directory to save images image_dir = 'civitai_images' os.makedirs(image_dir, exist_ok=True) # Load the ResNet50 model pretrained on ImageNet cnn_model = ResNet50(weights='imagenet', include_top=False, pooling='avg') # Text processing setup max_words = 10000 # Maximum number of words to keep max_len = 100 # Maximum length of each text sequence # Initialize and fit tokenizer on prompts tokenizer = Tokenizer(num_words=max_words) prompts = [sample['prompt'] for sample in dataset_subset] tokenizer.fit_on_texts(prompts) # Create MLP model for text processing def create_mlp_model(input_dim): model = Sequential([ Dense(256, activation='relu', input_dim=input_dim), Dropout(0.3), Dense(128, activation='relu'), Dropout(0.2), Dense(64, activation='relu'), Dense(32, activation='relu') ]) return model # Function to extract text features def extract_text_features(prompt): # Convert text to sequence and pad sequence = tokenizer.texts_to_sequences([prompt]) padded = pad_sequences(sequence, maxlen=max_len) # Get features from MLP return mlp_model.predict(padded) # Function to extract image features def extract_image_features(img_path, model): img = image.load_img(img_path, target_size=(224, 224)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) features = model.predict(img_array) return features.flatten() # Prepare text data text_sequences = tokenizer.texts_to_sequences(prompts) padded_sequences = pad_sequences(text_sequences, maxlen=max_len) # Create and train MLP model mlp_model = create_mlp_model(max_len) mlp_model.compile(optimizer='adam', loss='mse') mlp_model.fit(padded_sequences, padded_sequences, epochs=5, batch_size=32, validation_split=0.2) # Extract features for both images and text image_features = [] text_features = [] image_paths = [] model_names = [] for sample in tqdm(dataset_subset): img_url = sample['url'] model_name = sample['Model'] prompt = sample['prompt'] img_path = os.path.join(image_dir, os.path.basename(img_url)) try: # Download and process image response = requests.get(img_url) response.raise_for_status() if 'image' not in response.headers['Content-Type']: raise ValueError("URL does not contain an image") with open(img_path, 'wb') as f: f.write(response.content) # Extract image features img_features = extract_image_features(img_path, cnn_model) # Extract text features txt_features = extract_text_features(prompt) # Store features and metadata image_features.append(img_features) text_features.append(txt_features.flatten()) image_paths.append(img_path) model_names.append(model_name) except (UnidentifiedImageError, requests.exceptions.RequestException) as e: print(f"Error processing {img_url}: {e}") if os.path.exists(img_path): os.remove(img_path) # Convert features to numpy arrays image_features = np.array(image_features) text_features = np.array(text_features) # Combine image and text features combined_features = np.concatenate([image_features, text_features], axis=1) # Build the NearestNeighbors model nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(combined_features) # Save models and features joblib.dump(nbrs, 'nearest_neighbors_model.pkl') joblib.dump(mlp_model, 'mlp_model.pkl') joblib.dump(tokenizer, 'tokenizer.pkl') np.save('combined_features.npy', combined_features) np.save('image_paths.npy', image_paths) np.save('model_names.npy', model_names) # Function to get recommendations def get_recommendations(img, prompt="", n_neighbors=5): # Process input image img_path = "temp_input_image.png" img.save(img_path) img_features = extract_image_features(img_path, cnn_model) # Process input text txt_features = extract_text_features(prompt) # Combine features input_features = np.concatenate([img_features, txt_features.flatten()]) # Get recommendations distances, indices = nbrs.kneighbors([input_features]) recommended_images = [image_paths[idx] for idx in indices.flatten()] recommended_model_names = [model_names[idx] for idx in indices.flatten()] recommended_distances = distances.flatten() return [(Image.open(img_path), f'{name}, Distance: {dist:.2f}') for img_path, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)] # Gradio interface interface = gr.Interface( fn=get_recommendations, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Prompt") ], outputs=gr.Gallery(label="Recommended Images"), title="Image and Text Recommendation System", description="Upload an image and/or enter a prompt to get similar images with their model names and distances." ) if __name__ == "__main__": interface.launch()