nb-linguistic-quality-evaluator / run_regressor_bert.py
pere's picture
run regressor
7c34032
raw
history blame
2.97 kB
import torch
import argparse
import jsonlines
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset
from tqdm import tqdm
def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, torch_dtype=torch.bfloat16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Running on {device}")
# Load local jsonlines file
with jsonlines.open(args.input_file) as reader:
data = [line for line in reader]
# Convert list of dictionaries to dictionary of lists
data_dict = {key: [d[key] for d in data] for key in data[0]}
dataset = Dataset.from_dict(data_dict)
# Check how many lines have already been written to the output file
if os.path.exists(args.output_file):
with open(args.output_file, 'r') as f:
existing_lines = sum(1 for _ in f)
print(f"Skipping {existing_lines} already processed lines.")
else:
existing_lines = 0
# Skip already processed lines
if existing_lines > 0:
dataset = dataset.select(range(existing_lines, len(dataset)))
def compute_scores(batch):
inputs = tokenizer(batch[args.text_column], return_tensors="pt", padding="longest", truncation=True, max_length=args.max_length).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze(-1).float().cpu().numpy()
prefix = args.prefix
batch[f"{prefix}_score"] = logits.tolist()
batch[f"{prefix}_int_score"] = [int(round(max(0, min(score, 5)))) for score in logits]
return batch
# Process and write each batch incrementally
with jsonlines.open(args.output_file, mode='a') as writer:
for batch in tqdm(dataset.iter(batch_size=args.batch_size), total=(len(dataset) + args.batch_size - 1) // args.batch_size):
processed_batch = compute_scores(batch)
writer.write_all([dict(zip(batch.keys(), vals)) for vals in zip(*processed_batch.values())])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="north/scandinavian_linguistic_classifier_bert")
parser.add_argument("--input_file", type=str, required=True, help="Path to the input jsonlines file")
parser.add_argument("--output_file", type=str, required=True, help="Path to save the output jsonlines file")
parser.add_argument("--text_column", type=str, default="text")
parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length for tokenization")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size for processing")
parser.add_argument("--prefix", type=str, default="ling", help="Prefix for the score fields")
args = parser.parse_args()
main(args)