import os import torch import streamlit as st from diffusers import StableDiffusionPipeline from transformers import MBart50TokenizerFast, MBartForConditionalGeneration DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5" TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu") HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") def load_translation_models(translation_model_id): tokenizer = MBart50TokenizerFast.from_pretrained( translation_model_id, use_auth_token=HUGGING_FACE_TOKEN ) tokenizer.src_lang = 'pt_XX' text_model = MBartForConditionalGeneration.from_pretrained( translation_model_id, use_auth_token=HUGGING_FACE_TOKEN ) return tokenizer, text_model def pipeline_generate(diffusion_model_id): pipe = StableDiffusionPipeline.from_pretrained( diffusion_model_id, use_auth_token=HUGGING_FACE_TOKEN ) pipe = pipe.to(DEVICE_NAME) # Recommended if your computer has < 64 GB of RAM pipe.enable_attention_slicing() return pipe def translate(prompt, tokenizer, text_model): pt_tokens = tokenizer([prompt], return_tensors="pt") en_tokens = text_model.generate( **pt_tokens, max_new_tokens=100, num_beams=8, early_stopping=True ) en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True) return en_prompt[0] def generate_image(pipe, prompt): # First-time "warmup" pass (see explanation above) _ = pipe(prompt, num_inference_steps=1) return pipe(prompt).images[0] def process_prompt(prompt): tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID) prompt = translate(prompt, tokenizer, text_model) pipe = pipeline_generate(DIFFUSION_MODEL_ID) image = generate_image(pipe, prompt) return image st.write("# Crie imagens com Stable Diffusion") prompt_input = st.text_input("Escreva uma descrição da imagem") placeholder = st.empty() btn = placeholder.button('Processar imagem', disabled=False, key=1) reload = st.button('Reiniciar', disabled=False) if btn: placeholder.button('Processar imagem', disabled=True, key=2) image = process_prompt(prompt_input) st.image(image) placeholder.button('Processar imagem', disabled=False, key=3) placeholder.empty() if reload: st.experimental_rerun()