Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from num2words import num2words | |
import numpy as np | |
import os | |
import random | |
import re | |
import torch | |
import json | |
from shapely.geometry.polygon import Polygon | |
from shapely.affinity import scale | |
from PIL import Image, ImageDraw, ImageOps, ImageFilter, ImageFont, ImageColor | |
#2.7.5 | |
#os.system('pip3 install gradio==2.7.5') | |
#os.system('pip3 install gradio==3.14.0') | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM | |
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_pt | |
tokenizer = AutoTokenizer.from_pretrained("architext/gptj-162M") | |
finetuned = AutoModelForCausalLM.from_pretrained("architext/gptj-162M") | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
print(device) | |
finetuned = finetuned.to(device) | |
# Utility functions | |
def containsNumber(value): | |
for character in value: | |
if character.isdigit(): | |
return True | |
return False | |
def creativity(intensity): | |
if(intensity == 'Low'): | |
top_p = 0.95 | |
top_k = 10 | |
elif(intensity == 'Medium'): | |
top_p = 0.9 | |
top_k = 50 | |
if(intensity == 'High'): | |
top_p = 0.85 | |
top_k = 100 | |
return top_p, top_k | |
housegan_labels = {"living_room": 1, "kitchen": 2, "bedroom": 3, "bathroom": 4, "missing": 5, "closet": 6, | |
"balcony": 7, "hallway": 8, "dining_room": 9, "laundry_room": 10, "corridor": 8} | |
architext_colors = [[0, 0, 0], [249, 222, 182], [195, 209, 217], [250, 120, 128], [126, 202, 234], [190, 0, 198], [255, 255, 255], | |
[6, 53, 17], [17, 33, 58], [132, 151, 246], [197, 203, 159], [6, 53, 17],] | |
regex = re.compile(".*?\((.*?)\)") | |
def draw_polygons(polygons, colors, im_size=(512, 512), b_color="white", fpath=None): | |
image = Image.new("RGBA", im_size, color="white") | |
draw = ImageDraw.Draw(image) | |
for poly, color, in zip(polygons, colors): | |
#get initial polygon coordinates | |
xy = poly.exterior.xy | |
coords = np.dstack((xy[1], xy[0])).flatten() | |
# draw it on canvas, with the appropriate colors | |
draw.polygon(list(coords), fill=(0, 0, 0)) | |
#get inner polygon coordinates | |
small_poly = poly.buffer(-1, resolution=32, cap_style=2, join_style=2, mitre_limit=5.0) | |
if small_poly.geom_type == 'MultiPolygon': | |
mycoordslist = [list(x.exterior.coords) for x in small_poly] | |
for coord in mycoordslist: | |
coords = np.dstack((np.array(coord)[:,1], np.array(coord)[:, 0])).flatten() | |
draw.polygon(list(coords), fill=tuple(color)) | |
elif poly.geom_type == 'Polygon': | |
#get inner polygon coordinates | |
xy2 = small_poly.exterior.xy | |
coords2 = np.dstack((xy2[1], xy2[0])).flatten() | |
# draw it on canvas, with the appropriate colors | |
draw.polygon(list(coords2), fill=tuple(color)) | |
image = image.transpose(Image.FLIP_TOP_BOTTOM) | |
if(fpath): | |
image.save(fpath, quality=100, subsampling=0) | |
return draw, image | |
def prompt_to_layout(user_prompt, intensity, fpath=None): | |
if(containsNumber(user_prompt) == True): | |
spaced_prompt = user_prompt.split(' ') | |
new_prompt = ' '.join([word if word.isdigit() == False else num2words(int(word)).lower() for word in spaced_prompt]) | |
model_prompt = '[User prompt] Hallways are adjacent to bedrooms. {} [Layout]'.format(new_prompt) | |
top_p, top_k = creativity(intensity) | |
model_prompt = '[User prompt] {} [Layout]'.format(user_prompt) | |
input_ids = tokenizer(model_prompt, return_tensors='pt').to(device) | |
output = finetuned.generate(**input_ids, do_sample=True, top_p=top_p, top_k=top_k, | |
eos_token_id=50256, max_length=400) | |
output = tokenizer.batch_decode(output, skip_special_tokens=True) | |
layout = output[0].split('[User prompt]')[1].split('[Layout] ')[1].split(', ') | |
spaces = [txt.split(':')[0] for txt in layout] | |
coords = [] | |
for txt in layout: | |
if ':' in txt: | |
split_txt = txt.split(':') | |
coords.append(split_txt[1].rstrip()) | |
coordinates = [re.findall(regex, coord) for coord in coords] | |
# Initialize an empty list to store the numerical coordinates | |
num_coords = [] | |
# Iterate over each coordinate in the coordinates list | |
for coord in coordinates: | |
temp = [] # Temporary list to store the cleaned numbers | |
# Split the coordinate into individual numbers | |
for xy in coord: | |
numbers = xy.split(',') | |
# Clean each number and convert it to an integer | |
for num in numbers: | |
clean_num = re.sub(r'^\D*|\D*$', '', num) # Remove non-digit characters | |
# Check if the cleaned number is a digit | |
if clean_num.isdigit(): | |
# Convert the cleaned number to an integer and divide it by 14.2 | |
# If division by zero occurs, skip this number | |
try: | |
temp.append(int(clean_num)/14.2) | |
except ZeroDivisionError: | |
continue # Skip this number and continue with the next one | |
# Append the temporary list to the num_coords list | |
num_coords.append(temp) | |
new_spaces = [] | |
for i, v in enumerate(spaces): | |
totalcount = spaces.count(v) | |
count = spaces[:i].count(v) | |
new_spaces.append(v + str(count + 1) if totalcount > 1 else v) | |
out_dict = dict(zip(new_spaces, num_coords)) | |
out_dict = json.dumps(out_dict) | |
polygons = [] | |
for coord in coordinates: | |
polygons.append([point.split(',') for point in coord]) | |
geom = [] | |
for poly in polygons: | |
new_poly = [list(map(int, point)) for point in poly] | |
if len(new_poly) >= 4: | |
scaled_poly = scale(Polygon(new_poly), xfact=2, yfact=2, origin=(0,0)) | |
geom.append(scaled_poly) | |
colors: List[int] = [] | |
for space in spaces: | |
for key in housegan_labels.keys(): | |
if key in space: | |
colors.append(architext_colors[housegan_labels[key]]) | |
break | |
_, im = draw_polygons(geom, colors, fpath=fpath) | |
html = '<img class="labels" src="images/labels.png" />' | |
legend = Image.open("labels.png") | |
imgs_comb = np.vstack([im, legend]) | |
imgs_comb = Image.fromarray(imgs_comb) | |
return imgs_comb, out_dict | |
# Gradio App | |
custom_css=""" | |
@import url("https://use.typekit.net/nid3pfr.css"); | |
.gradio_wrapper .gradio_bg[is_embedded=false] { | |
min-height: 80%; | |
} | |
.gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page { | |
display: flex; | |
width: 100vw; | |
min-height: 50vh; | |
flex-direction: column; | |
justify-content: center; | |
align-items: center; | |
margin: 0px; | |
max-width: 100vw; | |
background: #FFFFFF; | |
} | |
.gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page .content { | |
padding: 0px; | |
margin: 0px; | |
} | |
.gradio_interface { | |
width: 100vw; | |
max-width: 1500px; | |
} | |
.gradio_interface .panel:nth-child(2) .component:nth-child(3) { | |
display:none | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .panel_buttons { | |
justify-content: flex-end; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .panel_button { | |
flex: 0 0 0; | |
min-width: 150px; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_button.submit { | |
background: #11213A; | |
border-radius: 5px; | |
color: #FFFFFF; | |
text-transform: uppercase; | |
min-width: 150px; | |
height: 4em; | |
letter-spacing: 0.15em; | |
flex: 0 0 0; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_button.submit:hover { | |
background: #000000; | |
} | |
.input_text:focus { | |
border-color: #FA7880; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_text input, | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_text textarea { | |
font: 200 45px garamond-premier-pro-display, serif; | |
line-height: 110%; | |
color: #11213A; | |
border-radius: 5px; | |
padding: 15px; | |
border: none; | |
background: #F2F4F4; | |
} | |
.input_text textarea:focus-visible { | |
outline: none; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_radio .radio_item.selected { | |
background-color: #11213A; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_radio .selected .radio_circle { | |
border-color: #4365c4; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image { | |
width: 100%; | |
height: 40vw; | |
max-height: 630px; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image .image_preview_holder { | |
background: transparent; | |
} | |
.panel:nth-child(1) { | |
margin-left: 50px; | |
margin-right: 50px; | |
margin-bottom: 80px; | |
max-width: 750px; | |
} | |
.panel { | |
background: transparent; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .component_set { | |
background: transparent; | |
box-shadow: none; | |
} | |
.panel:nth-child(2) .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_header { | |
display: none; | |
} | |
.gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page .footer { | |
transform: scale(0.75); | |
filter: grayscale(1); | |
} | |
.labels { | |
height: 20px; | |
width: auto; | |
} | |
@media (max-width: 1000px){ | |
.panel:nth-child(1) { | |
margin-left: 0px; | |
margin-right: 0px; | |
} | |
.gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image { | |
height: auto; | |
} | |
} | |
""" | |
creative_slider = gr.Radio(["Low", "Medium", "High"], value="Low", label='Creativity') | |
textbox = gr.Textbox(placeholder='An apartment with two bedrooms and one bathroom', lines=3, label="DESCRIBE YOUR IDEAL APARTMENT") | |
generated = gr.Image(label='Generated Layout', type='numpy') | |
layout = gr.Textbox(label='Layout Coordinates') | |
examples = [ | |
["two bedrooms and two bathrooms", "Low"], | |
["three bedrooms with a kitchen adjacent to the dining room", "Medium"] | |
] | |
def retry_prompt_to_layout(user_prompt, intensity, fpath=None): | |
max_attempts = 5 | |
attempts = 0 | |
while attempts < max_attempts: | |
try: | |
# Call the original function | |
result = prompt_to_layout(user_prompt, intensity, fpath) | |
return result | |
except Exception as e: | |
print(f"Attempt {attempts+1} failed with error: {e}") | |
attempts += 1 | |
iface = gr.Interface(fn=retry_prompt_to_layout, inputs=[textbox, creative_slider], | |
outputs=[generated, layout], | |
css=custom_css, | |
theme="default", | |
allow_flagging='never', | |
examples=examples, | |
cache_examples=False, | |
concurrency_limit=20) | |
iface.queue().launch() |