Slow initialization

#6
by drmeir - opened

I have a long input and would like to split it in order to process in distributed manner. To test how it would affect performance, I ran:

timer = Timer()
n_parts = 100
part = int(len(text)/n_parts)
for i in range(n_parts):
    begin = i * part
    end = begin + part
    results = m.infer([text[begin:end]])
print(timer.stop())

In this code, I split the text in 100 parts and perform the inference on each part. This takes almost twice as long as performing inference on the whole text at once. This suggests to me that initialization takes significant time. Is there a way to make it more efficient?

There's a lot to consider here.

First, if you're using the punctuators package, you're certainly running on CPU. I wouldn't have any expectations of good performance in any case. I wrote that package simply to demonstrate how to use the models; everyone's use case differs and I presumed others would optimize for their own case, as needed.

Second, with multiple inputs, it's best to batch, even on CPU. Something like m.infer([text[0:10], [text[10:20], ..., text[90:100]]).

Third, this model is a Transformer, which uses global scaled dot product attention. Running time is therefore quadratic in sequence length. So a single input of length N should run slower than N//n inputs of length n for n > 1.

The "initialization time" here is presumably batching vs. multiple passes through the graph.

My Python backend is running on PythonAnywhere. To speed up segmentation of a long text, I submit each part of the text as a separate request. All requests are submitted in parallel and each request is handled by a web worker, all web workers doing their job in parallel. My test script measures how much time in total the web workers would spend if I split the large test in 100 parts and give them to 100 web workers.

For the point of O(n^2). I modified the code in the question to make the number of parts a variable (n_parts). What is then the explanation for my script running twice as slow for n_parts=100 than for n_parts=1 when it should really become faster?

Sign up or log in to comment