Qilex's picture
Update Model to one that doesn't overfit
0670d1c
raw
history blame
2.09 kB
import gradio as gr
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
pipeline = DiffusionPipeline.from_pretrained("Qilex/VirtualPetDiffusion2")
def generate_pets(num_to_generate):
images = pipeline(num_to_generate)["sample"]
return images
def concatenate_imgs(imgs):
length = len(imgs)
if length == 1:
return imgs[0]
top = Image.fromarray(np.concatenate([np.array(x) for x in imgs[:2]],axis=1))
if length == 2:
return top
if len(imgs)==3:
fake = np.zeros([128,128,3],dtype=np.uint8)
fake[:] = 255
bottom = Image.fromarray(np.concatenate([imgs[2], fake],axis=1))
elif len(imgs)==4:
bottom = Image.fromarray(np.concatenate([imgs[2], imgs[3]],axis=1))
return Image.fromarray(np.concatenate([top,bottom],axis=0))
def go(num):
imgs = generate_pets(num)
grid = concatenate_imgs(imgs)
print(type(grid))
return grid
title = "VirtualPet Dream"
description = """
This AI will 'dream' you up a virtual pet.
\nThis is a denoising diffusion model trained in 48 hours for a hackathon, so the images can be pretty wonky.
\nImages are 128x128px.
\nBecause we're running on CPU, it takes 10-15 minutes to generate an image. Quick inference can be run in the <a href="https://colab.research.google.com/drive/19QtPOHv6HCpexyCMGXowX4vyZlF4ZZYN?usp=sharing">colab notebook</a>.
\n <a href="https://github.com/ke7osm/VirtualPet-Dream">Github Repo</a>
"""
article = '''Here's a gallery of some of the better pets:
<div style="display: flex; justify-content:space-evenly">
<img src="https://alexlyman.org/external_images/sample_10.png">
<img src="https://alexlyman.org/external_images/sample_5.png" >
<img src="https://alexlyman.org/external_images/sample_4.png" >
<img src="https://alexlyman.org/external_images/sample_8.png" >
</div>
\n
'''
gr.Interface(
fn=go,
inputs= gr.Slider(1, 4, value = 2, step = 1, label="Number of images to generate (more takes longer)"),
outputs=gr.Image(),
title=title,
description=description,
article = article,
).launch()