File size: 1,053 Bytes
4709571
 
 
fd5a1b3
4709571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5a1b3
b493a01
4709571
fd5a1b3
 
 
 
b493a01
4709571
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json

from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import srsly
import typer


def load_data(data_path, sample_size):
    with open(data_path) as f:
        data = json.loads(f.read())

    return data


def tag(data_path, tagged_data_path, sample_size: int = 10):
    data = srsly.read_jsonl(data_path)
    data = [next(data) for _ in range(sample_size)]

    tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
    model = AutoModel.from_pretrained(
        "Wellcome/WellcomeBertMesh", trust_remote_code=True
    )

    texts = [grant["title_and_description"] for grant in data]
    for batch_index in tqdm(range(0, len(texts), 10)):
        batch_texts = texts[batch_index : batch_index + 10]

        inputs = tokenizer(batch_texts, padding="max_length")
        labels = model(**inputs, return_labels=True)

        for i, tags in enumerate(labels):
            data[batch_index + i]["tags"] = tags

    srsly.write_jsonl(tagged_data_path, data)


if __name__ == "__main__":
    typer.run(tag)