import torch import requests import tempfile from torchvision.io import read_image def preprocess(image): from torchvision.io import read_image from torchvision.transforms import v2 as T transforms = [] transforms.append(T.ToDtype(torch.float, scale=True)) transforms.append(T.ToPureTensor()) transform = T.Compose(transforms) image = (255.0 * (image - image.min()) / (image.max() - image.min())).to( torch.uint8 ) image = image[:3, ...] transformed_image = transform(image) x = torch.unsqueeze(transformed_image, 0) return x # [:3, ...] def filter_predictions(pred, score_threshold=0.5): keep = pred["scores"] > score_threshold return {k: v[keep] for k, v in pred.items()} model = torch.load("model.pth", map_location=torch.device('cpu')) image_url = "https://github.com/cyber2a/Cyber2A-RTS-ToyModel/blob/main/data/images/valtest_nitze_008.jpg?raw=true" response = requests.get(image_url) with tempfile.NamedTemporaryFile(delete=True, suffix='.jpg') as tmp_file: # Write the content to the temporary file tmp_file.write(response.content) tmp_file_path = tmp_file.name image = read_image(tmp_file_path) scaled_tensor = preprocess(image) with torch.no_grad(): output = model(scaled_tensor) print(filter_predictions(output[0]))