flan-t5-ct2 / main.py
vasilee's picture
Update main.py
be9af02
raw
history blame contribute delete
No virus
2.89 kB
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union
from fastapi import FastAPI
from pydantic import BaseModel
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
'./multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
"./fastchat-t5-3b-ct2")
inferenceTranslator = Translator(
"./fastchat-t5-3b-ct2", compute_type="int8", device="cpu")
class EmbeddingRequest(BaseModel):
input: Union[str, None] = None
class TokensCountRequest(BaseModel):
input: Union[str, None] = None
class InferenceRequest(BaseModel):
input: Union[str, None] = None
max_length: Union[int, None] = 0
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
input = request.input
# Process the input data
batch_dict = embeddingTokenizer([input], max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = embeddingModel(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
# create response
return {
'embedding': embeddings[0].tolist()
}
@app.post('/inference')
async def inference(request: InferenceRequest):
input_text = request.input
max_length = 256
try:
max_length = int(request.max_length)
max_length = min(1024, max_length)
except:
pass
# process request
input_tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
results = inferenceTranslator.translate_batch(
[input_tokens], beam_size=1, max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=40, sampling_temperature=0.7, use_vmap=False)
output_tokens = results[0].hypotheses[0]
output_text = inferenceTokenizer.decode(
inferenceTokenizer.convert_tokens_to_ids(output_tokens))
# create response
return {
'generated_text': output_text
}
@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
input_text = request.input
tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
# create response
return {
'tokens': tokens,
'total': len(tokens)
}