File size: 5,998 Bytes
bb6c48a ea8b702 bb6c48a ea8b702 bb6c48a ea8b702 bb6c48a ea8b702 bb6c48a ea8b702 5123446 ea8b702 bb6c48a ea8b702 a267f88 da28bf8 5123446 bb6c48a 5123446 bb6c48a 15eff54 a267f88 bb6c48a a267f88 bb6c48a 15eff54 da28bf8 bb6c48a a267f88 bb6c48a 3d41673 a267f88 5123446 bb6c48a 15eff54 a267f88 15eff54 bb6c48a 4c6bb4e ee4e6f9 5123446 a267f88 15eff54 ee4e6f9 15eff54 bb6c48a 3d41673 bb6c48a 15eff54 bb6c48a ea8b702 5123446 bb6c48a ea8b702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import time
import os
import PIL
import gradio as gr
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from diffusers import StableDiffusionPipeline
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"
has_cuda = torch.cuda.is_available()
if has_cuda:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
device = "cuda"
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN)
device = "cpu"
pipe.to(device)
def safety_checker(images, clip_input):
return images, False
pipe.safety_checker = safety_checker
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
summarizer = pipeline("summarization")
#######################################################
def break_until_dot(txt):
return txt.rsplit('.', 1)[0] + '.'
def generate(prompt):
input_context = prompt
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
outputs = model.generate(
input_ids=input_ids,
max_length=120,
min_length=50,
temperature=0.7,
num_return_sequences=3,
do_sample=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return break_until_dot(decoded)
def generate_story(prompt):
story = generate(prompt=prompt)
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
summary = break_until_dot(summary)
return story, summary, gr.update(visible=True)
def on_change_event(app_state):
print(f'on_change_event {app_state}')
if app_state and app_state['running'] and app_state['img']:
img = app_state['img']
step = app_state['step']
print(f'Updating the image:! {app_state}')
app_state['dots'] += 1
app_state['dots'] = app_state['dots'] % 10
message = app_state['status_msg'] + ' *' * app_state['dots']
print (f'message={message}')
return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message)
else:
return gr.update(label='images list'), gr.update(value='')
with gr.Blocks() as demo:
def generate_image(prompt, inference_steps, app_state):
app_state['running'] = True
app_state['img_list'] = []
app_state['status_msg'] = 'Starting'
def callback(step, ts, latents):
app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}'
latents = 1 / 0.18215 * latents
res = pipe.vae.decode(latents).sample
res = (res / 2 + 0.5).clamp(0, 1)
res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
res = pipe.numpy_to_pil(res)[0]
app_state['img'] = res
app_state['step'] = step
app_state['img_list'].append(res)
app_state['status_msg'] = f'Generating step ({step + 1})'
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=1)
app_state['running'] = False
app_state['img'] = None
app_state['status_msg'] = ''
app_state['dots'] = 0
return gr.update(value=img.images[0], label='Generated image')
app_state = gr.State({'img': None,
'step':0,
'running':False,
'status_msg': '',
'img_list': [],
'dots': 0
})
title = gr.Markdown('## Lord of the rings app')
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
f' The text2img model is {model_id}. The summarization is done using distilbart.')
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
story = gr.Textbox(label="Your story")
summary = gr.Textbox(label="Summary")
bt_make_text = gr.Button("Generate text")
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
img_description = gr.Markdown('Image generation takes some time'
' but here you can see what is generated from the latent state of the diffuser every few steps.'
' Usually there is a significant improvement around step 12 that yields a much better image')
status_msg = gr.Markdown()
gallery = gr.Gallery()
image = gr.Image(label='Illustration for your story', show_label=True)
gallery.style(grid=[4])
inference_steps = gr.Slider(5, 30,
value=20,
step=1,
visible=True,
label=f"Num inference steps (more steps yields a better image but takes more time)")
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
eventslider = gr.Slider(visible=False)
dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5)
eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep])
if READ_TOKEN:
demo.queue().launch()
else:
demo.queue().launch(share=True, debug=True)
|