vasilee commited on
Commit
902ad9c
·
1 Parent(s): 8de2cc9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +94 -14
main.py CHANGED
@@ -1,20 +1,100 @@
1
- from fastapi import FastAPI
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
 
4
 
5
- from transformers import pipeline
6
 
7
- app = FastAPI()
 
 
 
 
8
 
9
- pipe_flan = pipeline("text2text-generation", model="reasonwang/flan-t5-xl-8bit")
10
 
11
- @app.get("/infer_t5")
12
- def t5(input):
13
- output = pipe_flan(input)
14
- return {"output": output[0]["generated_text"]}
15
 
16
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
17
 
18
- @app.get("/")
19
- def index() -> FileResponse:
20
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from torch import Tensor
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from ctranslate2 import Translator
5
 
 
6
 
7
+ def average_pool(last_hidden_states: Tensor,
8
+ attention_mask: Tensor) -> Tensor:
9
+ last_hidden = last_hidden_states.masked_fill(
10
+ ~attention_mask[..., None].bool(), 0.0)
11
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
12
 
 
13
 
14
+ # text-ada replacement
15
+ embeddingTokenizer = AutoTokenizer.from_pretrained(
16
+ './multilingual-e5-base')
17
+ embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
18
 
19
+ # chatGpt replacement
20
+ inferenceTokenizer = AutoTokenizer.from_pretrained(
21
+ "./ct2fast-flan-alpaca-xl")
22
+ inferenceTranslator = Translator(
23
+ "./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")
24
 
25
+
26
+ app = Flask(__name__)
27
+
28
+
29
+ @app.route('/text-embedding', methods=['POST'])
30
+ def text_embedding():
31
+ # Get the JSON data from the request
32
+ data = request.get_json()
33
+ input = data["input"]
34
+
35
+ # Process the input data
36
+ batch_dict = embeddingTokenizer([input], max_length=512,
37
+ padding=True, truncation=True, return_tensors='pt')
38
+ outputs = embeddingModel(**batch_dict)
39
+ embeddings = average_pool(outputs.last_hidden_state,
40
+ batch_dict['attention_mask'])
41
+ token_ids = batch_dict["input_ids"][0].tolist()
42
+
43
+ # Create a JSON response
44
+ response = {
45
+ 'embedding': embeddings[0].tolist()
46
+ }
47
+
48
+ return jsonify(response)
49
+
50
+
51
+ @app.route('/inference', methods=['POST'])
52
+ def inference():
53
+ # Get the JSON data from the request
54
+ data = request.get_json()
55
+ input_text = data["input"]
56
+ max_length = 256
57
+ try:
58
+ max_length = int(data["max_length"])
59
+ max_length = min(1024, max_length)
60
+ except:
61
+ pass
62
+
63
+ input_tokens = inferenceTokenizer.convert_ids_to_tokens(
64
+ inferenceTokenizer.encode(input_text))
65
+
66
+ results = inferenceTranslator.translate_batch(
67
+ [input_tokens], max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=30, sampling_temperature=1.1, use_vmap=True)
68
+
69
+ output_tokens = results[0].hypotheses[0]
70
+ output_text = inferenceTokenizer.decode(
71
+ inferenceTokenizer.convert_tokens_to_ids(output_tokens))
72
+
73
+ # Create a JSON response
74
+ response = {
75
+ 'generated_text': output_text
76
+ }
77
+
78
+ return jsonify(response)
79
+
80
+
81
+ @app.route('/tokens-count', methods=['POST'])
82
+ def tokens_count():
83
+ # Get the JSON data from the request
84
+ data = request.get_json()
85
+ input_text = data["input"]
86
+
87
+ tokens = inferenceTokenizer.convert_ids_to_tokens(
88
+ inferenceTokenizer.encode(input_text))
89
+
90
+ # Create a JSON response
91
+ response = {
92
+ 'tokens': tokens,
93
+ 'total': len(tokens)
94
+ }
95
+
96
+ return jsonify(response)
97
+
98
+
99
+ if __name__ == '__main__':
100
+ app.run()