Spaces:
Runtime error
Runtime error
import os | |
import time | |
import shutil | |
from pathlib import Path | |
from functools import partial | |
from typing import Union, Dict, List | |
import torch | |
from torch.utils.data import DataLoader | |
import datasets | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, PreTrainedTokenizer, DataCollatorWithPadding | |
from huggingface_hub import Repository, create_repo, HfApi | |
from optimum.onnxruntime import ( | |
AutoOptimizationConfig, | |
ORTModelForFeatureExtraction, | |
ORTOptimizer, | |
) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
opt_configs = { | |
"O2": AutoOptimizationConfig.O2(), | |
"O3": AutoOptimizationConfig.O3(), | |
"O4": AutoOptimizationConfig.O4(), | |
} | |
def get_batch_size(device_name: str, model_name: str, opt_level: str): | |
""" | |
TODO: run actual tests | |
T4 has 16GB | |
A10 has 24GB | |
Args: | |
device_name (`str`): | |
The name of the GPU device in use. | |
model_name (`str`): | |
The name of the model in use. | |
opt_level (`str`): | |
The optimization level in use. | |
Returns: | |
`int`: | |
The batch size to use. | |
""" | |
if "small" in model_name: | |
bs = 192 | |
elif "base" in model_name: | |
bs = 128 | |
elif "large" in model_name: | |
bs = 64 | |
else: | |
bs = 32 | |
if "A10" in device_name: | |
bs *= 2 | |
if opt_level == "O4": | |
bs *= 2 | |
return bs | |
def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): | |
""" | |
Mean pool the token embeddings. | |
Args: | |
last_hidden_state (`tuple`): | |
The output of the model. | |
attention_mask (`torch.Tensor`): | |
The attention mask. | |
Returns: | |
`torch.Tensor`: | |
The mean pooled embeddings. | |
""" | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
) | |
return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def get_model_and_tokenizer(model_name: str, optimization_level: str, progress): | |
""" | |
Load the model and tokenizer from the HuggingFace Hub. | |
If the model is not already optimized, optimize it and save it to the local directory. | |
Args: | |
model_name (`str`): | |
The name of the model to load. | |
optimization_level (`str`): | |
The optimization level to use. Should be one of `"O2"`, `"O3"`, or `"O4"`. | |
Returns: | |
model (`ORTModelForFeatureExtraction`): | |
The optimized model. | |
tokenizer (`PreTrainedTokenizer`): | |
The tokenizer. | |
""" | |
optimized_model_name = f"model_optimized_{optimization_level}.onnx" | |
model_dir = Path(model_name.replace("/", "_")) | |
if not (model_dir / optimized_model_name).exists(): | |
if progress is not None: | |
progress(0.2, "Downloading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.save_pretrained(model_dir) | |
if progress is not None: | |
progress(0.4, "Downloading model...") | |
model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True) | |
model.save_pretrained(model_dir) | |
optimizer = ORTOptimizer.from_pretrained(model) | |
optimization_config = opt_configs[optimization_level] | |
if progress is not None: | |
progress(0.6, "Optimizing model...") | |
optimizer.optimize(save_dir=model_dir, optimization_config=optimization_config) | |
Path(model_dir / "model_optimized.onnx").rename( | |
model_dir / optimized_model_name | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
if progress is not None: | |
progress(0.8, "Loading optimized model and tokenizer...") | |
return ( | |
ORTModelForFeatureExtraction.from_pretrained( | |
model_dir, | |
file_name=optimized_model_name, | |
provider="CUDAExecutionProvider", | |
), | |
tokenizer, | |
) | |
# def collate_fn(examples, tokenizer=None, padding=None, column_name="text"): | |
# try: | |
# keys = examples[0].keys() | |
# except KeyError: | |
# print(examples) | |
# else: | |
# batch = {k: [] for k in examples[0].keys()} | |
# tokenized = tokenizer( | |
# [x[column_name] for x in examples], | |
# truncation=True, | |
# padding=padding, | |
# max_length=512, | |
# return_tensors="pt" | |
# ) | |
# tokenized[column_name] = [x[column_name] for x in examples] | |
# return tokenized | |
def batch_embed( | |
ds: datasets.IterableDataset, | |
model: ORTModelForFeatureExtraction, | |
tokenizer: PreTrainedTokenizer, | |
model_name: str, | |
column_name: str, | |
new_dataset_id: str, | |
opt_level: str, | |
upload_batch_size: int = 10_000, | |
map_batch_size: int = 2000, | |
num2skip: int = 0, | |
num2embed: int = -1, | |
progress=None, | |
): | |
""" | |
Run the model on the dataset and upload the embeddings to the hub. | |
Args: | |
ds (`datasets.Dataset`): | |
dataset to embed. From `load_hf_dataset` | |
model (`ORTModelForFeatureExtraction`): | |
model to use for embedding. From `get_model_and_tokenizer` | |
tokenizer (`AutoTokenizer`): | |
tokenizer to use for embedding. From `get_model_and_tokenizer` | |
model_name (`str`): | |
name of the model to use. Used to determine batch size. | |
column_name (`str`): | |
column name to use for embedding. Default option in gradio app is `text` | |
new_dataset_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
e.g. nbroad/new-embeddings | |
opt_level (`str`): | |
optimization level to use. Should be one of `O2`, `O3`, `O4` | |
See here for more details on optimization levels: | |
https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration | |
upload_batch_size (`int`, *optional*, defaults to `10_000`): | |
number of embeddings to upload at once. Defaults to 10,000. | |
map_batch_size (`int`, *optional*, defaults to `2000`): | |
number of examples to tokenize at once. Defaults to 2000. | |
num2skip (`int`, *optional*, defaults to `0`): | |
number of examples to skip. Defaults to 0. | |
num2embed (`int`, *optional*, defaults to `-1`): | |
number of examples to embed. Defaults to -1, which means all examples. | |
Returns: | |
current_count (`int`): | |
number of examples embedded so far | |
time_taken (`float`): | |
time taken to embed the examples in seconds | |
""" | |
api = HfApi( | |
token=os.environ["HF_TOKEN"], | |
) | |
username = api.whoami()["name"] | |
if "/" not in new_dataset_id: | |
new_dataset_id = username + "/" + new_dataset_id | |
repo = init_git_repo(new_dataset_id) | |
embeds = [] | |
texts = [] | |
# current count keeps track of how many have been embedded in total | |
current_count = num2skip | |
# last_count keeps track of how many had been embedded since last push | |
last_count = current_count | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level) | |
start_time = time.time() | |
collator = DataCollatorWithPadding( | |
tokenizer, padding=True, max_length=512, pad_to_multiple_of=16 | |
) | |
dl = DataLoader( | |
ds, | |
batch_size=inference_bs, | |
shuffle=False, | |
num_workers=2, | |
pin_memory=True, | |
drop_last=False, | |
collate_fn=collator, | |
) | |
for batch in dl: | |
ids = batch["input_ids"].to(device) | |
mask = batch["attention_mask"].to(device) | |
t_ids = torch.zeros_like(ids) | |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids) | |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist()) | |
texts.extend(batch[column_name]) | |
current_count += ids.shape[0] | |
# Periodically upload to the hub | |
if len(embeds) > upload_batch_size: | |
push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
embeds = [] | |
texts = [] | |
last_count = current_count | |
# Provide updates | |
if progress is not None: | |
progress( | |
(current_count, None), | |
"Embedding docs...", | |
total=None, | |
unit="Docs Embedded", | |
) | |
time_taken = time.time() - start_time | |
# If there are any remaining embeddings, upload them | |
if len(embeds) > 0: | |
push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
return current_count - num2skip, time_taken | |
def init_git_repo(repo_id: str): | |
""" | |
Initialize a git repo for the new dataset. | |
***Removes existing local folder if exists*** | |
Args: | |
repo_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
e.g. nbroad/new-embeddings | |
""" | |
local_dir = repo_id.replace("/", "_") | |
create_repo( | |
repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
private=True, | |
exist_ok=True, | |
) | |
try: | |
repo = Repository( | |
local_dir=local_dir, | |
clone_from=repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
skip_lfs_files=True, | |
) | |
except EnvironmentError: | |
shutil.rmtree(local_dir) | |
repo = Repository( | |
local_dir=local_dir, | |
clone_from=repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
skip_lfs_files=True, | |
) | |
if repo is not None: | |
repo.git_pull() | |
return repo | |
def push_to_repo( | |
repo_id: str, | |
last_count: int, | |
current_count: int, | |
embeds: List[List[float]], | |
texts: List[str], | |
api: HfApi, | |
): | |
""" | |
Push embeddings to the repo. | |
Args: | |
repo_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
last_count (`int`): | |
last count of embeddings. | |
This is the number of embeddings that have already been pushed. | |
current_count (`int`): | |
current count of embeddings. | |
This is the number of embeddings that have been pushed after this batch. | |
embeds (`List[List[float]]`): | |
list of embeddings to push to the repo | |
texts (`List[str]`): | |
list of texts to push to the repo | |
api (`huggingface_hub.HfApi`): | |
api to use to push to the repo | |
""" | |
temp_ds = Dataset.from_dict( | |
{ | |
"embedding": embeds, | |
"text": texts, | |
} | |
) | |
local_dir = repo_id.replace("/", "_") | |
data_dir = Path(local_dir) / "data" | |
data_dir.mkdir(exist_ok=True, parents=True) | |
# use zfill so sorting puts the files in order | |
filename = f"embeddings_{str(last_count).zfill(8)}_{current_count}.parquet" | |
filepath = str(data_dir / filename) | |
temp_ds.to_parquet(filepath) | |
files = sorted(list(data_dir.glob("*.parquet"))) | |
api.upload_file( | |
path_or_fileobj=filepath, | |
path_in_repo=f"data/{filename}", | |
repo_id=repo_id, | |
repo_type="dataset", | |
run_as_future=True, | |
token=os.environ["HF_TOKEN"], | |
commit_message=f"Embedded examples {last_count} thru {current_count}", | |
) | |
# Delete old files | |
if len(files) > 4: | |
for file in files[:2]: | |
file.unlink() | |