prithivMLmods commited on
Commit
6c3e861
·
verified ·
1 Parent(s): 753dbac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -98,11 +98,11 @@ def clean_chat_history(chat_history):
98
  # ============================================
99
 
100
  # Environment variables and parameters for Stable Diffusion XL
101
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Use SDXL Model repo path via MODEL_VAL_PATH env var
102
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
103
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
104
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
105
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For potential batched image generation
106
 
107
  # Load the SDXL pipeline
108
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -113,7 +113,11 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
113
  ).to(device)
114
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
 
116
- # Optional: compile the model for speedup
 
 
 
 
117
  if USE_TORCH_COMPILE:
118
  sd_pipe.compile()
119
 
@@ -191,16 +195,16 @@ def generate(
191
  repetition_penalty: float = 1.2,
192
  ):
193
  """
194
- Generates chatbot responses with support for multimodal input, TTS, and now image generation.
195
- If the query starts with:
196
- - "@tts1" or "@tts2", it triggers text-to-speech.
197
- - "@image", it triggers image generation using the SDXL pipeline.
198
  """
199
  text = input_dict["text"]
200
  files = input_dict.get("files", [])
201
 
202
  # ----------------------------
203
- # NEW: IMAGE GENERATION BRANCH
204
  # ----------------------------
205
  if text.strip().lower().startswith("@image"):
206
  # Remove the "@image" tag and use the rest as prompt
@@ -343,4 +347,5 @@ demo = gr.ChatInterface(
343
  )
344
 
345
  if __name__ == "__main__":
 
346
  demo.queue(max_size=20).launch(share=True)
 
98
  # ============================================
99
 
100
  # Environment variables and parameters for Stable Diffusion XL
101
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
102
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
103
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
104
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
105
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
106
 
107
  # Load the SDXL pipeline
108
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
 
113
  ).to(device)
114
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
 
116
+ # **Fix for dtype mismatch in the text encoder:**
117
+ if torch.cuda.is_available():
118
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
119
+
120
+ # Optional: compile the model for speedup if enabled
121
  if USE_TORCH_COMPILE:
122
  sd_pipe.compile()
123
 
 
195
  repetition_penalty: float = 1.2,
196
  ):
197
  """
198
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
199
+ Special commands:
200
+ - "@tts1" or "@tts2": triggers text-to-speech.
201
+ - "@image": triggers image generation using the SDXL pipeline.
202
  """
203
  text = input_dict["text"]
204
  files = input_dict.get("files", [])
205
 
206
  # ----------------------------
207
+ # IMAGE GENERATION BRANCH
208
  # ----------------------------
209
  if text.strip().lower().startswith("@image"):
210
  # Remove the "@image" tag and use the rest as prompt
 
347
  )
348
 
349
  if __name__ == "__main__":
350
+ # To create a public link, set share=True in launch().
351
  demo.queue(max_size=20).launch(share=True)