Roshan1162003 commited on
Commit
4a8d5bf
·
1 Parent(s): f193fae

Add logging and static route for image serving

Browse files
Files changed (1) hide show
  1. app.py +26 -55
app.py CHANGED
@@ -12,7 +12,7 @@ logging.basicConfig(level=logging.DEBUG)
12
 
13
  # Define paths
14
  MODEL_PATH = "Roshan1162003/fine_tuned_model"
15
- STATIC_IMAGES_PATH = os.path.join("/tmp", "images")
16
  os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
17
  logging.info(f"Created static directory: {STATIC_IMAGES_PATH}")
18
 
@@ -22,28 +22,26 @@ RESTRICTED_TERMS = [
22
  "offensive", "hate", "nude", "porn", "gore", "drug"
23
  ]
24
 
25
- # Verify HF_TOKEN
26
- HF_TOKEN = os.getenv("HF_TOKEN")
27
- if not HF_TOKEN:
28
- logging.error("HF_TOKEN environment variable is not set")
29
- raise EnvironmentError("HF_TOKEN environment variable is not set")
30
-
31
  # Load the fine-tuned model
32
  pipe = None
33
  try:
34
- logging.info(f"Loading model from {MODEL_PATH} with HF_TOKEN")
35
- dtype = torch.float32
36
- logging.info(f"Using dtype: {dtype}")
37
- pipe = StableDiffusionPipeline.from_pretrained(
38
- MODEL_PATH,
39
- torch_dtype=dtype,
40
- use_safetensors=True,
41
- use_auth_token=HF_TOKEN
42
- ).to("cpu")
43
- logging.info("Model loaded successfully with default scheduler")
 
 
 
 
 
44
  except Exception as e:
45
- logging.error(f"Failed to load model: {str(e)}")
46
- raise RuntimeError(f"Failed to load model: {str(e)}")
47
 
48
  # Aspect ratio to resolution mapping
49
  ASPECT_RATIOS = {
@@ -54,7 +52,7 @@ ASPECT_RATIOS = {
54
 
55
  def is_prompt_safe(prompt):
56
  """Check if prompt contains restricted terms."""
57
- prompt_lower = prompt.lower() if isinstance(prompt, str) else ""
58
  for term in RESTRICTED_TERMS:
59
  if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
60
  return False
@@ -68,70 +66,43 @@ def index():
68
  @app.route("/generate", methods=["POST"])
69
  def generate():
70
  try:
71
- # Log full request data
72
- logging.info(f"Request headers: {request.headers}")
73
- logging.info(f"Request form data: {request.form}")
74
-
75
  prompt = request.form.get("prompt", "").strip()
76
- num_images = request.form.get("num_images")
77
  aspect_ratio = request.form.get("aspect_ratio", "1:1")
78
  model_name = request.form.get("model", "stable_diffusion")
 
79
 
80
- # Validate inputs
81
- if not prompt or not isinstance(prompt, str):
82
- logging.error("Invalid prompt: empty or not a string")
83
- return jsonify({"error": "Prompt is required and must be a string"}), 400
84
- try:
85
- num_images = int(num_images) if num_images else 1
86
- except (ValueError, TypeError):
87
- logging.error(f"Invalid num_images: {num_images}")
88
- return jsonify({"error": "Number of images must be an integer"}), 400
89
  if num_images < 1 or num_images > 5:
90
- logging.error(f"Invalid num_images: {num_images}")
91
  return jsonify({"error": "Number of images must be between 1 and 5"}), 400
92
  if aspect_ratio not in ASPECT_RATIOS:
93
- logging.error(f"Invalid aspect_ratio: {aspect_ratio}")
94
  return jsonify({"error": "Invalid aspect ratio"}), 400
95
- if model_name != "stable_diffusion":
96
- logging.error(f"Invalid model: {model_name}")
97
- return jsonify({"error": "Selected model is locked"}), 400
98
-
99
- logging.info(f"Received prompt: {prompt}, num_images: {num_images}, aspect_ratio: {aspect_ratio}")
100
 
101
  if not is_prompt_safe(prompt):
102
- logging.error(f"Restricted prompt detected: {prompt}")
103
  return jsonify({
104
  "error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
105
  }), 400
106
 
107
- if pipe is None:
108
- logging.error("Pipeline is None, cannot generate images")
109
- return jsonify({"error": "Model pipeline is not initialized"}), 500
110
-
111
  width, height = ASPECT_RATIOS[aspect_ratio]
112
  image_paths = []
113
- negative_prompt = "blurry, low quality, distorted face, unnatural features, overexposed, underexposed"
114
  for i in range(num_images):
115
- logging.info(f"Generating image {i+1}/{num_images} for prompt: {prompt}")
116
- start_time = time.time()
117
  image = pipe(
118
- prompt=prompt,
119
- negative_prompt=negative_prompt,
120
  width=width,
121
  height=height,
122
  num_inference_steps=50,
123
  guidance_scale=7.5,
124
  seed=42 + i
125
  ).images[0]
126
- end_time = time.time()
127
- logging.info(f"Image generation took {end_time - start_time:.2f} seconds")
128
  timestamp = int(time.time() * 1000)
129
  image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
130
  logging.info(f"Saving image to: {image_path}")
131
  image.save(image_path)
132
  logging.info(f"Image saved successfully: {image_path}")
133
- image_paths.append(image_path.replace("/tmp/", ""))
134
- logging.info(f"Image path for UI: {image_paths[-1]}")
135
  logging.info(f"Returning image paths: {image_paths}")
136
  return jsonify({"images": image_paths})
137
 
 
12
 
13
  # Define paths
14
  MODEL_PATH = "Roshan1162003/fine_tuned_model"
15
+ STATIC_IMAGES_PATH = os.path.join("static", "images")
16
  os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
17
  logging.info(f"Created static directory: {STATIC_IMAGES_PATH}")
18
 
 
22
  "offensive", "hate", "nude", "porn", "gore", "drug"
23
  ]
24
 
 
 
 
 
 
 
25
  # Load the fine-tuned model
26
  pipe = None
27
  try:
28
+ if torch.cuda.is_available():
29
+ pipe = StableDiffusionPipeline.from_pretrained(
30
+ MODEL_PATH,
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=True,
33
+ use_auth_token=os.getenv("HF_TOKEN")
34
+ ).to("cuda")
35
+ else:
36
+ pipe = StableDiffusionPipeline.from_pretrained(
37
+ MODEL_PATH,
38
+ torch_dtype=torch.float32,
39
+ use_safetensors=True,
40
+ use_auth_token=os.getenv("HF_TOKEN")
41
+ )
42
+ logging.info("Model loaded successfully")
43
  except Exception as e:
44
+ logging.error(f"Error loading model: {e}")
 
45
 
46
  # Aspect ratio to resolution mapping
47
  ASPECT_RATIOS = {
 
52
 
53
  def is_prompt_safe(prompt):
54
  """Check if prompt contains restricted terms."""
55
+ prompt_lower = prompt.lower()
56
  for term in RESTRICTED_TERMS:
57
  if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
58
  return False
 
66
  @app.route("/generate", methods=["POST"])
67
  def generate():
68
  try:
 
 
 
 
69
  prompt = request.form.get("prompt", "").strip()
70
+ num_images = int(request.form.get("num_images", 1))
71
  aspect_ratio = request.form.get("aspect_ratio", "1:1")
72
  model_name = request.form.get("model", "stable_diffusion")
73
+ logging.info(f"Received prompt: {prompt}, num_images: {num_images}, aspect_ratio: {aspect_ratio}")
74
 
75
+ if not prompt:
76
+ return jsonify({"error": "Prompt is required"}), 400
77
+ if model_name != "stable_diffusion":
78
+ return jsonify({"error": "Selected model is locked"}), 400
 
 
 
 
 
79
  if num_images < 1 or num_images > 5:
 
80
  return jsonify({"error": "Number of images must be between 1 and 5"}), 400
81
  if aspect_ratio not in ASPECT_RATIOS:
 
82
  return jsonify({"error": "Invalid aspect ratio"}), 400
 
 
 
 
 
83
 
84
  if not is_prompt_safe(prompt):
 
85
  return jsonify({
86
  "error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
87
  }), 400
88
 
 
 
 
 
89
  width, height = ASPECT_RATIOS[aspect_ratio]
90
  image_paths = []
 
91
  for i in range(num_images):
 
 
92
  image = pipe(
93
+ prompt,
 
94
  width=width,
95
  height=height,
96
  num_inference_steps=50,
97
  guidance_scale=7.5,
98
  seed=42 + i
99
  ).images[0]
 
 
100
  timestamp = int(time.time() * 1000)
101
  image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
102
  logging.info(f"Saving image to: {image_path}")
103
  image.save(image_path)
104
  logging.info(f"Image saved successfully: {image_path}")
105
+ image_paths.append(image_path.replace("static/", ""))
 
106
  logging.info(f"Returning image paths: {image_paths}")
107
  return jsonify({"images": image_paths})
108