thienphuc12339 commited on
Commit
4d27720
·
verified ·
1 Parent(s): 4a4d69b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -53
app.py CHANGED
@@ -6,7 +6,9 @@ from pathlib import Path
6
  import shutil
7
  import logging
8
  import uvicorn
9
- from typing import Optional
 
 
10
 
11
  from configs import ModelConfig, InferenceConfig
12
  from tools.models import load_pipeline
@@ -19,12 +21,29 @@ app = FastAPI(title="Sign Language Recognition API")
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
22
  # Define a Pydantic model for the response
23
  class InferenceResponse(BaseModel):
24
  status: str
25
- predictions: Optional[list] = None
26
  message: Optional[str] = None
27
 
 
 
 
 
 
 
 
 
 
28
  @app.post("/inference", response_model=InferenceResponse)
29
  async def inference_endpoint(
30
  file: UploadFile = File(...),
@@ -42,61 +61,106 @@ async def inference_endpoint(
42
  Returns:
43
  InferenceResponse: Kết quả nhận diện.
44
  """
45
- # Kiểm tra file có hợp lệ không
46
- if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")):
47
- raise HTTPException(status_code=400, detail="Unsupported file type.")
48
-
49
- # Tạo thư mục output nếu không tồn tại
50
- output_path = Path(output_dir)
51
- output_path.mkdir(parents=True, exist_ok=True)
52
-
53
- # Lưu video tạm thời
54
- video_path = output_path / file.filename
55
- with open(video_path, "wb") as buffer:
56
- shutil.copyfileobj(file.file, buffer)
57
-
58
- logger.info(f"Video saved to {video_path}")
59
-
60
- # Tải cấu hình mô hình dựa trên model_name
61
  try:
62
- if model_name == "spoter":
63
- model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0")
64
- elif model_name == "sl_gcn":
65
- model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0")
66
- elif model_name == "dsta_slr":
67
- model_config = ModelConfig(arch="dsta_slr", pretrained="models/dsta_slr_joint_motion_v3_0.onnx")
68
- else:
69
- raise ValueError("Unsupported model name.")
70
 
71
- inference_config = InferenceConfig(
72
- source=str(video_path),
73
- output_dir=str(output_path),
74
- use_onnx=True if model_config.pretrained.endswith(".onnx") else False,
75
- device="cpu", # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU
76
- cache_dir="models/huggingface",
77
- visualize=False,
78
- show_skeleton=False,
79
- visibility=0.5,
80
- angle_threshold=140,
81
- min_num_up_frames=10,
82
- min_num_down_frames=10,
83
- delay=400,
84
- top_k=3,
85
- bone_stream=False,
86
- motion_stream=False
87
- )
88
 
89
- # Tải pipeline
90
- pipeline = load_pipeline(model_config, inference_config)
91
- logger.info("Pipeline loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # Chạy inference
94
- result = run_inference(model_config, inference_config, pipeline)
95
- logger.info("Inference completed successfully.")
 
96
 
97
- # Trả về kết quả
98
- return InferenceResponse(status="success", predictions=result.get("results", []))
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- logger.exception("Error during inference.")
102
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
6
  import shutil
7
  import logging
8
  import uvicorn
9
+ from typing import Optional, List
10
+ import pandas as pd
11
+ import json
12
 
13
  from configs import ModelConfig, InferenceConfig
14
  from tools.models import load_pipeline
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # Define a Pydantic model for individual prediction
25
+ class Prediction(BaseModel):
26
+ gloss: str
27
+ score: float
28
+ start_time: float
29
+ end_time: float
30
+ inference_time: float
31
+
32
  # Define a Pydantic model for the response
33
  class InferenceResponse(BaseModel):
34
  status: str
35
+ predictions: Optional[List[Prediction]] = None
36
  message: Optional[str] = None
37
 
38
+ # Define id2gloss mapping
39
+ # Đây là một ví dụ. Bạn cần thay thế bằng bản đồ thực tế của bạn.
40
+ id2gloss = {
41
+ "0": "hello",
42
+ "1": "thanks",
43
+ "2": "yes",
44
+ # Thêm các ánh xạ cần thiết
45
+ }
46
+
47
  @app.post("/inference", response_model=InferenceResponse)
48
  async def inference_endpoint(
49
  file: UploadFile = File(...),
 
61
  Returns:
62
  InferenceResponse: Kết quả nhận diện.
63
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
+ # Kiểm tra file có hợp lệ không
66
+ if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")):
67
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
 
 
 
 
 
68
 
69
+ # Tạo thư mục output nếu không tồn tại
70
+ output_path = Path(output_dir)
71
+ output_path.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Lưu video tạm thời
74
+ video_path = output_path / file.filename
75
+ with open(video_path, "wb") as buffer:
76
+ shutil.copyfileobj(file.file, buffer)
77
+
78
+ logger.info(f"Video saved to {video_path}")
79
+
80
+ # Tải cấu hình mô hình dựa trên model_name
81
+ try:
82
+ if model_name == "spoter":
83
+ model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0")
84
+ elif model_name == "sl_gcn":
85
+ model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0")
86
+ elif model_name == "dsta_slr":
87
+ model_config = ModelConfig(arch="dsta_slr", pretrained="models/dsta_slr_joint_motion_v3_0.onnx")
88
+ else:
89
+ raise ValueError("Unsupported model name.")
90
+
91
+ inference_config = InferenceConfig(
92
+ source=str(video_path),
93
+ output_dir=str(output_path),
94
+ use_onnx=True if model_config.pretrained.endswith(".onnx") else False,
95
+ device="cpu", # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU
96
+ cache_dir="models/huggingface",
97
+ visualize=False,
98
+ show_skeleton=False,
99
+ visibility=0.5,
100
+ angle_threshold=140,
101
+ min_num_up_frames=10,
102
+ min_num_down_frames=10,
103
+ delay=400,
104
+ top_k=3,
105
+ bone_stream=False,
106
+ motion_stream=True # Theo cấu hình YAML bạn cung cấp
107
+ )
108
+
109
+ # Tải pipeline hoặc session
110
+ pipeline_or_session = load_pipeline(model_config, inference_config)
111
+ logger.info("Pipeline loaded successfully.")
112
+
113
+ # Chạy inference
114
+ run_inference(model_config, inference_config, pipeline_or_session)
115
+ logger.info("Inference completed successfully.")
116
 
117
+ # Đọc kết quả từ results.csv
118
+ results_csv = output_path / "results.csv"
119
+ if not results_csv.exists():
120
+ raise HTTPException(status_code=500, detail="Inference did not produce results.")
121
 
122
+ results_df = pd.read_csv(results_csv)
 
123
 
124
+ # Chuyển đổi DataFrame thành list of Prediction
125
+ predictions = []
126
+ for _, row in results_df.iterrows():
127
+ # Giả sử results.csv có các cột: start_time, end_time, inference_time, prediction
128
+ # Và 'prediction' là một danh sách các từ điển với 'gloss' và 'score'
129
+ start_time = row.get("start_time", 0.0)
130
+ end_time = row.get("end_time", 0.0)
131
+ inference_time = row.get("inference_time", 0.0)
132
+ prediction_list = row.get("prediction", [])
133
+
134
+ if isinstance(prediction_list, str):
135
+ # Nếu prediction được lưu dưới dạng chuỗi JSON
136
+ try:
137
+ prediction_list = json.loads(prediction_list.replace("'", '"'))
138
+ except json.JSONDecodeError:
139
+ prediction_list = []
140
+
141
+ for pred in prediction_list:
142
+ gloss = pred.get('gloss', 'Unknown')
143
+ score = pred.get('score', 0.0)
144
+ predictions.append(Prediction(
145
+ gloss=gloss,
146
+ score=score,
147
+ start_time=start_time,
148
+ end_time=end_time,
149
+ inference_time=inference_time
150
+ ))
151
+
152
+ return InferenceResponse(status="success", predictions=predictions)
153
+
154
+ except ValueError as ve:
155
+ logger.exception("ValueError during inference.")
156
+ raise HTTPException(status_code=400, detail=str(ve))
157
+ except Exception as e:
158
+ logger.exception("Error during inference.")
159
+ raise HTTPException(status_code=500, detail=str(e))
160
+
161
  except Exception as e:
162
+ logger.exception("Unexpected error.")
163
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
164
+
165
+ if __name__ == "__main__":
166
+ uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)