ffreemt commited on
Commit
670ac68
1 Parent(s): b153e87

Update device='cuda' if ... else 'cpu' in m3_server.py

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -0
  2. m3_server.py +87 -37
Dockerfile CHANGED
@@ -23,3 +23,4 @@ RUN pip install --no-cache-dir --upgrade pip && \
23
 
24
  CMD ["sh", "start-m3-server.sh"]
25
  # CMD ["sh", "-c", "HF_HOME=/tmp/cache", "python", "m3_server.py"]
 
 
23
 
24
  CMD ["sh", "start-m3-server.sh"]
25
  # CMD ["sh", "-c", "HF_HOME=/tmp/cache", "python", "m3_server.py"]
26
+ # ["sh", "-c", "'FOO=BAR python app.py'"]
m3_server.py CHANGED
@@ -1,11 +1,12 @@
1
  import asyncio
2
  import os
3
  import time
4
- from pathlib import Path
5
  from concurrent.futures import ThreadPoolExecutor
 
6
  from typing import List, Tuple, Union
7
  from uuid import uuid4
8
 
 
9
  from fastapi import FastAPI, HTTPException, Request
10
  from fastapi.responses import JSONResponse
11
  from FlagEmbedding import BGEM3FlagModel
@@ -13,26 +14,34 @@ from pydantic import BaseModel
13
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
14
 
15
  Path("/tmp/cache").mkdir(exist_ok=True)
16
- os.environ["HF_HOME"] = "/tmp/cache" # does not quite work, need
17
-
18
- batch_size = 2 # gpu batch_size in order of your available vram
19
- max_request = 10 # max request for future improvements on api calls / gpu batches (for now is pretty basic)
20
- max_length = 5000 # max context length for embeddings and passages in re-ranker
21
- max_q_length = 256 # max context lenght for questions in re-ranker
22
- request_flush_timeout = .1 # flush time out for future improvements on api calls / gpu batches (for now is pretty basic)
23
- rerank_weights = [0.4, 0.2, 0.4] # re-rank score weights
 
 
24
  request_time_out = 30 # Timeout threshold
25
- gpu_time_out = 5 # gpu processing timeout threshold
26
- port= 3000
27
- port= 7860
 
 
28
 
29
  class m3Wrapper:
30
- def __init__(self, model_name: str, device: str = 'cuda'):
31
  """Init."""
32
- self.model = BGEM3FlagModel(model_name, device=device, use_fp16=True if device != 'cpu' else False)
 
 
33
 
34
  def embed(self, sentences: List[str]) -> List[List[float]]:
35
- embeddings = self.model.encode(sentences, batch_size=batch_size, max_length=max_length)['dense_vecs']
 
 
36
  embeddings = embeddings.tolist()
37
  return embeddings
38
 
@@ -42,24 +51,31 @@ class m3Wrapper:
42
  batch_size=batch_size,
43
  max_query_length=max_q_length,
44
  max_passage_length=max_length,
45
- weights_for_different_modes=rerank_weights
46
- )['colbert+sparse+dense']
47
  return scores
48
 
 
49
  class EmbedRequest(BaseModel):
50
  sentences: List[str]
51
 
 
52
  class RerankRequest(BaseModel):
53
  sentence_pairs: List[Tuple[str, str]]
54
 
 
55
  class EmbedResponse(BaseModel):
56
  embeddings: List[List[float]]
57
 
 
58
  class RerankResponse(BaseModel):
59
  scores: List[float]
60
 
 
61
  class RequestProcessor:
62
- def __init__(self, model: m3Wrapper, max_request_to_flush: int, accumulation_timeout: float):
 
 
63
  """Init."""
64
  self.model = model
65
  self.max_batch_size = max_request_to_flush
@@ -73,7 +89,7 @@ class RequestProcessor:
73
 
74
  async def ensure_processing_loop_started(self):
75
  if not self.processing_loop_started:
76
- print('starting processing_loop')
77
  self.processing_loop_task = asyncio.create_task(self.processing_loop())
78
  self.processing_loop_started = True
79
 
@@ -83,12 +99,16 @@ class RequestProcessor:
83
  start_time = asyncio.get_event_loop().time()
84
 
85
  while len(requests) < self.max_batch_size:
86
- timeout = self.accumulation_timeout - (asyncio.get_event_loop().time() - start_time)
 
 
87
  if timeout <= 0:
88
  break
89
 
90
  try:
91
- req_data, req_type, req_id = await asyncio.wait_for(self.queue.get(), timeout=timeout)
 
 
92
  requests.append(req_data)
93
  request_types.append(req_type)
94
  request_ids.append(req_id)
@@ -96,15 +116,27 @@ class RequestProcessor:
96
  break
97
 
98
  if requests:
99
- await self.process_requests_by_type(requests, request_types, request_ids)
 
 
100
 
101
  async def process_requests_by_type(self, requests, request_types, request_ids):
102
  tasks = []
103
- for request_data, request_type, request_id in zip(requests, request_types, request_ids):
104
- if request_type == 'embed':
105
- task = asyncio.create_task(self.run_with_semaphore(self.model.embed, request_data.sentences, request_id))
 
 
 
 
 
 
106
  else: # 'rerank'
107
- task = asyncio.create_task(self.run_with_semaphore(self.model.rerank, request_data.sentence_pairs, request_id))
 
 
 
 
108
  tasks.append(task)
109
  await asyncio.gather(*tasks)
110
 
@@ -112,14 +144,20 @@ class RequestProcessor:
112
  async with self.gpu_lock: # Wait for sem
113
  future = self.executor.submit(func, data)
114
  try:
115
- result = await asyncio.wait_for(asyncio.wrap_future(future), timeout= gpu_time_out)
 
 
116
  self.response_futures[request_id].set_result(result)
117
  except asyncio.TimeoutError:
118
- self.response_futures[request_id].set_exception(TimeoutError("GPU processing timeout"))
 
 
119
  except Exception as e:
120
  self.response_futures[request_id].set_exception(e)
121
 
122
- async def process_request(self, request_data: Union[EmbedRequest, RerankRequest], request_type: str):
 
 
123
  try:
124
  await self.ensure_processing_loop_started()
125
  request_id = str(uuid4())
@@ -129,6 +167,7 @@ class RequestProcessor:
129
  except Exception as e:
130
  raise HTTPException(status_code=500, detail=f"Internal Server Error {e}")
131
 
 
132
  app = FastAPI(
133
  title="baai m3, serving embed and rerank",
134
  description="Swagger UI at https://mikeee-baai-m3.hf.space/docs",
@@ -136,8 +175,10 @@ app = FastAPI(
136
  )
137
 
138
  # Initialize the model and request processor
139
- model = m3Wrapper('BAAI/bge-m3')
140
- processor = RequestProcessor(model, accumulation_timeout= request_flush_timeout, max_request_to_flush= max_request)
 
 
141
 
142
  # Adding a middleware returning a 504 error if the request processing time is above a certain threshold
143
  @app.middleware("http")
@@ -148,25 +189,34 @@ async def timeout_middleware(request: Request, call_next):
148
 
149
  except asyncio.TimeoutError:
150
  process_time = time.time() - start_time
151
- return JSONResponse({'detail': 'Request processing time excedeed limit',
152
- 'processing_time': process_time},
153
- status_code=HTTP_504_GATEWAY_TIMEOUT)
 
 
 
 
 
154
 
155
  @app.get("/")
156
  async def landing():
157
  """Define landing page."""
158
  return "Swagger UI at https://mikeee-baai-m3.hf.space/docs"
159
 
 
160
  @app.post("/embeddings/", response_model=EmbedResponse)
161
  async def get_embeddings(request: EmbedRequest):
162
- embeddings = await processor.process_request(request, 'embed')
163
  return EmbedResponse(embeddings=embeddings)
164
 
 
165
  @app.post("/rerank/", response_model=RerankResponse)
166
  async def rerank(request: RerankRequest):
167
- scores = await processor.process_request(request, 'rerank')
168
  return RerankResponse(scores=scores)
169
 
 
170
  if __name__ == "__main__":
171
  import uvicorn
172
- uvicorn.run(app, host="0.0.0.0", port= port)
 
 
1
  import asyncio
2
  import os
3
  import time
 
4
  from concurrent.futures import ThreadPoolExecutor
5
+ from pathlib import Path
6
  from typing import List, Tuple, Union
7
  from uuid import uuid4
8
 
9
+ import torch
10
  from fastapi import FastAPI, HTTPException, Request
11
  from fastapi.responses import JSONResponse
12
  from FlagEmbedding import BGEM3FlagModel
 
14
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
15
 
16
  Path("/tmp/cache").mkdir(exist_ok=True)
17
+ os.environ[
18
+ "HF_HOME"
19
+ ] = "/tmp/cache" # does not quite work, need Path("/tmp/cache").mkdir(exist_ok=True)?
20
+
21
+ batch_size = 2 # gpu batch_size in order of your available vram
22
+ max_request = 10 # max request for future improvements on api calls / gpu batches (for now is pretty basic)
23
+ max_length = 5000 # max context length for embeddings and passages in re-ranker
24
+ max_q_length = 256 # max context lenght for questions in re-ranker
25
+ request_flush_timeout = 0.1 # flush time out for future improvements on api calls / gpu batches (for now is pretty basic)
26
+ rerank_weights = [0.4, 0.2, 0.4] # re-rank score weights
27
  request_time_out = 30 # Timeout threshold
28
+ gpu_time_out = 5 # gpu processing timeout threshold
29
+ port = 3000
30
+ port = 7860
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
 
34
  class m3Wrapper:
35
+ def __init__(self, model_name: str, device: str = DEVICE):
36
  """Init."""
37
+ self.model = BGEM3FlagModel(
38
+ model_name, device=device, use_fp16=True if device != "cpu" else False
39
+ )
40
 
41
  def embed(self, sentences: List[str]) -> List[List[float]]:
42
+ embeddings = self.model.encode(
43
+ sentences, batch_size=batch_size, max_length=max_length
44
+ )["dense_vecs"]
45
  embeddings = embeddings.tolist()
46
  return embeddings
47
 
 
51
  batch_size=batch_size,
52
  max_query_length=max_q_length,
53
  max_passage_length=max_length,
54
+ weights_for_different_modes=rerank_weights,
55
+ )["colbert+sparse+dense"]
56
  return scores
57
 
58
+
59
  class EmbedRequest(BaseModel):
60
  sentences: List[str]
61
 
62
+
63
  class RerankRequest(BaseModel):
64
  sentence_pairs: List[Tuple[str, str]]
65
 
66
+
67
  class EmbedResponse(BaseModel):
68
  embeddings: List[List[float]]
69
 
70
+
71
  class RerankResponse(BaseModel):
72
  scores: List[float]
73
 
74
+
75
  class RequestProcessor:
76
+ def __init__(
77
+ self, model: m3Wrapper, max_request_to_flush: int, accumulation_timeout: float
78
+ ):
79
  """Init."""
80
  self.model = model
81
  self.max_batch_size = max_request_to_flush
 
89
 
90
  async def ensure_processing_loop_started(self):
91
  if not self.processing_loop_started:
92
+ print("starting processing_loop")
93
  self.processing_loop_task = asyncio.create_task(self.processing_loop())
94
  self.processing_loop_started = True
95
 
 
99
  start_time = asyncio.get_event_loop().time()
100
 
101
  while len(requests) < self.max_batch_size:
102
+ timeout = self.accumulation_timeout - (
103
+ asyncio.get_event_loop().time() - start_time
104
+ )
105
  if timeout <= 0:
106
  break
107
 
108
  try:
109
+ req_data, req_type, req_id = await asyncio.wait_for(
110
+ self.queue.get(), timeout=timeout
111
+ )
112
  requests.append(req_data)
113
  request_types.append(req_type)
114
  request_ids.append(req_id)
 
116
  break
117
 
118
  if requests:
119
+ await self.process_requests_by_type(
120
+ requests, request_types, request_ids
121
+ )
122
 
123
  async def process_requests_by_type(self, requests, request_types, request_ids):
124
  tasks = []
125
+ for request_data, request_type, request_id in zip(
126
+ requests, request_types, request_ids
127
+ ):
128
+ if request_type == "embed":
129
+ task = asyncio.create_task(
130
+ self.run_with_semaphore(
131
+ self.model.embed, request_data.sentences, request_id
132
+ )
133
+ )
134
  else: # 'rerank'
135
+ task = asyncio.create_task(
136
+ self.run_with_semaphore(
137
+ self.model.rerank, request_data.sentence_pairs, request_id
138
+ )
139
+ )
140
  tasks.append(task)
141
  await asyncio.gather(*tasks)
142
 
 
144
  async with self.gpu_lock: # Wait for sem
145
  future = self.executor.submit(func, data)
146
  try:
147
+ result = await asyncio.wait_for(
148
+ asyncio.wrap_future(future), timeout=gpu_time_out
149
+ )
150
  self.response_futures[request_id].set_result(result)
151
  except asyncio.TimeoutError:
152
+ self.response_futures[request_id].set_exception(
153
+ TimeoutError("GPU processing timeout")
154
+ )
155
  except Exception as e:
156
  self.response_futures[request_id].set_exception(e)
157
 
158
+ async def process_request(
159
+ self, request_data: Union[EmbedRequest, RerankRequest], request_type: str
160
+ ):
161
  try:
162
  await self.ensure_processing_loop_started()
163
  request_id = str(uuid4())
 
167
  except Exception as e:
168
  raise HTTPException(status_code=500, detail=f"Internal Server Error {e}")
169
 
170
+
171
  app = FastAPI(
172
  title="baai m3, serving embed and rerank",
173
  description="Swagger UI at https://mikeee-baai-m3.hf.space/docs",
 
175
  )
176
 
177
  # Initialize the model and request processor
178
+ model = m3Wrapper("BAAI/bge-m3")
179
+ processor = RequestProcessor(
180
+ model, accumulation_timeout=request_flush_timeout, max_request_to_flush=max_request
181
+ )
182
 
183
  # Adding a middleware returning a 504 error if the request processing time is above a certain threshold
184
  @app.middleware("http")
 
189
 
190
  except asyncio.TimeoutError:
191
  process_time = time.time() - start_time
192
+ return JSONResponse(
193
+ {
194
+ "detail": "Request processing time excedeed limit",
195
+ "processing_time": process_time,
196
+ },
197
+ status_code=HTTP_504_GATEWAY_TIMEOUT,
198
+ )
199
+
200
 
201
  @app.get("/")
202
  async def landing():
203
  """Define landing page."""
204
  return "Swagger UI at https://mikeee-baai-m3.hf.space/docs"
205
 
206
+
207
  @app.post("/embeddings/", response_model=EmbedResponse)
208
  async def get_embeddings(request: EmbedRequest):
209
+ embeddings = await processor.process_request(request, "embed")
210
  return EmbedResponse(embeddings=embeddings)
211
 
212
+
213
  @app.post("/rerank/", response_model=RerankResponse)
214
  async def rerank(request: RerankRequest):
215
+ scores = await processor.process_request(request, "rerank")
216
  return RerankResponse(scores=scores)
217
 
218
+
219
  if __name__ == "__main__":
220
  import uvicorn
221
+
222
+ uvicorn.run(app, host="0.0.0.0", port=port)