EdBoy2202 commited on
Commit
8868d43
·
verified ·
1 Parent(s): bfc9dc4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import (
5
+ ViTFeatureExtractor,
6
+ ViTForImageClassification,
7
+ pipeline,
8
+ AutoTokenizer,
9
+ AutoModelForSeq2SeqLM
10
+ )
11
+ from diffusers import StableDiffusionPipeline
12
+
13
+ # Load models
14
+ @st.cache_resource
15
+ def load_models():
16
+ age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
17
+ age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
18
+
19
+ gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
20
+ gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
21
+
22
+ emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
23
+ emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
24
+
25
+ object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
26
+
27
+ action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
28
+ action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
29
+
30
+ prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
31
+ prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
32
+ prompt_enhancer = pipeline('text2text-generation',
33
+ model=prompt_enhancer_model,
34
+ tokenizer=prompt_enhancer_tokenizer,
35
+ repetition_penalty=1.2,
36
+ device="cpu")
37
+
38
+ # Load BK-SDM-Tiny for image generation
39
+ pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16)
40
+ return (age_model, age_transforms, gender_model, gender_transforms,
41
+ emotion_model, emotion_transforms, object_detector,
42
+ action_model, action_transforms, prompt_enhancer, pipe)
43
+
44
+ models = load_models()
45
+ (age_model, age_transforms, gender_model, gender_transforms,
46
+ emotion_model, emotion_transforms, object_detector,
47
+ action_model, action_transforms, prompt_enhancer, pipe) = models
48
+
49
+ def predict(image, model, transforms):
50
+ # Convert the image to RGB format if necessary
51
+ if image.mode != 'RGB':
52
+ image = image.convert('RGB')
53
+
54
+ # Apply the transformations and predict
55
+ inputs = transforms(images=[image], return_tensors='pt')
56
+ output = model(**inputs)
57
+ proba = output.logits.softmax(1)
58
+ return proba.argmax(1).item()
59
+
60
+ def detect_attributes(image):
61
+ age = predict(image, age_model, age_transforms)
62
+ gender = predict(image, gender_model, gender_transforms)
63
+ emotion = predict(image, emotion_model, emotion_transforms)
64
+ action = predict(image, action_model, action_transforms)
65
+
66
+ objects = object_detector(image)
67
+
68
+ return {
69
+ 'age': age_model.config.id2label[age],
70
+ 'gender': gender_model.config.id2label[gender],
71
+ 'emotion': emotion_model.config.id2label[emotion],
72
+ 'action': action_model.config.id2label[action],
73
+ 'objects': [obj['label'] for obj in objects]
74
+ }
75
+
76
+ def generate_prompt(attributes):
77
+ prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} "
78
+ prompt += f"while {attributes['action']}. "
79
+ if attributes['objects']:
80
+ prompt += f"Image has {', '.join(attributes['objects'])}. "
81
+ return prompt
82
+
83
+ def enhance_prompt(prompt):
84
+ prefix = "enhance prompt: "
85
+ enhanced = prompt_enhancer(prefix + prompt, max_length=256)
86
+ return enhanced[0]['generated_text']
87
+
88
+ @st.cache_data
89
+ def generate_image(prompt):
90
+ # Generate image from the prompt using the BK-SDM-Tiny model
91
+ with torch.no_grad():
92
+ image = pipe(prompt, num_inference_steps=50).images[0]
93
+ return image
94
+
95
+ st.title("Image Attribute Detection and Image Generation")
96
+
97
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
98
+
99
+ if uploaded_file is not None:
100
+ image = Image.open(uploaded_file)
101
+ st.image(image, caption='Uploaded Image', use_column_width=True)
102
+
103
+ if st.button('Analyze Image'):
104
+ with st.spinner('Detecting attributes...'):
105
+ attributes = detect_attributes(image)
106
+
107
+ st.write("Detected Attributes:")
108
+ for key, value in attributes.items():
109
+ st.write(f"{key.capitalize()}: {value}")
110
+
111
+ with st.spinner('Generating prompt...'):
112
+ initial_prompt = generate_prompt(attributes)
113
+ enhanced_prompt = enhance_prompt(initial_prompt)
114
+
115
+ st.write("Initial Prompt:")
116
+ st.write(initial_prompt)
117
+ st.write("Enhanced Prompt:")
118
+ st.write(enhanced_prompt)
119
+
120
+ with st.spinner('Generating image...'):
121
+ generated_image = generate_image(enhanced_prompt)
122
+ st.image(generated_image, caption='Generated Image', use_column_width=True)