File size: 2,993 Bytes
650ec6e
 
11b632d
7084126
 
650ec6e
 
27cb35e
650ec6e
 
 
7084126
 
 
650ec6e
e8720a0
5e05aea
650ec6e
 
 
 
f0dfc26
976f94a
0fef3d0
976f94a
91602b2
f0dfc26
650ec6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea2efae
650ec6e
 
f0dfc26
 
 
650ec6e
 
 
 
3368fa1
650ec6e
 
 
 
 
 
 
 
 
 
 
 
 
4556b47
650ec6e
edfe9df
4556b47
 
650ec6e
 
 
 
 
 
 
 
3368fa1
 
650ec6e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import random
import torch
import numpy as np
from tqdm import tqdm
from functools import partialmethod
import gradio as gr
from gradio.mix import Series
from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae

# disable tqdm logging from the rudalle pipeline
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).half().to(device)
translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)

def translation_wrapper(text: str):
    input_ids = translation_tokenizer.encode(text, return_tensors="pt")
    outputs = translation_model.generate(input_ids.to(device))
    decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True)
    return decoded

def dalle_wrapper(prompt: str):
    top_k, top_p = random.choice([
        (1024, 0.98),
        (512, 0.97),
        (384, 0.96),
    ])
    
    images , _ = generate_images(
        prompt, 
        tokenizer, 
        dalle, 
        vae, 
        top_k=top_k, 
        images_num=1, 
        top_p=top_p
    )
    title = f"<b>{prompt}</b>"
    return title, images[0]


translator = gr.Interface(fn=translation_wrapper, 
                          inputs=[gr.inputs.Textbox(label='What would you like to see?')],
                          outputs="text")
outputs = [
    gr.outputs.HTML(label=""),   
    gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)


description = (
    "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
    "This demo uses an English-Russian translation model to adapt the prompts. "
    "Try pressing [Submit] multiple times to generate new images!"
)
article = (
    "<p style='text-align: center'>"
    "<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
    "<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
    "</p>"
)
examples = [["A still life of grapes and a bottle of wine"], 
            ["Город в стиле киберпанк"], 
            ["A colorful photo of a coral reef"], 
            ["A white cat sitting in a cardboard box"]]
            
series = Series(translator, generator, 
                title='Kinda-English ruDALL-E',
                description=description,
                article=article,
                layout='horizontal',
                theme='huggingface',
                examples=examples,
                allow_flagging=False,
                live=False, 
                enable_queue=True,
               )
series.launch()