SearchMesh / tag.py
nsorros's picture
Tag more grants and implement most common
b493a01
raw history blame
No virus
1.05 kB
import json
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import srsly
import typer
def load_data(data_path, sample_size):
with open(data_path) as f:
data = json.loads(f.read())
return data
def tag(data_path, tagged_data_path, sample_size: int = 10):
data = srsly.read_jsonl(data_path)
data = [next(data) for _ in range(sample_size)]
tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
model = AutoModel.from_pretrained(
"Wellcome/WellcomeBertMesh", trust_remote_code=True
)
texts = [grant["title_and_description"] for grant in data]
for batch_index in tqdm(range(0, len(texts), 10)):
batch_texts = texts[batch_index : batch_index + 10]
inputs = tokenizer(batch_texts, padding="max_length")
labels = model(**inputs, return_labels=True)
for i, tags in enumerate(labels):
data[batch_index + i]["tags"] = tags
srsly.write_jsonl(tagged_data_path, data)
if __name__ == "__main__":
typer.run(tag)