Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import time | |
| import shutil | |
| from pathlib import Path | |
| from functools import partial | |
| from typing import Union, Dict, List | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import datasets | |
| from datasets import load_dataset, Dataset | |
| from transformers import AutoTokenizer, PreTrainedTokenizer | |
| from huggingface_hub import Repository, create_repo, HfApi | |
| from optimum.onnxruntime import ( | |
| AutoOptimizationConfig, | |
| ORTModelForFeatureExtraction, | |
| ORTOptimizer, | |
| ) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| opt_configs = { | |
| "O2": AutoOptimizationConfig.O2(), | |
| "O3": AutoOptimizationConfig.O3(), | |
| "O4": AutoOptimizationConfig.O4(), | |
| } | |
| def get_batch_size(device_name: str, model_name: str, opt_level: str): | |
| """ | |
| TODO: run actual tests | |
| T4 has 16GB | |
| A10 has 24GB | |
| Args: | |
| device_name (`str`): | |
| The name of the GPU device in use. | |
| model_name (`str`): | |
| The name of the model in use. | |
| opt_level (`str`): | |
| The optimization level in use. | |
| Returns: | |
| `int`: | |
| The batch size to use. | |
| """ | |
| if "small" in model_name: | |
| bs = 192 | |
| elif "base" in model_name: | |
| bs = 128 | |
| elif "large" in model_name: | |
| bs = 64 | |
| else: | |
| bs = 32 | |
| if "A10" in device_name: | |
| bs *= 2 | |
| if opt_level == "O4": | |
| bs *= 2 | |
| return bs | |
| def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): | |
| """ | |
| Mean pool the token embeddings. | |
| Args: | |
| last_hidden_state (`tuple`): | |
| The output of the model. | |
| attention_mask (`torch.Tensor`): | |
| The attention mask. | |
| Returns: | |
| `torch.Tensor`: | |
| The mean pooled embeddings. | |
| """ | |
| input_mask_expanded = ( | |
| attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
| ) | |
| return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"): | |
| """ | |
| Load a dataset from the HuggingFace Hub. Will be streaming so | |
| as to not load the whole dataset to local storage. | |
| Args: | |
| ds_name (`str`): | |
| The name of the dataset to load. | |
| ds_config (`str`, *optional*, Defaults to `None`): | |
| The configuration of the dataset to load. | |
| ds_split (`str`, *optional*, Defaults to `"train"`): | |
| The split of the dataset to load. | |
| Returns: | |
| ds (`datasets.IterableDataset`): | |
| The loaded dataset. | |
| """ | |
| if ds_config == "": | |
| ds_config = None | |
| if ds_name == "wikipedia": | |
| pattern = re.compile(r"[^a-zA-Z0-9]") | |
| folder = Path("/data") / pattern.sub("", ds_name+ds_config) | |
| files = list(map(str, folder.glob("chunk_*"))) | |
| return load_dataset("parquet", data_files=files, split="train") | |
| ds = load_dataset(ds_name, ds_config, split=ds_split) | |
| return ds | |
| def download_wikipedia(ds_name, ds_config, num2skip, num2embed): | |
| ds = load_dataset(ds_name, ds_config, streaming=True, split="train") | |
| def gen(): | |
| if num2embed > 0: | |
| for example in ds.skip(num2skip).take(num2embed): | |
| yield {"text": example["text"]} | |
| else: | |
| for example in ds.skip(num2skip): | |
| yield {"text": example["text"]} | |
| ds2 = Dataset.from_generator(gen) | |
| chunk_size = 20_000 | |
| filenames = [] | |
| pattern = re.compile(r"[^a-zA-Z0-9]") | |
| folder = Path("/data") / pattern.sub("", ds_name+ds_config) | |
| folder.mkdir(exist_ok=True, parents=True) | |
| for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)): | |
| end_idx = min(start_idx + chunk_size, len(ds2)) | |
| temp = ds2.select(range(start_idx, end_idx)) | |
| temp.to_parquet(str(folder / f"chunk_{chunk_num}")) | |
| filenames.append(str(folder / f"chunk_{chunk_num}")) | |
| return load_dataset("parquet", data_files=filenames, split="train") | |
| def get_model_and_tokenizer(model_name: str, optimization_level: str, progress): | |
| """ | |
| Load the model and tokenizer from the HuggingFace Hub. | |
| If the model is not already optimized, optimize it and save it to the local directory. | |
| Args: | |
| model_name (`str`): | |
| The name of the model to load. | |
| optimization_level (`str`): | |
| The optimization level to use. Should be one of `"O2"`, `"O3"`, or `"O4"`. | |
| Returns: | |
| model (`ORTModelForFeatureExtraction`): | |
| The optimized model. | |
| tokenizer (`PreTrainedTokenizer`): | |
| The tokenizer. | |
| """ | |
| optimized_model_name = f"model_optimized_{optimization_level}.onnx" | |
| model_dir = Path(model_name.replace("/", "_")) | |
| if not (model_dir / optimized_model_name).exists(): | |
| if progress is not None: | |
| progress(0.2, "Downloading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.save_pretrained(model_dir) | |
| if progress is not None: | |
| progress(0.4, "Downloading model...") | |
| model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True) | |
| model.save_pretrained(model_dir) | |
| optimizer = ORTOptimizer.from_pretrained(model) | |
| optimization_config = opt_configs[optimization_level] | |
| if progress is not None: | |
| progress(0.6, "Optimizing model...") | |
| optimizer.optimize(save_dir=model_dir, optimization_config=optimization_config) | |
| Path(model_dir / "model_optimized.onnx").rename( | |
| model_dir / optimized_model_name | |
| ) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| if progress is not None: | |
| progress(0.8, "Loading optimized model and tokenizer...") | |
| return ( | |
| ORTModelForFeatureExtraction.from_pretrained( | |
| model_dir, | |
| file_name=optimized_model_name, | |
| provider="CUDAExecutionProvider", | |
| ), | |
| tokenizer, | |
| ) | |
| def tokenize( | |
| examples: Dict[str, List[str]], | |
| tokenizer: PreTrainedTokenizer, | |
| column_name: str = "text", | |
| padding: Union[bool, str] = True, | |
| max_length: int = 512, | |
| ): | |
| """ | |
| Tokenize the examples using the tokenizer. | |
| Args: | |
| examples (`Dict[str, List[str]]`): | |
| examples to tokenize | |
| tokenizer (`PreTrainedTokenizer`): | |
| tokenizer to use | |
| column_name (`str`, *optional*, defaults to `text`): | |
| column name to use for tokenization. Defaults to `text` | |
| padding (`bool`, *optional*, defaults to `True`): | |
| whether to pad the examples. Defaults to `True` | |
| Use `"max_length"` if using `O4` optimization level | |
| If `True`, the batch will be padded to the longest in the batch. | |
| max_length (`int`, *optional*, Defaults to `512`): | |
| max length to use for the model. Defaults to `512`. | |
| Any sequences longer will be truncated. | |
| If padding is `"max_length"`, the padding will be added until the sequence | |
| is of length `max_length`. | |
| Returns: | |
| `Dict[str, List[List[int]]]`: | |
| tokenized examples | |
| """ | |
| # TODO: add lengths, sort by length, use dynamic padding | |
| # TODO: option for controlling length for models that can go shorter/longer than 512 | |
| return tokenizer( | |
| examples[column_name], truncation=True, padding=padding, max_length=max_length | |
| ) | |
| def collate_fn(examples, tokenizer=None, padding=None, column_name="text"): | |
| try: | |
| keys = examples[0].keys() | |
| except KeyError: | |
| print(examples) | |
| else: | |
| batch = {k: [] for k in examples[0].keys()} | |
| tokenized = tokenizer( | |
| [x[column_name] for x in examples], | |
| truncation=True, | |
| padding=padding, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| tokenized[column_name] = [x[column_name] for x in examples] | |
| return tokenized | |
| # for example in examples: | |
| # for k, v in example.items(): | |
| # batch[k].append(v) | |
| # return { | |
| # k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items() | |
| # } | |
| def batch_embed( | |
| ds: datasets.IterableDataset, | |
| model: ORTModelForFeatureExtraction, | |
| tokenizer: PreTrainedTokenizer, | |
| model_name: str, | |
| column_name: str, | |
| new_dataset_id: str, | |
| opt_level: str, | |
| upload_batch_size: int = 10_000, | |
| map_batch_size: int = 2000, | |
| num2skip: int = 0, | |
| num2embed: int = -1, | |
| progress=None, | |
| ): | |
| """ | |
| Run the model on the dataset and upload the embeddings to the hub. | |
| Args: | |
| ds (`datasets.Dataset`): | |
| dataset to embed. From `load_hf_dataset` | |
| model (`ORTModelForFeatureExtraction`): | |
| model to use for embedding. From `get_model_and_tokenizer` | |
| tokenizer (`AutoTokenizer`): | |
| tokenizer to use for embedding. From `get_model_and_tokenizer` | |
| model_name (`str`): | |
| name of the model to use. Used to determine batch size. | |
| column_name (`str`): | |
| column name to use for embedding. Default option in gradio app is `text` | |
| new_dataset_id (`str`): | |
| id of the new dataset to create. Should include username or organization. | |
| e.g. nbroad/new-embeddings | |
| opt_level (`str`): | |
| optimization level to use. Should be one of `O2`, `O3`, `O4` | |
| See here for more details on optimization levels: | |
| https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration | |
| upload_batch_size (`int`, *optional*, defaults to `10_000`): | |
| number of embeddings to upload at once. Defaults to 10,000. | |
| map_batch_size (`int`, *optional*, defaults to `2000`): | |
| number of examples to tokenize at once. Defaults to 2000. | |
| num2skip (`int`, *optional*, defaults to `0`): | |
| number of examples to skip. Defaults to 0. | |
| num2embed (`int`, *optional*, defaults to `-1`): | |
| number of examples to embed. Defaults to -1, which means all examples. | |
| Returns: | |
| current_count (`int`): | |
| number of examples embedded so far | |
| time_taken (`float`): | |
| time taken to embed the examples in seconds | |
| """ | |
| api = HfApi( | |
| token=os.environ["HF_TOKEN"], | |
| ) | |
| username = api.whoami()["name"] | |
| if "/" not in new_dataset_id: | |
| new_dataset_id = username + "/" + new_dataset_id | |
| repo = init_git_repo(new_dataset_id) | |
| # ds = ds.map( | |
| # tokenize, | |
| # batched=True, | |
| # batch_size=map_batch_size, | |
| # fn_kwargs={ | |
| # "tokenizer": tokenizer, | |
| # "column_name": column_name, | |
| # "padding": "max_length" if opt_level == "O4" else True, | |
| # }, | |
| # ) | |
| embeds = [] | |
| texts = [] | |
| # last_count keeps track of how many had been embedded since last push | |
| last_count = 0 | |
| # current count keeps track of how many have been embedded in total | |
| current_count = 0 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level) | |
| # skip through some examples if specified | |
| if num2skip > 0: | |
| ds = ds.skip(num2skip) | |
| start_time = time.time() | |
| for batch in DataLoader( | |
| ds, | |
| batch_size=inference_bs, | |
| shuffle=False, | |
| num_workers=2, | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=partial( | |
| collate_fn, | |
| column_name=column_name, | |
| tokenizer=tokenizer, | |
| padding="max_length" if opt_level == "O4" else True | |
| ) | |
| ): | |
| ids = batch["input_ids"].to(device) | |
| mask = batch["attention_mask"].to(device) | |
| t_ids = torch.zeros_like(ids) | |
| outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids) | |
| embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist()) | |
| texts.extend(batch[column_name]) | |
| current_count += ids.shape[0] | |
| # Check if we have embedded enough examples | |
| if current_count >= num2embed: | |
| diff = current_count - num2embed | |
| embeds = embeds[:-diff] | |
| texts = texts[:-diff] | |
| current_count = num2embed | |
| break | |
| # Periodically upload to the hub | |
| if len(embeds) > upload_batch_size: | |
| push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
| embeds = [] | |
| texts = [] | |
| last_count = current_count | |
| # Provide updates | |
| if progress is not None: | |
| progress( | |
| (current_count, None), | |
| "Embedding docs...", | |
| total=None, | |
| unit="Docs Embedded", | |
| ) | |
| time_taken = time.time() - start_time | |
| # If there are any remaining embeddings, upload them | |
| if len(embeds) > 0: | |
| push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
| return current_count - num2skip, time_taken | |
| def init_git_repo(repo_id: str): | |
| """ | |
| Initialize a git repo for the new dataset. | |
| ***Removes existing local folder if exists*** | |
| Args: | |
| repo_id (`str`): | |
| id of the new dataset to create. Should include username or organization. | |
| e.g. nbroad/new-embeddings | |
| """ | |
| local_dir = repo_id.replace("/", "_") | |
| create_repo( | |
| repo_id, | |
| repo_type="dataset", | |
| token=os.environ["HF_TOKEN"], | |
| private=True, | |
| exist_ok=True, | |
| ) | |
| try: | |
| repo = Repository( | |
| local_dir=local_dir, | |
| clone_from=repo_id, | |
| repo_type="dataset", | |
| token=os.environ["HF_TOKEN"], | |
| skip_lfs_files=True, | |
| ) | |
| except EnvironmentError: | |
| shutil.rmtree(local_dir) | |
| repo = Repository( | |
| local_dir=local_dir, | |
| clone_from=repo_id, | |
| repo_type="dataset", | |
| token=os.environ["HF_TOKEN"], | |
| skip_lfs_files=True, | |
| ) | |
| if repo is not None: | |
| repo.git_pull() | |
| return repo | |
| def push_to_repo( | |
| repo_id: str, | |
| last_count: int, | |
| current_count: int, | |
| embeds: List[List[float]], | |
| texts: List[str], | |
| api: HfApi, | |
| ): | |
| """ | |
| Push embeddings to the repo. | |
| Args: | |
| repo_id (`str`): | |
| id of the new dataset to create. Should include username or organization. | |
| last_count (`int`): | |
| last count of embeddings. | |
| This is the number of embeddings that have already been pushed. | |
| current_count (`int`): | |
| current count of embeddings. | |
| This is the number of embeddings that have been pushed after this batch. | |
| embeds (`List[List[float]]`): | |
| list of embeddings to push to the repo | |
| texts (`List[str]`): | |
| list of texts to push to the repo | |
| api (`huggingface_hub.HfApi`): | |
| api to use to push to the repo | |
| """ | |
| temp_ds = Dataset.from_dict( | |
| { | |
| "embedding": embeds, | |
| "text": texts, | |
| } | |
| ) | |
| local_dir = repo_id.replace("/", "_") | |
| data_dir = Path(local_dir) / "data" | |
| data_dir.mkdir(exist_ok=True, parents=True) | |
| # use zfill so sorting puts the files in order | |
| filename = f"embeddings_{str(last_count).zfill(8)}_{current_count}.parquet" | |
| filepath = str(data_dir / filename) | |
| temp_ds.to_parquet(filepath) | |
| files = sorted(list(data_dir.glob("*.parquet"))) | |
| api.upload_file( | |
| path_or_fileobj=filepath, | |
| path_in_repo=f"data/{filename}", | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| run_as_future=True, | |
| token=os.environ["HF_TOKEN"], | |
| commit_message=f"Embedded examples {last_count} thru {current_count}", | |
| ) | |
| # Delete old files | |
| if len(files) > 4: | |
| for file in files[:2]: | |
| file.unlink() | |