adpro commited on
Commit
763ef08
·
verified ·
1 Parent(s): 42a5c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -205
app.py CHANGED
@@ -1,209 +1,77 @@
1
- import spaces
2
- import gradio as gr
3
  import numpy as np
4
- import random
5
- from PIL import Image
6
  import torch
7
- from diffusers import (
8
- ControlNetModel,
9
- DiffusionPipeline,
10
- StableDiffusionControlNetPipeline,
11
- StableDiffusionXLControlNetPipeline,
12
- UniPCMultistepScheduler,
13
- EulerDiscreteScheduler,
14
- AutoencoderKL
15
- )
16
- from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor
17
- from transformers import CLIPImageProcessor
18
- from diffusers.utils import load_image
19
- from gradio_imageslider import ImageSlider
20
- import boto3
21
- from io import BytesIO
22
- from datetime import datetime
23
- import json
24
-
25
- device = "cuda"
26
- base_model_id = "SG161222/RealVisXL_V5.0"
27
- controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0"
28
- vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
29
-
30
-
31
- if torch.cuda.is_available():
32
-
33
- # load pipe
34
- controlnet = ControlNetModel.from_pretrained(
35
- controlnet_model_id,
36
- variant="fp16",
37
- use_safetensors=True,
38
- torch_dtype=torch.bfloat16
39
- )
40
- vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.bfloat16)
41
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
42
- base_model_id,
43
- controlnet=controlnet,
44
- vae=vae,
45
- variant="fp16",
46
- use_safetensors=True,
47
- torch_dtype=torch.bfloat16,
48
- )
49
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
50
- pipe.to(device)
51
-
52
- depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
53
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
54
-
55
-
56
- MAX_SEED = np.iinfo(np.int32).max
57
- MAX_IMAGE_SIZE = 1024
58
-
59
- USE_TORCH_COMPILE = 0
60
- ENABLE_CPU_OFFLOAD = 0
61
-
62
-
63
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
64
- if randomize_seed:
65
- seed = random.randint(0, MAX_SEED)
66
- return seed
67
-
68
-
69
- def get_depth_map(image):
70
- original_size = (image.size[1], image.size[0])
71
- print("start generate depth", original_size)
72
- image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
73
- with torch.no_grad(), torch.autocast("cuda"):
74
- depth_map = depth_estimator(image).predicted_depth
75
- depth_map = torch.nn.functional.interpolate(
76
- depth_map.unsqueeze(1),
77
- size=original_size,
78
- mode="bicubic",
79
- align_corners=False,
80
- )
81
- depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
82
- depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
83
- depth_map = (depth_map - depth_min) / (depth_max - depth_min)
84
- image = torch.cat([depth_map] * 3, dim=1)
85
- image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
86
- image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
87
- print("generate depth success")
88
- return image
89
-
90
-
91
- def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name):
92
- print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
93
- connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
94
-
95
- s3 = boto3.client(
96
- 's3',
97
- endpoint_url=connectionUrl,
98
- region_name='auto',
99
- aws_access_key_id=access_key,
100
- aws_secret_access_key=secret_key
101
- )
102
-
103
- current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
104
- image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
105
- buffer = BytesIO()
106
- image.save(buffer, "PNG")
107
- buffer.seek(0)
108
- s3.upload_fileobj(buffer, bucket_name, image_file)
109
- print("upload finish", image_file)
110
- return image_file
111
-
112
-
113
-
114
- @spaces.GPU(duration=120)
115
- def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
116
- print("process start")
117
- if image_url:
118
- print(image_url)
119
- orginal_image = load_image(image_url)
120
- else:
121
- orginal_image = Image.fromarray(image)
122
-
123
- size = (orginal_image.size[0], orginal_image.size[1])
124
- print("gorinal image size", size)
125
- depth_image = get_depth_map(orginal_image)
126
- generator = torch.Generator().manual_seed(seed)
127
- print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
128
- print("run pipe")
129
- generated_image = pipe(
130
- prompt=prompt,
131
- negative_prompt=n_prompt,
132
- width=size[0],
133
- height=size[1],
134
- guidance_scale=guidance_scale,
135
- num_inference_steps=num_steps,
136
- strength=control_strength,
137
- generator=generator,
138
- image=depth_image
139
- ).images[0]
140
- print("geneate image success")
141
- if upload_to_s3:
142
- url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket)
143
- result = {"status": "success", "url": url}
144
  else:
145
- result = {"status": "success", "message": "Image generated but not uploaded"}
146
-
147
- return generated_image, json.dumps(result)
148
-
149
- with gr.Blocks() as demo:
150
-
151
- with gr.Row():
152
- with gr.Column():
153
- image = gr.Image()
154
- image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
155
- prompt = gr.Textbox(label="Prompt")
156
- run_button = gr.Button("Run")
157
-
158
- with gr.Accordion("Advanced options", open=True):
159
-
160
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1)
161
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
162
- control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1)
163
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
164
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
165
- n_prompt = gr.Textbox(
166
- label="Negative prompt",
167
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
168
- )
169
-
170
- upload_to_s3 = gr.Checkbox(label="Upload to R2", value=False)
171
- account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
172
- access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
173
- secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
174
- bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
175
-
176
-
177
- with gr.Column():
178
- result = gr.Image(label="Generated Image")
179
- logs = gr.Textbox(label="logs")
180
-
181
- inputs = [
182
- image,
183
- image_url,
184
- prompt,
185
- n_prompt,
186
- num_steps,
187
- guidance_scale,
188
- control_strength,
189
- seed,
190
- upload_to_s3,
191
- account_id,
192
- access_key,
193
- secret_key,
194
- bucket
195
- ]
196
- run_button.click(
197
- fn=randomize_seed_fn,
198
- inputs=[seed, randomize_seed],
199
- outputs=seed,
200
- queue=False,
201
- api_name=False,
202
- ).then(
203
- fn=process,
204
- inputs=inputs,
205
- outputs=[result, logs],
206
- api_name="predict"
207
- )
208
 
209
- demo.queue().launch()
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Response
2
+ import cv2
3
  import numpy as np
 
 
4
  import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ import io
8
+
9
+ app = FastAPI()
10
+
11
+ # Load AI Model MiDaS
12
+ midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
13
+ midas.eval()
14
+ transform = T.Compose([
15
+ T.Resize((256, 256)),
16
+ T.ToTensor(),
17
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ ])
19
+
20
+ @app.post("/upload/")
21
+ async def upload_image(file: UploadFile = File(...)):
22
+ try:
23
+ start_time = time.time()
24
+ image_bytes = await file.read()
25
+ print(f"📷 Ảnh nhận được ({len(image_bytes)} bytes)")
26
+
27
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
28
+ print("✅ Ảnh mở thành công!")
29
+ image = image.transpose(Image.FLIP_TOP_BOTTOM)
30
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
31
+ # Chuyển đổi ảnh sang tensor
32
+ img_tensor = transform(image).unsqueeze(0)
33
+ with torch.no_grad():
34
+ depth_map = midas(img_tensor).squeeze().cpu().numpy()
35
+
36
+ # Chuẩn hóa depth map
37
+ depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
38
+ depth_resized = cv2.resize(depth_map, (160, 120))
39
+
40
+ # Mã hóa ảnh thành JPEG
41
+ _, buffer = cv2.imencode(".jpg", depth_resized)
42
+ print("✅ Depth Map đã được tạo!")
43
+ end_time = time.time()
44
+
45
+ start_detect_time = time.time()
46
+ command = detect_path(depth_map)
47
+ end_detect_time = time.time()
48
+ print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")
49
+
50
+ return {"command": command}
51
+ except Exception as e:
52
+ print("❌ Lỗi xử lý ảnh:", str(e))
53
+ return {"error": str(e)}
54
+ def detect_path(depth_map):
55
+ """Phân tích đường đi từ ảnh Depth Map"""
56
+ h, w = depth_map.shape
57
+ center_x = w // 2
58
+ scan_y = int(h * 0.8) # Quét dòng 80% từ trên xuống
59
+
60
+ left_region = np.mean(depth_map[scan_y, :center_x])
61
+ right_region = np.mean(depth_map[scan_y, center_x:])
62
+ center_region = np.mean(depth_map[scan_y, center_x - 40:center_x + 40])
63
+
64
+ # 🟢 Cải thiện logic xử lý
65
+ threshold = 100 # Ngưỡng phân biệt vật cản
66
+ if center_region > threshold:
67
+ return "forward"
68
+ elif left_region > right_region:
69
+ return "left"
70
+ elif right_region > left_region:
71
+ return "right"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
+ return "backward"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ if __name__ == "__main__":
76
+ import uvicorn
77
+ uvicorn.run(app, host="0.0.0.0", port=7860)