ailanta commited on
Commit
9624440
·
verified ·
1 Parent(s): c65b861

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stylegan-nada-ailanta.ipynb
2
+ Automatically generated by Colab.
3
+ Original file is located at
4
+ https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE
5
+ # Проект "CLIP-Guided Domain Adaptation of Image Generators"
6
+ Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
7
+ Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
8
+ - Сдвиг генератора по текстовому промпту
9
+ - Генерация примеров
10
+ - Генерация примеров из готовых пресетов
11
+ - Веб-демо
12
+ - Стилизация изображения из файла
13
+ ## 1. Установка
14
+ """
15
+
16
+ # @title
17
+ # Импорт нужных библиотек
18
+ import os
19
+ import sys
20
+ from tqdm import tqdm
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.optim as optim
24
+ from torchvision import transforms
25
+ from torchvision.utils import save_image
26
+ from PIL import Image
27
+ import numpy as np
28
+ import matplotlib.pyplot as plt
29
+
30
+ # Настройка устройства
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ import os
34
+ import subprocess
35
+
36
+ if not os.path.exists("stylegan2-pytorch"):
37
+ subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"])
38
+ os.chdir("stylegan2-pytorch")
39
+
40
+ import gdown
41
+ gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT')
42
+ gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL')
43
+ sys.path.append(os.path.abspath("stylegan2-pytorch"))
44
+ from model import Generator
45
+
46
+ # Параметры генератора
47
+ latent_dim = 512
48
+ f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
49
+ state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
50
+ f_generator.load_state_dict(state_dict['g_ema'])
51
+ f_generator.eval()
52
+
53
+ g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
54
+ g_generator.load_state_dict(state_dict['g_ema'])
55
+
56
+ # Загрузка модели CLIP
57
+ import clip
58
+ clip_model, preprocess = clip.load("ViT-B/32", device=device)
59
+
60
+ latent_dim=512
61
+ batch_size=4
62
+
63
+ """## 6. Готовые пресеты"""
64
+
65
+ # @title Загрузка пресетов
66
+ os.makedirs("/content/presets", exist_ok=True)
67
+
68
+ gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth')
69
+ gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth')
70
+ gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth')
71
+
72
+ # @title Генерация примеров из пресета
73
+ # Загрузка генератора из файла
74
+ def load_model(file_path, latent_dim=512, size=1024):
75
+
76
+ state_dicts = torch.load(file_path, map_location=device)
77
+
78
+ # Инициализация
79
+ trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)
80
+
81
+ # Загрузка весов
82
+ trained_generator.load_state_dict(state_dicts)
83
+
84
+ trained_generator.eval()
85
+
86
+ return trained_generator
87
+
88
+ model_paths = {
89
+ "Photo -> Pencil Sketch": "/content/presets/sketch.pth",
90
+ "Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
91
+ "Human -> Werewolf": "/content/presets/werewolf.pth"
92
+ }
93
+
94
+
95
+ """## 8. Веб-демо"""
96
+
97
+ import gradio as gr
98
+
99
+ def get_avg_image(net):
100
+ avg_image = net(net.latent_avg.unsqueeze(0),
101
+ input_code=True,
102
+ randomize_noise=False,
103
+ return_latents=False,
104
+ average_code=True)[0]
105
+ avg_image = avg_image.to('cuda').float().detach()
106
+ return avg_image
107
+
108
+ # Функция обработки изображения
109
+ def process_image(image):
110
+ # Конвертация в объект PIL
111
+ image = Image.fromarray(image)
112
+
113
+ # Изменение размера до 256x256
114
+ image = image.resize((256, 256))
115
+
116
+ input_image = transform(image).unsqueeze(0).to(device)
117
+
118
+ opts.n_iters_per_batch = 5
119
+ opts.resize_outputs = False # generate outputs at full resolution
120
+
121
+ from restyle.utils.inference_utils import run_on_batch
122
+
123
+ with torch.no_grad():
124
+ avg_image = get_avg_image(restyle_net)
125
+ result_batch, result_latents = run_on_batch(input_image, restyle_net, opts, avg_image)
126
+
127
+ inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1)
128
+
129
+ with torch.no_grad():
130
+ sampled_src = f_generator(inverted_latent, input_is_latent=True)[0]
131
+ frozen_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
132
+ frozen_image = frozen_image.permute(0, 2, 3, 1).cpu().numpy()
133
+
134
+ g_generator.eval()
135
+
136
+ sampled_src = g_generator(inverted_latent, input_is_latent=True)[0]
137
+ trained_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
138
+ trained_image = trained_image.permute(0, 2, 3, 1).cpu().numpy()
139
+ images = []
140
+ images.append(image)
141
+ images.append(frozen_image.squeeze(0))
142
+ images.append(trained_image.squeeze(0))
143
+ return images
144
+
145
+ # Интерфейс Gradio
146
+ iface = gr.Interface(
147
+ fn=process_image, # Функция обработки
148
+ inputs=gr.Image(type="numpy"), # Поле для загрузки изображения
149
+ outputs=gr.Gallery(label="Результаты генерации", columns=2),
150
+ title="Обработка изображения",
151
+ description="Загрузите изображение"
152
+ )
153
+
154
+ iface.launch()