csaguiar's picture
adding token access
a8d3972
raw
history blame
2.45 kB
import os
import torch
import streamlit as st
from diffusers import StableDiffusionPipeline
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"
TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa
DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu")
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
def load_translation_models(translation_model_id):
tokenizer = MBart50TokenizerFast.from_pretrained(
translation_model_id,
use_auth_token=HUGGING_FACE_TOKEN
)
tokenizer.src_lang = 'pt_XX'
text_model = MBartForConditionalGeneration.from_pretrained(
translation_model_id,
use_auth_token=HUGGING_FACE_TOKEN
)
return tokenizer, text_model
def pipeline_generate(diffusion_model_id):
pipe = StableDiffusionPipeline.from_pretrained(
diffusion_model_id, torch_dtype=torch.float16, revision="fp16",
use_auth_token=HUGGING_FACE_TOKEN
)
pipe = pipe.to(DEVICE_NAME)
# Recommended if your computer has < 64 GB of RAM
pipe.enable_attention_slicing()
return pipe
def translate(prompt, tokenizer, text_model):
pt_tokens = tokenizer([prompt], return_tensors="pt")
en_tokens = text_model.generate(
**pt_tokens, max_new_tokens=100,
num_beams=8, early_stopping=True
)
en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True)
return en_prompt[0]
def generate_image(pipe, prompt):
# First-time "warmup" pass (see explanation above)
_ = pipe(prompt, num_inference_steps=1)
return pipe(prompt).images[0]
def process_prompt(prompt):
tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID)
prompt = translate(prompt, tokenizer, text_model)
pipe = pipeline_generate(DIFFUSION_MODEL_ID)
image = generate_image(pipe, prompt)
return image
st.write("# Crie imagens com Stable Diffusion")
prompt_input = st.text_input("Escreva uma descrição da imagem")
placeholder = st.empty()
btn = placeholder.button('Processar imagem', disabled=False, key=1)
reload = st.button('Reiniciar', disabled=False)
if btn:
placeholder.button('Processar imagem', disabled=True, key=2)
image = process_prompt(prompt_input)
st.image(image)
placeholder.button('Processar imagem', disabled=False, key=3)
placeholder.empty()
if reload:
st.experimental_rerun()