Princeaka commited on
Commit
4c7150a
·
verified ·
1 Parent(s): d6333ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -98
app.py CHANGED
@@ -1,133 +1,257 @@
 
1
  import os
2
  import shutil
3
  import asyncio
 
4
  from typing import Optional
5
 
 
 
 
6
  import gradio as gr
7
- from fastapi import FastAPI, UploadFile, Form
8
  import uvicorn
9
- import socket
10
 
 
11
  from multimodal_module import MultiModalChatModule
12
 
13
- # Initialize AI module
14
  AI = MultiModalChatModule()
15
 
16
- # ---------------------------
17
- # Utility
18
- # ---------------------------
19
- class GradioFileWrapper:
20
- def __init__(self, file_path):
21
- self._path = file_path
22
 
23
- async def download_to_drive(self, dst_path: str):
 
 
 
 
 
 
24
  loop = asyncio.get_event_loop()
25
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
26
 
27
- def run_async(coro):
28
- return asyncio.run(coro)
 
 
 
 
 
 
29
 
30
- def get_free_port(default=7860):
31
- """Find a free port if default is busy."""
32
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
33
- try:
34
- s.bind(("0.0.0.0", default))
35
- return default
36
- except OSError:
37
- s.bind(("0.0.0.0", 0))
38
- return s.getsockname()[1]
39
-
40
- # ---------------------------
41
- # FastAPI API for external apps
42
- # ---------------------------
43
- api = FastAPI()
44
-
45
- @api.post("/api/text_chat")
46
- async def api_text_chat(
47
- user_id: Optional[int] = Form(0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  text: str = Form(...),
49
- lang: str = Form("en")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ):
51
  try:
52
- reply = await AI.generate_response(text, int(user_id), lang)
53
- return {"reply": reply}
 
54
  except Exception as e:
55
- return {"error": str(e)}
56
 
57
- @api.post("/api/image_caption")
58
- async def api_image_caption(user_id: Optional[int] = Form(0), image: UploadFile = None):
 
 
 
 
59
  try:
60
- temp_path = f"/tmp/{image.filename}"
61
- with open(temp_path, "wb") as f:
62
- f.write(await image.read())
63
- wrapper = GradioFileWrapper(temp_path)
64
- caption = await AI.process_image_message(wrapper, int(user_id))
65
- return {"caption": caption}
66
  except Exception as e:
67
- return {"error": str(e)}
68
 
69
- @api.post("/api/voice_process")
70
- async def api_voice_process(user_id: Optional[int] = Form(0), audio: UploadFile = None):
 
 
 
 
 
 
 
71
  try:
72
- temp_path = f"/tmp/{audio.filename}"
73
- with open(temp_path, "wb") as f:
74
- f.write(await audio.read())
75
- wrapper = GradioFileWrapper(temp_path)
76
- reply = await AI.process_voice_message(wrapper, int(user_id))
77
- return {"reply": reply}
78
  except Exception as e:
79
- return {"error": str(e)}
80
 
81
- @api.post("/api/video_process")
82
- async def api_video_process(user_id: Optional[int] = Form(0), video: UploadFile = None):
 
 
 
 
 
 
83
  try:
84
- temp_path = f"/tmp/{video.filename}"
85
- with open(temp_path, "wb") as f:
86
- f.write(await video.read())
87
- wrapper = GradioFileWrapper(temp_path)
88
- reply = await AI.process_video_message(wrapper, int(user_id))
89
- return {"reply": reply}
 
 
 
 
 
 
 
90
  except Exception as e:
91
- return {"error": str(e)}
92
 
93
- @api.post("/api/file_process")
94
- async def api_file_process(user_id: Optional[int] = Form(0), file: UploadFile = None):
 
 
 
 
95
  try:
96
- temp_path = f"/tmp/{file.filename}"
97
- with open(temp_path, "wb") as f:
98
- f.write(await file.read())
99
- wrapper = GradioFileWrapper(temp_path)
100
- reply = await AI.process_file_message(wrapper, int(user_id))
101
- return {"reply": reply}
102
  except Exception as e:
103
- return {"error": str(e)}
104
-
105
- # ---------------------------
106
- # Gradio UI
107
- # ---------------------------
108
- with gr.Blocks(title="Multimodal Bot") as demo:
109
- gr.Markdown("# 🧠 Multimodal Bot\nInteract via text, voice, images, video, or files.")
110
-
111
- with gr.Tab("💬 Text Chat"):
112
- user_id_txt = gr.Textbox(label="User ID", placeholder="0")
113
- lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
114
- txt_in = gr.Textbox(label="Your message", lines=4)
115
- txt_out = gr.Textbox(label="Bot reply", lines=6)
116
- gr.Button("Send").click(lambda uid, txt, lang: run_async(AI.generate_response(txt, int(uid or 0), lang)),
117
- [user_id_txt, txt_in, lang_sel], txt_out)
118
-
119
- with gr.Tab("🖼 Image Captioning"):
120
- user_id_img = gr.Textbox(label="User ID", placeholder="0")
121
- img_in = gr.Image(type="filepath", label="Upload an image")
122
- img_out = gr.Textbox(label="Caption")
123
- gr.Button("Caption").click(lambda uid, img: run_async(AI.process_image_message(GradioFileWrapper(img), int(uid or 0))),
124
- [user_id_img, img_in], img_out)
125
-
126
- # ---------------------------
127
- # Mount Gradio UI to FastAPI
128
- # ---------------------------
129
- api = gr.mount_gradio_app(api, demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
131
  if __name__ == "__main__":
132
- port = get_free_port()
133
- uvicorn.run(api, host="0.0.0.0", port=port)
 
 
1
+ # app.py — FastAPI REST API + mounted Gradio UI (Hugging Face Spaces compatible)
2
  import os
3
  import shutil
4
  import asyncio
5
+ import inspect
6
  from typing import Optional
7
 
8
+ from fastapi import FastAPI, UploadFile, File, Form
9
+ from fastapi.responses import JSONResponse, PlainTextResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  import gradio as gr
 
12
  import uvicorn
 
13
 
14
+ # ---- Your module ----
15
  from multimodal_module import MultiModalChatModule
16
 
17
+ # Instantiate once at import time
18
  AI = MultiModalChatModule()
19
 
20
+ TMP_DIR = "/tmp"
21
+ os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
22
 
23
+ # ---------------- Helpers ----------------
24
+ class FileWrapper:
25
+ """Tiny adapter so your module can .download_to_drive(path)."""
26
+ def __init__(self, path: str):
27
+ self._path = path
28
+
29
+ async def download_to_drive(self, dst_path: str) -> None:
30
  loop = asyncio.get_event_loop()
31
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
32
 
33
+ async def save_upload_to_tmp(up: UploadFile) -> str:
34
+ if not up or not up.filename:
35
+ raise ValueError("No file uploaded")
36
+ dest = os.path.join(TMP_DIR, up.filename)
37
+ data = await up.read()
38
+ with open(dest, "wb") as f:
39
+ f.write(data)
40
+ return dest
41
 
42
+ async def call_ai(fn, *args, **kwargs):
43
+ """Call AI methods whether they are sync or async."""
44
+ if fn is None:
45
+ raise AttributeError("Requested AI method is not implemented in multimodal_module")
46
+ if inspect.iscoroutinefunction(fn):
47
+ return await fn(*args, **kwargs)
48
+ return await asyncio.to_thread(lambda: fn(*args, **kwargs))
49
+
50
+ # ---------------- FastAPI app ----------------
51
+ app = FastAPI(title="Multimodal Module API", version="1.0.0")
52
+
53
+ # CORS so external apps can call it
54
+ app.add_middleware(
55
+ CORSMiddleware,
56
+ allow_origins=["*"], # tighten for production
57
+ allow_credentials=True,
58
+ allow_methods=["*"],
59
+ allow_headers=["*"],
60
+ )
61
+
62
+ # ---- Health / root ----
63
+ @app.get("/health", response_class=PlainTextResponse)
64
+ async def health():
65
+ return "ok"
66
+
67
+ @app.get("/")
68
+ async def root():
69
+ return {
70
+ "name": "Multimodal Module API",
71
+ "status": "ready",
72
+ "docs": "/docs",
73
+ "gradio_ui": "/ui"
74
+ }
75
+
76
+ # ---------------- REST Endpoints ----------------
77
+ # Text chat
78
+ @app.post("/api/text")
79
+ async def api_text(
80
  text: str = Form(...),
81
+ user_id: Optional[int] = Form(0),
82
+ lang: str = Form("en"),
83
+ ):
84
+ try:
85
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
86
+ reply = await call_ai(fn, text, int(user_id), lang)
87
+ return {"status": "ok", "reply": reply}
88
+ except Exception as e:
89
+ return JSONResponse({"error": str(e)}, status_code=500)
90
+
91
+ # Hugging Face-style predict (optional)
92
+ @app.post("/api/predict")
93
+ async def api_predict(
94
+ inputs: str = Form(...),
95
+ user_id: Optional[int] = Form(0),
96
+ lang: str = Form("en"),
97
+ ):
98
+ try:
99
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
100
+ reply = await call_ai(fn, inputs, int(user_id), lang)
101
+ return {"data": [reply]}
102
+ except Exception as e:
103
+ return JSONResponse({"error": str(e)}, status_code=500)
104
+
105
+ # Voice -> ASR / emotion
106
+ @app.post("/api/voice")
107
+ async def api_voice(
108
+ user_id: Optional[int] = Form(0),
109
+ audio_file: UploadFile = File(...),
110
+ ):
111
+ try:
112
+ path = await save_upload_to_tmp(audio_file)
113
+ fn = getattr(AI, "process_voice_message", None)
114
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
115
+ return JSONResponse(result)
116
+ except Exception as e:
117
+ return JSONResponse({"error": str(e)}, status_code=500)
118
+
119
+ # TTS
120
+ @app.post("/api/voice_reply")
121
+ async def api_voice_reply(
122
+ user_id: Optional[int] = Form(0),
123
+ reply_text: str = Form(...),
124
+ fmt: str = Form("ogg"),
125
  ):
126
  try:
127
+ fn = getattr(AI, "generate_voice_reply", None)
128
+ out_path = await call_ai(fn, reply_text, int(user_id), fmt)
129
+ return {"status": "ok", "file": out_path}
130
  except Exception as e:
131
+ return JSONResponse({"error": str(e)}, status_code=500)
132
 
133
+ # Image caption
134
+ @app.post("/api/image_caption")
135
+ async def api_image_caption(
136
+ user_id: Optional[int] = Form(0),
137
+ image_file: UploadFile = File(...),
138
+ ):
139
  try:
140
+ path = await save_upload_to_tmp(image_file)
141
+ fn = getattr(AI, "process_image_message", None)
142
+ caption = await call_ai(fn, FileWrapper(path), int(user_id))
143
+ return {"status": "ok", "caption": caption}
 
 
144
  except Exception as e:
145
+ return JSONResponse({"error": str(e)}, status_code=500)
146
 
147
+ # Text-to-image
148
+ @app.post("/api/generate_image")
149
+ async def api_generate_image(
150
+ user_id: Optional[int] = Form(0),
151
+ prompt: str = Form(...),
152
+ width: int = Form(512),
153
+ height: int = Form(512),
154
+ steps: int = Form(30),
155
+ ):
156
  try:
157
+ fn = getattr(AI, "generate_image_from_text", None)
158
+ out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
159
+ return {"status": "ok", "file": out_path}
 
 
 
160
  except Exception as e:
161
+ return JSONResponse({"error": str(e)}, status_code=500)
162
 
163
+ # Image edit / inpaint
164
+ @app.post("/api/edit_image")
165
+ async def api_edit_image(
166
+ user_id: Optional[int] = Form(0),
167
+ image_file: UploadFile = File(...),
168
+ mask_file: Optional[UploadFile] = File(None),
169
+ prompt: str = Form(""),
170
+ ):
171
  try:
172
+ img_path = await save_upload_to_tmp(image_file)
173
+ mask_path = None
174
+ if mask_file:
175
+ mask_path = await save_upload_to_tmp(mask_file)
176
+ fn = getattr(AI, "edit_image_inpaint", None)
177
+ out_path = await call_ai(
178
+ fn,
179
+ FileWrapper(img_path),
180
+ FileWrapper(mask_path) if mask_path else None,
181
+ prompt,
182
+ int(user_id),
183
+ )
184
+ return {"status": "ok", "file": out_path}
185
  except Exception as e:
186
+ return JSONResponse({"error": str(e)}, status_code=500)
187
 
188
+ # Video
189
+ @app.post("/api/video")
190
+ async def api_video(
191
+ user_id: Optional[int] = Form(0),
192
+ video_file: UploadFile = File(...),
193
+ ):
194
  try:
195
+ path = await save_upload_to_tmp(video_file)
196
+ fn = getattr(AI, "process_video", None)
197
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
198
+ return JSONResponse(result)
 
 
199
  except Exception as e:
200
+ return JSONResponse({"error": str(e)}, status_code=500)
201
+
202
+ # Files (PDF/DOCX/TXT)
203
+ @app.post("/api/file")
204
+ async def api_file(
205
+ user_id: Optional[int] = Form(0),
206
+ file_obj: UploadFile = File(...),
207
+ ):
208
+ try:
209
+ path = await save_upload_to_tmp(file_obj)
210
+ fn = getattr(AI, "process_file", None)
211
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
212
+ return JSONResponse(result)
213
+ except Exception as e:
214
+ return JSONResponse({"error": str(e)}, status_code=500)
215
+
216
+ # Code completion
217
+ @app.post("/api/code")
218
+ async def api_code(
219
+ user_id: Optional[int] = Form(0),
220
+ prompt: str = Form(...),
221
+ max_tokens: int = Form(512),
222
+ ):
223
+ try:
224
+ fn = getattr(AI, "code_complete", None)
225
+ try:
226
+ result = await call_ai(fn, int(user_id), prompt, max_tokens)
227
+ except TypeError:
228
+ result = await call_ai(fn, prompt, max_tokens=max_tokens)
229
+ return {"status": "ok", "code": result}
230
+ except Exception as e:
231
+ return JSONResponse({"error": str(e)}, status_code=500)
232
+
233
+ # ---------------- Gradio UI (mounted at /ui) ----------------
234
+ def _gradio_text_fn(text, user_id, lang):
235
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
236
+ if fn is None:
237
+ return "Error: text handler not implemented in multimodal_module"
238
+ # Gradio callbacks run in a worker thread, safe to create/own an event loop
239
+ return asyncio.run(call_ai(fn, text, int(user_id or 0), lang))
240
+
241
+ with gr.Blocks(title="Multimodal Bot — UI") as demo:
242
+ gr.Markdown("# 🧠 Multimodal Bot — UI\nThis is a helper UI. Use the REST API for external apps.")
243
+ with gr.Row():
244
+ g_uid = gr.Textbox(label="User ID", value="0")
245
+ g_lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language")
246
+ g_in = gr.Textbox(lines=3, label="Message")
247
+ g_out = gr.Textbox(lines=6, label="Reply")
248
+ gr.Button("Send").click(_gradio_text_fn, [g_in, g_uid, g_lang], g_out)
249
+
250
+ # Mount Gradio *into* FastAPI at /ui (does not open another port)
251
+ app = gr.mount_gradio_app(app, demo, path="/ui")
252
 
253
+ # ---------------- Entrypoint ----------------
254
  if __name__ == "__main__":
255
+ # Hugging Face Spaces (FastAPI template) sets PORT; bind exactly to it.
256
+ port = int(os.environ.get("PORT", "7860"))
257
+ uvicorn.run("app:app", host="0.0.0.0", port=port)