import streamlit as st from PIL import Image import torch from transformers import ( ViTFeatureExtractor, ViTForImageClassification, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM ) from diffusers import StableDiffusionPipeline # Load models @st.cache_resource def load_models(): age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance") prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance") prompt_enhancer = pipeline('text2text-generation', model=prompt_enhancer_model, tokenizer=prompt_enhancer_tokenizer, repetition_penalty=1.2, device="cpu") # Load BK-SDM-Tiny for image generation pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16) return (age_model, age_transforms, gender_model, gender_transforms, emotion_model, emotion_transforms, object_detector, action_model, action_transforms, prompt_enhancer, pipe) models = load_models() (age_model, age_transforms, gender_model, gender_transforms, emotion_model, emotion_transforms, object_detector, action_model, action_transforms, prompt_enhancer, pipe) = models def predict(image, model, transforms): # Convert the image to RGB format if necessary if image.mode != 'RGB': image = image.convert('RGB') # Apply the transformations and predict inputs = transforms(images=[image], return_tensors='pt') output = model(**inputs) proba = output.logits.softmax(1) return proba.argmax(1).item() def detect_attributes(image): age = predict(image, age_model, age_transforms) gender = predict(image, gender_model, gender_transforms) emotion = predict(image, emotion_model, emotion_transforms) action = predict(image, action_model, action_transforms) objects = object_detector(image) return { 'age': age_model.config.id2label[age], 'gender': gender_model.config.id2label[gender], 'emotion': emotion_model.config.id2label[emotion], 'action': action_model.config.id2label[action], 'objects': [obj['label'] for obj in objects] } def generate_prompt(attributes): prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} " prompt += f"while {attributes['action']}. " if attributes['objects']: prompt += f"Image has {', '.join(attributes['objects'])}. " return prompt def enhance_prompt(prompt): prefix = "enhance prompt: " enhanced = prompt_enhancer(prefix + prompt, max_length=256) return enhanced[0]['generated_text'] @st.cache_data def generate_image(prompt): # Generate image from the prompt using the BK-SDM-Tiny model with torch.no_grad(): image = pipe(prompt, num_inference_steps=50).images[0] return image st.title("Image Attribute Detection and Image Generation") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Analyze Image'): with st.spinner('Detecting attributes...'): attributes = detect_attributes(image) st.write("Detected Attributes:") for key, value in attributes.items(): st.write(f"{key.capitalize()}: {value}") with st.spinner('Generating prompt...'): initial_prompt = generate_prompt(attributes) enhanced_prompt = enhance_prompt(initial_prompt) st.write("Initial Prompt:") st.write(initial_prompt) st.write("Enhanced Prompt:") st.write(enhanced_prompt) with st.spinner('Generating image...'): generated_image = generate_image(enhanced_prompt) st.image(generated_image, caption='Generated Image', use_column_width=True)