Msaqibsharif commited on
Commit
c7a5bac
·
verified ·
1 Parent(s): d57c17f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -1
app.py CHANGED
@@ -1 +1,140 @@
1
- print("hello")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from transformers import DetrImageProcessor, DetrForObjectDetection
6
+ from diffusers import StableDiffusionPipeline
7
+ from huggingface_hub import login
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables from .env file
11
+ load_dotenv()
12
+
13
+ # Retrieve Hugging Face token from environment variable
14
+ HF_TOKEN = os.getenv('HF_TOKEN')
15
+
16
+ if HF_TOKEN is None:
17
+ raise ValueError("Hugging Face token not found in environment variables.")
18
+
19
+ # Login to Hugging Face using the token
20
+ login(token=HF_TOKEN)
21
+
22
+ # Load DETR model for object detection
23
+ def load_detr_model():
24
+ try:
25
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
26
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
27
+ return model, processor, None
28
+ except Exception as e:
29
+ return None, None, f"Error loading DETR model: {str(e)}"
30
+
31
+ detr_model, detr_processor, detr_error = load_detr_model()
32
+
33
+ def detect_objects(image):
34
+ if image is None:
35
+ return None, "Invalid image: image is None."
36
+
37
+ if detr_model is not None and detr_processor is not None:
38
+ try:
39
+ inputs = detr_processor(images=image, return_tensors="pt")
40
+ outputs = detr_model(**inputs)
41
+ target_sizes = torch.tensor([image.size[::-1]])
42
+ results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
43
+
44
+ detected_objects = [
45
+ {"label": detr_model.config.id2label[label.item()],
46
+ "box": box.tolist()}
47
+ for label, box in zip(results['labels'], results['boxes'])
48
+ ]
49
+ return detected_objects, None
50
+ except Exception as e:
51
+ return None, f"Error in detect_objects: {str(e)}"
52
+ else:
53
+ return None, "DETR models not loaded. Skipping object detection."
54
+
55
+ # Load Stable Diffusion model for image generation
56
+ def load_stable_diffusion_model():
57
+ try:
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
60
+ return pipeline, None
61
+ except Exception as e:
62
+ return None, f"Error loading Stable Diffusion model: {str(e)}"
63
+
64
+ sd_pipeline, sd_error = load_stable_diffusion_model()
65
+
66
+ def adjust_dimensions(width, height):
67
+ # Adjust width and height to be divisible by 8
68
+ adjusted_width = (width // 8) * 8
69
+ adjusted_height = (height // 8) * 8
70
+ return adjusted_width, adjusted_height
71
+
72
+ def generate_image(prompt, width, height):
73
+ if sd_pipeline is not None:
74
+ try:
75
+ adjusted_width, adjusted_height = adjust_dimensions(width, height)
76
+ image = sd_pipeline(prompt, width=adjusted_width, height=adjusted_height).images[0]
77
+ # Resize back to original dimensions if needed
78
+ image = image.resize((width, height), Image.LANCZOS)
79
+ return image, None
80
+ except Exception as e:
81
+ return None, f"Error in generate_image: {str(e)}"
82
+ else:
83
+ return None, "Stable Diffusion model not loaded. Skipping image generation."
84
+
85
+ def process_image(image):
86
+ if image is None:
87
+ return None, "Invalid image: image is None."
88
+
89
+ try:
90
+ # Detect objects in the provided image
91
+ detected_objects, detect_error = detect_objects(image)
92
+ if detect_error:
93
+ return None, detect_error
94
+
95
+ # Create a prompt based on detected objects
96
+ prompt = "modern redesign of an interior room with "
97
+ if detected_objects:
98
+ prompt += ", ".join([obj['label'] for obj in detected_objects])
99
+ else:
100
+ prompt += "empty room"
101
+
102
+ # Generate a redesigned image based on the prompt
103
+ width, height = image.size
104
+ generated_image, gen_image_error = generate_image(prompt, width, height)
105
+ if gen_image_error:
106
+ return None, gen_image_error
107
+
108
+ return generated_image, None
109
+ except Exception as e:
110
+ return None, f"Error in process_image: {str(e)}"
111
+
112
+ # Custom CSS for styling
113
+ custom_css = """
114
+ body {
115
+ background-color: black;
116
+ }
117
+
118
+ h1 {
119
+ background: linear-gradient(to right, blue, purple);
120
+ -webkit-background-clip: text;
121
+ color: transparent;
122
+ font-size: 3em;
123
+ text-align: center;
124
+ margin-bottom: 20px;
125
+ }
126
+ """
127
+
128
+ # Creating the Gradio interface with custom styling
129
+ iface = gr.Interface(
130
+ fn=process_image,
131
+ inputs=[gr.Image(type="pil", label="Upload Room Image")],
132
+ outputs=[gr.Image(type="pil", label="Redesigned Image"), gr.Textbox(label="Error Message")],
133
+ title="Interior Redesign",
134
+ css=custom_css
135
+ )
136
+
137
+ try:
138
+ iface.launch()
139
+ except Exception as e:
140
+ print(f"Error occurred while launching the interface: {str(e)}")