File size: 3,318 Bytes
be8596b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import sys
import zipfile

import numpy as np

import wembeddings.wembeddings as wembeddings

if __name__ == "__main__":
    import argparse

    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", type=str, help="Input file")
    parser.add_argument("output_npz", type=str, help="Output NPZ file")
    parser.add_argument("--batch_size", default=64, type=int, help="Batch size")
    parser.add_argument("--dtype", default="float16", type=str, help="Dtype to save as")
    parser.add_argument("--format", default="conllu", type=str, help="Input format (conllu, conll)")
    parser.add_argument("--model", default="bert-base-multilingual-uncased-last4", type=str, help="Model name (see wembeddings.py for options)")
    parser.add_argument("--server", default=None, type=str, help="Use given server to compute the embeddings")
    parser.add_argument("--threads", default=4, type=int, help="Threads to use")
    args = parser.parse_args()

    args.dtype = getattr(np, args.dtype)
    assert args.format in ["conll", "conllu"]

    # Load the input file
    sentences = []
    with open(args.input_path, mode="r", encoding="utf-8") as input_file:
        in_sentence = False
        for line in input_file:
            line = line.rstrip("\n")
            if line:
                if not in_sentence:
                    sentences.append([])
                    in_sentence = True

                columns = line.split("\t")
                if args.format == "conll":
                    sentences[-1].append(columns[0])
                elif args.format == "conllu":
                    if columns[0].isdigit():
                        assert len(columns) == 10
                        sentences[-1].append(columns[1])
            else:
                in_sentence = False
    print("Loaded {} sentences and {} words.".format(len(sentences), sum(map(len, sentences))), file=sys.stderr, flush=True)

    # Initialize suitable computational class
    if args.server is not None:
        wembeddings = wembeddings.WEmbeddings.ClientNetwork(args.server)
    else:
        wembeddings = wembeddings.WEmbeddings(threads=args.threads)

    # Compute word embeddings
    with zipfile.ZipFile(args.output_npz, mode="w", compression=zipfile.ZIP_STORED) as output_npz:
        for i in range(0, len(sentences), args.batch_size):
            sentences_embeddings = wembeddings.compute_embeddings(args.model, sentences[i:i + args.batch_size])
            for j, sentence_embeddings in enumerate(sentences_embeddings):
                with output_npz.open("arr_{}".format(i + j), mode="w") as embeddings_file:
                    np.save(embeddings_file, sentence_embeddings.astype(args.dtype))
                if (i + j + 1) % 100 == 0:
                    print("Processed {}/{} sentences.".format(i + j + 1, len(sentences)), file=sys.stderr, flush=True)
    print("Done, all embeddings saved.", file=sys.stderr, flush=True)