EdBoy2202's picture
Update app.py
3ad8ff2 verified
import streamlit as st
from PIL import Image
import torch
from transformers import (
ViTFeatureExtractor,
ViTForImageClassification,
pipeline,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from diffusers import StableDiffusionPipeline
# Load models
@st.cache_resource
def load_models():
age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
prompt_enhancer = pipeline('text2text-generation',
model=prompt_enhancer_model,
tokenizer=prompt_enhancer_tokenizer,
repetition_penalty=1.2,
device="cpu")
# Load BK-SDM-Tiny for image generation
pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16)
return (age_model, age_transforms, gender_model, gender_transforms,
emotion_model, emotion_transforms, object_detector,
action_model, action_transforms, prompt_enhancer, pipe)
models = load_models()
(age_model, age_transforms, gender_model, gender_transforms,
emotion_model, emotion_transforms, object_detector,
action_model, action_transforms, prompt_enhancer, pipe) = models
def predict(image, model, transforms):
# Convert the image to RGB format if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply the transformations and predict
inputs = transforms(images=[image], return_tensors='pt')
output = model(**inputs)
proba = output.logits.softmax(1)
return proba.argmax(1).item()
def detect_attributes(image):
age = predict(image, age_model, age_transforms)
gender = predict(image, gender_model, gender_transforms)
emotion = predict(image, emotion_model, emotion_transforms)
action = predict(image, action_model, action_transforms)
objects = object_detector(image)
return {
'age': age_model.config.id2label[age],
'gender': gender_model.config.id2label[gender],
'emotion': emotion_model.config.id2label[emotion],
'action': action_model.config.id2label[action],
'objects': [obj['label'] for obj in objects]
}
def generate_prompt(attributes):
prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} "
prompt += f"while {attributes['action']}. "
if attributes['objects']:
prompt += f"Image has {', '.join(attributes['objects'])}. "
return prompt
def enhance_prompt(prompt):
prefix = "enhance prompt: "
enhanced = prompt_enhancer(prefix + prompt, max_length=256)
return enhanced[0]['generated_text']
@st.cache_data
def generate_image(prompt):
# Generate image from the prompt using the BK-SDM-Tiny model
with torch.no_grad():
image = pipe(prompt, num_inference_steps=50).images[0]
return image
st.title("Image Attribute Detection and Image Generation")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Analyze Image'):
with st.spinner('Detecting attributes...'):
attributes = detect_attributes(image)
st.write("Detected Attributes:")
for key, value in attributes.items():
st.write(f"{key.capitalize()}: {value}")
with st.spinner('Generating prompt...'):
initial_prompt = generate_prompt(attributes)
enhanced_prompt = enhance_prompt(initial_prompt)
st.write("Initial Prompt:")
st.write(initial_prompt)
st.write("Enhanced Prompt:")
st.write(enhanced_prompt)
with st.spinner('Generating image...'):
generated_image = generate_image(enhanced_prompt)
st.image(generated_image, caption='Generated Image', use_column_width=True)