Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -98,11 +98,11 @@ def clean_chat_history(chat_history):
|
|
98 |
# ============================================
|
99 |
|
100 |
# Environment variables and parameters for Stable Diffusion XL
|
101 |
-
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") #
|
102 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
103 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
104 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
105 |
-
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For
|
106 |
|
107 |
# Load the SDXL pipeline
|
108 |
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
@@ -113,7 +113,11 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
113 |
).to(device)
|
114 |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
115 |
|
116 |
-
#
|
|
|
|
|
|
|
|
|
117 |
if USE_TORCH_COMPILE:
|
118 |
sd_pipe.compile()
|
119 |
|
@@ -191,16 +195,16 @@ def generate(
|
|
191 |
repetition_penalty: float = 1.2,
|
192 |
):
|
193 |
"""
|
194 |
-
Generates chatbot responses with support for multimodal input, TTS, and
|
195 |
-
|
196 |
-
- "@tts1" or "@tts2"
|
197 |
-
- "@image"
|
198 |
"""
|
199 |
text = input_dict["text"]
|
200 |
files = input_dict.get("files", [])
|
201 |
|
202 |
# ----------------------------
|
203 |
-
#
|
204 |
# ----------------------------
|
205 |
if text.strip().lower().startswith("@image"):
|
206 |
# Remove the "@image" tag and use the rest as prompt
|
@@ -343,4 +347,5 @@ demo = gr.ChatInterface(
|
|
343 |
)
|
344 |
|
345 |
if __name__ == "__main__":
|
|
|
346 |
demo.queue(max_size=20).launch(share=True)
|
|
|
98 |
# ============================================
|
99 |
|
100 |
# Environment variables and parameters for Stable Diffusion XL
|
101 |
+
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
102 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
103 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
104 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
105 |
+
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
|
106 |
|
107 |
# Load the SDXL pipeline
|
108 |
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
113 |
).to(device)
|
114 |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
115 |
|
116 |
+
# **Fix for dtype mismatch in the text encoder:**
|
117 |
+
if torch.cuda.is_available():
|
118 |
+
sd_pipe.text_encoder = sd_pipe.text_encoder.half()
|
119 |
+
|
120 |
+
# Optional: compile the model for speedup if enabled
|
121 |
if USE_TORCH_COMPILE:
|
122 |
sd_pipe.compile()
|
123 |
|
|
|
195 |
repetition_penalty: float = 1.2,
|
196 |
):
|
197 |
"""
|
198 |
+
Generates chatbot responses with support for multimodal input, TTS, and image generation.
|
199 |
+
Special commands:
|
200 |
+
- "@tts1" or "@tts2": triggers text-to-speech.
|
201 |
+
- "@image": triggers image generation using the SDXL pipeline.
|
202 |
"""
|
203 |
text = input_dict["text"]
|
204 |
files = input_dict.get("files", [])
|
205 |
|
206 |
# ----------------------------
|
207 |
+
# IMAGE GENERATION BRANCH
|
208 |
# ----------------------------
|
209 |
if text.strip().lower().startswith("@image"):
|
210 |
# Remove the "@image" tag and use the rest as prompt
|
|
|
347 |
)
|
348 |
|
349 |
if __name__ == "__main__":
|
350 |
+
# To create a public link, set share=True in launch().
|
351 |
demo.queue(max_size=20).launch(share=True)
|