import random import torch from tqdm import tqdm from functools import partialmethod import gradio as gr from gradio.mix import Series from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer from rudalle.pipelines import generate_images from rudalle import get_rudalle_model, get_tokenizer, get_vae # disable tqdm logging from the rudalle pipeline tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).to(device) translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru") dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device) tokenizer = get_tokenizer() vae = get_vae().to(device) def translation_wrapper(text: str): input_ids = translation_tokenizer.encode(text, return_tensors="pt") outputs = translation_model.generate(input_ids) decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True) return decoded def dalle_wrapper(prompt: str): top_k, top_p = random.choice([ (1024, 0.98), (512, 0.97), (384, 0.96), ]) images , _ = generate_images( prompt, tokenizer, dalle, vae, top_k=top_k, images_num=1, top_p=top_p ) title = f"{prompt}" return title, images[0] translator = gr.Interface(fn=translation_wrapper, inputs=[gr.inputs.Textbox(label='What would you like to see?')], outputs="text") outputs = [ gr.outputs.HTML(label=""), gr.outputs.Image(label=""), ] generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs) description = ( "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). " "This demo uses an English-Russian translation model to adapt the prompts. " "Try pressing [Submit] multiple times to generate new images!" ) article = ( "

" "GitHub | " "Article (in Russian)" "

" ) examples = [["A still life of grapes and a bottle of wine"], ["Город в стиле киберпанк"], ["A colorful photo of a coral reef"], ["A white cat sitting in a cardboard box"]] series = Series(translator, generator, title='Kinda-English ruDALL-E', description=description, article=article, layout='horizontal', theme='huggingface', examples=examples, allow_flagging=False, live=False, enable_queue=True, ) series.launch()