SearchMesh / tag.py
nsorros's picture
Update tagged grants
fd5a1b3
import json
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import srsly
import typer
def load_data(data_path, sample_size):
with open(data_path) as f:
data = json.loads(f.read())
return data
def tag(data_path, tagged_data_path, sample_size: int = 10):
data = srsly.read_jsonl(data_path)
data = [next(data) for _ in range(sample_size)]
tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
model = AutoModel.from_pretrained(
"Wellcome/WellcomeBertMesh", trust_remote_code=True
)
texts = [grant["title_and_description"] for grant in data]
for batch_index in tqdm(range(0, len(texts), 10)):
batch_texts = texts[batch_index:batch_index+10]
inputs = tokenizer(batch_texts, padding="max_length")
labels = model(**inputs, return_labels=True)
for i, tags in enumerate(labels):
data[i]["tags"] = tags
srsly.write_jsonl(tagged_data_path, data)
if __name__ == "__main__":
typer.run(tag)