import argparse import json import os from typing import Optional, Tuple from tqdm.auto import tqdm import torch from datasets import DatasetDict, load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM def check_base_path(path: str) -> Optional[str]: if path is not None: base_path = os.path.basename(path) if os.path.exists(base_path): return path else: raise Exception(f'Path not found {base_path}') return path def parse_args(): DEFAULT_MODEL_ID = 'EmbeddingStudio/query-parser-falcon-7b-instruct' DEFAULT_DATASET = 'EmbeddingStudio/query-parsing-instructions-falcon' DEFAULT_SPLIT = 'test' DEFAULT_INSTRUCTION_FIELD = 'text' DEFAULT_RESPONSE_DELIMITER = '## Response:\n' DEFAULT_CATEGORY_DELIMITER = '## Category:' DEFAULT_OUTPUT_PATH = f'{DEFAULT_MODEL_ID.split("/")[-1]}-test.json' parser = argparse.ArgumentParser(description='EmbeddingStudio script for testing Zero-Shot Search Query Parsers') parser.add_argument("--model-id", help=f"Huggingface model ID (default: {DEFAULT_MODEL_ID})", default=DEFAULT_MODEL_ID, type=str, ) parser.add_argument("--dataset-name", help=f"Huggingface dataset name which contains instructions (default: {DEFAULT_DATASET})", default=DEFAULT_DATASET, type=str, ) parser.add_argument("--dataset-split", help=f"Huggingface dataset split name (default: {DEFAULT_SPLIT})", default=DEFAULT_SPLIT, type=str, ) parser.add_argument("--dataset-instructions-field", help=f"Huggingface dataset field with instructions (default: {DEFAULT_INSTRUCTION_FIELD})", default=DEFAULT_INSTRUCTION_FIELD, type=str, ) parser.add_argument("--instructions-response-delimiter", help=f"Instruction response delimiter (default: {DEFAULT_RESPONSE_DELIMITER})", default=DEFAULT_RESPONSE_DELIMITER, type=str, ) parser.add_argument("--instructions-category-delimiter", help=f"Instruction category name delimiter (default: {DEFAULT_CATEGORY_DELIMITER})", default=DEFAULT_CATEGORY_DELIMITER, type=str, ) parser.add_argument("--output", help=f"JSON file with test results (default: {DEFAULT_OUTPUT_PATH})", default=DEFAULT_OUTPUT_PATH, type=check_base_path, ) args = parser.parse_args() return args def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]: tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, add_prefix_space=True, use_fast=False, ) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": 0}) return tokenizer, model @torch.no_grad() def predict( tokenizer: AutoTokenizer, model: AutoModelForCausalLM, dataset: DatasetDict, index: int, field_name: str = 'text', response_delimiter: str = '## Response:\n', category_delimiter: str = '## Category: ' ) -> Tuple[dict, dict, str]: input_text = dataset[index][field_name].split(response_delimiter)[0] + response_delimiter input_ids = tokenizer.encode(input_text, return_tensors='pt') real = json.loads(dataset[index][field_name].split(response_delimiter)[-1]) category = dataset[index][field_name].split(category_delimiter)[-1].split('\n')[0] # Generating text output = model.generate(input_ids.to('cuda'), max_new_tokens=1000, do_sample=True, temperature=0.05, pad_token_id=50256 ) parsed = json.loads(tokenizer.decode(output[0], skip_special_tokens=True).split(response_delimiter)[-1]) return [parsed, real, category] @torch.no_grad() def test_model(model_id: str, dataset_name: str, split_name: str, field_name: str, response_delimiter: str, category_delimiter: str, output_path: str, ): dataset = load_dataset(dataset_name, split=split_name) tokenizer, model = load_model(model_id) model.eval() test_results = [] for index in tqdm(range(len(dataset[split_name]))): try: test_results.append(predict(tokenizer, model, dataset[split_name], index, field_name, response_delimiter, category_delimiter)) except Exception as e: continue with open(output_path, 'w') as f: json.dump(test_results) if __name__ == '__main__': args = parse_args() test_model( args.model_id, args.dataset_name, args.dataset_split, args.dataset_instructions_field, args.instructions_response_delimiter, args.instructions_category_delimiter, args.output )