Spaces:
Sleeping
Sleeping
Commit
·
4a8d5bf
1
Parent(s):
f193fae
Add logging and static route for image serving
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ logging.basicConfig(level=logging.DEBUG)
|
|
| 12 |
|
| 13 |
# Define paths
|
| 14 |
MODEL_PATH = "Roshan1162003/fine_tuned_model"
|
| 15 |
-
STATIC_IMAGES_PATH = os.path.join("
|
| 16 |
os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
|
| 17 |
logging.info(f"Created static directory: {STATIC_IMAGES_PATH}")
|
| 18 |
|
|
@@ -22,28 +22,26 @@ RESTRICTED_TERMS = [
|
|
| 22 |
"offensive", "hate", "nude", "porn", "gore", "drug"
|
| 23 |
]
|
| 24 |
|
| 25 |
-
# Verify HF_TOKEN
|
| 26 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
-
if not HF_TOKEN:
|
| 28 |
-
logging.error("HF_TOKEN environment variable is not set")
|
| 29 |
-
raise EnvironmentError("HF_TOKEN environment variable is not set")
|
| 30 |
-
|
| 31 |
# Load the fine-tuned model
|
| 32 |
pipe = None
|
| 33 |
try:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
-
logging.error(f"
|
| 46 |
-
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 47 |
|
| 48 |
# Aspect ratio to resolution mapping
|
| 49 |
ASPECT_RATIOS = {
|
|
@@ -54,7 +52,7 @@ ASPECT_RATIOS = {
|
|
| 54 |
|
| 55 |
def is_prompt_safe(prompt):
|
| 56 |
"""Check if prompt contains restricted terms."""
|
| 57 |
-
prompt_lower = prompt.lower()
|
| 58 |
for term in RESTRICTED_TERMS:
|
| 59 |
if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
|
| 60 |
return False
|
|
@@ -68,70 +66,43 @@ def index():
|
|
| 68 |
@app.route("/generate", methods=["POST"])
|
| 69 |
def generate():
|
| 70 |
try:
|
| 71 |
-
# Log full request data
|
| 72 |
-
logging.info(f"Request headers: {request.headers}")
|
| 73 |
-
logging.info(f"Request form data: {request.form}")
|
| 74 |
-
|
| 75 |
prompt = request.form.get("prompt", "").strip()
|
| 76 |
-
num_images = request.form.get("num_images")
|
| 77 |
aspect_ratio = request.form.get("aspect_ratio", "1:1")
|
| 78 |
model_name = request.form.get("model", "stable_diffusion")
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
return jsonify({"error": "
|
| 84 |
-
try:
|
| 85 |
-
num_images = int(num_images) if num_images else 1
|
| 86 |
-
except (ValueError, TypeError):
|
| 87 |
-
logging.error(f"Invalid num_images: {num_images}")
|
| 88 |
-
return jsonify({"error": "Number of images must be an integer"}), 400
|
| 89 |
if num_images < 1 or num_images > 5:
|
| 90 |
-
logging.error(f"Invalid num_images: {num_images}")
|
| 91 |
return jsonify({"error": "Number of images must be between 1 and 5"}), 400
|
| 92 |
if aspect_ratio not in ASPECT_RATIOS:
|
| 93 |
-
logging.error(f"Invalid aspect_ratio: {aspect_ratio}")
|
| 94 |
return jsonify({"error": "Invalid aspect ratio"}), 400
|
| 95 |
-
if model_name != "stable_diffusion":
|
| 96 |
-
logging.error(f"Invalid model: {model_name}")
|
| 97 |
-
return jsonify({"error": "Selected model is locked"}), 400
|
| 98 |
-
|
| 99 |
-
logging.info(f"Received prompt: {prompt}, num_images: {num_images}, aspect_ratio: {aspect_ratio}")
|
| 100 |
|
| 101 |
if not is_prompt_safe(prompt):
|
| 102 |
-
logging.error(f"Restricted prompt detected: {prompt}")
|
| 103 |
return jsonify({
|
| 104 |
"error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
|
| 105 |
}), 400
|
| 106 |
|
| 107 |
-
if pipe is None:
|
| 108 |
-
logging.error("Pipeline is None, cannot generate images")
|
| 109 |
-
return jsonify({"error": "Model pipeline is not initialized"}), 500
|
| 110 |
-
|
| 111 |
width, height = ASPECT_RATIOS[aspect_ratio]
|
| 112 |
image_paths = []
|
| 113 |
-
negative_prompt = "blurry, low quality, distorted face, unnatural features, overexposed, underexposed"
|
| 114 |
for i in range(num_images):
|
| 115 |
-
logging.info(f"Generating image {i+1}/{num_images} for prompt: {prompt}")
|
| 116 |
-
start_time = time.time()
|
| 117 |
image = pipe(
|
| 118 |
-
prompt
|
| 119 |
-
negative_prompt=negative_prompt,
|
| 120 |
width=width,
|
| 121 |
height=height,
|
| 122 |
num_inference_steps=50,
|
| 123 |
guidance_scale=7.5,
|
| 124 |
seed=42 + i
|
| 125 |
).images[0]
|
| 126 |
-
end_time = time.time()
|
| 127 |
-
logging.info(f"Image generation took {end_time - start_time:.2f} seconds")
|
| 128 |
timestamp = int(time.time() * 1000)
|
| 129 |
image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
|
| 130 |
logging.info(f"Saving image to: {image_path}")
|
| 131 |
image.save(image_path)
|
| 132 |
logging.info(f"Image saved successfully: {image_path}")
|
| 133 |
-
image_paths.append(image_path.replace("/
|
| 134 |
-
logging.info(f"Image path for UI: {image_paths[-1]}")
|
| 135 |
logging.info(f"Returning image paths: {image_paths}")
|
| 136 |
return jsonify({"images": image_paths})
|
| 137 |
|
|
|
|
| 12 |
|
| 13 |
# Define paths
|
| 14 |
MODEL_PATH = "Roshan1162003/fine_tuned_model"
|
| 15 |
+
STATIC_IMAGES_PATH = os.path.join("static", "images")
|
| 16 |
os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
|
| 17 |
logging.info(f"Created static directory: {STATIC_IMAGES_PATH}")
|
| 18 |
|
|
|
|
| 22 |
"offensive", "hate", "nude", "porn", "gore", "drug"
|
| 23 |
]
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Load the fine-tuned model
|
| 26 |
pipe = None
|
| 27 |
try:
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 30 |
+
MODEL_PATH,
|
| 31 |
+
torch_dtype=torch.float16,
|
| 32 |
+
use_safetensors=True,
|
| 33 |
+
use_auth_token=os.getenv("HF_TOKEN")
|
| 34 |
+
).to("cuda")
|
| 35 |
+
else:
|
| 36 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 37 |
+
MODEL_PATH,
|
| 38 |
+
torch_dtype=torch.float32,
|
| 39 |
+
use_safetensors=True,
|
| 40 |
+
use_auth_token=os.getenv("HF_TOKEN")
|
| 41 |
+
)
|
| 42 |
+
logging.info("Model loaded successfully")
|
| 43 |
except Exception as e:
|
| 44 |
+
logging.error(f"Error loading model: {e}")
|
|
|
|
| 45 |
|
| 46 |
# Aspect ratio to resolution mapping
|
| 47 |
ASPECT_RATIOS = {
|
|
|
|
| 52 |
|
| 53 |
def is_prompt_safe(prompt):
|
| 54 |
"""Check if prompt contains restricted terms."""
|
| 55 |
+
prompt_lower = prompt.lower()
|
| 56 |
for term in RESTRICTED_TERMS:
|
| 57 |
if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
|
| 58 |
return False
|
|
|
|
| 66 |
@app.route("/generate", methods=["POST"])
|
| 67 |
def generate():
|
| 68 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
prompt = request.form.get("prompt", "").strip()
|
| 70 |
+
num_images = int(request.form.get("num_images", 1))
|
| 71 |
aspect_ratio = request.form.get("aspect_ratio", "1:1")
|
| 72 |
model_name = request.form.get("model", "stable_diffusion")
|
| 73 |
+
logging.info(f"Received prompt: {prompt}, num_images: {num_images}, aspect_ratio: {aspect_ratio}")
|
| 74 |
|
| 75 |
+
if not prompt:
|
| 76 |
+
return jsonify({"error": "Prompt is required"}), 400
|
| 77 |
+
if model_name != "stable_diffusion":
|
| 78 |
+
return jsonify({"error": "Selected model is locked"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
if num_images < 1 or num_images > 5:
|
|
|
|
| 80 |
return jsonify({"error": "Number of images must be between 1 and 5"}), 400
|
| 81 |
if aspect_ratio not in ASPECT_RATIOS:
|
|
|
|
| 82 |
return jsonify({"error": "Invalid aspect ratio"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if not is_prompt_safe(prompt):
|
|
|
|
| 85 |
return jsonify({
|
| 86 |
"error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
|
| 87 |
}), 400
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
width, height = ASPECT_RATIOS[aspect_ratio]
|
| 90 |
image_paths = []
|
|
|
|
| 91 |
for i in range(num_images):
|
|
|
|
|
|
|
| 92 |
image = pipe(
|
| 93 |
+
prompt,
|
|
|
|
| 94 |
width=width,
|
| 95 |
height=height,
|
| 96 |
num_inference_steps=50,
|
| 97 |
guidance_scale=7.5,
|
| 98 |
seed=42 + i
|
| 99 |
).images[0]
|
|
|
|
|
|
|
| 100 |
timestamp = int(time.time() * 1000)
|
| 101 |
image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
|
| 102 |
logging.info(f"Saving image to: {image_path}")
|
| 103 |
image.save(image_path)
|
| 104 |
logging.info(f"Image saved successfully: {image_path}")
|
| 105 |
+
image_paths.append(image_path.replace("static/", ""))
|
|
|
|
| 106 |
logging.info(f"Returning image paths: {image_paths}")
|
| 107 |
return jsonify({"images": image_paths})
|
| 108 |
|