STAN-EM / app.py
Jack Wong
updated app.py
207ea71
raw
history blame
1.88 kB
import time
import torch
import tempfile
import gradio as gr
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from transformers import PreTrainedTokenizer, PreTrainedModel
temp_dir = tempfile.TemporaryDirectory()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_tokenizer() -> PreTrainedTokenizer:
return AutoTokenizer.from_pretrained('thenlper/gte-large', trust_remote_code=True, cache_dir=temp_dir.name)
def get_model() -> PreTrainedModel:
return AutoModel.from_pretrained('thenlper/gte-large', trust_remote_code=True, cache_dir=temp_dir.name).to(device)
def average_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def normalize_embeddings(embeddings: torch.Tensor) -> list[float]:
norm_embeddings = F.normalize(embeddings, p=2, dim=1)
return norm_embeddings.tolist()
def get_embeddings(text: str) -> list[float]:
tokenizer = get_tokenizer()
model = get_model()
with torch.inference_mode():
start = time.time()
batch_dict = tokenizer(
text,
max_length=512,
truncation=True,
padding=True,
return_tensors='pt'
).to(device)
outputs = model(**batch_dict, return_dict=True)
embeddings = average_pooling(
last_hidden_states=outputs.last_hidden_state,
attention_mask=batch_dict['attention_mask']
)
norm_embeddings = normalize_embeddings(embeddings)
end = time.time()
print("Execution time: ", end - start)
return norm_embeddings
iface = gr.Interface(fn=get_embeddings, inputs="text", outputs="text")
iface.launch()