File size: 3,618 Bytes
10e7317
5fb0367
a8061e6
134e4fc
d9784e7
9e47866
 
 
 
10e7317
134e4fc
 
e80b364
134e4fc
42cf413
 
 
 
 
 
 
 
 
 
 
134e4fc
42cf413
 
 
 
 
 
 
 
 
 
 
 
 
303e6b7
10e7317
caadc58
e80b364
134e4fc
caadc58
5fb0367
 
 
 
 
 
542b24d
5fb0367
4dc01bc
11c6419
 
4dc01bc
 
6c398e4
4dc01bc
 
 
53094cb
4dc01bc
 
10e7317
 
da25d13
caadc58
 
0a880bd
caadc58
 
 
 
 
 
 
 
0a880bd
42cf413
 
0a880bd
caadc58
b11db55
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from torchvision import transforms
import numpy as np
from keybert import KeyBERT
import gensim
from gensim.parsing.preprocessing import STOPWORDS as stop_words
import re

def sketch_transform(size):
    return transforms.Compose([
        transforms.Resize(size)
    ])

def keep_english(text):
    return re.sub(r'[^a-zA-Z0-9\s]', '', text.strip())

def remove_stopwords(text):
    return " ".join([word for word in gensim.utils.simple_preprocess(text, deacc = True, min_len = 0) if word not in stop_words])

def remove_unwanted_words(text, unwanted_words):
    text_list = text.lower().split()
    resultwords  = [word for word in text_list if word not in unwanted_words]
    return ' '.join(resultwords)
    
def generate_image(image, text_prompt, text_city):
    unwanted_words = ['seattle', 'chicago', 'sanfrancisco', 'newyork', 'image', 'shows', 'generally', 'sky', 'view']
    n=5
    diversity=0.3
    ngram=(3,3)
    kw_model = KeyBERT()
    
    prompt = text_prompt
    prompt = keep_english(prompt)
    prompt = remove_unwanted_words(prompt, unwanted_words)
    keywords = kw_model.extract_keywords(prompt, top_n=n, use_mmr=True, diversity=diversity, keyphrase_ngram_range=ngram)
    keywords_substring = ", ".join([key[0] for key in keywords])
    prompt = f"Realistic {text_city} aerial satellite top view image with high quality details with buildings and roads in {text_city} that probably has the following objects and characterstics: {keywords_substring}"
    print(prompt)
    if image is not None:
        sketch = image['layers'][0]
        sketch = transforms.ToTensor()(sketch).unsqueeze(0)
        sketch = sketch_transform(size=(256,256))(sketch)

        with torch.no_grad():  # Disable gradient calculation for inference
            model_output = pipe(prompt, num_inference_steps=20, generator=torch.manual_seed(0), image=sketch)
            generated_image = model_output.images[0]

    return generated_image

controlnet_model_name_or_path = "./controlnet"
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"

controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=torch.float16,  conditioning_channels=3)
pipe = StableDiffusionControlNetPipeline.from_pretrained(pretrained_model_name_or_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None)

if torch.cuda.is_available():
    device = 'cuda' 
else:
    device = 'cpu'

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=device)

iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.ImageEditor(sources=(), 
                       image_mode='RGB', 
                       type='pil',
                       brush=gr.Brush(colors=["#ffb266",  #building
                                              "#4059ff", #parking
                                              "#66ff66", #grass
                                              #"#009900", #forest
                                              "#cce5ff", #water
                                              #"#c0c0c0", #path
                                              "#606060" #road
                                              ], color_mode="fixed")
                      ),
        gr.Textbox(placeholder='residential area with a lot of trees', label='Prompt'),
        gr.Textbox(placeholder='Seattle, SanFrancisco, NewYork, Chicago', label='City')
        ],
    outputs="image"
)

iface.launch()