Mavthunder commited on
Commit
de7814e
Β·
verified Β·
1 Parent(s): cf39539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -76
app.py CHANGED
@@ -1,91 +1,101 @@
1
  import gradio as gr
2
- from PIL import Image, ImageEnhance
3
- import numpy as np
4
  import torch
5
- from transformers import AutoProcessor, AutoModel, pipeline, ViTFeatureExtractor, ViTForImageClassification, CLIPProcessor
6
- import cv2
 
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Aesthetic Scorer: rsinema/aesthetic-scorer (public)
11
- ae_processor = CLIPProcessor.from_pretrained("rsinema/aesthetic-scorer")
12
- ae_model = AutoModel.from_pretrained("rsinema/aesthetic-scorer").to(device)
13
- ae_model.eval()
 
 
14
 
15
- def aesthetic_score(img_pil):
16
- inputs = ae_processor(images=img_pil, return_tensors="pt")["pixel_values"].to(device)
 
17
  with torch.no_grad():
18
- scores = ae_model(inputs)
19
- # scores returns 7 dims; first is overall aesthetic
20
- return float(scores[0][0].item())
21
-
22
- # Enhancement using public Zero-DCE model
23
- zero_dce_pipe = pipeline(
24
- "image-enhancement",
25
- model="nateraw/zero-dce",
26
- device=0 if torch.cuda.is_available() else -1
27
- )
28
 
29
- def enhance_image(img_pil):
30
- enhanced = zero_dce_pipe(img_pil)
31
- return enhanced[0]
 
 
 
 
 
 
 
 
 
32
 
33
- # Image Classifier (ViT)
34
- cls_ext = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
35
- cls_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
36
- cls_model.eval()
37
 
38
- def classify_image(img_pil):
39
- inputs = cls_ext(images=img_pil, return_tensors="pt").to(device)
40
  with torch.no_grad():
41
- logits = cls_model(**inputs).logits
42
- label = cls_model.config.id2label[logits.argmax(-1).item()].lower()
43
- return label
44
 
45
- # Category-specific vibes
46
- CATEGORY_VIBES = {
47
- "person": [...], # same presets as before
48
- "food": [...],
49
- "landscape": [...],
50
- "default": [...],
51
- }
52
 
53
- def apply_adjustments(img, exposure, contrast, saturation, warmth, clarity):
54
- img = img.convert("RGB")
55
- if exposure: img = ImageEnhance.Brightness(img).enhance(2**exposure)
56
- if contrast: img = ImageEnhance.Contrast(img).enhance(1 + contrast)
57
- if saturation: img = ImageEnhance.Color(img).enhance(1 + saturation)
58
- if clarity:
59
- arr = np.array(img).astype(np.float32)
60
- arr = np.clip(arr * (1 + clarity), 0, 255).astype(np.uint8)
61
- img = Image.fromarray(arr)
62
- if warmth:
63
- r, g, b = img.split()
64
- r = r.point(lambda i: min(255, i*(1+warmth)))
65
- b = b.point(lambda i: min(255, i*(1-warmth)))
66
- img = Image.merge("RGB",(r,g,b))
67
- return img
68
 
69
- def process(image):
70
- enhanced = enhance_image(image)
71
- label = classify_image(enhanced)
72
- vibes = CATEGORY_VIBES.get(label, CATEGORY_VIBES["default"])
73
-
74
- best, best_score, best_name = None, -float("inf"), None
75
- for vibe in vibes:
76
- out = apply_adjustments(enhanced, **vibe)
77
- score = aesthetic_score(out)
78
- if score > best_score:
79
- best, best_score, best_name = out, score, vibe["name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return best, f"Classified as {label} β†’ Chosen style: {best_name} (score {best_score:.2f})"
 
 
 
82
 
83
- demo = gr.Interface(
84
- fn=process,
85
- inputs=gr.Image(type="pil"),
86
- outputs=[gr.Image(type="pil"), gr.Text()],
87
- title="Content-Aware Aesthetic AI (Public)",
88
- description="Enhance β†’ classify β†’ apply category vibes β†’ score with public aesthetic model"
89
- )
90
- if __name__ == "__main__":
91
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from PIL import Image
6
+ import numpy as np
7
 
8
+ # -----------------------------
9
+ # 1. Zero-DCE model (light enhancement)
10
+ # -----------------------------
11
+ class ZeroDCE(nn.Module):
12
+ def __init__(self):
13
+ super(ZeroDCE, self).__init__()
14
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
15
+ self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
16
+ self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
17
+ self.conv4 = nn.Conv2d(32, 24, 3, padding=1)
18
+ self.relu = nn.ReLU(inplace=True)
19
 
20
+ def forward(self, x):
21
+ x1 = self.relu(self.conv1(x))
22
+ x2 = self.relu(self.conv2(x1))
23
+ x3 = self.relu(self.conv3(x2))
24
+ x_r = torch.tanh(self.conv4(x3))
25
+ return x_r
26
 
27
+ def enhance_image(img, model):
28
+ img_tensor = torch.from_numpy(np.array(img)).float() / 255.0
29
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
30
  with torch.no_grad():
31
+ enhanced = model(img_tensor)
32
+ enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
33
+ enhanced = np.clip(enhanced * 255, 0, 255).astype(np.uint8)
34
+ return Image.fromarray(enhanced)
 
 
 
 
 
 
35
 
36
+ # -----------------------------
37
+ # 2. Aesthetic Scoring Model
38
+ # -----------------------------
39
+ class AestheticPredictor(nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
43
+ self.mlp = nn.Sequential(
44
+ nn.Linear(self.clip.config.projection_dim, 512),
45
+ nn.ReLU(),
46
+ nn.Linear(512, 1)
47
+ )
48
 
49
+ def forward(self, pixel_values, input_ids, attention_mask):
50
+ outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
51
+ pooled_output = outputs.pooler_output
52
+ return self.mlp(pooled_output)
53
 
54
+ def score_image(image, processor, model):
55
+ inputs = processor(text=["aesthetic photo"], images=image, return_tensors="pt", padding=True).to(device)
56
  with torch.no_grad():
57
+ score = model(**inputs)
58
+ return score.item()
 
59
 
60
+ # -----------------------------
61
+ # 3. Pipeline function
62
+ # -----------------------------
63
+ def process_image(input_img):
64
+ # Step 1: enhance
65
+ enhanced_img = enhance_image(input_img, zero_dce)
 
66
 
67
+ # Step 2: aesthetic scoring
68
+ original_score = score_image(input_img, processor, ae_model)
69
+ enhanced_score = score_image(enhanced_img, processor, ae_model)
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Step 3: choose best
72
+ if enhanced_score > original_score:
73
+ return enhanced_img, f"Enhanced chosen (score {enhanced_score:.2f} vs {original_score:.2f})"
74
+ else:
75
+ return input_img, f"Original kept (score {original_score:.2f} vs {enhanced_score:.2f})"
76
+
77
+ # -----------------------------
78
+ # 4. Setup
79
+ # -----------------------------
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ zero_dce = ZeroDCE().to(device)
83
+ ae_model = AestheticPredictor().to(device)
84
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
85
+
86
+ # -----------------------------
87
+ # 5. Gradio UI
88
+ # -----------------------------
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown("## πŸ“Έ AI Photo Enhancer with Aesthetic Scoring")
91
+
92
+ with gr.Row():
93
+ inp = gr.Image(type="pil", label="Upload your photo")
94
+ out = gr.Image(type="pil", label="Best looking result")
95
 
96
+ info = gr.Label(label="Result Info")
97
+
98
+ btn = gr.Button("Enhance ✨")
99
+ btn.click(process_image, inputs=inp, outputs=[out, info])
100
 
101
+ demo.launch()