Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,209 +1,77 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
import numpy as np
|
4 |
-
import random
|
5 |
-
from PIL import Image
|
6 |
import torch
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
depth_map = depth_estimator(image).predicted_depth
|
75 |
-
depth_map = torch.nn.functional.interpolate(
|
76 |
-
depth_map.unsqueeze(1),
|
77 |
-
size=original_size,
|
78 |
-
mode="bicubic",
|
79 |
-
align_corners=False,
|
80 |
-
)
|
81 |
-
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
82 |
-
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
83 |
-
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
84 |
-
image = torch.cat([depth_map] * 3, dim=1)
|
85 |
-
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
86 |
-
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
87 |
-
print("generate depth success")
|
88 |
-
return image
|
89 |
-
|
90 |
-
|
91 |
-
def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name):
|
92 |
-
print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
|
93 |
-
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
94 |
-
|
95 |
-
s3 = boto3.client(
|
96 |
-
's3',
|
97 |
-
endpoint_url=connectionUrl,
|
98 |
-
region_name='auto',
|
99 |
-
aws_access_key_id=access_key,
|
100 |
-
aws_secret_access_key=secret_key
|
101 |
-
)
|
102 |
-
|
103 |
-
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
104 |
-
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
|
105 |
-
buffer = BytesIO()
|
106 |
-
image.save(buffer, "PNG")
|
107 |
-
buffer.seek(0)
|
108 |
-
s3.upload_fileobj(buffer, bucket_name, image_file)
|
109 |
-
print("upload finish", image_file)
|
110 |
-
return image_file
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
@spaces.GPU(duration=120)
|
115 |
-
def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
116 |
-
print("process start")
|
117 |
-
if image_url:
|
118 |
-
print(image_url)
|
119 |
-
orginal_image = load_image(image_url)
|
120 |
-
else:
|
121 |
-
orginal_image = Image.fromarray(image)
|
122 |
-
|
123 |
-
size = (orginal_image.size[0], orginal_image.size[1])
|
124 |
-
print("gorinal image size", size)
|
125 |
-
depth_image = get_depth_map(orginal_image)
|
126 |
-
generator = torch.Generator().manual_seed(seed)
|
127 |
-
print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
|
128 |
-
print("run pipe")
|
129 |
-
generated_image = pipe(
|
130 |
-
prompt=prompt,
|
131 |
-
negative_prompt=n_prompt,
|
132 |
-
width=size[0],
|
133 |
-
height=size[1],
|
134 |
-
guidance_scale=guidance_scale,
|
135 |
-
num_inference_steps=num_steps,
|
136 |
-
strength=control_strength,
|
137 |
-
generator=generator,
|
138 |
-
image=depth_image
|
139 |
-
).images[0]
|
140 |
-
print("geneate image success")
|
141 |
-
if upload_to_s3:
|
142 |
-
url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket)
|
143 |
-
result = {"status": "success", "url": url}
|
144 |
else:
|
145 |
-
|
146 |
-
|
147 |
-
return generated_image, json.dumps(result)
|
148 |
-
|
149 |
-
with gr.Blocks() as demo:
|
150 |
-
|
151 |
-
with gr.Row():
|
152 |
-
with gr.Column():
|
153 |
-
image = gr.Image()
|
154 |
-
image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
|
155 |
-
prompt = gr.Textbox(label="Prompt")
|
156 |
-
run_button = gr.Button("Run")
|
157 |
-
|
158 |
-
with gr.Accordion("Advanced options", open=True):
|
159 |
-
|
160 |
-
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1)
|
161 |
-
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
162 |
-
control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1)
|
163 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
164 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
165 |
-
n_prompt = gr.Textbox(
|
166 |
-
label="Negative prompt",
|
167 |
-
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
168 |
-
)
|
169 |
-
|
170 |
-
upload_to_s3 = gr.Checkbox(label="Upload to R2", value=False)
|
171 |
-
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
|
172 |
-
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
|
173 |
-
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
|
174 |
-
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
|
175 |
-
|
176 |
-
|
177 |
-
with gr.Column():
|
178 |
-
result = gr.Image(label="Generated Image")
|
179 |
-
logs = gr.Textbox(label="logs")
|
180 |
-
|
181 |
-
inputs = [
|
182 |
-
image,
|
183 |
-
image_url,
|
184 |
-
prompt,
|
185 |
-
n_prompt,
|
186 |
-
num_steps,
|
187 |
-
guidance_scale,
|
188 |
-
control_strength,
|
189 |
-
seed,
|
190 |
-
upload_to_s3,
|
191 |
-
account_id,
|
192 |
-
access_key,
|
193 |
-
secret_key,
|
194 |
-
bucket
|
195 |
-
]
|
196 |
-
run_button.click(
|
197 |
-
fn=randomize_seed_fn,
|
198 |
-
inputs=[seed, randomize_seed],
|
199 |
-
outputs=seed,
|
200 |
-
queue=False,
|
201 |
-
api_name=False,
|
202 |
-
).then(
|
203 |
-
fn=process,
|
204 |
-
inputs=inputs,
|
205 |
-
outputs=[result, logs],
|
206 |
-
api_name="predict"
|
207 |
-
)
|
208 |
|
209 |
-
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, Response
|
2 |
+
import cv2
|
3 |
import numpy as np
|
|
|
|
|
4 |
import torch
|
5 |
+
import torchvision.transforms as T
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
|
9 |
+
app = FastAPI()
|
10 |
+
|
11 |
+
# Load AI Model MiDaS
|
12 |
+
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
|
13 |
+
midas.eval()
|
14 |
+
transform = T.Compose([
|
15 |
+
T.Resize((256, 256)),
|
16 |
+
T.ToTensor(),
|
17 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
18 |
+
])
|
19 |
+
|
20 |
+
@app.post("/upload/")
|
21 |
+
async def upload_image(file: UploadFile = File(...)):
|
22 |
+
try:
|
23 |
+
start_time = time.time()
|
24 |
+
image_bytes = await file.read()
|
25 |
+
print(f"📷 Ảnh nhận được ({len(image_bytes)} bytes)")
|
26 |
+
|
27 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
28 |
+
print("✅ Ảnh mở thành công!")
|
29 |
+
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
30 |
+
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
31 |
+
# Chuyển đổi ảnh sang tensor
|
32 |
+
img_tensor = transform(image).unsqueeze(0)
|
33 |
+
with torch.no_grad():
|
34 |
+
depth_map = midas(img_tensor).squeeze().cpu().numpy()
|
35 |
+
|
36 |
+
# Chuẩn hóa depth map
|
37 |
+
depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
38 |
+
depth_resized = cv2.resize(depth_map, (160, 120))
|
39 |
+
|
40 |
+
# Mã hóa ảnh thành JPEG
|
41 |
+
_, buffer = cv2.imencode(".jpg", depth_resized)
|
42 |
+
print("✅ Depth Map đã được tạo!")
|
43 |
+
end_time = time.time()
|
44 |
+
|
45 |
+
start_detect_time = time.time()
|
46 |
+
command = detect_path(depth_map)
|
47 |
+
end_detect_time = time.time()
|
48 |
+
print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")
|
49 |
+
|
50 |
+
return {"command": command}
|
51 |
+
except Exception as e:
|
52 |
+
print("❌ Lỗi xử lý ảnh:", str(e))
|
53 |
+
return {"error": str(e)}
|
54 |
+
def detect_path(depth_map):
|
55 |
+
"""Phân tích đường đi từ ảnh Depth Map"""
|
56 |
+
h, w = depth_map.shape
|
57 |
+
center_x = w // 2
|
58 |
+
scan_y = int(h * 0.8) # Quét dòng 80% từ trên xuống
|
59 |
+
|
60 |
+
left_region = np.mean(depth_map[scan_y, :center_x])
|
61 |
+
right_region = np.mean(depth_map[scan_y, center_x:])
|
62 |
+
center_region = np.mean(depth_map[scan_y, center_x - 40:center_x + 40])
|
63 |
+
|
64 |
+
# 🟢 Cải thiện logic xử lý
|
65 |
+
threshold = 100 # Ngưỡng phân biệt vật cản
|
66 |
+
if center_region > threshold:
|
67 |
+
return "forward"
|
68 |
+
elif left_region > right_region:
|
69 |
+
return "left"
|
70 |
+
elif right_region > left_region:
|
71 |
+
return "right"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
else:
|
73 |
+
return "backward"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
if __name__ == "__main__":
|
76 |
+
import uvicorn
|
77 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|