vasilee commited on
Commit
f3fc705
·
1 Parent(s): 4a33a49

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -38
main.py CHANGED
@@ -1,7 +1,10 @@
1
- from flask import Flask, request, jsonify
2
  from torch import Tensor
3
  from transformers import AutoTokenizer, AutoModel
4
  from ctranslate2 import Translator
 
 
 
 
5
 
6
 
7
  def average_pool(last_hidden_states: Tensor,
@@ -13,24 +16,40 @@ def average_pool(last_hidden_states: Tensor,
13
 
14
  # text-ada replacement
15
  embeddingTokenizer = AutoTokenizer.from_pretrained(
16
- './multilingual-e5-base')
17
- embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
18
 
19
  # chatGpt replacement
20
  inferenceTokenizer = AutoTokenizer.from_pretrained(
21
- "./ct2fast-flan-alpaca-xl")
22
  inferenceTranslator = Translator(
23
- "./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
 
26
- app = Flask(__name__)
27
 
 
 
 
28
 
29
- @app.route('/text-embedding', methods=['POST'])
30
- def text_embedding():
31
- # Get the JSON data from the request
32
- data = request.get_json()
33
- input = data["input"]
34
 
35
  # Process the input data
36
  batch_dict = embeddingTokenizer([input], max_length=512,
@@ -38,28 +57,24 @@ def text_embedding():
38
  outputs = embeddingModel(**batch_dict)
39
  embeddings = average_pool(outputs.last_hidden_state,
40
  batch_dict['attention_mask'])
41
- token_ids = batch_dict["input_ids"][0].tolist()
42
 
43
- # Create a JSON response
44
- response = {
45
  'embedding': embeddings[0].tolist()
46
  }
47
 
48
- return jsonify(response)
49
 
50
-
51
- @app.route('/inference', methods=['POST'])
52
- def inference():
53
- # Get the JSON data from the request
54
- data = request.get_json()
55
- input_text = data["input"]
56
  max_length = 256
57
  try:
58
- max_length = int(data["max_length"])
59
  max_length = min(1024, max_length)
60
  except:
61
  pass
62
 
 
63
  input_tokens = inferenceTokenizer.convert_ids_to_tokens(
64
  inferenceTokenizer.encode(input_text))
65
 
@@ -70,31 +85,21 @@ def inference():
70
  output_text = inferenceTokenizer.decode(
71
  inferenceTokenizer.convert_tokens_to_ids(output_tokens))
72
 
73
- # Create a JSON response
74
- response = {
75
  'generated_text': output_text
76
  }
77
 
78
- return jsonify(response)
79
 
80
-
81
- @app.route('/tokens-count', methods=['POST'])
82
- def tokens_count():
83
- # Get the JSON data from the request
84
- data = request.get_json()
85
- input_text = data["input"]
86
 
87
  tokens = inferenceTokenizer.convert_ids_to_tokens(
88
  inferenceTokenizer.encode(input_text))
89
 
90
- # Create a JSON response
91
  response = {
92
  'tokens': tokens,
93
  'total': len(tokens)
94
  }
95
-
96
- return jsonify(response)
97
-
98
-
99
- if __name__ == '__main__':
100
- app.run()
 
 
1
  from torch import Tensor
2
  from transformers import AutoTokenizer, AutoModel
3
  from ctranslate2 import Translator
4
+ from typing import Union
5
+
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
 
9
 
10
  def average_pool(last_hidden_states: Tensor,
 
16
 
17
  # text-ada replacement
18
  embeddingTokenizer = AutoTokenizer.from_pretrained(
19
+ './models/multilingual-e5-base')
20
+ embeddingModel = AutoModel.from_pretrained('./models/multilingual-e5-base')
21
 
22
  # chatGpt replacement
23
  inferenceTokenizer = AutoTokenizer.from_pretrained(
24
+ "./models/ct2fast-flan-alpaca-xl")
25
  inferenceTranslator = Translator(
26
+ "./models/ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")
27
+
28
+
29
+ class EmbeddingRequest(BaseModel):
30
+ input: Union[str, None] = None
31
+
32
+
33
+ class TokensCountRequest(BaseModel):
34
+ input: Union[str, None] = None
35
+
36
+
37
+ class InferenceRequest(BaseModel):
38
+ input: Union[str, None] = None
39
+ max_length: Union[int, None] = 0
40
+
41
 
42
+ app = FastAPI()
43
 
 
44
 
45
+ @app.get("/")
46
+ async def root():
47
+ return {"message": "Hello World"}
48
 
49
+
50
+ @app.post("/text-embedding")
51
+ async def text_embedding(request: EmbeddingRequest):
52
+ input = request.input
 
53
 
54
  # Process the input data
55
  batch_dict = embeddingTokenizer([input], max_length=512,
 
57
  outputs = embeddingModel(**batch_dict)
58
  embeddings = average_pool(outputs.last_hidden_state,
59
  batch_dict['attention_mask'])
 
60
 
61
+ # create response
62
+ return {
63
  'embedding': embeddings[0].tolist()
64
  }
65
 
 
66
 
67
+ @app.post('/inference')
68
+ async def inference(request: InferenceRequest):
69
+ input_text = request.input
 
 
 
70
  max_length = 256
71
  try:
72
+ max_length = int(request.max_length)
73
  max_length = min(1024, max_length)
74
  except:
75
  pass
76
 
77
+ # process request
78
  input_tokens = inferenceTokenizer.convert_ids_to_tokens(
79
  inferenceTokenizer.encode(input_text))
80
 
 
85
  output_text = inferenceTokenizer.decode(
86
  inferenceTokenizer.convert_tokens_to_ids(output_tokens))
87
 
88
+ # create response
89
+ return {
90
  'generated_text': output_text
91
  }
92
 
 
93
 
94
+ @app.post('/tokens-count')
95
+ async def tokens_count(request: TokensCountRequest):
96
+ input_text = request.input
 
 
 
97
 
98
  tokens = inferenceTokenizer.convert_ids_to_tokens(
99
  inferenceTokenizer.encode(input_text))
100
 
101
+ # create response
102
  response = {
103
  'tokens': tokens,
104
  'total': len(tokens)
105
  }