| |
|
|
| import torch |
| import numpy as np |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import tempfile |
| import os |
| from huggingface_hub import hf_hub_download |
| from counting import CountingModule |
|
|
| MODEL = None |
| DEVICE = torch.device("cpu") |
|
|
| def load_model(use_box=False): |
| """ |
| load counting model from Hugging Face Hub |
| |
| Args: |
| use_box: use bounding box as input (default: False) |
| |
| Returns: |
| model: loaded counting model |
| device: device |
| """ |
| global MODEL, DEVICE |
| |
| try: |
| print("π Loading counting model...") |
|
|
| MODEL = CountingModule(use_box=use_box) |
| |
| ckpt_path = hf_hub_download( |
| repo_id="phoebe777777/111", |
| filename="microscopy_matching_cnt.pth", |
| token=None, |
| force_download=False |
| ) |
| |
| print(f"β
Checkpoint downloaded: {ckpt_path}") |
| |
| MODEL.load_state_dict( |
| torch.load(ckpt_path, map_location="cpu"), |
| strict=True |
| ) |
| MODEL.eval() |
| |
| if torch.cuda.is_available(): |
| DEVICE = torch.device("cuda") |
| MODEL.move_to_device(DEVICE) |
| print("β
Model moved to CUDA") |
| else: |
| DEVICE = torch.device("cpu") |
| MODEL.move_to_device(DEVICE) |
| print("β
Model on CPU") |
| |
| print("β
Counting model loaded successfully") |
| return MODEL, DEVICE |
| |
| except Exception as e: |
| print(f"β Error loading counting model: {e}") |
| import traceback |
| traceback.print_exc() |
| return None, torch.device("cpu") |
|
|
|
|
| @torch.no_grad() |
| def run(model, img_path, box=None, device="cpu", visualize=True): |
| """ |
| Run counting inference on a single image |
| |
| Args: |
| model: loaded counting model |
| img_path: image path |
| box: bounding box [[x1, y1, x2, y2], ...] or None |
| device: device |
| visualize: whether to generate visualization |
| |
| Returns: |
| result_dict: { |
| 'density_map': numpy array, |
| 'count': float, |
| 'visualized_path': str (if visualize=True) |
| } |
| """ |
| print("DEVICE:", device) |
| model.move_to_device(device) |
| model.eval() |
| if box is not None: |
| use_box = True |
| else: |
| use_box = False |
| model.use_box = use_box |
|
|
| if model is None: |
| return { |
| 'density_map': None, |
| 'count': 0, |
| 'visualized_path': None, |
| 'error': 'Model not loaded' |
| } |
| |
| try: |
| print(f"π Running counting inference on {img_path}") |
| |
| with torch.no_grad(): |
| density_map, count = model(img_path, box) |
| |
| print(f"β
Counting result: {count:.1f} objects") |
| |
| result = { |
| 'density_map': density_map, |
| 'count': count, |
| 'visualized_path': None |
| } |
|
|
| |
| return result |
| |
| except Exception as e: |
| print(f"β Counting inference error: {e}") |
| import traceback |
| traceback.print_exc() |
| return { |
| 'density_map': None, |
| 'count': 0, |
| 'visualized_path': None, |
| 'error': str(e) |
| } |
|
|
|
|
| def visualize_result(image_path, density_map, count): |
| """ |
| Visualize counting results (consistent with your original visualization code) |
| |
| Args: |
| image_path: original image path |
| density_map: numpy array of predicted density map |
| count |
| |
| Returns: |
| output_path: temporary file path of the visualization result |
| """ |
| try: |
| import skimage.io as io |
| |
| img = io.imread(image_path) |
| |
| if len(img.shape) == 3 and img.shape[2] > 3: |
| img = img[:, :, :3] |
| if len(img.shape) == 2: |
| img = np.stack([img]*3, axis=-1) |
| |
| img_show = img.squeeze() |
| density_map_show = density_map.squeeze() |
| |
| img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8) |
| |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| |
| ax.imshow(img_show) |
| ax.imshow(density_map_show, cmap='jet', alpha=0.5) |
| ax.axis('off') |
| |
| plt.tight_layout() |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| plt.savefig(temp_file.name, dpi=300) |
| plt.close() |
| |
| print(f"β
Visualization saved to {temp_file.name}") |
| return temp_file.name |
| |
| except Exception as e: |
| print(f"β Visualization error: {e}") |
| import traceback |
| traceback.print_exc() |
| return image_path |
|
|
|
|
| if __name__ == "__main__": |
| print("="*60) |
| print("Testing Counting Model") |
| print("="*60) |
|
|
| model, device = load_model(use_box=False) |
| |
| if model is not None: |
| print("\n" + "="*60) |
| print("Model loaded successfully, testing inference...") |
| print("="*60) |
| |
| test_image = "example_imgs/1977_Well_F-5_Field_1.png" |
| |
| if os.path.exists(test_image): |
| result = run( |
| model, |
| test_image, |
| box=None, |
| device=device, |
| visualize=True |
| ) |
| |
| if 'error' not in result: |
| print("\n" + "="*60) |
| print("Inference Results:") |
| print("="*60) |
| print(f"Count: {result['count']:.1f}") |
| print(f"Density map shape: {result['density_map'].shape}") |
| if result['visualized_path']: |
| print(f"Visualization saved to: {result['visualized_path']}") |
| else: |
| print(f"\nβ Inference failed: {result['error']}") |
| else: |
| print(f"\nβ οΈ Test image not found: {test_image}") |
| else: |
| print("\nβ Model loading failed") |
|
|