import os import sys import utils import datasets import eval_utils from constants import DIALECTS_WITH_LABELS from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, ) from huggingface_hub import login access_token = os.environ["HF_TOKEN"] login(token=access_token) model_name = sys.argv[1] commit_id = sys.argv[2] inference_function = sys.argv[3] utils.update_model_queue( repo_id=os.environ["PREDICTIONS_DATASET_NAME"], model_name=model_name, commit_id=commit_id, inference_function=inference_function, status="in_progress", ) try: tokenizer = AutoTokenizer.from_pretrained(model_name, revision=commit_id) if inference_function == "prompt_chat_LLM": model = AutoModelForCausalLM.from_pretrained(model_name, revision=commit_id) else: model = AutoModelForSequenceClassification.from_pretrained( model_name, revision=commit_id ) # Load the dataset dataset_name = os.environ["DATASET_NAME"] dataset = datasets.load_dataset(dataset_name)["test"] sentences = dataset["sentence"] labels = {dialect: dataset[dialect] for dialect in DIALECTS_WITH_LABELS} predictions = [] for i, sentence in enumerate(sentences): predictions.append( getattr(eval_utils, inference_function)(model, tokenizer, sentence) ) print( f"Inference progress ({model_name}, {inference_function}): {round(100 * (i + 1) / len(sentences), 1)}%" ) # Store the predictions in a private dataset utils.upload_predictions( os.environ["PREDICTIONS_DATASET_NAME"], predictions, model_name, commit_id, inference_function, ) print(f"Inference completed!") except Exception as e: print(f"An error occurred during inference of {model_name}: {e}") utils.update_model_queue( repo_id=os.environ["PREDICTIONS_DATASET_NAME"], model_name=model_name, commit_id=commit_id, inference_function=inference_function, status="failed (online)", )