Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
from torchvision import transforms | |
import numpy as np | |
from keybert import KeyBERT | |
import gensim | |
from gensim.parsing.preprocessing import STOPWORDS as stop_words | |
import re | |
def sketch_transform(size): | |
return transforms.Compose([ | |
transforms.Resize(size) | |
]) | |
def keep_english(text): | |
return re.sub(r'[^a-zA-Z0-9\s]', '', text.strip()) | |
def remove_stopwords(text): | |
return " ".join([word for word in gensim.utils.simple_preprocess(text, deacc = True, min_len = 0) if word not in stop_words]) | |
def remove_unwanted_words(text, unwanted_words): | |
text_list = text.lower().split() | |
resultwords = [word for word in text_list if word not in unwanted_words] | |
return ' '.join(resultwords) | |
def generate_image(image, text_prompt, text_city): | |
unwanted_words = ['seattle', 'chicago', 'sanfrancisco', 'newyork', 'image', 'shows', 'generally', 'sky', 'view'] | |
n=5 | |
diversity=0.3 | |
ngram=(3,3) | |
kw_model = KeyBERT() | |
prompt = text_prompt | |
prompt = keep_english(prompt) | |
prompt = remove_unwanted_words(prompt, unwanted_words) | |
keywords = kw_model.extract_keywords(prompt, top_n=n, use_mmr=True, diversity=diversity, keyphrase_ngram_range=ngram) | |
keywords_substring = ", ".join([key[0] for key in keywords]) | |
prompt = f"Realistic {text_city} aerial satellite top view image with high quality details with buildings and roads in {text_city} that probably has the following objects and characterstics: {keywords_substring}" | |
print(prompt) | |
if image is not None: | |
sketch = image['layers'][0] | |
sketch = transforms.ToTensor()(sketch).unsqueeze(0) | |
sketch = sketch_transform(size=(256,256))(sketch) | |
with torch.no_grad(): # Disable gradient calculation for inference | |
model_output = pipe(prompt, num_inference_steps=20, generator=torch.manual_seed(0), image=sketch) | |
generated_image = model_output.images[0] | |
return generated_image | |
controlnet_model_name_or_path = "./controlnet" | |
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" | |
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=torch.float16, conditioning_channels=3) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained(pretrained_model_name_or_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None) | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload(device=device) | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.ImageEditor(sources=(), | |
image_mode='RGB', | |
type='pil', | |
brush=gr.Brush(colors=["#ffb266", #building | |
"#4059ff", #parking | |
"#66ff66", #grass | |
#"#009900", #forest | |
"#cce5ff", #water | |
#"#c0c0c0", #path | |
"#606060" #road | |
], color_mode="fixed") | |
), | |
gr.Textbox(placeholder='residential area with a lot of trees', label='Prompt'), | |
gr.Textbox(placeholder='Seattle, SanFrancisco, NewYork, Chicago', label='City') | |
], | |
outputs="image" | |
) | |
iface.launch() | |