Spaces:
Sleeping
Sleeping
import json | |
from transformers import AutoModel, AutoTokenizer | |
from tqdm import tqdm | |
import srsly | |
import typer | |
def load_data(data_path, sample_size): | |
with open(data_path) as f: | |
data = json.loads(f.read()) | |
return data | |
def tag(data_path, tagged_data_path, sample_size: int = 10): | |
data = srsly.read_jsonl(data_path) | |
data = [next(data) for _ in range(sample_size)] | |
tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh") | |
model = AutoModel.from_pretrained( | |
"Wellcome/WellcomeBertMesh", trust_remote_code=True | |
) | |
texts = [grant["title_and_description"] for grant in data] | |
for batch_index in tqdm(range(0, len(texts), 10)): | |
batch_texts = texts[batch_index : batch_index + 10] | |
inputs = tokenizer(batch_texts, padding="max_length") | |
labels = model(**inputs, return_labels=True) | |
for i, tags in enumerate(labels): | |
data[batch_index + i]["tags"] = tags | |
srsly.write_jsonl(tagged_data_path, data) | |
if __name__ == "__main__": | |
typer.run(tag) | |