File size: 4,554 Bytes
db177eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39cae37
 
 
db177eb
 
 
 
401fb48
 
db177eb
 
68ed1d2
 
d171212
db177eb
 
 
 
 
 
 
 
 
39cae37
db177eb
 
68ed1d2
 
 
db177eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""Stylegan-nada-ailanta.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE

# Проект "CLIP-Guided Domain Adaptation of Image Generators"

Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
- Сдвиг генератора по текстовому промпту
- Генерация примеров
- Генерация примеров из готовых пресетов
- Веб-демо
- Стилизация изображения из файла

## 1. Установка
"""

# @title
# Импорт нужных библиотек
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import gradio as gr
import subprocess
import gdown

# Настройка устройства
device = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists("stylegan2-pytorch"):
    subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"])
os.chdir("stylegan2-pytorch")

gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT')
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL')
sys.path.append("/home/user/app/stylegan2-pytorch")
from model import Generator

# Параметры генератора
latent_dim = 512
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
f_generator.load_state_dict(state_dict['g_ema'])
f_generator.eval()

# Загрузка пресетов
os.makedirs("/content/presets", exist_ok=True)

gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth')
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth')
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth')

# Загрузка генератора из файла
def load_model(file_path, latent_dim=512, size=1024):

    state_dicts = torch.load(file_path, map_location=device)

    # Инициализация
    trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)

    # Загрузка весов
    trained_generator.load_state_dict(state_dicts)

    trained_generator.eval()

    return trained_generator

model_paths = {
    "Photo -> Pencil Sketch": "/content/presets/sketch.pth",
    "Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
    "Human -> Werewolf": "/content/presets/werewolf.pth"
}

# Функция обработки
def generate(model_name):
    model_path = model_paths[model_name]
    g_generator = load_model(model_path)
    images = []
    with torch.no_grad():
        w_optimized = f_generator.style(torch.randn(2, latent_dim).to(device))
        w_plus = w_optimized.unsqueeze(1).repeat(1, f_generator.n_latent, 1).clone()

        frozen_images = f_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
        frozen_images = (frozen_images.clamp(-1, 1) + 1) / 2.0  # Нормализация к [0, 1]
        frozen_images = frozen_images.permute(0, 2, 3, 1).cpu().numpy()
        images.extend(frozen_images)
        trained_images = g_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
        trained_images = (trained_images.clamp(-1, 1) + 1) / 2.0  # Нормализация к [0, 1]
        trained_images = trained_images.permute(0, 2, 3, 1).cpu().numpy()
        images.extend(trained_images)
    return images

# Интерфейс
iface = gr.Interface(
    fn=generate,
    inputs=gr.Dropdown(choices=list(model_paths.keys()), label="Выберите пресет"),
    outputs=gr.Gallery(label="Результаты генерации", columns=2),
    title="Выбор модели",
    description="Выберите преобразование из списка."
)

iface.launch(debug=True)