Floor / app.py
sialnoman318's picture
Update app.py
20775db verified
raw
history blame contribute delete
3.75 kB
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from datasets import load_dataset
import random
from groq import Groq
# Load the dataset
dataset = load_dataset("zimhe/pseudo-floor-plan-12k")
# Initialize GroqAPI
client = Groq(api_key="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")
# Function to check dataset structure
def inspect_dataset():
if "train" not in dataset or not dataset["train"]:
return "Error: Dataset does not contain a valid 'train' split."
if "caption" not in dataset["train"].features:
return "Error: 'caption' field not found in the dataset."
return None
# Function to select a random floor plan template based on caption
def get_floor_plan_by_caption(caption):
error = inspect_dataset()
if error:
return None, error
filtered_data = [
item for item in dataset["train"]
if caption.lower() in item["caption"].lower()
]
if not filtered_data:
return None, "Error: No templates available for the specified caption."
return random.choice(filtered_data), None
# Function to create a plot for the floor plan
def create_floor_plan_from_template(template):
try:
plot_width, plot_height = template["plot_size"]["width"], template["plot_size"]["height"]
except KeyError:
return "Error: Template is missing plot size information."
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_xlim(0, plot_width)
ax.set_ylim(0, plot_height)
# Draw plot boundary
ax.add_patch(Rectangle((0, 0), plot_width, plot_height, edgecolor="black", fill=None, linewidth=2, label="Plot Boundary"))
# Draw rooms based on the template
for room in template["rooms"]:
x, y, width, height = room["x"], room["y"], room["width"], room["height"]
ax.add_patch(Rectangle((x, y), width, height, edgecolor="blue", fill=None, linewidth=1))
ax.text(x + width / 2, y + height / 2, room["name"], ha="center", va="center", fontsize=8)
# Add additional features like courtyard, parking, and washrooms if provided
for feature in template["features"]:
x, y, width, height = feature["x"], feature["y"], feature["width"], feature["height"]
color = "green" if feature["type"] == "courtyard" else "red" if feature["type"] == "parking" else "purple"
ax.add_patch(Rectangle((x, y), width, height, edgecolor=color, fill=None, linewidth=1.5))
ax.text(x + width / 2, y + height / 2, feature["type"].capitalize(), ha="center", va="center", fontsize=8)
# Finalize layout
ax.axis("off")
plt.tight_layout()
# Save and return the plot
img_path = "floor_plan_template.png"
plt.savefig(img_path)
plt.close(fig)
return img_path
# Define the Gradio interface
def floor_plan_with_groq(caption):
# Fetch a template for the given caption
template, error = get_floor_plan_by_caption(caption)
if error:
return error
# Generate the floor plan using the template
floor_plan_image = create_floor_plan_from_template(template)
if isinstance(floor_plan_image, str) and floor_plan_image.startswith("Error"):
return floor_plan_image
# Enhance the floor plan using Groq
enhanced_image = client.enhance_image(floor_plan_image)
return enhanced_image
# Gradio Interface
interface = gr.Interface(
fn=floor_plan_with_groq,
inputs=[
gr.Textbox(label="Enter Caption for Floor Plan"),
],
outputs="image",
title="Enhanced Floor Plan Generator with Groq AI",
description="Generate diverse and realistic floor plans using captions from a pre-trained dataset and Groq's AI capabilities."
)
if __name__ == "__main__":
interface.launch()