Yussifweb3 commited on
Commit
3657bdc
·
verified ·
1 Parent(s): f5fe15c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import Optional
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ from gradio_client import Client
8
+ import uvicorn
9
+
10
+ app = FastAPI()
11
+
12
+ # Add CORS middleware
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # In production, replace with your frontend URL
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ def encode_image_to_base64(image_bytes):
22
+ """Convert image bytes to base64 string"""
23
+ return base64.b64encode(image_bytes).decode('utf-8')
24
+
25
+ def process_uploaded_image(file: UploadFile) -> bytes:
26
+ """Process uploaded image and return bytes"""
27
+ contents = file.file.read()
28
+ image = Image.open(io.BytesIO(contents))
29
+
30
+ # Convert to RGB if necessary
31
+ if image.mode != 'RGB':
32
+ image = image.convert('RGB')
33
+
34
+ # Save to bytes
35
+ img_byte_arr = io.BytesIO()
36
+ image.save(img_byte_arr, format='PNG')
37
+ img_byte_arr = img_byte_arr.getvalue()
38
+
39
+ return img_byte_arr
40
+
41
+ @app.post("/api/face-swap")
42
+ async def face_swap(
43
+ source_file: UploadFile = File(...),
44
+ target_file: UploadFile = File(...),
45
+ do_face_enhancer: Optional[bool] = Form(True)
46
+ ):
47
+ try:
48
+ # Process uploaded images
49
+ source_bytes = process_uploaded_image(source_file)
50
+ target_bytes = process_uploaded_image(target_file)
51
+
52
+ # Initialize Gradio client
53
+ client = Client("tuan2308/face-swap")
54
+
55
+ # Make prediction
56
+ result = await client.predict(
57
+ source_bytes, # Source image
58
+ target_bytes, # Target image
59
+ do_face_enhancer, # Face enhancement option
60
+ api_name="/predict"
61
+ )
62
+
63
+ # Process result
64
+ if isinstance(result, bytes):
65
+ # If result is already bytes, encode to base64
66
+ result_base64 = encode_image_to_base64(result)
67
+ else:
68
+ # If result is a path or other format, handle accordingly
69
+ # You might need to adjust this based on the actual return type
70
+ with open(result, 'rb') as f:
71
+ result_base64 = encode_image_to_base64(f.read())
72
+
73
+ return {
74
+ "status": "success",
75
+ "image": f"data:image/png;base64,{result_base64}"
76
+ }
77
+
78
+ except Exception as e:
79
+ return {
80
+ "status": "error",
81
+ "message": str(e)
82
+ }
83
+
84
+ # Health check endpoint
85
+ @app.get("/health")
86
+ async def health_check():
87
+ return {"status": "healthy"}
88
+
89
+ if __name__ == "__main__":
90
+ uvicorn.run(app, host="0.0.0.0", port=8000)