|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import ( |
|
ViTFeatureExtractor, |
|
ViTForImageClassification, |
|
pipeline, |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM |
|
) |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') |
|
age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') |
|
|
|
gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') |
|
gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') |
|
|
|
emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') |
|
emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') |
|
|
|
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") |
|
|
|
action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
|
action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
|
|
|
prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
|
prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
|
prompt_enhancer = pipeline('text2text-generation', |
|
model=prompt_enhancer_model, |
|
tokenizer=prompt_enhancer_tokenizer, |
|
repetition_penalty=1.2, |
|
device="cpu") |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16) |
|
return (age_model, age_transforms, gender_model, gender_transforms, |
|
emotion_model, emotion_transforms, object_detector, |
|
action_model, action_transforms, prompt_enhancer, pipe) |
|
|
|
models = load_models() |
|
(age_model, age_transforms, gender_model, gender_transforms, |
|
emotion_model, emotion_transforms, object_detector, |
|
action_model, action_transforms, prompt_enhancer, pipe) = models |
|
|
|
def predict(image, model, transforms): |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
inputs = transforms(images=[image], return_tensors='pt') |
|
output = model(**inputs) |
|
proba = output.logits.softmax(1) |
|
return proba.argmax(1).item() |
|
|
|
def detect_attributes(image): |
|
age = predict(image, age_model, age_transforms) |
|
gender = predict(image, gender_model, gender_transforms) |
|
emotion = predict(image, emotion_model, emotion_transforms) |
|
action = predict(image, action_model, action_transforms) |
|
|
|
objects = object_detector(image) |
|
|
|
return { |
|
'age': age_model.config.id2label[age], |
|
'gender': gender_model.config.id2label[gender], |
|
'emotion': emotion_model.config.id2label[emotion], |
|
'action': action_model.config.id2label[action], |
|
'objects': [obj['label'] for obj in objects] |
|
} |
|
|
|
def generate_prompt(attributes): |
|
prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} " |
|
prompt += f"while {attributes['action']}. " |
|
if attributes['objects']: |
|
prompt += f"Image has {', '.join(attributes['objects'])}. " |
|
return prompt |
|
|
|
def enhance_prompt(prompt): |
|
prefix = "enhance prompt: " |
|
enhanced = prompt_enhancer(prefix + prompt, max_length=256) |
|
return enhanced[0]['generated_text'] |
|
|
|
@st.cache_data |
|
def generate_image(prompt): |
|
|
|
with torch.no_grad(): |
|
image = pipe(prompt, num_inference_steps=50).images[0] |
|
return image |
|
|
|
st.title("Image Attribute Detection and Image Generation") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
if st.button('Analyze Image'): |
|
with st.spinner('Detecting attributes...'): |
|
attributes = detect_attributes(image) |
|
|
|
st.write("Detected Attributes:") |
|
for key, value in attributes.items(): |
|
st.write(f"{key.capitalize()}: {value}") |
|
|
|
with st.spinner('Generating prompt...'): |
|
initial_prompt = generate_prompt(attributes) |
|
enhanced_prompt = enhance_prompt(initial_prompt) |
|
|
|
st.write("Initial Prompt:") |
|
st.write(initial_prompt) |
|
st.write("Enhanced Prompt:") |
|
st.write(enhanced_prompt) |
|
|
|
with st.spinner('Generating image...'): |
|
generated_image = generate_image(enhanced_prompt) |
|
st.image(generated_image, caption='Generated Image', use_column_width=True) |