Performance improvement

#1
by ybelkada HF staff - opened

Great demo!
I have one suggestion for performance improvement, you could run get_image_embeddings only once, to retrieve the image embeddings and iteratively run the classic forward pass with the input points by making sure you popped the pixel_values from the inputs dict.

Also, you might need to add with torch.no_grad(): context manager when calling the forward pass for faster inference

Thanks for the feedback, I added torch.no_grad() to the function definition, I also added get_image_embeddings and storing the embedding until a new image is uploaded and there is a significant speed up for the 2nd inference on the same image. Although, I might need to add some other check because i'm not sure how this will work when multiple people are using the app.

mattmdjaga changed discussion status to closed

Sign up or log in to comment