Fails when using multi-threading and CUDA device. SOLVED

#3
by CoderCowMoo - opened

Quan mate, I've spent about 10-20 hours looking through the modelling file, the pytorch issues and code, the transformers documentation and code, to figure out why the example code in the README.md doesn't work in a gradio demo.

Turns out, Gradio uses separate threads to execute functions tied to inputs and outputs.
Also turns out, torch.set_default_device doesn't work across threads in pytorch <=2.2.2

Solution is one line.

torch.set_default_tensor_type('torch.cuda.FloatTensor')

Gradio demo (with streaming hopefully) coming soon.

Sign up or log in to comment