csaguiar's picture
Update app.py
6557a3f
raw
history blame contribute delete
No virus
2.41 kB
import os
import torch
import streamlit as st
from diffusers import StableDiffusionPipeline
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"
TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa
DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu")
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
def load_translation_models(translation_model_id):
tokenizer = MBart50TokenizerFast.from_pretrained(
translation_model_id,
use_auth_token=HUGGING_FACE_TOKEN
)
tokenizer.src_lang = 'pt_XX'
text_model = MBartForConditionalGeneration.from_pretrained(
translation_model_id,
use_auth_token=HUGGING_FACE_TOKEN
)
return tokenizer, text_model
def pipeline_generate(diffusion_model_id):
pipe = StableDiffusionPipeline.from_pretrained(
diffusion_model_id,
use_auth_token=HUGGING_FACE_TOKEN
)
pipe = pipe.to(DEVICE_NAME)
# Recommended if your computer has < 64 GB of RAM
pipe.enable_attention_slicing()
return pipe
def translate(prompt, tokenizer, text_model):
pt_tokens = tokenizer([prompt], return_tensors="pt")
en_tokens = text_model.generate(
**pt_tokens, max_new_tokens=100,
num_beams=8, early_stopping=True
)
en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True)
return en_prompt[0]
def generate_image(pipe, prompt):
# First-time "warmup" pass (see explanation above)
_ = pipe(prompt, num_inference_steps=1)
return pipe(prompt).images[0]
def process_prompt(prompt):
tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID)
prompt = translate(prompt, tokenizer, text_model)
pipe = pipeline_generate(DIFFUSION_MODEL_ID)
image = generate_image(pipe, prompt)
return image
st.write("# Crie imagens com Stable Diffusion")
prompt_input = st.text_input("Escreva uma descrição da imagem")
placeholder = st.empty()
btn = placeholder.button('Processar imagem', disabled=False, key=1)
reload = st.button('Reiniciar', disabled=False)
if btn:
placeholder.button('Processar imagem', disabled=True, key=2)
image = process_prompt(prompt_input)
st.image(image)
placeholder.button('Processar imagem', disabled=False, key=3)
placeholder.empty()
if reload:
st.experimental_rerun()