fastelectronicvegetable commited on
Commit
f0405e6
1 Parent(s): 2c88fd5
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. main.py +8 -14
  3. requirements.txt +1 -1
Dockerfile CHANGED
@@ -17,4 +17,4 @@ WORKDIR $HOME/app
17
 
18
  COPY --chown=user . $HOME/app
19
 
20
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
17
 
18
  COPY --chown=user . $HOME/app
19
 
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers=4"]
main.py CHANGED
@@ -1,11 +1,11 @@
1
  from typing import Generic, List, Optional, TypeVar
2
  from functools import partial
3
- from pydantic import BaseModel, ValidationError, validator
4
- from pydantic.generics import GenericModel
5
  from sentence_transformers import SentenceTransformer
6
  from fastapi import FastAPI
7
- import os, asyncio, numpy, ujson
8
  from fastapi.middleware.cors import CORSMiddleware
 
9
 
10
  MODEL = SentenceTransformer("all-mpnet-base-v2")
11
 
@@ -24,24 +24,18 @@ def cache(func):
24
 
25
  @cache
26
  def _encode(sentences: List[str]):
27
- array = [numpy.around(a.numpy(), 3) for a in MODEL.encode(sentences, normalize_embeddings=True, convert_to_tensor=True, batch_size=4, show_progress_bar=True)]
 
28
  return array
29
 
30
- async def encode(sentences: List[str]) -> List[numpy.ndarray]:
31
- loop = asyncio.get_event_loop()
32
- result = await loop.run_in_executor(None, _encode, sentences)
33
- return result
34
-
35
  class EmbedReq(BaseModel):
36
  sentences: List[str]
37
 
38
  app = FastAPI()
39
 
40
- @app.post("/embed")
41
- async def embed(embed: EmbedReq):
42
- result = await encode(embed.sentences)
43
- # Convert it to an ordinary list of floats
44
- return ujson.dumps([r.tolist() for r in result])
45
 
46
  app.add_middleware(
47
  CORSMiddleware,
 
1
  from typing import Generic, List, Optional, TypeVar
2
  from functools import partial
3
+ from pydantic import BaseModel
 
4
  from sentence_transformers import SentenceTransformer
5
  from fastapi import FastAPI
6
+ import numpy
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import ORJSONResponse
9
 
10
  MODEL = SentenceTransformer("all-mpnet-base-v2")
11
 
 
24
 
25
  @cache
26
  def _encode(sentences: List[str]):
27
+ embeddings = MODEL.encode(sentences, normalize_embeddings=True, batch_size=2, show_progress_bar=True)
28
+ array = [numpy.around(a, 3).tolist() for a in embeddings]
29
  return array
30
 
 
 
 
 
 
31
  class EmbedReq(BaseModel):
32
  sentences: List[str]
33
 
34
  app = FastAPI()
35
 
36
+ @app.post("/embed", response_class=ORJSONResponse)
37
+ def embed(embed: EmbedReq):
38
+ return _encode(embed.sentences)
 
 
39
 
40
  app.add_middleware(
41
  CORSMiddleware,
requirements.txt CHANGED
@@ -4,4 +4,4 @@ joblib==1.2
4
  fastapi==0.89
5
  uvicorn[standard]==0.20
6
  huggingface-hub==0.10.1
7
- ujson==5.7
 
4
  fastapi==0.89
5
  uvicorn[standard]==0.20
6
  huggingface-hub==0.10.1
7
+ orjson==3.8.9