test_sketch / app.py
ahmadarrabi's picture
qwer
9e47866 verified
raw
history blame
3.62 kB
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from torchvision import transforms
import numpy as np
from keybert import KeyBERT
import gensim
from gensim.parsing.preprocessing import STOPWORDS as stop_words
import re
def sketch_transform(size):
return transforms.Compose([
transforms.Resize(size)
])
def keep_english(text):
return re.sub(r'[^a-zA-Z0-9\s]', '', text.strip())
def remove_stopwords(text):
return " ".join([word for word in gensim.utils.simple_preprocess(text, deacc = True, min_len = 0) if word not in stop_words])
def remove_unwanted_words(text, unwanted_words):
text_list = text.lower().split()
resultwords = [word for word in text_list if word not in unwanted_words]
return ' '.join(resultwords)
def generate_image(image, text_prompt, text_city):
unwanted_words = ['seattle', 'chicago', 'sanfrancisco', 'newyork', 'image', 'shows', 'generally', 'sky', 'view']
n=5
diversity=0.3
ngram=(3,3)
kw_model = KeyBERT()
prompt = text_prompt
prompt = keep_english(prompt)
prompt = remove_unwanted_words(prompt, unwanted_words)
keywords = kw_model.extract_keywords(prompt, top_n=n, use_mmr=True, diversity=diversity, keyphrase_ngram_range=ngram)
keywords_substring = ", ".join([key[0] for key in keywords])
prompt = f"Realistic {text_city} aerial satellite top view image with high quality details with buildings and roads in {text_city} that probably has the following objects and characterstics: {keywords_substring}"
print(prompt)
if image is not None:
sketch = image['layers'][0]
sketch = transforms.ToTensor()(sketch).unsqueeze(0)
sketch = sketch_transform(size=(256,256))(sketch)
with torch.no_grad(): # Disable gradient calculation for inference
model_output = pipe(prompt, num_inference_steps=20, generator=torch.manual_seed(0), image=sketch)
generated_image = model_output.images[0]
return generated_image
controlnet_model_name_or_path = "./controlnet"
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=torch.float16, conditioning_channels=3)
pipe = StableDiffusionControlNetPipeline.from_pretrained(pretrained_model_name_or_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None)
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=device)
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.ImageEditor(sources=(),
image_mode='RGB',
type='pil',
brush=gr.Brush(colors=["#ffb266", #building
"#4059ff", #parking
"#66ff66", #grass
#"#009900", #forest
"#cce5ff", #water
#"#c0c0c0", #path
"#606060" #road
], color_mode="fixed")
),
gr.Textbox(placeholder='residential area with a lot of trees', label='Prompt'),
gr.Textbox(placeholder='Seattle, SanFrancisco, NewYork, Chicago', label='City')
],
outputs="image"
)
iface.launch()