umjunsik1323 commited on
Commit
a791b0d
·
verified ·
1 Parent(s): ee0a294

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +152 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision.datasets import CIFAR100
4
+ from PIL import Image
5
+ import random
6
+ import numpy as np
7
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
8
+
9
+ Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)
10
+
11
+ try:
12
+ sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
13
+ sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
14
+ sr_model.eval()
15
+ except Exception as e:
16
+ sr_model = None
17
+
18
+ try:
19
+ classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
20
+ classifier_model.eval()
21
+ except Exception as e:
22
+ classifier_model = None
23
+
24
+ cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
25
+ cifar100_labels = cifar100_dataset.classes
26
+
27
+ def upscale_image(low_res_pil_image):
28
+ if sr_model is None or low_res_pil_image is None:
29
+ return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)
30
+
31
+ with torch.no_grad():
32
+ inputs = sr_processor(low_res_pil_image, return_tensors="pt")
33
+ outputs = sr_model(**inputs)
34
+
35
+ output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
36
+ output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
37
+ output_image = (output_numpy * 255.0).round().astype(np.uint8)
38
+
39
+ return Image.fromarray(output_image)
40
+
41
+ def predict_ai(low_res_pil_image):
42
+ try:
43
+ from torchvision import transforms
44
+ preprocess_for_classifier = transforms.Compose([
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(
47
+ mean=[0.5071, 0.4867, 0.4408],
48
+ std=[0.2675, 0.2565, 0.2761]
49
+ ),
50
+ ])
51
+ img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
52
+ batch_t = torch.unsqueeze(img_t, 0)
53
+
54
+ with torch.no_grad():
55
+ out = classifier_model(batch_t)
56
+
57
+ _, index = torch.max(out, 1)
58
+ return cifar100_labels[index[0]]
59
+ except Exception as e:
60
+ return "Error"
61
+
62
+ def generate_category_markdown():
63
+ md = "|||||\n|:---|:---|:---|:---|\n"
64
+ for i in range(0, 100, 4):
65
+ row = cifar100_labels[i:i+4]
66
+ md += "| " + " | ".join(row) + " |\n"
67
+ return md
68
+
69
+ # --- 4. 게임 로직 ---
70
+ def battle(user_guess, state):
71
+ user_score = state["user_score"]
72
+ ai_score = state["ai_score"]
73
+ current_image_idx = state["current_image_idx"]
74
+ played_indices = state["played_indices"]
75
+
76
+ low_res_image, label_idx = cifar100_dataset[current_image_idx]
77
+ current_label = cifar100_labels[label_idx]
78
+
79
+ ai_guess = predict_ai(low_res_image)
80
+
81
+ if user_guess.lower().strip() == current_label.lower():
82
+ user_score += 1
83
+ if ai_guess.lower() == current_label.lower():
84
+ ai_score += 1
85
+
86
+ if len(played_indices) >= len(cifar100_dataset):
87
+ message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
88
+ next_high_res_image = None
89
+ else:
90
+ while True:
91
+ next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
92
+ if next_image_idx not in played_indices:
93
+ break
94
+
95
+ next_low_res_image, _ = cifar100_dataset[next_image_idx]
96
+ next_high_res_image = upscale_image(next_low_res_image)
97
+ message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
98
+ state["current_image_idx"] = next_image_idx
99
+ played_indices.add(next_image_idx)
100
+
101
+ new_state = {
102
+ "user_score": user_score,
103
+ "ai_score": ai_score,
104
+ "current_image_idx": state["current_image_idx"],
105
+ "played_indices": played_indices
106
+ }
107
+
108
+ return user_score, ai_score, message, "", next_high_res_image, new_state
109
+
110
+ def start_game():
111
+ if not classifier_model or not sr_model:
112
+ return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}
113
+
114
+ first_idx = random.randint(0, len(cifar100_dataset) - 1)
115
+ first_low_res_image, _ = cifar100_dataset[first_idx]
116
+ first_high_res_image = upscale_image(first_low_res_image)
117
+
118
+ initial_state = {
119
+ "user_score": 0,
120
+ "ai_score": 0,
121
+ "current_image_idx": first_idx,
122
+ "played_indices": {first_idx}
123
+ }
124
+ return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state
125
+
126
+ # --- 5. Gradio 인터페이스 ---
127
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
128
+ state = gr.State()
129
+
130
+ gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
131
+ gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")
132
+
133
+ with gr.Row():
134
+ user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
135
+ ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
136
+
137
+ with gr.Row(equal_height=False):
138
+ with gr.Column(scale=2):
139
+ image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
140
+ result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
141
+ with gr.Column(scale=1):
142
+ user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
143
+ submit_button = gr.Button("Submit Guess", variant="primary")
144
+ with gr.Accordion("View All 100 Categories", open=False):
145
+ gr.Markdown(generate_category_markdown())
146
+
147
+ submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
148
+ user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
149
+ demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ transformers
6
+ accelerate