SearchMesh / tag.py
nsorros's picture
Add search app
4709571
raw history blame
No virus
879 Bytes
import json
from transformers import AutoModel, AutoTokenizer
import srsly
import typer
def load_data(data_path, sample_size):
with open(data_path) as f:
data = json.loads(f.read())
return data
def tag(data_path, tagged_data_path, sample_size: int = 10):
data = srsly.read_jsonl(data_path)
data = [next(data) for _ in range(sample_size)]
tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
model = AutoModel.from_pretrained(
"Wellcome/WellcomeBertMesh", trust_remote_code=True
)
texts = [grant["title_and_description"] for grant in data]
inputs = tokenizer(texts, padding="max_length")
labels = model(**inputs, return_labels=True)
for i, tags in enumerate(labels):
data[i]["tags"] = tags
srsly.write_jsonl(tagged_data_path, data)
if __name__ == "__main__":
typer.run(tag)