OVAWARE commited on
Commit
369aa68
·
verified ·
1 Parent(s): 894b6b1

Attempt at fixing model

Browse files
Files changed (1) hide show
  1. app.py +56 -20
app.py CHANGED
@@ -7,11 +7,11 @@ from transformers import BertTokenizer, BertModel
7
  import numpy as np
8
  import os
9
  import time
 
10
 
11
  LATENT_DIM = 128
12
  HIDDEN_DIM = 256
13
 
14
-
15
  # Text encoder
16
  class TextEncoder(nn.Module):
17
  def __init__(self, hidden_size, output_size):
@@ -23,7 +23,7 @@ class TextEncoder(nn.Module):
23
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
24
  return self.fc(outputs.last_hidden_state[:, 0, :])
25
 
26
- # CVAE model
27
  class CVAE(nn.Module):
28
  def __init__(self, text_encoder):
29
  super(CVAE, self).__init__()
@@ -81,14 +81,20 @@ class CVAE(nn.Module):
81
  # Initialize the BERT tokenizer
82
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
83
 
84
- def clean_image(image, threshold=0.75):
85
  np_image = np.array(image)
86
  alpha_channel = np_image[:, :, 3]
87
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
88
  alpha_channel[alpha_channel > int(threshold * 255)] = 255
89
  return Image.fromarray(np_image)
90
 
91
- def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
 
 
 
 
 
 
92
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
93
  input_ids = encoded_input['input_ids'].to(device)
94
  attention_mask = encoded_input['attention_mask'].to(device)
@@ -110,31 +116,52 @@ def generate_image(model, text_prompt, device, input_image=None, img_control=0.5
110
 
111
  return generated_image
112
 
113
- def load_model(model_path, device):
114
- text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
115
- model = CVAE(text_encoder).to(device)
116
- model.load_state_dict(torch.load(model_path, map_location=device))
117
- model.eval()
118
- return model
119
-
120
- def generate_image_gradio(prompt, model_path, clean_image_flag, size, input_image=None, img_control=0.5):
 
 
 
 
 
 
 
 
 
 
 
121
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
- model = load_model(model_path, device)
 
 
 
 
123
 
124
  start_time = time.time()
125
- generated_image = generate_image(model, prompt, device, input_image, img_control)
 
 
 
 
126
  end_time = time.time()
127
  generation_time = end_time - start_time
128
 
129
  if clean_image_flag:
130
  generated_image = clean_image(generated_image)
131
 
132
- generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
 
 
 
133
 
134
  return generated_image, f"Generation time: {generation_time:.4f} seconds"
135
 
136
- # Gradio interface
137
- def gradio_interface():
138
  with gr.Blocks() as demo:
139
  gr.Markdown("# Image Generator from Text Prompt")
140
 
@@ -152,14 +179,23 @@ def gradio_interface():
152
  output_image = gr.Image(label="Generated Image")
153
  generation_time = gr.Textbox(label="Generation Time")
154
 
 
155
  generate_button.click(
156
- generate_image_gradio,
157
  inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control],
158
- outputs=[output_image, generation_time]
 
159
  )
160
 
161
  return demo
162
 
163
  if __name__ == "__main__":
164
  demo = gradio_interface()
165
- demo.launch()
 
 
 
 
 
 
 
 
7
  import numpy as np
8
  import os
9
  import time
10
+ from typing import Optional, Union
11
 
12
  LATENT_DIM = 128
13
  HIDDEN_DIM = 256
14
 
 
15
  # Text encoder
16
  class TextEncoder(nn.Module):
17
  def __init__(self, hidden_size, output_size):
 
23
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
24
  return self.fc(outputs.last_hidden_state[:, 0, :])
25
 
26
+ # CVAE model (unchanged)
27
  class CVAE(nn.Module):
28
  def __init__(self, text_encoder):
29
  super(CVAE, self).__init__()
 
81
  # Initialize the BERT tokenizer
82
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
83
 
84
+ def clean_image(image: Image.Image, threshold: float = 0.75) -> Image.Image:
85
  np_image = np.array(image)
86
  alpha_channel = np_image[:, :, 3]
87
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
88
  alpha_channel[alpha_channel > int(threshold * 255)] = 255
89
  return Image.fromarray(np_image)
90
 
91
+ def generate_image(
92
+ model: CVAE,
93
+ text_prompt: str,
94
+ device: torch.device,
95
+ input_image: Optional[Image.Image] = None,
96
+ img_control: float = 0.5
97
+ ) -> Image.Image:
98
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
99
  input_ids = encoded_input['input_ids'].to(device)
100
  attention_mask = encoded_input['attention_mask'].to(device)
 
116
 
117
  return generated_image
118
 
119
+ # Model loading with caching
120
+ _model_cache = {}
121
+ def load_model(model_path: str, device: torch.device) -> CVAE:
122
+ if model_path not in _model_cache:
123
+ text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
124
+ model = CVAE(text_encoder).to(device)
125
+ model.load_state_dict(torch.load(model_path, map_location=device))
126
+ model.eval()
127
+ _model_cache[model_path] = model
128
+ return _model_cache[model_path]
129
+
130
+ def generate_image_gradio(
131
+ prompt: str,
132
+ model_path: str,
133
+ clean_image_flag: bool,
134
+ size: int,
135
+ input_image: Optional[Image.Image] = None,
136
+ img_control: float = 0.5
137
+ ) -> tuple[Image.Image, str]:
138
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139
+
140
+ try:
141
+ model = load_model(model_path, device)
142
+ except Exception as e:
143
+ raise gr.Error(f"Failed to load model: {str(e)}")
144
 
145
  start_time = time.time()
146
+ try:
147
+ generated_image = generate_image(model, prompt, device, input_image, img_control)
148
+ except Exception as e:
149
+ raise gr.Error(f"Failed to generate image: {str(e)}")
150
+
151
  end_time = time.time()
152
  generation_time = end_time - start_time
153
 
154
  if clean_image_flag:
155
  generated_image = clean_image(generated_image)
156
 
157
+ try:
158
+ generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
159
+ except Exception as e:
160
+ raise gr.Error(f"Failed to resize image: {str(e)}")
161
 
162
  return generated_image, f"Generation time: {generation_time:.4f} seconds"
163
 
164
+ def gradio_interface() -> gr.Blocks:
 
165
  with gr.Blocks() as demo:
166
  gr.Markdown("# Image Generator from Text Prompt")
167
 
 
179
  output_image = gr.Image(label="Generated Image")
180
  generation_time = gr.Textbox(label="Generation Time")
181
 
182
+ # Use gr.Error for error handling
183
  generate_button.click(
184
+ fn=generate_image_gradio,
185
  inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control],
186
+ outputs=[output_image, generation_time],
187
+ api_name="generate" # Explicit API endpoint name
188
  )
189
 
190
  return demo
191
 
192
  if __name__ == "__main__":
193
  demo = gradio_interface()
194
+ demo.launch(
195
+ server_name="0.0.0.0",
196
+ server_port=7860,
197
+ show_error=True,
198
+ # Configure CORS if needed
199
+ # allowed_paths=["/custom/path"],
200
+ # cors_allowed_origins=["*"]
201
+ )