ffreemt commited on
Commit
2999f6e
1 Parent(s): 0d8fa99

Update Dockerfile m3_server.py (port = 7860)

Browse files
Files changed (5) hide show
  1. Dockerfile +26 -0
  2. README.md +3 -2
  3. m3_server.py +158 -0
  4. requirements.txt +3 -0
  5. start-m3-server.sh +3 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+ ENV PIP_ROOT_USER_ACTION=ignore \
3
+ TZ=Asia/Shanghai
4
+
5
+ WORKDIR /app
6
+ COPY . .
7
+
8
+ RUN pip install --no-cache-dir --upgrade pip && \
9
+ pip install --no-cache-dir -r requirements.txt
10
+
11
+ # EXPOSE 7860
12
+ # ENV PYTHONUNBUFFERED=1 \
13
+ # GRADIO_ALLOW_FLAGGING=never \
14
+ # GRADIO_NUM_PORTS=1 \
15
+ # GRADIO_SERVER_NAME=0.0.0.0 \
16
+ # GRADIO_THEME=huggingface \
17
+ # SYSTEM=spaces \
18
+ # SHELL=/bin/bash
19
+
20
+ # CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
21
+ # CMD ["TRANSFORMERS_CACHE=./", "infinity_emb", "--model-name-or-path", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "--port", "7860"]
22
+ # CMD ["python", "app.py"]
23
+
24
+ # CMD ["sh", "start-m3-server.sh"]
25
+
26
+ CMD ["python", "m3_server.py"]
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
- title: Baai M3
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: docker
7
- pinned: false
8
  license: mit
9
  ---
 
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: baai-m3
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: docker
7
+ pinned: true
8
  license: mit
9
  ---
10
+ Swagger UI at https://mikeee-baai-m3.hf.space/docs
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
m3_server.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import List, Tuple, Union
5
+ from uuid import uuid4
6
+
7
+ from fastapi import FastAPI, HTTPException, Request
8
+ from fastapi.responses import JSONResponse
9
+ from FlagEmbedding import BGEM3FlagModel
10
+ from pydantic import BaseModel
11
+ from starlette.status import HTTP_504_GATEWAY_TIMEOUT
12
+
13
+ batch_size = 2 # gpu batch_size in order of your available vram
14
+ max_request = 10 # max request for future improvements on api calls / gpu batches (for now is pretty basic)
15
+ max_length = 5000 # max context length for embeddings and passages in re-ranker
16
+ max_q_length = 256 # max context lenght for questions in re-ranker
17
+ request_flush_timeout = .1 # flush time out for future improvements on api calls / gpu batches (for now is pretty basic)
18
+ rerank_weights = [0.4, 0.2, 0.4] # re-rank score weights
19
+ request_time_out = 30 # Timeout threshold
20
+ gpu_time_out = 5 # gpu processing timeout threshold
21
+ port= 3000
22
+ port= 7860
23
+
24
+ class m3Wrapper:
25
+ def __init__(self, model_name: str, device: str = 'cuda'):
26
+ """Init."""
27
+ self.model = BGEM3FlagModel(model_name, device=device, use_fp16=True if device != 'cpu' else False)
28
+
29
+ def embed(self, sentences: List[str]) -> List[List[float]]:
30
+ embeddings = self.model.encode(sentences, batch_size=batch_size, max_length=max_length)['dense_vecs']
31
+ embeddings = embeddings.tolist()
32
+ return embeddings
33
+
34
+ def rerank(self, sentence_pairs: List[Tuple[str, str]]) -> List[float]:
35
+ scores = self.model.compute_score(
36
+ sentence_pairs,
37
+ batch_size=batch_size,
38
+ max_query_length=max_q_length,
39
+ max_passage_length=max_length,
40
+ weights_for_different_modes=rerank_weights
41
+ )['colbert+sparse+dense']
42
+ return scores
43
+
44
+ class EmbedRequest(BaseModel):
45
+ sentences: List[str]
46
+
47
+ class RerankRequest(BaseModel):
48
+ sentence_pairs: List[Tuple[str, str]]
49
+
50
+ class EmbedResponse(BaseModel):
51
+ embeddings: List[List[float]]
52
+
53
+ class RerankResponse(BaseModel):
54
+ scores: List[float]
55
+
56
+ class RequestProcessor:
57
+ def __init__(self, model: m3Wrapper, max_request_to_flush: int, accumulation_timeout: float):
58
+ """Init."""
59
+ self.model = model
60
+ self.max_batch_size = max_request_to_flush
61
+ self.accumulation_timeout = accumulation_timeout
62
+ self.queue = asyncio.Queue()
63
+ self.response_futures = {}
64
+ self.processing_loop_task = None
65
+ self.processing_loop_started = False # Processing pool flag lazy init state
66
+ self.executor = ThreadPoolExecutor() # Thread pool
67
+ self.gpu_lock = asyncio.Semaphore(1) # Sem for gpu sync usage
68
+
69
+ async def ensure_processing_loop_started(self):
70
+ if not self.processing_loop_started:
71
+ print('starting processing_loop')
72
+ self.processing_loop_task = asyncio.create_task(self.processing_loop())
73
+ self.processing_loop_started = True
74
+
75
+ async def processing_loop(self):
76
+ while True:
77
+ requests, request_types, request_ids = [], [], []
78
+ start_time = asyncio.get_event_loop().time()
79
+
80
+ while len(requests) < self.max_batch_size:
81
+ timeout = self.accumulation_timeout - (asyncio.get_event_loop().time() - start_time)
82
+ if timeout <= 0:
83
+ break
84
+
85
+ try:
86
+ req_data, req_type, req_id = await asyncio.wait_for(self.queue.get(), timeout=timeout)
87
+ requests.append(req_data)
88
+ request_types.append(req_type)
89
+ request_ids.append(req_id)
90
+ except asyncio.TimeoutError:
91
+ break
92
+
93
+ if requests:
94
+ await self.process_requests_by_type(requests, request_types, request_ids)
95
+
96
+ async def process_requests_by_type(self, requests, request_types, request_ids):
97
+ tasks = []
98
+ for request_data, request_type, request_id in zip(requests, request_types, request_ids):
99
+ if request_type == 'embed':
100
+ task = asyncio.create_task(self.run_with_semaphore(self.model.embed, request_data.sentences, request_id))
101
+ else: # 'rerank'
102
+ task = asyncio.create_task(self.run_with_semaphore(self.model.rerank, request_data.sentence_pairs, request_id))
103
+ tasks.append(task)
104
+ await asyncio.gather(*tasks)
105
+
106
+ async def run_with_semaphore(self, func, data, request_id):
107
+ async with self.gpu_lock: # Wait for sem
108
+ future = self.executor.submit(func, data)
109
+ try:
110
+ result = await asyncio.wait_for(asyncio.wrap_future(future), timeout= gpu_time_out)
111
+ self.response_futures[request_id].set_result(result)
112
+ except asyncio.TimeoutError:
113
+ self.response_futures[request_id].set_exception(TimeoutError("GPU processing timeout"))
114
+ except Exception as e:
115
+ self.response_futures[request_id].set_exception(e)
116
+
117
+ async def process_request(self, request_data: Union[EmbedRequest, RerankRequest], request_type: str):
118
+ try:
119
+ await self.ensure_processing_loop_started()
120
+ request_id = str(uuid4())
121
+ self.response_futures[request_id] = asyncio.Future()
122
+ await self.queue.put((request_data, request_type, request_id))
123
+ return await self.response_futures[request_id]
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=f"Internal Server Error {e}")
126
+
127
+ app = FastAPI()
128
+
129
+ # Initialize the model and request processor
130
+ model = m3Wrapper('BAAI/bge-m3')
131
+ processor = RequestProcessor(model, accumulation_timeout= request_flush_timeout, max_request_to_flush= max_request)
132
+
133
+ # Adding a middleware returning a 504 error if the request processing time is above a certain threshold
134
+ @app.middleware("http")
135
+ async def timeout_middleware(request: Request, call_next):
136
+ try:
137
+ start_time = time.time()
138
+ return await asyncio.wait_for(call_next(request), timeout=request_time_out)
139
+
140
+ except asyncio.TimeoutError:
141
+ process_time = time.time() - start_time
142
+ return JSONResponse({'detail': 'Request processing time excedeed limit',
143
+ 'processing_time': process_time},
144
+ status_code=HTTP_504_GATEWAY_TIMEOUT)
145
+
146
+ @app.post("/embeddings/", response_model=EmbedResponse)
147
+ async def get_embeddings(request: EmbedRequest):
148
+ embeddings = await processor.process_request(request, 'embed')
149
+ return EmbedResponse(embeddings=embeddings)
150
+
151
+ @app.post("/rerank/", response_model=RerankResponse)
152
+ async def rerank(request: RerankRequest):
153
+ scores = await processor.process_request(request, 'rerank')
154
+ return RerankResponse(scores=scores)
155
+
156
+ if __name__ == "__main__":
157
+ import uvicorn
158
+ uvicorn.run(app, host="0.0.0.0", port= port)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi
2
+ flagembedding
3
+ uvicorn
start-m3-server.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export HF_HOME=/tmp/cache
2
+ export TRANSFORMERS_CACHE=/tmp/cache
3
+ python m3_server.py