diff --git a/checkpoint/barc_complete/model_best.pth.tar b/checkpoint/barc_complete/model_best.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..fd0fa82e10dc43d53ac35e6af2fc5bc9bb8bd3cf --- /dev/null +++ b/checkpoint/barc_complete/model_best.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0834c7f6a298a707e748da7185bd52a318697a34d7d0462e86cf57e287fa5da3 +size 549078471 diff --git a/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt b/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..7d7ae431a18aed3a51cee275dfc300d68f430487 --- /dev/null +++ b/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ff03508f6b9431da1c224697ce1c68cab000758215c4a4766e136c28f828f2d +size 1725484 diff --git a/data/breed_data/NIHMS866262-supplement-2.xlsx b/data/breed_data/NIHMS866262-supplement-2.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..0bcea54381008d956311a639c52f19ec6b26d6c4 --- /dev/null +++ b/data/breed_data/NIHMS866262-supplement-2.xlsx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd6301ec254452ecb86df745220bef98b69d59794429c5cb452b03bb76e17eae +size 94169 diff --git a/data/breed_data/complete_abbrev_dict_v2.pkl b/data/breed_data/complete_abbrev_dict_v2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..94987bc0d384c5ecc32c2eb7d3b54c43fe9e0a75 --- /dev/null +++ b/data/breed_data/complete_abbrev_dict_v2.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2354d2c7e3b2f7ee88f41234e138b7828d58fa6618c0ed0d0d4b12febaee8801 +size 26517 diff --git a/data/breed_data/complete_summary_breeds_v2.pkl b/data/breed_data/complete_summary_breeds_v2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0464c827536776681b9c4fcc6ee435a57b82332e --- /dev/null +++ b/data/breed_data/complete_summary_breeds_v2.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95461e44d7a6924e1d9879711c865177ac7f15faa1ffb932cb42995c8eae3412 +size 89004 diff --git a/data/smal_data/mean_dog_bone_lengths.txt b/data/smal_data/mean_dog_bone_lengths.txt new file mode 100644 index 0000000000000000000000000000000000000000..abf7bbf02f8fc4eda65bddf7a5c8eb3ab88d6e38 --- /dev/null +++ b/data/smal_data/mean_dog_bone_lengths.txt @@ -0,0 +1,34 @@ +0.0 +0.09044851362705231 +0.1525898575782776 +0.08656660467386246 +0.08330804109573364 +0.17591887712478638 +0.1955687403678894 +0.1663869321346283 +0.20741023123264313 +0.10695090889930725 +0.1955687403678894 +0.1663869321346283 +0.20741020143032074 +0.10695091634988785 +0.19678470492362976 +0.135447695851326 +0.10385762155056 +0.1951410472393036 +0.22369971871376038 +0.14296436309814453 +0.10385762155056 +0.1951410472393036 +0.22369973361492157 +0.14296436309814453 +0.11435563117265701 +0.1225045919418335 +0.055157795548439026 +0.07148551940917969 +0.0759430006146431 +0.09476413577795029 +0.0287716593593359 +0.11548781394958496 +0.15003003180027008 +0.15003003180027008 diff --git a/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl b/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7cf953b71028d76b12c08a4e04e90a44d0155e5e --- /dev/null +++ b/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf01081234c09445ede7079083727705e6c13a21a77bf97f305e4ad6527f06df +size 34904364 diff --git a/data/smal_data/my_smpl_data_SMBLD_v3.pkl b/data/smal_data/my_smpl_data_SMBLD_v3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0372a9723c39f810cd61076867844b76ce49917a --- /dev/null +++ b/data/smal_data/my_smpl_data_SMBLD_v3.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84ad0ef6f85d662464c4d0301adede172bb241158b1cea66a810a930a9473cc8 +size 31841 diff --git a/data/smal_data/symmetry_inds.json b/data/smal_data/symmetry_inds.json new file mode 100644 index 0000000000000000000000000000000000000000..c17c305b15222b9cba70acd64b46725a2c0332c1 --- /dev/null +++ b/data/smal_data/symmetry_inds.json @@ -0,0 +1,3897 @@ +{ + "center_inds": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 37, + 55, + 119, + 120, + 163, + 209, + 210, + 211, + 213, + 216, + 227, + 326, + 395, + 452, + 578, + 910, + 959, + 964, + 975, + 976, + 977, + 1172, + 1175, + 1176, + 1178, + 1194, + 1243, + 1739, + 1796, + 1797, + 1798, + 1799, + 1800, + 1801, + 1802, + 1803, + 1804, + 1805, + 1806, + 1807, + 1808, + 1809, + 1810, + 1811, + 1812, + 1813, + 1814, + 1815, + 1816, + 1817, + 1818, + 1819, + 1820, + 1821, + 1822, + 1823, + 1824, + 1825, + 1826, + 1827, + 1828, + 1829, + 1830, + 1831, + 1832, + 1833, + 1834, + 1835, + 1836, + 1837, + 1838, + 1839, + 1840, + 1842, + 1843, + 1844, + 1845, + 1846, + 1847, + 1848, + 1849, + 1850, + 1851, + 1852, + 1853, + 1854, + 1855, + 1856, + 1857, + 1858, + 1859, + 1860, + 1861, + 1862, + 1863, + 1870, + 1919, + 1960, + 1961, + 1965, + 1967, + 2003 + ], + "left_inds": [ + 2012, + 2013, + 2014, + 2015, + 2016, + 2017, + 2018, + 2019, + 2020, + 2021, + 2022, + 2023, + 2024, + 2025, + 2026, + 2027, + 2028, + 2029, + 2030, + 2031, + 2032, + 2033, + 2034, + 2035, + 2036, + 2037, + 2038, + 2039, + 2040, + 2041, + 2042, + 2043, + 2044, + 2045, + 2046, + 2047, + 2048, + 2049, + 2050, + 2051, + 2052, + 2053, + 2054, + 2055, + 2056, + 2057, + 2058, + 2059, + 2060, + 2061, + 2062, + 2063, + 2064, + 2065, + 2066, + 2067, + 2068, + 2069, + 2070, + 2071, + 2072, + 2073, + 2074, + 2075, + 2076, + 2077, + 2078, + 2079, + 2080, + 2081, + 2082, + 2083, + 2084, + 2085, + 2086, + 2087, + 2088, + 2089, + 2090, + 2091, + 2092, + 2093, + 2094, + 2095, + 2096, + 2097, + 2098, + 2099, + 2100, + 2101, + 2102, + 2103, + 2104, + 2105, + 2106, + 2107, + 2108, + 2109, + 2110, + 2111, + 2112, + 2113, + 2114, + 2115, + 2116, + 2117, + 2118, + 2119, + 2120, + 2121, + 2122, + 2123, + 2124, + 2125, + 2126, + 2127, + 2128, + 2129, + 2130, + 2131, + 2132, + 2133, + 2134, + 2135, + 2136, + 2137, + 2138, + 2139, + 2140, + 2141, + 2142, + 2143, + 2144, + 2145, + 2146, + 2147, + 2148, + 2149, + 2150, + 2151, + 2152, + 2153, + 2154, + 2155, + 2156, + 2157, + 2158, + 2159, + 2160, + 2161, + 2162, + 2163, + 2164, + 2165, + 2166, + 2167, + 2168, + 2169, + 2170, + 2171, + 2172, + 2173, + 2174, + 2175, + 2176, + 2177, + 2178, + 2179, + 2180, + 2181, + 2182, + 2183, + 2184, + 2185, + 2186, + 2187, + 2188, + 2189, + 2190, + 2191, + 2192, + 2193, + 2194, + 2195, + 2196, + 2197, + 2198, + 2199, + 2200, + 2201, + 2202, + 2203, + 2204, + 2205, + 2206, + 2207, + 2208, + 2209, + 2210, + 2211, + 2212, + 2213, + 2214, + 2215, + 2216, + 2217, + 2218, + 2219, + 2220, + 2221, + 2222, + 2223, + 2224, + 2225, + 2226, + 2227, + 2228, + 2229, + 2230, + 2231, + 2232, + 2233, + 2234, + 2235, + 2236, + 2237, + 2238, + 2239, + 2240, + 2241, + 2242, + 2243, + 2244, + 2245, + 2246, + 2247, + 2248, + 2249, + 2250, + 2251, + 2252, + 2253, + 2254, + 2255, + 2256, + 2257, + 2258, + 2259, + 2260, + 2261, + 2262, + 2263, + 2264, + 2265, + 2266, + 2267, + 2268, + 2269, + 2270, + 2271, + 2272, + 2273, + 2274, + 2275, + 2276, + 2277, + 2278, + 2279, + 2280, + 2281, + 2282, + 2283, + 2284, + 2285, + 2286, + 2287, + 2288, + 2289, + 2290, + 2291, + 2292, + 2293, + 2294, + 2295, + 2296, + 2297, + 2298, + 2299, + 2300, + 2301, + 2302, + 2303, + 2304, + 2305, + 2306, + 2307, + 2308, + 2309, + 2310, + 2311, + 2312, + 2313, + 2314, + 2315, + 2316, + 2317, + 2318, + 2319, + 2320, + 2321, + 2322, + 2323, + 2324, + 2325, + 2326, + 2327, + 2328, + 2329, + 2330, + 2331, + 2332, + 2333, + 2334, + 2335, + 2336, + 2337, + 2338, + 2339, + 2340, + 2341, + 2342, + 2343, + 2344, + 2345, + 2346, + 2347, + 2348, + 2349, + 2350, + 2351, + 2352, + 2353, + 2354, + 2355, + 2356, + 2357, + 2358, + 2359, + 2360, + 2361, + 2362, + 2363, + 2364, + 2365, + 2366, + 2367, + 2368, + 2369, + 2370, + 2371, + 2372, + 2373, + 2374, + 2375, + 2376, + 2377, + 2378, + 2379, + 2380, + 2381, + 2382, + 2383, + 2384, + 2385, + 2386, + 2387, + 2388, + 2389, + 2390, + 2391, + 2392, + 2393, + 2394, + 2395, + 2396, + 2397, + 2398, + 2399, + 2400, + 2401, + 2402, + 2403, + 2404, + 2405, + 2406, + 2407, + 2408, + 2409, + 2410, + 2411, + 2412, + 2413, + 2414, + 2415, + 2416, + 2417, + 2418, + 2419, + 2420, + 2421, + 2422, + 2423, + 2424, + 2425, + 2426, + 2427, + 2428, + 2429, + 2430, + 2431, + 2432, + 2433, + 2434, + 2435, + 2436, + 2437, + 2438, + 2439, + 2440, + 2441, + 2442, + 2443, + 2444, + 2445, + 2446, + 2447, + 2448, + 2449, + 2450, + 2451, + 2452, + 2453, + 2454, + 2455, + 2456, + 2457, + 2458, + 2459, + 2460, + 2461, + 2462, + 2463, + 2464, + 2465, + 2466, + 2467, + 2468, + 2469, + 2470, + 2471, + 2472, + 2473, + 2474, + 2475, + 2476, + 2477, + 2478, + 2479, + 2480, + 2481, + 2482, + 2483, + 2484, + 2485, + 2486, + 2487, + 2488, + 2489, + 2490, + 2491, + 2492, + 2493, + 2494, + 2495, + 2496, + 2497, + 2498, + 2499, + 2500, + 2501, + 2502, + 2503, + 2504, + 2505, + 2506, + 2507, + 2508, + 2509, + 2510, + 2511, + 2512, + 2513, + 2514, + 2515, + 2516, + 2517, + 2518, + 2519, + 2520, + 2521, + 2522, + 2523, + 2524, + 2525, + 2526, + 2527, + 2528, + 2529, + 2530, + 2531, + 2532, + 2533, + 2534, + 2535, + 2536, + 2537, + 2538, + 2539, + 2540, + 2541, + 2542, + 2543, + 2544, + 2545, + 2546, + 2547, + 2548, + 2549, + 2550, + 2551, + 2552, + 2553, + 2554, + 2555, + 2556, + 2557, + 2558, + 2559, + 2560, + 2561, + 2562, + 2563, + 2564, + 2565, + 2566, + 2567, + 2568, + 2569, + 2570, + 2571, + 2572, + 2573, + 2574, + 2575, + 2576, + 2577, + 2578, + 2579, + 2580, + 2581, + 2582, + 2583, + 2584, + 2585, + 2586, + 2587, + 2588, + 2589, + 2590, + 2591, + 2592, + 2593, + 2594, + 2595, + 2596, + 2597, + 2598, + 2599, + 2600, + 2601, + 2602, + 2603, + 2604, + 2605, + 2606, + 2607, + 2608, + 2609, + 2610, + 2611, + 2612, + 2613, + 2614, + 2615, + 2616, + 2617, + 2618, + 2619, + 2620, + 2621, + 2622, + 2623, + 2624, + 2625, + 2626, + 2627, + 2628, + 2629, + 2630, + 2631, + 2632, + 2633, + 2634, + 2635, + 2636, + 2637, + 2638, + 2639, + 2640, + 2641, + 2642, + 2643, + 2644, + 2645, + 2646, + 2647, + 2648, + 2649, + 2650, + 2651, + 2652, + 2653, + 2654, + 2655, + 2656, + 2657, + 2658, + 2659, + 2660, + 2661, + 2662, + 2663, + 2664, + 2665, + 2666, + 2667, + 2668, + 2669, + 2670, + 2671, + 2672, + 2673, + 2674, + 2675, + 2676, + 2677, + 2678, + 2679, + 2680, + 2681, + 2682, + 2683, + 2684, + 2685, + 2686, + 2687, + 2688, + 2689, + 2690, + 2691, + 2692, + 2693, + 2694, + 2695, + 2696, + 2697, + 2698, + 2699, + 2700, + 2701, + 2702, + 2703, + 2704, + 2705, + 2706, + 2707, + 2708, + 2709, + 2710, + 2711, + 2712, + 2713, + 2714, + 2715, + 2716, + 2717, + 2718, + 2719, + 2720, + 2721, + 2722, + 2723, + 2724, + 2725, + 2726, + 2727, + 2728, + 2729, + 2730, + 2731, + 2732, + 2733, + 2734, + 2735, + 2736, + 2737, + 2738, + 2739, + 2740, + 2741, + 2742, + 2743, + 2744, + 2745, + 2746, + 2747, + 2748, + 2749, + 2750, + 2751, + 2752, + 2753, + 2754, + 2755, + 2756, + 2757, + 2758, + 2759, + 2760, + 2761, + 2762, + 2763, + 2764, + 2765, + 2766, + 2767, + 2768, + 2769, + 2770, + 2771, + 2772, + 2773, + 2774, + 2775, + 2776, + 2777, + 2778, + 2779, + 2780, + 2781, + 2782, + 2783, + 2784, + 2785, + 2786, + 2787, + 2788, + 2789, + 2790, + 2791, + 2792, + 2793, + 2794, + 2795, + 2796, + 2797, + 2798, + 2799, + 2800, + 2801, + 2802, + 2803, + 2804, + 2805, + 2806, + 2807, + 2808, + 2809, + 2810, + 2811, + 2812, + 2813, + 2814, + 2815, + 2816, + 2817, + 2818, + 2819, + 2820, + 2821, + 2822, + 2823, + 2824, + 2825, + 2826, + 2827, + 2828, + 2829, + 2830, + 2831, + 2832, + 2833, + 2834, + 2835, + 2836, + 2837, + 2838, + 2839, + 2840, + 2841, + 2842, + 2843, + 2844, + 2845, + 2846, + 2847, + 2848, + 2849, + 2850, + 2851, + 2852, + 2853, + 2854, + 2855, + 2856, + 2857, + 2858, + 2859, + 2860, + 2861, + 2862, + 2863, + 2864, + 2865, + 2866, + 2867, + 2868, + 2869, + 2870, + 2871, + 2872, + 2873, + 2874, + 2875, + 2876, + 2877, + 2878, + 2879, + 2880, + 2881, + 2882, + 2883, + 2884, + 2885, + 2886, + 2887, + 2888, + 2889, + 2890, + 2891, + 2892, + 2893, + 2894, + 2895, + 2896, + 2897, + 2898, + 2899, + 2900, + 2901, + 2902, + 2903, + 2904, + 2905, + 2906, + 2907, + 2908, + 2909, + 2910, + 2911, + 2912, + 2913, + 2914, + 2915, + 2916, + 2917, + 2918, + 2919, + 2920, + 2921, + 2922, + 2923, + 2924, + 2925, + 2926, + 2927, + 2928, + 2929, + 2930, + 2931, + 2932, + 2933, + 2934, + 2935, + 2936, + 2937, + 2938, + 2939, + 2940, + 2941, + 2942, + 2943, + 2944, + 2945, + 2946, + 2947, + 2948, + 2949, + 2950, + 2951, + 2952, + 2953, + 2954, + 2955, + 2956, + 2957, + 2958, + 2959, + 2960, + 2961, + 2962, + 2963, + 2964, + 2965, + 2966, + 2967, + 2968, + 2969, + 2970, + 2971, + 2972, + 2973, + 2974, + 2975, + 2976, + 2977, + 2978, + 2979, + 2980, + 2981, + 2982, + 2983, + 2984, + 2985, + 2986, + 2987, + 2988, + 2989, + 2990, + 2991, + 2992, + 2993, + 2994, + 2995, + 2996, + 2997, + 2998, + 2999, + 3000, + 3001, + 3002, + 3003, + 3004, + 3005, + 3006, + 3007, + 3008, + 3009, + 3010, + 3011, + 3012, + 3013, + 3014, + 3015, + 3016, + 3017, + 3018, + 3019, + 3020, + 3021, + 3022, + 3023, + 3024, + 3025, + 3026, + 3027, + 3028, + 3029, + 3030, + 3031, + 3032, + 3033, + 3034, + 3035, + 3036, + 3037, + 3038, + 3039, + 3040, + 3041, + 3042, + 3043, + 3044, + 3045, + 3046, + 3047, + 3048, + 3049, + 3050, + 3051, + 3052, + 3053, + 3054, + 3055, + 3056, + 3057, + 3058, + 3059, + 3060, + 3061, + 3062, + 3063, + 3064, + 3065, + 3066, + 3067, + 3068, + 3069, + 3070, + 3071, + 3072, + 3073, + 3074, + 3075, + 3076, + 3077, + 3078, + 3079, + 3080, + 3081, + 3082, + 3083, + 3084, + 3085, + 3086, + 3087, + 3088, + 3089, + 3090, + 3091, + 3092, + 3093, + 3094, + 3095, + 3096, + 3097, + 3098, + 3099, + 3100, + 3101, + 3102, + 3103, + 3104, + 3105, + 3106, + 3107, + 3108, + 3109, + 3110, + 3111, + 3112, + 3113, + 3114, + 3115, + 3116, + 3117, + 3118, + 3119, + 3120, + 3121, + 3122, + 3123, + 3124, + 3125, + 3126, + 3127, + 3128, + 3129, + 3130, + 3131, + 3132, + 3133, + 3134, + 3135, + 3136, + 3137, + 3138, + 3139, + 3140, + 3141, + 3142, + 3143, + 3144, + 3145, + 3146, + 3147, + 3148, + 3149, + 3150, + 3151, + 3152, + 3153, + 3154, + 3155, + 3156, + 3157, + 3158, + 3159, + 3160, + 3161, + 3162, + 3163, + 3164, + 3165, + 3166, + 3167, + 3168, + 3169, + 3170, + 3171, + 3172, + 3173, + 3174, + 3175, + 3176, + 3177, + 3178, + 3179, + 3180, + 3181, + 3182, + 3183, + 3184, + 3185, + 3186, + 3187, + 3188, + 3189, + 3190, + 3191, + 3192, + 3193, + 3194, + 3195, + 3196, + 3197, + 3198, + 3199, + 3200, + 3201, + 3202, + 3203, + 3204, + 3205, + 3206, + 3207, + 3208, + 3209, + 3210, + 3211, + 3212, + 3213, + 3214, + 3215, + 3216, + 3217, + 3218, + 3219, + 3220, + 3221, + 3222, + 3223, + 3224, + 3225, + 3226, + 3227, + 3228, + 3229, + 3230, + 3231, + 3232, + 3233, + 3234, + 3235, + 3236, + 3237, + 3238, + 3239, + 3240, + 3241, + 3242, + 3243, + 3244, + 3245, + 3246, + 3247, + 3248, + 3249, + 3250, + 3251, + 3252, + 3253, + 3254, + 3255, + 3256, + 3257, + 3258, + 3259, + 3260, + 3261, + 3262, + 3263, + 3264, + 3265, + 3266, + 3267, + 3268, + 3269, + 3270, + 3271, + 3272, + 3273, + 3274, + 3275, + 3276, + 3277, + 3278, + 3279, + 3280, + 3281, + 3282, + 3283, + 3284, + 3285, + 3286, + 3287, + 3288, + 3289, + 3290, + 3291, + 3292, + 3293, + 3294, + 3295, + 3296, + 3297, + 3298, + 3299, + 3300, + 3301, + 3302, + 3303, + 3304, + 3305, + 3306, + 3307, + 3308, + 3309, + 3310, + 3311, + 3312, + 3313, + 3314, + 3315, + 3316, + 3317, + 3318, + 3319, + 3320, + 3321, + 3322, + 3323, + 3324, + 3325, + 3326, + 3327, + 3328, + 3329, + 3330, + 3331, + 3332, + 3333, + 3334, + 3335, + 3336, + 3337, + 3338, + 3339, + 3340, + 3341, + 3342, + 3343, + 3344, + 3345, + 3346, + 3347, + 3348, + 3349, + 3350, + 3351, + 3352, + 3353, + 3354, + 3355, + 3356, + 3357, + 3358, + 3359, + 3360, + 3361, + 3362, + 3363, + 3364, + 3365, + 3366, + 3367, + 3368, + 3369, + 3370, + 3371, + 3372, + 3373, + 3374, + 3375, + 3376, + 3377, + 3378, + 3379, + 3380, + 3381, + 3382, + 3383, + 3384, + 3385, + 3386, + 3387, + 3388, + 3389, + 3390, + 3391, + 3392, + 3393, + 3394, + 3395, + 3396, + 3397, + 3398, + 3399, + 3400, + 3401, + 3402, + 3403, + 3404, + 3405, + 3406, + 3407, + 3408, + 3409, + 3410, + 3411, + 3412, + 3413, + 3414, + 3415, + 3416, + 3417, + 3418, + 3419, + 3420, + 3421, + 3422, + 3423, + 3424, + 3425, + 3426, + 3427, + 3428, + 3429, + 3430, + 3431, + 3432, + 3433, + 3434, + 3435, + 3436, + 3437, + 3438, + 3439, + 3440, + 3441, + 3442, + 3443, + 3444, + 3445, + 3446, + 3447, + 3448, + 3449, + 3450, + 3451, + 3452, + 3453, + 3454, + 3455, + 3456, + 3457, + 3458, + 3459, + 3460, + 3461, + 3462, + 3463, + 3464, + 3465, + 3466, + 3467, + 3468, + 3469, + 3470, + 3471, + 3472, + 3473, + 3474, + 3475, + 3476, + 3477, + 3478, + 3479, + 3480, + 3481, + 3482, + 3483, + 3484, + 3485, + 3486, + 3487, + 3488, + 3489, + 3490, + 3491, + 3492, + 3493, + 3494, + 3495, + 3496, + 3497, + 3498, + 3499, + 3500, + 3501, + 3502, + 3503, + 3504, + 3505, + 3506, + 3507, + 3508, + 3509, + 3510, + 3511, + 3512, + 3513, + 3514, + 3515, + 3516, + 3517, + 3518, + 3519, + 3520, + 3521, + 3522, + 3523, + 3524, + 3525, + 3526, + 3527, + 3528, + 3529, + 3530, + 3531, + 3532, + 3533, + 3534, + 3535, + 3536, + 3537, + 3538, + 3539, + 3540, + 3541, + 3542, + 3543, + 3544, + 3545, + 3546, + 3547, + 3548, + 3549, + 3550, + 3551, + 3552, + 3553, + 3554, + 3555, + 3556, + 3557, + 3558, + 3559, + 3560, + 3561, + 3562, + 3563, + 3564, + 3565, + 3566, + 3567, + 3568, + 3569, + 3570, + 3571, + 3572, + 3573, + 3574, + 3575, + 3576, + 3577, + 3578, + 3579, + 3580, + 3581, + 3582, + 3583, + 3584, + 3585, + 3586, + 3587, + 3588, + 3589, + 3590, + 3591, + 3592, + 3593, + 3594, + 3595, + 3596, + 3597, + 3598, + 3599, + 3600, + 3601, + 3602, + 3603, + 3604, + 3605, + 3606, + 3607, + 3608, + 3609, + 3610, + 3611, + 3612, + 3613, + 3614, + 3615, + 3616, + 3617, + 3618, + 3619, + 3620, + 3621, + 3622, + 3623, + 3624, + 3625, + 3626, + 3627, + 3628, + 3629, + 3630, + 3631, + 3632, + 3633, + 3634, + 3635, + 3636, + 3637, + 3638, + 3639, + 3640, + 3641, + 3642, + 3643, + 3644, + 3645, + 3646, + 3647, + 3648, + 3649, + 3650, + 3651, + 3652, + 3653, + 3654, + 3655, + 3656, + 3657, + 3658, + 3659, + 3660, + 3661, + 3662, + 3663, + 3664, + 3665, + 3666, + 3667, + 3668, + 3669, + 3670, + 3671, + 3672, + 3673, + 3674, + 3675, + 3676, + 3677, + 3678, + 3679, + 3680, + 3681, + 3682, + 3683, + 3684, + 3685, + 3686, + 3687, + 3688, + 3689, + 3690, + 3691, + 3692, + 3693, + 3694, + 3695, + 3696, + 3697, + 3698, + 3699, + 3700, + 3701, + 3702, + 3703, + 3704, + 3705, + 3706, + 3707, + 3708, + 3709, + 3710, + 3711, + 3712, + 3713, + 3714, + 3715, + 3716, + 3717, + 3718, + 3719, + 3720, + 3721, + 3722, + 3723, + 3724, + 3725, + 3726, + 3727, + 3728, + 3729, + 3730, + 3731, + 3732, + 3733, + 3734, + 3735, + 3736, + 3737, + 3738, + 3739, + 3740, + 3741, + 3742, + 3743, + 3744, + 3745, + 3746, + 3747, + 3748, + 3749, + 3750, + 3751, + 3752, + 3753, + 3754, + 3755, + 3756, + 3757, + 3758, + 3759, + 3760, + 3761, + 3762, + 3763, + 3764, + 3765, + 3766, + 3767, + 3768, + 3769, + 3770, + 3771, + 3772, + 3773, + 3774, + 3775, + 3776, + 3777, + 3778, + 3779, + 3780, + 3781, + 3782, + 3783, + 3784, + 3785, + 3786, + 3787, + 3788, + 3789, + 3790, + 3791, + 3792, + 3793, + 3794, + 3795, + 3796, + 3797, + 3798, + 3799, + 3800, + 3801, + 3802, + 3803, + 3804, + 3805, + 3806, + 3807, + 3808, + 3809, + 3810, + 3811, + 3812, + 3813, + 3814, + 3815, + 3816, + 3817, + 3818, + 3819, + 3820, + 3821, + 3822, + 3823, + 3824, + 3825, + 3826, + 3827, + 3828, + 3829, + 3830, + 3831, + 3832, + 3833, + 3834, + 3835, + 3836, + 3837, + 3838, + 3839, + 3840, + 3841, + 3842, + 3843, + 3844, + 3845, + 3846, + 3847, + 3848, + 3849, + 3850, + 3851, + 3852, + 3853, + 3854, + 3855, + 3856, + 3857, + 3858, + 3859, + 3860, + 3861, + 3862, + 3863, + 3864, + 3865, + 3866, + 3867, + 3868, + 3869, + 3870, + 3871, + 3872, + 3873, + 3874, + 3875, + 3876, + 3877, + 3878, + 3879, + 3880, + 3881, + 3882, + 3883, + 3884, + 3885, + 3886, + 3887, + 3888 + ], + "right_inds": [ + 33, + 34, + 35, + 36, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 212, + 214, + 215, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 327, + 328, + 329, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 440, + 441, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 451, + 453, + 454, + 455, + 456, + 457, + 458, + 459, + 460, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 513, + 514, + 515, + 516, + 517, + 518, + 519, + 520, + 521, + 522, + 523, + 524, + 525, + 526, + 527, + 528, + 529, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + 538, + 539, + 540, + 541, + 542, + 543, + 544, + 545, + 546, + 547, + 548, + 549, + 550, + 551, + 552, + 553, + 554, + 555, + 556, + 557, + 558, + 559, + 560, + 561, + 562, + 563, + 564, + 565, + 566, + 567, + 568, + 569, + 570, + 571, + 572, + 573, + 574, + 575, + 576, + 577, + 579, + 580, + 581, + 582, + 583, + 584, + 585, + 586, + 587, + 588, + 589, + 590, + 591, + 592, + 593, + 594, + 595, + 596, + 597, + 598, + 599, + 600, + 601, + 602, + 603, + 604, + 605, + 606, + 607, + 608, + 609, + 610, + 611, + 612, + 613, + 614, + 615, + 616, + 617, + 618, + 619, + 620, + 621, + 622, + 623, + 624, + 625, + 626, + 627, + 628, + 629, + 630, + 631, + 632, + 633, + 634, + 635, + 636, + 637, + 638, + 639, + 640, + 641, + 642, + 643, + 644, + 645, + 646, + 647, + 648, + 649, + 650, + 651, + 652, + 653, + 654, + 655, + 656, + 657, + 658, + 659, + 660, + 661, + 662, + 663, + 664, + 665, + 666, + 667, + 668, + 669, + 670, + 671, + 672, + 673, + 674, + 675, + 676, + 677, + 678, + 679, + 680, + 681, + 682, + 683, + 684, + 685, + 686, + 687, + 688, + 689, + 690, + 691, + 692, + 693, + 694, + 695, + 696, + 697, + 698, + 699, + 700, + 701, + 702, + 703, + 704, + 705, + 706, + 707, + 708, + 709, + 710, + 711, + 712, + 713, + 714, + 715, + 716, + 717, + 718, + 719, + 720, + 721, + 722, + 723, + 724, + 725, + 726, + 727, + 728, + 729, + 730, + 731, + 732, + 733, + 734, + 735, + 736, + 737, + 738, + 739, + 740, + 741, + 742, + 743, + 744, + 745, + 746, + 747, + 748, + 749, + 750, + 751, + 752, + 753, + 754, + 755, + 756, + 757, + 758, + 759, + 760, + 761, + 762, + 763, + 764, + 765, + 766, + 767, + 768, + 769, + 770, + 771, + 772, + 773, + 774, + 775, + 776, + 777, + 778, + 779, + 780, + 781, + 782, + 783, + 784, + 785, + 786, + 787, + 788, + 789, + 790, + 791, + 792, + 793, + 794, + 795, + 796, + 797, + 798, + 799, + 800, + 801, + 802, + 803, + 804, + 805, + 806, + 807, + 808, + 809, + 810, + 811, + 812, + 813, + 814, + 815, + 816, + 817, + 818, + 819, + 820, + 821, + 822, + 823, + 824, + 825, + 826, + 827, + 828, + 829, + 830, + 831, + 832, + 833, + 834, + 835, + 836, + 837, + 838, + 839, + 840, + 841, + 842, + 843, + 844, + 845, + 846, + 847, + 848, + 849, + 850, + 851, + 852, + 853, + 854, + 855, + 856, + 857, + 858, + 859, + 860, + 861, + 862, + 863, + 864, + 865, + 866, + 867, + 868, + 869, + 870, + 871, + 872, + 873, + 874, + 875, + 876, + 877, + 878, + 879, + 880, + 881, + 882, + 883, + 884, + 885, + 886, + 887, + 888, + 889, + 890, + 891, + 892, + 893, + 894, + 895, + 896, + 897, + 898, + 899, + 900, + 901, + 902, + 903, + 904, + 905, + 906, + 907, + 908, + 909, + 911, + 912, + 913, + 914, + 915, + 916, + 917, + 918, + 919, + 920, + 921, + 922, + 923, + 924, + 925, + 926, + 927, + 928, + 929, + 930, + 931, + 932, + 933, + 934, + 935, + 936, + 937, + 938, + 939, + 940, + 941, + 942, + 943, + 944, + 945, + 946, + 947, + 948, + 949, + 950, + 951, + 952, + 953, + 954, + 955, + 956, + 957, + 958, + 960, + 961, + 962, + 963, + 965, + 966, + 967, + 968, + 969, + 970, + 971, + 972, + 973, + 974, + 978, + 979, + 980, + 981, + 982, + 983, + 984, + 985, + 986, + 987, + 988, + 989, + 990, + 991, + 992, + 993, + 994, + 995, + 996, + 997, + 998, + 999, + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, + 1011, + 1012, + 1013, + 1014, + 1015, + 1016, + 1017, + 1018, + 1019, + 1020, + 1021, + 1022, + 1023, + 1024, + 1025, + 1026, + 1027, + 1028, + 1029, + 1030, + 1031, + 1032, + 1033, + 1034, + 1035, + 1036, + 1037, + 1038, + 1039, + 1040, + 1041, + 1042, + 1043, + 1044, + 1045, + 1046, + 1047, + 1048, + 1049, + 1050, + 1051, + 1052, + 1053, + 1054, + 1055, + 1056, + 1057, + 1058, + 1059, + 1060, + 1061, + 1062, + 1063, + 1064, + 1065, + 1066, + 1067, + 1068, + 1069, + 1070, + 1071, + 1072, + 1073, + 1074, + 1075, + 1076, + 1077, + 1078, + 1079, + 1080, + 1081, + 1082, + 1083, + 1084, + 1085, + 1086, + 1087, + 1088, + 1089, + 1090, + 1091, + 1092, + 1093, + 1094, + 1095, + 1096, + 1097, + 1098, + 1099, + 1100, + 1101, + 1102, + 1103, + 1104, + 1105, + 1106, + 1107, + 1108, + 1109, + 1110, + 1111, + 1112, + 1113, + 1114, + 1115, + 1116, + 1117, + 1118, + 1119, + 1120, + 1121, + 1122, + 1123, + 1124, + 1125, + 1126, + 1127, + 1128, + 1129, + 1130, + 1131, + 1132, + 1133, + 1134, + 1135, + 1136, + 1137, + 1138, + 1139, + 1140, + 1141, + 1142, + 1143, + 1144, + 1145, + 1146, + 1147, + 1148, + 1149, + 1150, + 1151, + 1152, + 1153, + 1154, + 1155, + 1156, + 1157, + 1158, + 1159, + 1160, + 1161, + 1162, + 1163, + 1164, + 1165, + 1166, + 1167, + 1168, + 1169, + 1170, + 1171, + 1173, + 1174, + 1177, + 1179, + 1180, + 1181, + 1182, + 1183, + 1184, + 1185, + 1186, + 1187, + 1188, + 1189, + 1190, + 1191, + 1192, + 1193, + 1195, + 1196, + 1197, + 1198, + 1199, + 1200, + 1201, + 1202, + 1203, + 1204, + 1205, + 1206, + 1207, + 1208, + 1209, + 1210, + 1211, + 1212, + 1213, + 1214, + 1215, + 1216, + 1217, + 1218, + 1219, + 1220, + 1221, + 1222, + 1223, + 1224, + 1225, + 1226, + 1227, + 1228, + 1229, + 1230, + 1231, + 1232, + 1233, + 1234, + 1235, + 1236, + 1237, + 1238, + 1239, + 1240, + 1241, + 1242, + 1244, + 1245, + 1246, + 1247, + 1248, + 1249, + 1250, + 1251, + 1252, + 1253, + 1254, + 1255, + 1256, + 1257, + 1258, + 1259, + 1260, + 1261, + 1262, + 1263, + 1264, + 1265, + 1266, + 1267, + 1268, + 1269, + 1270, + 1271, + 1272, + 1273, + 1274, + 1275, + 1276, + 1277, + 1278, + 1279, + 1280, + 1281, + 1282, + 1283, + 1284, + 1285, + 1286, + 1287, + 1288, + 1289, + 1290, + 1291, + 1292, + 1293, + 1294, + 1295, + 1296, + 1297, + 1298, + 1299, + 1300, + 1301, + 1302, + 1303, + 1304, + 1305, + 1306, + 1307, + 1308, + 1309, + 1310, + 1311, + 1312, + 1313, + 1314, + 1315, + 1316, + 1317, + 1318, + 1319, + 1320, + 1321, + 1322, + 1323, + 1324, + 1325, + 1326, + 1327, + 1328, + 1329, + 1330, + 1331, + 1332, + 1333, + 1334, + 1335, + 1336, + 1337, + 1338, + 1339, + 1340, + 1341, + 1342, + 1343, + 1344, + 1345, + 1346, + 1347, + 1348, + 1349, + 1350, + 1351, + 1352, + 1353, + 1354, + 1355, + 1356, + 1357, + 1358, + 1359, + 1360, + 1361, + 1362, + 1363, + 1364, + 1365, + 1366, + 1367, + 1368, + 1369, + 1370, + 1371, + 1372, + 1373, + 1374, + 1375, + 1376, + 1377, + 1378, + 1379, + 1380, + 1381, + 1382, + 1383, + 1384, + 1385, + 1386, + 1387, + 1388, + 1389, + 1390, + 1391, + 1392, + 1393, + 1394, + 1395, + 1396, + 1397, + 1398, + 1399, + 1400, + 1401, + 1402, + 1403, + 1404, + 1405, + 1406, + 1407, + 1408, + 1409, + 1410, + 1411, + 1412, + 1413, + 1414, + 1415, + 1416, + 1417, + 1418, + 1419, + 1420, + 1421, + 1422, + 1423, + 1424, + 1425, + 1426, + 1427, + 1428, + 1429, + 1430, + 1431, + 1432, + 1433, + 1434, + 1435, + 1436, + 1437, + 1438, + 1439, + 1440, + 1441, + 1442, + 1443, + 1444, + 1445, + 1446, + 1447, + 1448, + 1449, + 1450, + 1451, + 1452, + 1453, + 1454, + 1455, + 1456, + 1457, + 1458, + 1459, + 1460, + 1461, + 1462, + 1463, + 1464, + 1465, + 1466, + 1467, + 1468, + 1469, + 1470, + 1471, + 1472, + 1473, + 1474, + 1475, + 1476, + 1477, + 1478, + 1479, + 1480, + 1481, + 1482, + 1483, + 1484, + 1485, + 1486, + 1487, + 1488, + 1489, + 1490, + 1491, + 1492, + 1493, + 1494, + 1495, + 1496, + 1497, + 1498, + 1499, + 1500, + 1501, + 1502, + 1503, + 1504, + 1505, + 1506, + 1507, + 1508, + 1509, + 1510, + 1511, + 1512, + 1513, + 1514, + 1515, + 1516, + 1517, + 1518, + 1519, + 1520, + 1521, + 1522, + 1523, + 1524, + 1525, + 1526, + 1527, + 1528, + 1529, + 1530, + 1531, + 1532, + 1533, + 1534, + 1535, + 1536, + 1537, + 1538, + 1539, + 1540, + 1541, + 1542, + 1543, + 1544, + 1545, + 1546, + 1547, + 1548, + 1549, + 1550, + 1551, + 1552, + 1553, + 1554, + 1555, + 1556, + 1557, + 1558, + 1559, + 1560, + 1561, + 1562, + 1563, + 1564, + 1565, + 1566, + 1567, + 1568, + 1569, + 1570, + 1571, + 1572, + 1573, + 1574, + 1575, + 1576, + 1577, + 1578, + 1579, + 1580, + 1581, + 1582, + 1583, + 1584, + 1585, + 1586, + 1587, + 1588, + 1589, + 1590, + 1591, + 1592, + 1593, + 1594, + 1595, + 1596, + 1597, + 1598, + 1599, + 1600, + 1601, + 1602, + 1603, + 1604, + 1605, + 1606, + 1607, + 1608, + 1609, + 1610, + 1611, + 1612, + 1613, + 1614, + 1615, + 1616, + 1617, + 1618, + 1619, + 1620, + 1621, + 1622, + 1623, + 1624, + 1625, + 1626, + 1627, + 1628, + 1629, + 1630, + 1631, + 1632, + 1633, + 1634, + 1635, + 1636, + 1637, + 1638, + 1639, + 1640, + 1641, + 1642, + 1643, + 1644, + 1645, + 1646, + 1647, + 1648, + 1649, + 1650, + 1651, + 1652, + 1653, + 1654, + 1655, + 1656, + 1657, + 1658, + 1659, + 1660, + 1661, + 1662, + 1663, + 1664, + 1665, + 1666, + 1667, + 1668, + 1669, + 1670, + 1671, + 1672, + 1673, + 1674, + 1675, + 1676, + 1677, + 1678, + 1679, + 1680, + 1681, + 1682, + 1683, + 1684, + 1685, + 1686, + 1687, + 1688, + 1689, + 1690, + 1691, + 1692, + 1693, + 1694, + 1695, + 1696, + 1697, + 1698, + 1699, + 1700, + 1701, + 1702, + 1703, + 1704, + 1705, + 1706, + 1707, + 1708, + 1709, + 1710, + 1711, + 1712, + 1713, + 1714, + 1715, + 1716, + 1717, + 1718, + 1719, + 1720, + 1721, + 1722, + 1723, + 1724, + 1725, + 1726, + 1727, + 1728, + 1729, + 1730, + 1731, + 1732, + 1733, + 1734, + 1735, + 1736, + 1737, + 1738, + 1740, + 1741, + 1742, + 1743, + 1744, + 1745, + 1746, + 1747, + 1748, + 1749, + 1750, + 1751, + 1752, + 1753, + 1754, + 1755, + 1756, + 1757, + 1758, + 1759, + 1760, + 1761, + 1762, + 1763, + 1764, + 1765, + 1766, + 1767, + 1768, + 1769, + 1770, + 1771, + 1772, + 1773, + 1774, + 1775, + 1776, + 1777, + 1778, + 1779, + 1780, + 1781, + 1782, + 1783, + 1784, + 1785, + 1786, + 1787, + 1788, + 1789, + 1790, + 1791, + 1792, + 1793, + 1794, + 1795, + 1841, + 1864, + 1865, + 1866, + 1867, + 1868, + 1869, + 1871, + 1872, + 1873, + 1874, + 1875, + 1876, + 1877, + 1878, + 1879, + 1880, + 1881, + 1882, + 1883, + 1884, + 1885, + 1886, + 1887, + 1888, + 1889, + 1890, + 1891, + 1892, + 1893, + 1894, + 1895, + 1896, + 1897, + 1898, + 1899, + 1900, + 1901, + 1902, + 1903, + 1904, + 1905, + 1906, + 1907, + 1908, + 1909, + 1910, + 1911, + 1912, + 1913, + 1914, + 1915, + 1916, + 1917, + 1918, + 1920, + 1921, + 1922, + 1923, + 1924, + 1925, + 1926, + 1927, + 1928, + 1929, + 1930, + 1931, + 1932, + 1933, + 1934, + 1935, + 1936, + 1937, + 1938, + 1939, + 1940, + 1941, + 1942, + 1943, + 1944, + 1945, + 1946, + 1947, + 1948, + 1949, + 1950, + 1951, + 1952, + 1953, + 1954, + 1955, + 1956, + 1957, + 1958, + 1959, + 1962, + 1963, + 1964, + 1966, + 1968, + 1969, + 1970, + 1971, + 1972, + 1973, + 1974, + 1975, + 1976, + 1977, + 1978, + 1979, + 1980, + 1981, + 1982, + 1983, + 1984, + 1985, + 1986, + 1987, + 1988, + 1989, + 1990, + 1991, + 1992, + 1993, + 1994, + 1995, + 1996, + 1997, + 1998, + 1999, + 2000, + 2001, + 2002, + 2004, + 2005, + 2006, + 2007, + 2008, + 2009, + 2010, + 2011 + ] +} \ No newline at end of file diff --git a/data/statistics/statistics_modified_v1.json b/data/statistics/statistics_modified_v1.json new file mode 100644 index 0000000000000000000000000000000000000000..2eb46afa77a3dd46b8c8c41eb97939b606001550 --- /dev/null +++ b/data/statistics/statistics_modified_v1.json @@ -0,0 +1,615 @@ +{ + "trans_mean": [ + 0.02, + 0.0, + 14.79 + ], + "trans_std": [ + 0.10, + 0.10, + 2.65 + ], + "flength_mean": [ + 2169.0 + ], + "flength_std": [ + 448.0 + ], + "pose_mean": [ + [ + [ + 0.44, + 0.0, + -0.0 + ], + [ + 0.0, + 0.0, + -1.0 + ], + [ + -0.0, + 0.44, + 0.0 + ] + ], + [ + [ + 0.97, + -0.0, + -0.08 + ], + [ + 0.0, + 0.98, + 0.0 + ], + [ + 0.08, + -0.0, + 0.98 + ] + ], + [ + [ + 0.98, + 0.0, + 0.01 + ], + [ + -0.0, + 0.99, + 0.0 + ], + [ + -0.01, + 0.0, + 0.98 + ] + ], + [ + [ + 0.98, + -0.0, + -0.03 + ], + [ + 0.0, + 0.99, + 0.0 + ], + [ + 0.04, + -0.0, + 0.98 + ] + ], + [ + [ + 0.98, + 0.0, + 0.02 + ], + [ + -0.0, + 0.99, + -0.0 + ], + [ + -0.02, + -0.0, + 0.98 + ] + ], + [ + [ + 0.99, + 0.0, + -0.0 + ], + [ + -0.0, + 0.99, + -0.0 + ], + [ + 0.0, + 0.0, + 0.99 + ] + ], + [ + [ + 0.99, + 0.0, + 0.03 + ], + [ + 0.0, + 0.99, + -0.0 + ], + [ + -0.03, + 0.0, + 0.99 + ] + ], + [ + [ + 0.95, + -0.05, + 0.04 + ], + [ + 0.05, + 0.98, + -0.01 + ], + [ + -0.03, + 0.01, + 0.96 + ] + ], + [ + [ + 0.91, + -0.01, + -0.19 + ], + [ + -0.01, + 0.98, + -0.05 + ], + [ + 0.19, + 0.03, + 0.91 + ] + ], + [ + [ + 0.85, + -0.04, + 0.23 + ], + [ + -0.0, + 0.99, + 0.07 + ], + [ + -0.23, + -0.06, + 0.85 + ] + ], + [ + [ + 0.93, + 0.0, + 0.16 + ], + [ + -0.01, + 0.99, + 0.01 + ], + [ + -0.16, + -0.02, + 0.93 + ] + ], + [ + [ + 0.95, + 0.05, + 0.03 + ], + [ + -0.05, + 0.98, + 0.02 + ], + [ + -0.03, + -0.01, + 0.96 + ] + ], + [ + [ + 0.91, + 0.01, + -0.19 + ], + [ + 0.02, + 0.98, + 0.05 + ], + [ + 0.2, + -0.03, + 0.91 + ] + ], + [ + [ + 0.84, + 0.03, + 0.24 + ], + [ + 0.01, + 0.99, + -0.06 + ], + [ + -0.24, + 0.07, + 0.84 + ] + ], + [ + [ + 0.93, + -0.0, + 0.18 + ], + [ + 0.01, + 0.99, + -0.01 + ], + [ + -0.18, + 0.02, + 0.93 + ] + ], + [ + [ + 0.95, + -0.0, + 0.01 + ], + [ + -0.0, + 0.96, + 0.0 + ], + [ + -0.0, + -0.0, + 0.99 + ] + ], + [ + [ + 0.93, + 0.0, + -0.11 + ], + [ + -0.0, + 0.97, + -0.0 + ], + [ + 0.12, + 0.0, + 0.95 + ] + ], + [ + [ + 0.96, + -0.04, + -0.06 + ], + [ + 0.03, + 0.98, + -0.02 + ], + [ + 0.06, + 0.01, + 0.96 + ] + ], + [ + [ + 0.96, + 0.05, + 0.04 + ], + [ + -0.05, + 0.98, + -0.05 + ], + [ + -0.05, + 0.05, + 0.96 + ] + ], + [ + [ + 0.96, + -0.0, + -0.09 + ], + [ + -0.0, + 0.99, + 0.01 + ], + [ + 0.09, + -0.01, + 0.96 + ] + ], + [ + [ + 0.96, + 0.0, + 0.06 + ], + [ + -0.02, + 0.98, + 0.05 + ], + [ + -0.05, + -0.06, + 0.96 + ] + ], + [ + [ + 0.96, + 0.04, + -0.07 + ], + [ + -0.03, + 0.98, + 0.02 + ], + [ + 0.07, + -0.01, + 0.96 + ] + ], + [ + [ + 0.96, + -0.05, + 0.04 + ], + [ + 0.04, + 0.98, + 0.05 + ], + [ + -0.05, + -0.04, + 0.97 + ] + ], + [ + [ + 0.96, + -0.0, + -0.09 + ], + [ + 0.0, + 0.99, + -0.01 + ], + [ + 0.09, + 0.01, + 0.96 + ] + ], + [ + [ + 0.96, + -0.0, + 0.06 + ], + [ + 0.02, + 0.98, + -0.05 + ], + [ + -0.05, + 0.06, + 0.96 + ] + ], + [ + [ + 0.73, + 0.0, + -0.4 + ], + [ + -0.0, + 0.98, + 0.0 + ], + [ + 0.39, + 0.0, + 0.73 + ] + ], + [ + [ + 0.95, + -0.0, + -0.07 + ], + [ + 0.0, + 0.99, + -0.0 + ], + [ + 0.07, + 0.0, + 0.95 + ] + ], + [ + [ + 0.98, + 0.0, + -0.09 + ], + [ + -0.0, + 0.99, + -0.0 + ], + [ + 0.09, + 0.0, + 0.98 + ] + ], + [ + [ + 0.99, + -0.0, + 0.03 + ], + [ + 0.0, + 0.99, + -0.0 + ], + [ + -0.03, + 0.0, + 0.99 + ] + ], + [ + [ + 0.96, + -0.0, + 0.1 + ], + [ + 0.0, + 0.98, + -0.0 + ], + [ + -0.09, + 0.0, + 0.96 + ] + ], + [ + [ + 0.79, + -0.01, + 0.21 + ], + [ + 0.01, + 0.96, + 0.0 + ], + [ + -0.2, + 0.0, + 0.82 + ] + ], + [ + [ + 0.89, + -0.0, + 0.07 + ], + [ + 0.0, + 0.98, + 0.0 + ], + [ + -0.07, + 0.0, + 0.9 + ] + ], + [ + [ + 0.96, + -0.0, + 0.09 + ], + [ + 0.0, + 0.99, + -0.0 + ], + [ + -0.1, + 0.0, + 0.96 + ] + ], + [ + [ + 0.93, + -0.09, + -0.07 + ], + [ + 0.1, + 0.93, + 0.03 + ], + [ + 0.03, + -0.06, + 0.95 + ] + ], + [ + [ + 0.86, + 0.1, + -0.37 + ], + [ + -0.12, + 0.94, + 0.01 + ], + [ + 0.35, + 0.05, + 0.88 + ] + ] + ] +} diff --git a/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg b/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3120182f24a6246e2fca3fdccbb905a30b70cc8 Binary files /dev/null and b/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg differ diff --git a/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg b/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg new file mode 100644 index 0000000000000000000000000000000000000000..73092b851ac8b555406365a34ee91a8d8069b9bc Binary files /dev/null and b/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg differ diff --git a/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png b/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png new file mode 100644 index 0000000000000000000000000000000000000000..16b538f73b30ba8bd00816a9804803c62b12ad6a Binary files /dev/null and b/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png differ diff --git a/gradio_demo/barc_demo_v3.py b/gradio_demo/barc_demo_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..84c8b11541419092853bdbaf9b96b2b406e1dc8a --- /dev/null +++ b/gradio_demo/barc_demo_v3.py @@ -0,0 +1,268 @@ +# python gradio_demo/barc_demo_v3.py + +import numpy as np +import os +import glob +import torch +from torch.utils.data import DataLoader +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +import torchvision.transforms as T +import cv2 +from matplotlib import pyplot as plt +from PIL import Image + +import gradio as gr + + + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../', 'src')) +from stacked_hourglass.datasets.imgcropslist import ImgCrops +from combined_model.train_main_image_to_3d_withbreedrel import do_visual_epoch +from combined_model.model_shape_v7 import ModelImageTo3d_withshape_withproj + +from configs.barc_cfg_defaults import get_cfg_global_updated + + + +def get_prediction(model, img_path_or_img, confidence=0.5): + """ + see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g + get_prediction + parameters: + - img_path - path of the input image + - confidence - threshold value for prediction score + method: + - Image is obtained from the image path + - the image is converted to image tensor using PyTorch's Transforms + - image is passed through the model to get the predictions + - class, box coordinates are obtained, but only prediction score > threshold + are chosen. + + """ + if isinstance(img_path_or_img, str): + img = Image.open(img_path_or_img).convert('RGB') + else: + img = img_path_or_img + transform = T.Compose([T.ToTensor()]) + img = transform(img) + pred = model([img]) + # pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())] + pred_class = list(pred[0]['labels'].numpy()) + pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())] + pred_score = list(pred[0]['scores'].detach().numpy()) + try: + pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1] + pred_boxes = pred_boxes[:pred_t+1] + pred_class = pred_class[:pred_t+1] + return pred_boxes, pred_class, pred_score + except: + print('no bounding box with a score that is high enough found! -> work on full image') + return None, None, None + +def detect_object(model, img_path_or_img, confidence=0.5, rect_th=2, text_size=0.5, text_th=1): + """ + see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g + object_detection_api + parameters: + - img_path_or_img - path of the input image + - confidence - threshold value for prediction score + - rect_th - thickness of bounding box + - text_size - size of the class label text + - text_th - thichness of the text + method: + - prediction is obtained from get_prediction method + - for each prediction, bounding box is drawn and text is written + with opencv + - the final image is displayed + """ + boxes, pred_cls, pred_scores = get_prediction(model, img_path_or_img, confidence) + if isinstance(img_path_or_img, str): + img = cv2.imread(img_path_or_img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + img = img_path_or_img + is_first = True + bbox = None + if boxes is not None: + for i in range(len(boxes)): + cls = pred_cls[i] + if cls == 18 and bbox is None: + cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th) + # cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) + cv2.putText(img, str(pred_scores[i]), boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) + bbox = boxes[i] + return img, bbox + + + +def run_bbox_inference(input_image): + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model.eval() + out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png') + img, bbox = detect_object(model=model, img_path_or_img=input_image, confidence=0.5) + fig = plt.figure() # plt.figure(figsize=(20,30)) + plt.imsave(out_path, img) + return img, bbox + + + + + +def run_barc_inference(input_image, bbox=None): + + # load configs + cfg = get_cfg_global_updated() + + model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, 'barc_complete', 'model_best.pth.tar') + + + + # Select the hardware device to use for inference. + if torch.cuda.is_available() and cfg.device=='cuda': + device = torch.device('cuda', torch.cuda.current_device()) + # torch.backends.cudnn.benchmark = True + else: + device = torch.device('cpu') + + path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, model_file_complete) + + # Disable gradient calculations. + torch.set_grad_enabled(False) + + # prepare complete model + complete_model = ModelImageTo3d_withshape_withproj( + num_stage_comb=cfg.params.NUM_STAGE_COMB, num_stage_heads=cfg.params.NUM_STAGE_HEADS, \ + num_stage_heads_pose=cfg.params.NUM_STAGE_HEADS_POSE, trans_sep=cfg.params.TRANS_SEP, \ + arch=cfg.params.ARCH, n_joints=cfg.params.N_JOINTS, n_classes=cfg.params.N_CLASSES, \ + n_keyp=cfg.params.N_KEYP, n_bones=cfg.params.N_BONES, n_betas=cfg.params.N_BETAS, n_betas_limbs=cfg.params.N_BETAS_LIMBS, \ + n_breeds=cfg.params.N_BREEDS, n_z=cfg.params.N_Z, image_size=cfg.params.IMG_SIZE, \ + silh_no_tail=cfg.params.SILH_NO_TAIL, thr_keyp_sc=cfg.params.KP_THRESHOLD, add_z_to_3d_input=cfg.params.ADD_Z_TO_3D_INPUT, + n_segbps=cfg.params.N_SEGBPS, add_segbps_to_3d_input=cfg.params.ADD_SEGBPS_TO_3D_INPUT, add_partseg=cfg.params.ADD_PARTSEG, n_partseg=cfg.params.N_PARTSEG, \ + fix_flength=cfg.params.FIX_FLENGTH, structure_z_to_betas=cfg.params.STRUCTURE_Z_TO_B, structure_pose_net=cfg.params.STRUCTURE_POSE_NET, + nf_version=cfg.params.NF_VERSION) + + # load trained model + print(path_model_file_complete) + assert os.path.isfile(path_model_file_complete) + print('Loading model weights from file: {}'.format(path_model_file_complete)) + checkpoint_complete = torch.load(path_model_file_complete) + state_dict_complete = checkpoint_complete['state_dict'] + complete_model.load_state_dict(state_dict_complete, strict=False) + complete_model = complete_model.to(device) + + save_imgs_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples') + if not os.path.exists(save_imgs_path): + os.makedirs(save_imgs_path) + + input_image_list = [input_image] + if bbox is not None: + input_bbox_list = [bbox] + else: + input_bbox_list = None + val_dataset = ImgCrops(image_list=input_image_list, bbox_list=input_bbox_list, dataset_mode='complete') + test_name_list = val_dataset.test_name_list + val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, + num_workers=0, pin_memory=True, drop_last=False) + + # run visual evaluation + # remark: take ACC_Joints and DATA_INFO from StanExt as this is the training dataset + all_results = do_visual_epoch(val_loader, complete_model, device, + ImgCrops.DATA_INFO, + weight_dict=None, + acc_joints=ImgCrops.ACC_JOINTS, + save_imgs_path=None, # save_imgs_path, + metrics='all', + test_name_list=test_name_list, + render_all=cfg.params.RENDER_ALL, + pck_thresh=cfg.params.PCK_THRESH, + return_results=True) + + mesh = all_results[0]['mesh_posed'] + result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z') + + mesh.apply_transform([[-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + mesh.export(file_obj=result_path + '.glb') + result_gltf = result_path + '.glb' + return [result_gltf, result_gltf] + + + + + + +def run_complete_inference(input_image): + + output_interm_image, output_interm_bbox = run_bbox_inference(input_image.copy()) + + print(output_interm_bbox) + + # output_image = run_barc_inference(input_image) + output_image = run_barc_inference(input_image, output_interm_bbox) + + return output_image + + + + +# demo = gr.Interface(run_barc_inference, gr.Image(), "image") +# demo = gr.Interface(run_complete_inference, gr.Image(), "image") + + + +# see: https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization/blob/main/PIFu/spaces.py + +description = ''' +# BARC + +#### Project Page +* https://barc.is.tue.mpg.de/ + +#### Description +This is a demo for BARC. While BARC is trained on image crops, this demo uses a pretrained Faster-RCNN in order to get bounding boxes for the dogs. +To see your result you may have to wait a minute or two, please be paitient. + +
+ +More + +#### Citation + +``` +@inproceedings{BARC:2022, + title = {BARC}: Learning to Regress {3D} Dog Shape from Images by Exploiting Breed Information, + author = {Rueegg, Nadine and Zuffi, Silvia and Schindler, Konrad and Black, Michael J.}, + booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, + year = {2022} +} +``` + +
+''' + +examples = sorted(glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.jpg')) + glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.png'))) + + +demo = gr.Interface( + fn=run_complete_inference, + description=description, + # inputs=gr.Image(type="filepath", label="Input Image"), + inputs=gr.Image(label="Input Image"), + outputs=[ + gr.Model3D( + clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), + gr.File(label="Download 3D Model") + ], + examples=examples, + thumbnail="barc_thumbnail.png", + allow_flagging="never", + cache_examples=True +) + + + +demo.launch(share=True) \ No newline at end of file diff --git a/src/bps_2d/bps_for_segmentation.py b/src/bps_2d/bps_for_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7382c5e875f878b296321fed6e0c46b037781e --- /dev/null +++ b/src/bps_2d/bps_for_segmentation.py @@ -0,0 +1,114 @@ + +# code idea from https://github.com/sergeyprokudin/bps + +import os +import numpy as np +from PIL import Image +import time +import scipy +import scipy.spatial +import pymp + + +##################### +QUERY_POINTS = np.asarray([30, 34, 31, 55, 29, 84, 35, 108, 34, 145, 29, 171, 27, + 196, 29, 228, 58, 35, 61, 55, 57, 83, 56, 109, 63, 148, 58, 164, 57, 197, 60, + 227, 81, 26, 87, 58, 85, 87, 89, 117, 86, 142, 89, 172, 84, 197, 88, 227, 113, + 32, 116, 58, 112, 88, 118, 113, 109, 147, 114, 173, 119, 201, 113, 229, 139, + 29, 141, 59, 142, 93, 139, 117, 146, 147, 141, 173, 142, 201, 143, 227, 170, + 26, 173, 59, 166, 90, 174, 117, 176, 141, 169, 175, 167, 198, 172, 227, 198, + 30, 195, 59, 204, 85, 198, 116, 195, 140, 198, 175, 194, 193, 199, 227, 221, + 26, 223, 57, 227, 83, 227, 113, 227, 140, 226, 173, 230, 196, 228, 229]).reshape((64, 2)) +##################### + +class SegBPS(): + + def __init__(self, query_points=QUERY_POINTS, size=256): + self.size = size + self.query_points = query_points + row, col = np.indices((self.size, self.size)) + self.indices_rc = np.stack((row, col), axis=2) # (256, 256, 2) + self.pts_aranged = np.arange(64) + return + + def _do_kdtree(self, combined_x_y_arrays, points): + # see https://stackoverflow.com/questions/10818546/finding-index-of-nearest- + # point-in-numpy-arrays-of-x-and-y-coordinates + mytree = scipy.spatial.cKDTree(combined_x_y_arrays) + dist, indexes = mytree.query(points) + return indexes + + def calculate_bps_points(self, seg, thr=0.5, vis=False, out_path=None): + # seg: input segmentation image of shape (256, 256) with values between 0 and 1 + query_val = seg[self.query_points[:, 0], self.query_points[:, 1]] + pts_fg = self.pts_aranged[query_val>=thr] + pts_bg = self.pts_aranged[query_val=thr] + if candidate_inds_bg.shape[0] == 0: + candidate_inds_bg = np.ones((1, 2)) * 128 # np.zeros((1, 2)) + if candidate_inds_fg.shape[0] == 0: + candidate_inds_fg = np.ones((1, 2)) * 128 # np.zeros((1, 2)) + # calculate nearest points + all_nearest_points = np.zeros((64, 2)) + all_nearest_points[pts_fg, :] = candidate_inds_bg[self._do_kdtree(candidate_inds_bg, self.query_points[pts_fg, :]), :] + all_nearest_points[pts_bg, :] = candidate_inds_fg[self._do_kdtree(candidate_inds_fg, self.query_points[pts_bg, :]), :] + all_nearest_points_01 = all_nearest_points / 255. + if vis: + self.visualize_result(seg, all_nearest_points, out_path=out_path) + return all_nearest_points_01 + + def calculate_bps_points_batch(self, seg_batch, thr=0.5, vis=False, out_path=None): + # seg_batch: input segmentation image of shape (bs, 256, 256) with values between 0 and 1 + bs = seg_batch.shape[0] + all_nearest_points_01_batch = np.zeros((bs, self.query_points.shape[0], 2)) + for ind in range(0, bs): # 0.25 + seg = seg_batch[ind, :, :] + all_nearest_points_01 = self.calculate_bps_points(seg, thr=thr, vis=vis, out_path=out_path) + all_nearest_points_01_batch[ind, :, :] = all_nearest_points_01 + return all_nearest_points_01_batch + + def visualize_result(self, seg, all_nearest_points, out_path=None): + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + # img: (256, 256, 3) + img = (np.stack((seg, seg, seg), axis=2) * 155).astype(np.int) + if out_path is None: + ind_img = 0 + out_path = '../test_img' + str(ind_img) + '.png' + fig, ax = plt.subplots() + plt.imshow(img) + plt.gca().set_axis_off() + plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) + plt.margins(0,0) + ratio_in_out = 1 # 255 + for idx, (y, x) in enumerate(self.query_points): + x = int(x*ratio_in_out) + y = int(y*ratio_in_out) + plt.scatter([x], [y], marker="x", s=50) + x2 = int(all_nearest_points[idx, 1]) + y2 = int(all_nearest_points[idx, 0]) + plt.scatter([x2], [y2], marker="o", s=50) + plt.plot([x, x2], [y, y2]) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0) + plt.close() + return + + + + + +if __name__ == "__main__": + ind_img = 2 # 4 + path_seg_top = '...../pytorch-stacked-hourglass/results/dogs_hg8_ks_24_v1/test/' + path_seg = os.path.join(path_seg_top, 'seg_big_' + str(ind_img) + '.png') + img = np.asarray(Image.open(path_seg)) + # min is 0.004, max is 0.9 + # low values are background, high values are foreground + seg = img[:, :, 1] / 255. + # calculate points + bps = SegBPS() + bps.calculate_bps_points(seg, thr=0.5, vis=False, out_path=None) + + diff --git a/src/combined_model/loss_image_to_3d_withbreedrel.py b/src/combined_model/loss_image_to_3d_withbreedrel.py new file mode 100644 index 0000000000000000000000000000000000000000..5414f8443d9df4aac7ceb409380474d9d69ef27b --- /dev/null +++ b/src/combined_model/loss_image_to_3d_withbreedrel.py @@ -0,0 +1,277 @@ + + +import torch +import numpy as np +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) +# from priors.pose_prior_35 import Prior +# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior +from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior +from priors.shape_prior import ShapePrior +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa +from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS + +class Loss(torch.nn.Module): + def __init__(self, data_info, nf_version=None): + super(Loss, self).__init__() + self.criterion_regr = torch.nn.MSELoss() # takes the mean + self.criterion_class = torch.nn.CrossEntropyLoss() + self.data_info = data_info + self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) + self.l_anchor = None + self.l_pos = None + self.l_neg = None + + if nf_version is not None: + self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) + self.shape_prior = ShapePrior(UNITY_SMAL_SHAPE_PRIOR_DOGS) + self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1) + + # load 3d data for the unity dogs (an optional shape prior for 11 breeds) + with open(UNITY_SMAL_SHAPE_PRIOR_DOGS, 'rb') as f: + data = pkl.load(f) + dog_betas_unity = data['dogs_betas'] + self.dog_betas_unity = {29: torch.tensor(dog_betas_unity[0, :]).float(), + 91: torch.tensor(dog_betas_unity[1, :]).float(), + 84: torch.tensor(0.5*dog_betas_unity[3, :] + 0.5*dog_betas_unity[14, :]).float(), + 85: torch.tensor(dog_betas_unity[5, :]).float(), + 28: torch.tensor(dog_betas_unity[6, :]).float(), + 94: torch.tensor(dog_betas_unity[7, :]).float(), + 92: torch.tensor(dog_betas_unity[8, :]).float(), + 95: torch.tensor(dog_betas_unity[10, :]).float(), + 20: torch.tensor(dog_betas_unity[11, :]).float(), + 83: torch.tensor(dog_betas_unity[12, :]).float(), + 99: torch.tensor(dog_betas_unity[16, :]).float()} + + def prepare_anchor_pos_neg(self, batch_size, device): + l0 = np.arange(0, batch_size, 2) + l_anchor = [] + l_pos = [] + l_neg = [] + for ind in l0: + xx = set(np.arange(0, batch_size)) + xx.discard(ind) + xx.discard(ind+1) + for ind2 in xx: + if ind2 % 2 == 0: + l_anchor.append(ind) + l_pos.append(ind + 1) + else: + l_anchor.append(ind + 1) + l_pos.append(ind) + l_neg.append(ind2) + self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device) + self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device) + self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device) + return + + + def forward(self, output_reproj, target_dict, weight_dict=None): + # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] + # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] + batch_size = output_reproj['keyp_2d'].shape[0] + + # loss on reprojected keypoints + output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2)) + target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) + weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) + keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) + loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ + max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) + + # loss on reprojected silhouette + assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape + silh_loss_type = 'default' + if silh_loss_type == 'default': + with torch.no_grad(): + thr_silh = 20 + diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) + diff_x = diff.reshape((batch_size, -1)) + weights_resh_x = weights_resh.reshape((batch_size, -1)) + unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) + loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3]) + loss_silh = loss_silh_bs[unweighted_kp_mean_dist 0: + for ind_dog in range(target_dict['breed_index'].shape[0]): + breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) + if breed_index in self.dog_betas_unity.keys(): + betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device) + betas_output = output_reproj['betas'][ind_dog, :] + betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :] + loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1]) + else: + weight_dict['models3d'] = 0 + + # shape resularization loss on shapedirs + # -> in the current version shapedirs are kept fixed, so we don't need those losses + if weight_dict['shapedirs'] > 0: + raise NotImplementedError + else: + loss_shapedirs = torch.zeros((1)).mean().to(output_reproj['betas'].device) + + # prior on back joints (not used in cvpr 2022 paper) + # -> elementwise MSE loss on all 6 coefficients of 6d rotation representation + if 'pose_0' in weight_dict.keys(): + if weight_dict['pose_0'] > 0: + pred_pose_rot6d = output_reproj['pose_rot6d'] + w_rj_np = np.zeros((pred_pose_rot6d.shape[1])) + w_rj_np[[2, 3, 4, 5]] = 1.0 # back + w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device) + zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1)) + loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None]) + else: + loss_pose = torch.zeros((1)).mean() + + # pose prior + # -> we did experiment with different pose priors, for example: + # * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py, + # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py) + # * vae + # * normalizing flow pose prior + # -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below + if 'poseprior' in weight_dict.keys(): + if weight_dict['poseprior'] > 0: + pred_pose_rot6d = output_reproj['pose_rot6d'] + pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) + if 'normalizing_flow_tiger' in weight_dict['poseprior_options']: + if output_reproj['normflow_z'] is not None: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square') + else: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square') + elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']: + if output_reproj['normflow_z'] is not None: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob') + else: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob') + else: + raise NotImplementedError + else: + loss_poseprior = torch.zeros((1)).mean() + else: + weight_dict['poseprior'] = 0 + loss_poseprior = torch.zeros((1)).mean() + + # add a prior which penalizes side-movement angles for legs + if 'poselegssidemovement' in weight_dict.keys(): + use_pose_legs_side_loss = True + else: + use_pose_legs_side_loss = False + if use_pose_legs_side_loss: + leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back + leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back + vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype) + vec[2] = -1 + x0_rotmat = pred_pose + x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :] + x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :] + x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec + x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec + eps=0 # 1e-7 + # use the component of the vector which points to the side + loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean() + else: + loss_poselegssidemovement = torch.zeros((1)).mean() + weight_dict['poselegssidemovement'] = 0 + + # dog breed classification loss + dog_breed_gt = target_dict['breed_index'] + dog_breed_pred = output_reproj['dog_breed'] + loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt) + + # dog breed relationship loss + # -> we did experiment with many other options, but none was significantly better + if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed + assert weight_dict['breed'] > 0 + z = output_reproj['z'] + # go through all pairs and compare them to each other sample + if self.l_anchor is None: + self.prepare_anchor_pos_neg(batch_size, z.device) + anchor = torch.index_select(z, 0, self.l_anchor) + positive = torch.index_select(z, 0, self.l_pos) + negative = torch.index_select(z, 0, self.l_neg) + loss_breed = self.criterion_triplet(anchor, positive, negative) + else: + loss_breed = torch.zeros((1)).mean() + + # regularizarion for focal length + loss_flength_near_mean = torch.mean(output_reproj['flength']**2) + loss_flength = loss_flength_near_mean + + # bodypart segmentation loss + if 'partseg' in weight_dict.keys(): + if weight_dict['partseg'] > 0: + raise NotImplementedError + else: + loss_partseg = torch.zeros((1)).mean() + else: + weight_dict['partseg'] = 0 + loss_partseg = torch.zeros((1)).mean() + + # weight and combine losses + loss_keyp_weighted = loss_keyp * weight_dict['keyp'] + loss_silh_weighted = loss_silh * weight_dict['silh'] + loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs'] + loss_pose_weighted = loss_pose * weight_dict['pose_0'] + loss_class_weighted = loss_class * weight_dict['class'] + loss_breed_weighted = loss_breed * weight_dict['breed'] + loss_flength_weighted = loss_flength * weight_dict['flength'] + loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior'] + loss_partseg_weighted = loss_partseg * weight_dict['partseg'] + loss_models3d_weighted = loss_models3d * weight_dict['models3d'] + loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement'] + + #################################################################################################### + loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \ + loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \ + loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted + #################################################################################################### + + loss_dict = {'loss': loss.item(), + 'loss_keyp_weighted': loss_keyp_weighted.item(), \ + 'loss_silh_weighted': loss_silh_weighted.item(), \ + 'loss_shape_weighted': loss_shape_weighted.item(), \ + 'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \ + 'loss_pose0_weighted': loss_pose_weighted.item(), \ + 'loss_class_weighted': loss_class_weighted.item(), \ + 'loss_breed_weighted': loss_breed_weighted.item(), \ + 'loss_flength_weighted': loss_flength_weighted.item(), \ + 'loss_poseprior_weighted': loss_poseprior_weighted.item(), \ + 'loss_partseg_weighted': loss_partseg_weighted.item(), \ + 'loss_models3d_weighted': loss_models3d_weighted.item(), \ + 'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item()} + + return loss, loss_dict + + + + diff --git a/src/combined_model/model_shape_v7.py b/src/combined_model/model_shape_v7.py new file mode 100644 index 0000000000000000000000000000000000000000..807488d335e9a4f0870cff88a0540cc90b998f3f --- /dev/null +++ b/src/combined_model/model_shape_v7.py @@ -0,0 +1,500 @@ + +import pickle as pkl +import numpy as np +import torchvision.models as models +from torchvision import transforms +import torch +from torch import nn +from torch.nn.parameter import Parameter +from kornia.geometry.subpix import dsnt # kornia 0.4.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from stacked_hourglass.utils.evaluation import get_preds_soft +from stacked_hourglass import hg1, hg2, hg8 +from lifting_to_3d.linear_model import LinearModelComplete, LinearModel +from lifting_to_3d.inn_model_for_shape import INNForShape +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d +from smal_pytorch.smal_model.smal_torch_new import SMAL +from smal_pytorch.renderer.differentiable_renderer import SilhRenderer +from bps_2d.bps_for_segmentation import SegBPS +from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS as SHAPE_PRIOR +from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL + + + +class SmallLinear(nn.Module): + def __init__(self, input_size=64, output_size=30, linear_size=128): + super(SmallLinear, self).__init__() + self.relu = nn.ReLU(inplace=True) + self.w1 = nn.Linear(input_size, linear_size) + self.w2 = nn.Linear(linear_size, linear_size) + self.w3 = nn.Linear(linear_size, output_size) + def forward(self, x): + # pre-processing + y = self.w1(x) + y = self.relu(y) + y = self.w2(y) + y = self.relu(y) + y = self.w3(y) + return y + + +class MyConv1d(nn.Module): + def __init__(self, input_size=37, output_size=30, start=True): + super(MyConv1d, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.start = start + self.weight = Parameter(torch.ones((self.output_size))) + self.bias = Parameter(torch.zeros((self.output_size))) + def forward(self, x): + # pre-processing + if self.start: + y = x[:, :self.output_size] + else: + y = x[:, -self.output_size:] + y = y * self.weight[None, :] + self.bias[None, :] + return y + + +class ModelShapeAndBreed(nn.Module): + def __init__(self, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'): + super(ModelShapeAndBreed, self).__init__() + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs # n_betas_logscale + self.n_breeds = n_breeds + self.structure_z_to_betas = structure_z_to_betas + if self.structure_z_to_betas == '1dconv': + if not (n_z == self.n_betas+self.n_betas_limbs): + raise ValueError + # shape branch + self.resnet = models.resnet34(pretrained=False) + # replace the first layer + n_in = 3 + 1 + self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.resnet.fc = nn.Linear(512, n_z) + # softmax + self.soft_max = torch.nn.Softmax(dim=1) + # fc network (and other versions) to connect z with betas + p_dropout = 0.2 + if self.structure_z_to_betas == 'default': + self.linear_betas = LinearModel(linear_size=1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = LinearModel(linear_size=1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_betas_limbs) + elif self.structure_z_to_betas == 'lin': + self.linear_betas = nn.Linear(n_z, self.n_betas) + self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs) + elif self.structure_z_to_betas == 'fc_0': + self.linear_betas = SmallLinear(linear_size=128, # 1024, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024, + input_size=n_z, + output_size=self.n_betas_limbs) + elif structure_z_to_betas == 'fc_1': + self.linear_betas = LinearModel(linear_size=64, # 1024, + num_stage=1, + p_dropout=0, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = LinearModel(linear_size=64, # 1024, + num_stage=1, + p_dropout=0, + input_size=n_z, + output_size=self.n_betas_limbs) + elif self.structure_z_to_betas == '1dconv': + self.linear_betas = MyConv1d(n_z, self.n_betas, start=True) + self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False) + elif self.structure_z_to_betas == 'inn': + self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0) + else: + raise ValueError + # network to connect latent shape vector z with dog breed classification + self.linear_breeds = LinearModel(linear_size=1024, # 1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_breeds) + # shape multiplicator + self.shape_multiplicator_np = np.ones(self.n_betas) + with open(SHAPE_PRIOR, 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + res = u.load() + # shape predictions are centered around the mean dog of our dog model + self.betas_mean_np = res['dog_cluster_mean'] + + def forward(self, img, seg_raw=None, seg_prep=None): + # img is the network input image + # seg_raw is before softmax and subtracting 0.5 + # seg_prep would be the prepared_segmentation + if seg_prep is None: + seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5 + input_img_and_seg = torch.cat((img, seg_prep), axis=1) + res_output = self.resnet(input_img_and_seg) + dog_breed_output = self.linear_breeds(res_output) + if self.structure_z_to_betas == 'inn': + shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output) + else: + shape_output_orig = self.linear_betas(res_output) * 0.1 + betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device) + shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas] + shape_limbs_output_orig = self.linear_betas_limbs(res_output) + shape_limbs_output = shape_limbs_output_orig * 0.1 + output_dict = {'z': res_output, + 'breeds': dog_breed_output, + 'betas': shape_output_orig, + 'betas_limbs': shape_limbs_output_orig} + return output_dict + + + +class LearnableShapedirs(nn.Module): + def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10): + super(LearnableShapedirs, self).__init__() + # shapedirs_init = self.smal.shapedirs.detach() + self.n_betas = n_betas + self.n_betas_fixed = n_betas_fixed + self.sym_ids_dict = sym_ids_dict + sym_left_ids = self.sym_ids_dict['left'] + sym_right_ids = self.sym_ids_dict['right'] + sym_center_ids = self.sym_ids_dict['center'] + self.n_center = sym_center_ids.shape[0] + self.n_left = sym_left_ids.shape[0] + self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs + # get indices to go from half_shapedirs to shapedirs + inds_back = np.zeros((3889)) + for ind in range(0, sym_center_ids.shape[0]): + ind_in_forward = sym_center_ids[ind] + inds_back[ind_in_forward] = ind + for ind in range(0, sym_left_ids.shape[0]): + ind_in_forward = sym_left_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind + for ind in range(0, sym_right_ids.shape[0]): + ind_in_forward = sym_right_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind + self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long()) + # self.smal.shapedirs: (51, 11667) + # shapedirs: (3889, 3, n_sd) + # shapedirs_half: (2012, 3, n_sd) + sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas)) + self.register_buffer('sd', sd) + sd_center = sd[sym_center_ids, :, self.n_betas_fixed:] + sd_left = sd[sym_left_ids, :, self.n_betas_fixed:] + self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach())) + self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach())) + self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach())) + self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach())) + self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach())) + def forward(self): + device = self.learnable_half_shapedirs_c0.device + half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \ + torch.zeros((self.n_center, self.n_sd)).to(device), \ + self.learnable_half_shapedirs_c2), axis=1) + half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \ + self.learnable_half_shapedirs_l1, \ + self.learnable_half_shapedirs_l2), axis=1) + half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \ + - self.learnable_half_shapedirs_l1, \ + self.learnable_half_shapedirs_l2), axis=1) + half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right)) + shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch) + shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd) + shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667) + return shapedirs_complete, shapedirs_complete_prepared + + + + + +class ModelImageToBreed(nn.Module): + def __init__(self, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True): + super(ModelImageToBreed, self).__init__() + self.n_classes = n_classes + self.n_partseg = n_partseg + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_keyp = n_keyp + self.n_bones = n_bones + self.n_breeds = n_breeds + self.image_size = image_size + self.upsample_seg = True + self.threshold_scores = thr_keyp_sc + self.n_z = n_z + self.add_partseg = add_partseg + # ------------------------------ STACKED HOUR GLASS ------------------------------ + if arch == 'hg8': + self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) + else: + raise Exception('unrecognised model architecture: ' + arch) + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z) + def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): + batch_size = input_img.shape[0] + device = input_img.device + # ------------------------------ STACKED HOUR GLASS ------------------------------ + hourglass_out_dict = self.stacked_hourglass(input_img) + last_seg = hourglass_out_dict['seg_final'] + last_heatmap = hourglass_out_dict['out_list_kp'][-1] + # - prepare keypoints (from heatmap) + # normalize predictions -> from logits to probability distribution + # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) + # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) + # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) + keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) + if self.threshold_scores is not None: + scores[scores>self.threshold_scores] = 1.0 + scores[scores<=self.threshold_scores] = 0.0 + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + # breed_model takes as input the image as well as the predicted segmentation map + # -> we need to split up ModelImageTo3d, such that we can use the silhouette + resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) + pred_breed = resnet_output['breeds'] # (bs, n_breeds) + pred_betas = resnet_output['betas'] + pred_betas_limbs = resnet_output['betas_limbs'] + small_output = {'keypoints_norm': keypoints_norm, + 'keypoints_scores': scores} + small_output_reproj = {'betas': pred_betas, + 'betas_limbs': pred_betas_limbs, + 'dog_breed': pred_breed} + return small_output, None, small_output_reproj + +class ModelImageTo3d_withshape_withproj(nn.Module): + def __init__(self, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None): + super(ModelImageTo3d_withshape_withproj, self).__init__() + self.n_classes = n_classes + self.n_partseg = n_partseg + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_keyp = n_keyp + self.n_bones = n_bones + self.n_breeds = n_breeds + self.image_size = image_size + self.threshold_scores = thr_keyp_sc + self.upsample_seg = True + self.silh_no_tail = silh_no_tail + self.add_z_to_3d_input = add_z_to_3d_input + self.add_segbps_to_3d_input = add_segbps_to_3d_input + self.add_partseg = add_partseg + assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input) + self.n_z = n_z + if add_segbps_to_3d_input: + self.n_segbps = n_segbps # 64 + self.segbps_model = SegBPS() + else: + self.n_segbps = 0 + self.fix_flength = fix_flength + self.render_partseg = render_partseg + self.structure_z_to_betas = structure_z_to_betas + self.structure_pose_net = structure_pose_net + assert self.structure_pose_net in ['default', 'vae', 'normflow'] + self.nf_version = nf_version + self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas))) + self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32)) + p_dropout = 0.2 # 0.5 + # ------------------------------ SMAL MODEL ------------------------------ + self.smal = SMAL(template_name='neutral') + # New for rendering without tail + f_np = self.smal.faces.detach().cpu().numpy() + self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :] + # in theory we could optimize for improved shapedirs, but we do not do that + # -> would need to implement regularizations + # -> there are better ways than changing the shapedirs + self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10) + # ------------------------------ STACKED HOUR GLASS ------------------------------ + if arch == 'hg8': + self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) + else: + raise Exception('unrecognised model architecture: ' + arch) + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas) + # ------------------------------ LINEAR 3D MODEL ------------------------------ + # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} + self.soft_max = torch.nn.Softmax(dim=1) + input_size = self.n_keyp*3 + self.n_bones + self.model_3d = LinearModelComplete(linear_size=1024, + num_stage_comb=num_stage_comb, + num_stage_heads=num_stage_heads, + num_stage_heads_pose=num_stage_heads_pose, + trans_sep=trans_sep, + p_dropout=p_dropout, # 0.5, + input_size=input_size, + intermediate_size=1024, + output_info=None, + n_joints=n_joints, + n_z=self.n_z, + add_z_to_3d_input=self.add_z_to_3d_input, + n_segbps=self.n_segbps, + add_segbps_to_3d_input=self.add_segbps_to_3d_input, + structure_pose_net=self.structure_pose_net, + nf_version = self.nf_version) + # ------------------------------ RENDERING ------------------------------ + self.silh_renderer = SilhRenderer(image_size) + + def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): + batch_size = input_img.shape[0] + device = input_img.device + # ------------------------------ STACKED HOUR GLASS ------------------------------ + hourglass_out_dict = self.stacked_hourglass(input_img) + last_seg = hourglass_out_dict['seg_final'] + last_heatmap = hourglass_out_dict['out_list_kp'][-1] + # - prepare keypoints (from heatmap) + # normalize predictions -> from logits to probability distribution + # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) + # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) + # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) + keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) + if self.threshold_scores is not None: + scores[scores>self.threshold_scores] = 1.0 + scores[scores<=self.threshold_scores] = 0.0 + # ------------------------------ LEARNABLE SHAPE MODEL ------------------------------ + # in our cvpr 2022 paper we do not change the shapedirs + # learnable_sd_complete has shape (3889, 3, n_sd) + # learnable_sd_complete_prepared has shape (n_sd, 11667) + learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs() + shapedirs_sel = learnable_sd_complete_prepared # None + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + # breed_model takes as input the image as well as the predicted segmentation map + # -> we need to split up ModelImageTo3d, such that we can use the silhouette + resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) + pred_breed = resnet_output['breeds'] # (bs, n_breeds) + pred_z = resnet_output['z'] + # - prepare shape + pred_betas = resnet_output['betas'] + pred_betas_limbs = resnet_output['betas_limbs'] + # - calculate bone lengths + with torch.no_grad(): + use_mean_bone_lengths = False + if use_mean_bone_lengths: + bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) + else: + assert (bone_lengths_prepared is None) + bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True) + # ------------------------------ LINEAR 3D MODEL ------------------------------ + # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} + # prepare input for 2d-to-3d network + keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2) + if bone_lengths_prepared is None: + bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) + # should we add silhouette to 3d input? should we add z? + if self.add_segbps_to_3d_input: + seg_raw = last_seg + seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog + with torch.no_grad(): + seg_prep_np = seg_prep_bps.detach().cpu().numpy() + bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2) + bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1)) + bps_output_prep = bps_output * 2. - 1 + input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1) + elif self.add_z_to_3d_input: + # we do not use this in our cvpr 2022 version + input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + input_vec_additional = pred_z + input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1) + else: + input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) + output = self.model_3d(input_vec) + # add predicted keypoints to the output dict + output['keypoints_norm'] = keypoints_norm + output['keypoints_scores'] = scores + # - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again + pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) + if self.structure_pose_net == 'default': + pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :] + elif self.structure_pose_net == 'normflow': + pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) + pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] + pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros + else: + pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) + pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] + pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros + pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))) + pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3)) + pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6)) + + if self.fix_flength: + output['flength'] = torch.zeros_like(output['flength']) + pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :] + else: + pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) + pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig) + pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] + + # ------------------------------ RENDERING ------------------------------ + # get 3d model (SMAL) + V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf='green', shapedirs_sel=shapedirs_sel) + keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) + # render silhouette + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + if not self.silh_no_tail: + pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, + points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength) + else: + faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) + pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, + points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength) + # get torch 'Meshes' + torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep) + + # render body parts (not part of cvpr 2022 version) + if self.render_partseg: + raise NotImplementedError + else: + partseg_images = None + partseg_images_hg = None + + # ------------------------------ PREPARE OUTPUT ------------------------------ + # create output dictionarys + # output: contains all output from model_image_to_3d + # output_unnorm: same as output, but normalizations are undone + # output_reproj: smal output and reprojected keypoints as well as silhouette + keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1) + output_unnorm = {'pose_rotmat': pred_pose, + 'flength': pred_flength, + 'trans': pred_trans, + 'keypoints':keypoints_heatmap_256} + output_reproj = {'vertices_smal': V, + 'torch_meshes': torch_meshes, + 'keyp_3d': keyp_3d, + 'keyp_2d': pred_keyp, + 'silh': pred_silh_images, + 'betas': pred_betas, + 'betas_limbs': pred_betas_limbs, + 'pose_rot6d': pred_pose_rot6d, # used for pose prior... + 'dog_breed': pred_breed, + 'shapedirs': shapedirs_sel, + 'z': pred_z, + 'flength_unnorm': pred_flength, + 'flength': output['flength'], + 'partseg_images_rend': partseg_images, + 'partseg_images_hg_nograd': partseg_images_hg, + 'normflow_z': output['normflow_z']} + + return output, output_unnorm, output_reproj + + def render_vis_nograd(self, vertices, focal_lengths, color=0): + # this function is for visualization only + # vertices: (bs, n_verts, 3) + # focal_lengths: (bs, 1) + # color: integer, either 0 or 1 + # returns a torch tensor of shape (bs, image_size, image_size, 3) + with torch.no_grad(): + batch_size = vertices.shape[0] + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + visualizations = self.silh_renderer.get_visualization_nograd(vertices, + faces_prep, focal_lengths, color=color) + return visualizations + diff --git a/src/combined_model/train_main_image_to_3d_withbreedrel.py b/src/combined_model/train_main_image_to_3d_withbreedrel.py new file mode 100644 index 0000000000000000000000000000000000000000..8c06655d08cbc60e1239147aa01272cd901fa04b --- /dev/null +++ b/src/combined_model/train_main_image_to_3d_withbreedrel.py @@ -0,0 +1,470 @@ + +import torch +import torch.nn as nn +import torch.backends.cudnn +import torch.nn.parallel +from tqdm import tqdm +import os +import pathlib +from matplotlib import pyplot as plt +import cv2 +import numpy as np +import torch +import trimesh + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image +from metrics.metrics import Metrics +from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_training_epoch(train_loader, model, loss_module, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + accuracies = AverageMeter() + # Put the model in training mode. + model.train() + # prepare progress bar + iterable = enumerate(train_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False) + iterable = progress + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + # prepare variables, put them on the right device + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key == 'has_seg': + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do training step ----------------------- + assert model.training, 'model must be in training mode.' + with torch.enable_grad(): + # ----- forward pass ----- + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + # ----- loss ----- + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ----- backward pass and parameter update ----- + optimiser.zero_grad() + loss.backward() + optimiser.step() + # ---------------------------------------------------------------- + + # prepare losses for progress bar + bs_fake = 1 # batch_size + losses.update(loss_dict['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg + ) + progress.set_postfix_str(my_string) + + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_validation_epoch(val_loader, model, loss_module, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + accuracies = AverageMeter() + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + # Put the model in evaluation mode. + model.eval() + # prepare progress bar + iterable = enumerate(val_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False) + iterable = progress + # summarize information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + batch_size = val_loader.batch_size + # prepare variables, put them on the right device + my_step = 0 + for i, (input, target_dict) in iterable: + curr_batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key == 'has_seg': + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do validation step ----------------------- + with torch.no_grad(): + # ----- forward pass ----- + # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores']) + # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints']) + # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength']) + # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh']) + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + # ----- loss ----- + if metrics == 'no_loss': + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ---------------------------------------------------------------- + + if i == 0: + if len_dataset is None: + len_data = val_loader.batch_size * len(val_loader) # 1703 + else: + len_data = len_dataset + if metrics == 'all' or metrics == 'no_loss': + pck = np.zeros((len_data)) + pck_by_part = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} + acc_sil_2d = np.zeros(len_data) + + all_betas = np.zeros((len_data, output_reproj['betas'].shape[1])) + all_betas_limbs = np.zeros((len_data, output_reproj['betas_limbs'].shape[1])) + all_z = np.zeros((len_data, output_reproj['z'].shape[1])) + all_pose_rotmat = np.zeros((len_data, output_unnorm['pose_rotmat'].shape[1], 3, 3)) + all_flength = np.zeros((len_data, output_unnorm['flength'].shape[1])) + all_trans = np.zeros((len_data, output_unnorm['trans'].shape[1])) + all_breed_indices = np.zeros((len_data)) + all_image_names = [] # len_data * [None] + + index = i + ind_img = 0 + if save_imgs_path is not None: + # render predicted 3d models + visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'], + focal_lengths=output_unnorm['flength'], + color=0) # color=2) + for ind_img in range(len(target_dict['index'])): + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + # save image with predicted keypoints + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = output['keypoints_scores'][ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model (front view) + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + out_path = save_imgs_path + '/comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # save predicted 3d model (side view) + vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=output_unnorm['flength'], + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save mesh + V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False) + mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj') + except: + print('dont save an image') + + if metrics == 'all' or metrics == 'no_loss': + # prepare a dictionary with all the predicted results + preds = {} + preds['betas'] = output_reproj['betas'].cpu().detach().numpy() + preds['betas_limbs'] = output_reproj['betas_limbs'].cpu().detach().numpy() + preds['z'] = output_reproj['z'].cpu().detach().numpy() + preds['pose_rotmat'] = output_unnorm['pose_rotmat'].cpu().detach().numpy() + preds['flength'] = output_unnorm['flength'].cpu().detach().numpy() + preds['trans'] = output_unnorm['trans'].cpu().detach().numpy() + preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1)) + img_names = [] + for ind_img2 in range(0, output_reproj['betas'].shape[0]): + if test_name_list is not None: + img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_') + img_name2 = img_name2.split('.')[0] + else: + img_name2 = str(index) + '_' + str(ind_img2) + img_names.append(img_name2) + preds['image_names'] = img_names + # prepare keypoints for PCK calculation - predicted as well as ground truth + pred_keypoints_norm = output['keypoints_norm'] # -1 to 1 + pred_keypoints_256 = output_reproj['keyp_2d'] + pred_keypoints = pred_keypoints_256 + gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) + gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1 + gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm + # prepare silhouette for IoU calculation - predicted as well as ground truth + has_seg = target_dict['has_seg'] + img_border_mask = target_dict['img_border_mask'][:, 0, :, :] + gtseg = target_dict['silh'] + synth_silhouettes = output_reproj['silh'][:, 0, :, :] # output_reproj['silh'] + synth_silhouettes[synth_silhouettes>0.5] = 1 + synth_silhouettes[synth_silhouettes<0.5] = 0 + # calculate PCK as well as IoU (similar to WLDO) + preds['acc_PCK'] = Metrics.PCK( + pred_keypoints, gt_keypoints, + gtseg, has_seg, idxs=EVAL_KEYPOINTS, + thresh_range=[pck_thresh], # [0.15], + ) + preds['acc_IOU'] = Metrics.IOU( + synth_silhouettes, gtseg, + img_border_mask, mask=has_seg + ) + for group, group_kps in KEYPOINT_GROUPS.items(): + preds[f'{group}_PCK'] = Metrics.PCK( + pred_keypoints, gt_keypoints, gtseg, has_seg, + thresh_range=[pck_thresh], # [0.15], + idxs=group_kps + ) + # add results for all images in this batch to lists + curr_batch_size = pred_keypoints_256.shape[0] + if not (preds['acc_PCK'].data.cpu().numpy().shape == (pck[my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): + import pdb; pdb.set_trace() + pck[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() + acc_sil_2d[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() + for part in pck_by_part: + pck_by_part[part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() + all_betas[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas'] + all_betas_limbs[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs'] + all_z[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z'] + all_pose_rotmat[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat'] + all_flength[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength'] + all_trans[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans'] + all_breed_indices[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index'] + all_image_names.extend(preds['image_names']) + # update progress bar + if progress is not None: + my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format( + pck[:(my_step * batch_size + curr_batch_size)].mean(), + acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean()) + progress.set_postfix_str(my_string) + else: + # measure accuracy and record loss + bs_fake = 1 # batch_size + losses.update(loss_dict['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg + ) + progress.set_postfix_str(my_string) + my_step += 1 + if metrics == 'all': + summary = {'pck': pck, 'acc_sil_2d': acc_sil_2d, 'pck_by_part':pck_by_part, + 'betas': all_betas, 'betas_limbs': all_betas_limbs, 'z': all_z, 'pose_rotmat': all_pose_rotmat, + 'flenght': all_flength, 'trans': all_trans, 'image_names': all_image_names, 'breed_indices': all_breed_indices} + return my_string, summary + elif metrics == 'no_loss': + return my_string, np.average(np.asarray(acc_sil_2d)) + else: + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False): + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + all_results = [] + + # Put the model in evaluation mode. + model.eval() + + iterable = enumerate(val_loader) + + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + input = input.float().to(device) + partial_results = {} + + # ----------------------- do visualization step ----------------------- + with torch.no_grad(): + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + + index = i + ind_img = 0 + for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size): + + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + partial_results['img_name'] = img_name + visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'], + focal_lengths=output_unnorm['flength'], + color=0) # 2) + # save image with predicted keypoints + pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = output['keypoints_scores'][ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model + # (1) front view + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + partial_results['comp_pred'] = im_masked + if save_imgs_path is not None: + out_path = save_imgs_path + '/comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # (2) side view + vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=output_unnorm['flength'], + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['rot_tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + render_all = True + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save posed mesh + V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False) + partial_results['mesh_posed'] = mesh_posed + if save_imgs_path is not None: + mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj') + except: + print('pass...') + all_results.append(partial_results) + if return_results: + return all_results + else: + return \ No newline at end of file diff --git a/src/configs/SMAL_configs.py b/src/configs/SMAL_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..5c977b887b2265ddc44bb85f4713be673bae64ff --- /dev/null +++ b/src/configs/SMAL_configs.py @@ -0,0 +1,165 @@ + + +import numpy as np +import os +import sys + + +# SMAL_DATA_DIR = '/is/cluster/work/nrueegg/dog_project/pytorch-dogs-inference/src/smal_pytorch/smpl_models/' +# SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'smal_pytorch', 'smal_data') +SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'smal_data') + +# we replace the old SMAL model by a more dog specific model (see BARC cvpr 2022 paper) +# our model has several differences compared to the original SMAL model, some of them are: +# - the PCA shape space is recalculated (from partially new data and weighted) +# - coefficients for limb length changes are allowed (similar to WLDO, we did borrow some of their code) +# - all dogs have a core of approximately the same length +# - dogs are centered in their root joint (which is close to the tail base) +# -> like this the root rotations is always around this joint AND (0, 0, 0) +# -> before this it would happen that the animal 'slips' from the image middle to the side when rotating it. Now +# 'trans' also defines the center of the rotation +# - we correct the back joint locations such that all those joints are more aligned +SMAL_MODEL_PATH = os.path.join(SMAL_DATA_DIR, 'my_smpl_SMBLD_nbj_v3.pkl') +UNITY_SMAL_SHAPE_PRIOR_DOGS = os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl') + +SYMMETRY_INDS_FILE = os.path.join(SMAL_DATA_DIR, 'symmetry_inds.json') + +mean_dog_bone_lengths_txt = os.path.join(SMAL_DATA_DIR, 'mean_dog_bone_lengths.txt') + +# there exist different keypoint configurations, for example keypoints corresponding to SMAL joints or keypoints defined based on vertex locations +KEYPOINT_CONFIGURATION = 'green' # green: same as in https://github.com/benjiebob/SMALify/blob/master/config.py + +# some vertex indices, (from silvia zuffi´s code, create_projected_images_cats.py) +KEY_VIDS = np.array(([1068, 1080, 1029, 1226], # left eye + [2660, 3030, 2675, 3038], # right eye + [910], # mouth low + [360, 1203, 1235, 1230], # front left leg, low + [3188, 3156, 2327, 3183], # front right leg, low + [1976, 1974, 1980, 856], # back left leg, low + [3854, 2820, 3852, 3858], # back right leg, low + [452, 1811], # tail start + [416, 235, 182], # front left leg, top + [2156, 2382, 2203], # front right leg, top + [829], # back left leg, top + [2793], # back right leg, top + [60, 114, 186, 59], # throat, close to base of neck + [2091, 2037, 2036, 2160], # withers (a bit lower than in reality) + [384, 799, 1169, 431], # front left leg, middle + [2351, 2763, 2397, 3127], # front right leg, middle + [221, 104], # back left leg, middle + [2754, 2192], # back right leg, middle + [191, 1158, 3116, 2165], # neck + [28], # Tail tip + [542], # Left Ear + [2507], # Right Ear + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip + [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail + +# the following vertices are used for visibility only: if one of the vertices is visible, +# then we assume that the joint is visible! There is some noise, but we don't care, as this is +# for generation of the synthetic dataset only +KEY_VIDS_VISIBILITY_ONLY = np.array(([1068, 1080, 1029, 1226, 645], # left eye + [2660, 3030, 2675, 3038, 2567], # right eye + [910, 11, 5], # mouth low + [360, 1203, 1235, 1230, 298, 408, 303, 293, 384], # front left leg, low + [3188, 3156, 2327, 3183, 2261, 2271, 2573, 2265], # front right leg, low + [1976, 1974, 1980, 856, 559, 851, 556], # back left leg, low + [3854, 2820, 3852, 3858, 2524, 2522, 2815, 2072], # back right leg, low + [452, 1811, 63, 194, 52, 370, 64], # tail start + [416, 235, 182, 440, 8, 80, 73, 112], # front left leg, top + [2156, 2382, 2203, 2050, 2052, 2406, 3], # front right leg, top + [829, 219, 218, 173, 17, 7, 279], # back left leg, top + [2793, 582, 140, 87, 2188, 2147, 2063], # back right leg, top + [60, 114, 186, 59, 878, 130, 189, 45], # throat, close to base of neck + [2091, 2037, 2036, 2160, 190, 2164], # withers (a bit lower than in reality) + [384, 799, 1169, 431, 321, 314, 437, 310, 323], # front left leg, middle + [2351, 2763, 2397, 3127, 2278, 2285, 2282, 2275, 2359], # front right leg, middle + [221, 104, 105, 97, 103], # back left leg, middle + [2754, 2192, 2080, 2251, 2075, 2074], # back right leg, middle + [191, 1158, 3116, 2165, 154, 653, 133, 339], # neck + [28, 474, 475, 731, 24], # Tail tip + [542, 147, 509, 200, 522], # Left Ear + [2507,2174, 2122, 2126, 2474], # Right Ear + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip + [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail + +# see: https://github.com/benjiebob/SMALify/blob/master/config.py +# JOINT DEFINITIONS - based on SMAL joints and additional {eyes, ear tips, chin and nose} +TORSO_JOINTS = [2, 5, 8, 11, 12, 23] +CANONICAL_MODEL_JOINTS = [ + 10, 9, 8, # upper_left [paw, middle, top] + 20, 19, 18, # lower_left [paw, middle, top] + 14, 13, 12, # upper_right [paw, middle, top] + 24, 23, 22, # lower_right [paw, middle, top] + 25, 31, # tail [start, end] + 33, 34, # ear base [left, right] + 35, 36, # nose, chin + 38, 37, # ear tip [left, right] + 39, 40, # eyes [left, right] + 6, 11, # withers, throat (throat is inaccurate and withers also) + 28] # tail middle + # old: 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat + + + +# the following list gives the indices of the KEY_VIDS_JOINTS that must be taken in order +# to judge if the CANONICAL_MODEL_JOINTS are visible - those are all approximations! +CMJ_VISIBILITY_IN_KEY_VIDS = [ + 3, 14, 8, # left front leg + 5, 16, 10, # left rear leg + 4, 15, 9, # right front leg + 6, 17, 11, # right rear leg + 7, 19, # tail front, tail back + 20, 21, # ear base (but can not be found in blue, se we take the tip) + 2, 2, # mouth (was: 22, 2) + 20, 21, # ear tips + 1, 0, # eyes + 18, # withers, not sure where this point is + 12, # throat + 23, # mid tail + ] + +# define which bone lengths are used as input to the 2d-to-3d network +IDXS_BONES_NO_REDUNDANCY = [6,7,8,9,16,17,18,19,32,1,2,3,4,5,14,15,24,25,26,27,28,29,30,31] +# load bone lengths of the mean dog (already filtered) +mean_dog_bone_lengths = [] +with open(mean_dog_bone_lengths_txt, 'r') as f: + for line in f: + mean_dog_bone_lengths.append(float(line.split('\n')[0])) +MEAN_DOG_BONE_LENGTHS_NO_RED = np.asarray(mean_dog_bone_lengths)[IDXS_BONES_NO_REDUNDANCY] # (24, ) + +# Body part segmentation: +# the body can be segmented based on the bones and for the new dog model also based on the new shapedirs +# axis_horizontal = self.shapedirs[2, :].reshape((-1, 3))[:, 0] +# all_indices = np.arange(3889) +# tail_indices = all_indices[axis_horizontal.detach().cpu().numpy() < 0.0] +VERTEX_IDS_TAIL = [ 0, 4, 9, 10, 24, 25, 28, 453, 454, 456, 457, + 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, + 469, 470, 471, 472, 473, 474, 475, 724, 725, 726, 727, + 728, 729, 730, 731, 813, 975, 976, 977, 1109, 1110, 1111, + 1811, 1813, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, + 1828, 1835, 1836, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, + 1968, 1969, 2418, 2419, 2421, 2422, 2423, 2424, 2425, 2426, 2427, + 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, + 2439, 2440, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2777, + 3067, 3068, 3069, 3842, 3843, 3844, 3845, 3846, 3847] + +# same as in https://github.com/benjiebob/WLDO/blob/master/global_utils/config.py +EVAL_KEYPOINTS = [ + 0, 1, 2, # left front + 3, 4, 5, # left rear + 6, 7, 8, # right front + 9, 10, 11, # right rear + 12, 13, # tail start -> end + 14, 15, # left ear, right ear + 16, 17, # nose, chin + 18, 19] # left tip, right tip + +KEYPOINT_GROUPS = { + 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs + 'tail': [12, 13], # tail + 'ears': [14, 15, 18, 19], # ears + 'face': [16, 17] # face +} + + diff --git a/src/configs/anipose_data_info.py b/src/configs/anipose_data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7bad68b45cf9926fdfd3ca1b7e1f147e909cfd --- /dev/null +++ b/src/configs/anipose_data_info.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import List +import json +import numpy as np +import os + +STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics') +STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json') + +@dataclass +class DataInfo: + rgb_mean: List[float] + rgb_stddev: List[float] + joint_names: List[str] + hflip_indices: List[int] + n_joints: int + n_keyp: int + n_bones: int + n_betas: int + image_size: int + trans_mean: np.ndarray + trans_std: np.ndarray + flength_mean: np.ndarray + flength_std: np.ndarray + pose_rot6d_mean: np.ndarray + keypoint_weights: List[float] + +# SMAL samples 3d statistics +# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore +def load_statistics(statistics_path): + with open(statistics_path) as f: + statistics = json.load(f) + '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']] + statistics['pose_mean'] = new_pose_mean + j_out = json.dumps(statistics, indent=4) #, sort_keys=True) + with open(self.statistics_path, 'w') as file: file.write(j_out)''' + new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']), + 'trans_std': np.asarray(statistics['trans_std']), + 'flength_mean': np.asarray(statistics['flength_mean']), + 'flength_std': np.asarray(statistics['flength_std']), + 'pose_mean': np.asarray(statistics['pose_mean']), + } + new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6)) + return new_statistics +STATISTICS = load_statistics(STATISTICS_PATH) + +AniPose_JOINT_NAMES_swapped = [ + 'L_F_Paw', 'L_F_Knee', 'L_F_Elbow', + 'L_B_Paw', 'L_B_Knee', 'L_B_Elbow', + 'R_F_Paw', 'R_F_Knee', 'R_F_Elbow', + 'R_B_Paw', 'R_B_Knee', 'R_B_Elbow', + 'TailBase', '_Tail_end_', 'L_EarBase', 'R_EarBase', + 'Nose', '_Chin_', '_Left_ear_tip_', '_Right_ear_tip_', + 'L_Eye', 'R_Eye', 'Withers', 'Throat'] + +KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2] + +COMPLETE_DATA_INFO = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=AniPose_JOINT_NAMES_swapped, # AniPose_JOINT_NAMES, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23], + n_joints = 35, + n_keyp = 24, # 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS + ) diff --git a/src/configs/barc_cfg_defaults.py b/src/configs/barc_cfg_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..d8be3802c91fe44e974ae453ca574f8cdce5fd80 --- /dev/null +++ b/src/configs/barc_cfg_defaults.py @@ -0,0 +1,111 @@ + +from yacs.config import CfgNode as CN +import argparse +import yaml +import os + +abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',)) + +_C = CN() +_C.barc_dir = abs_barc_dir +_C.device = 'cuda' + +## path settings +_C.paths = CN() +_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/' +_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/' +_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +## parameter settings +_C.params = CN() +_C.params.ARCH = 'hg8' +_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae' +_C.params.NF_VERSION = 3 +_C.params.N_JOINTS = 35 +_C.params.N_KEYP = 24 #20 +_C.params.N_SEG = 2 +_C.params.N_PARTSEG = 15 +_C.params.UPSAMPLE_SEG = True +_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt) +_C.params.N_BETAS = 30 # 10 +_C.params.N_BETAS_LIMBS = 7 +_C.params.N_BONES = 24 +_C.params.N_BREEDS = 121 # 120 breeds plus background +_C.params.IMG_SIZE = 256 +_C.params.SILH_NO_TAIL = False +_C.params.KP_THRESHOLD = None +_C.params.ADD_Z_TO_3D_INPUT = False +_C.params.N_SEGBPS = 64*2 +_C.params.ADD_SEGBPS_TO_3D_INPUT = True +_C.params.FIX_FLENGTH = False +_C.params.RENDER_ALL = True +_C.params.VLIN = 2 +_C.params.STRUCTURE_Z_TO_B = 'lin' +_C.params.N_Z_FREE = 64 +_C.params.PCK_THRESH = 0.15 + +## optimization settings +_C.optim = CN() +_C.optim.LR = 5e-4 +_C.optim.SCHEDULE = [150, 175, 200] +_C.optim.GAMMA = 0.1 +_C.optim.MOMENTUM = 0 +_C.optim.WEIGHT_DECAY = 0 +_C.optim.EPOCHS = 220 +_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler) +_C.optim.TRAIN_PARTS = 'all_without_shapedirs' + +## dataset settings +_C.data = CN() +_C.data.DATASET = 'stanext24' +_C.data.V12 = True +_C.data.SHORTEN_VAL_DATASET_TO = None +_C.data.VAL_OPT = 'val' +_C.data.VAL_METRICS = 'no_loss' + +# --------------------------------------- +def update_dependent_vars(cfg): + cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG + if cfg.params.VLIN == 0: + cfg.params.NUM_STAGE_COMB = 2 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 1 + cfg.params.TRANS_SEP = False + elif cfg.params.VLIN == 1: + cfg.params.NUM_STAGE_COMB = 3 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 2 + cfg.params.TRANS_SEP = False + elif cfg.params.VLIN == 2: + cfg.params.NUM_STAGE_COMB = 3 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 2 + cfg.params.TRANS_SEP = True + else: + raise NotImplementedError + if cfg.params.STRUCTURE_Z_TO_B == '1dconv': + cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS + else: + cfg.params.N_Z = cfg.params.N_Z_FREE + return + + +update_dependent_vars(_C) +global _cfg_global +_cfg_global = _C.clone() + + +def get_cfg_defaults(): + # Get a yacs CfgNode object with default values as defined within this file. + # Return a clone so that the defaults will not be altered. + return _C.clone() + +def update_cfg_global_with_yaml(cfg_yaml_file): + _cfg_global.merge_from_file(cfg_yaml_file) + update_dependent_vars(_cfg_global) + return + +def get_cfg_global_updated(): + # return _cfg_global.clone() + return _cfg_global + diff --git a/src/configs/barc_loss_weights.json b/src/configs/barc_loss_weights.json new file mode 100644 index 0000000000000000000000000000000000000000..8ddc9e1e6c882431f23b6881c124bf424ae7c3e9 --- /dev/null +++ b/src/configs/barc_loss_weights.json @@ -0,0 +1,30 @@ + + + +{ + "breed_options": [ + "4" + ], + "breed": 5.0, + "class": 1.0, + "models3d": 1.0, + "keyp": 0.2, + "silh": 50.0, + "shape_options": [ + "smal", + "limbs7" + ], + "shape": [ + 1e-05, + 1 + ], + "poseprior_options": [ + "normalizing_flow_tiger_logprob" + ], + "poseprior": 0.1, + "poselegssidemovement": 10.0, + "flength": 1.0, + "partseg": 0, + "shapedirs": 0, + "pose_0": 0.0 +} \ No newline at end of file diff --git a/src/configs/data_info.py b/src/configs/data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..cf28608e6361b089d49520e6bf03d142e1aab799 --- /dev/null +++ b/src/configs/data_info.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from typing import List +import json +import numpy as np +import os +import sys + +STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics') +STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json') + +@dataclass +class DataInfo: + rgb_mean: List[float] + rgb_stddev: List[float] + joint_names: List[str] + hflip_indices: List[int] + n_joints: int + n_keyp: int + n_bones: int + n_betas: int + image_size: int + trans_mean: np.ndarray + trans_std: np.ndarray + flength_mean: np.ndarray + flength_std: np.ndarray + pose_rot6d_mean: np.ndarray + keypoint_weights: List[float] + +# SMAL samples 3d statistics +# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore +def load_statistics(statistics_path): + with open(statistics_path) as f: + statistics = json.load(f) + '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']] + statistics['pose_mean'] = new_pose_mean + j_out = json.dumps(statistics, indent=4) #, sort_keys=True) + with open(self.statistics_path, 'w') as file: file.write(j_out)''' + new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']), + 'trans_std': np.asarray(statistics['trans_std']), + 'flength_mean': np.asarray(statistics['flength_mean']), + 'flength_std': np.asarray(statistics['flength_std']), + 'pose_mean': np.asarray(statistics['pose_mean']), + } + new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6)) + return new_statistics +STATISTICS = load_statistics(STATISTICS_PATH) + + +############################################################################ +# for StanExt (original number of keypoints, 20 not 24) + +# for keypoint names see: https://github.com/benjiebob/StanfordExtra/blob/master/keypoint_definitions.csv +StanExt_JOINT_NAMES = [ + 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top', + 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top', + 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top', + 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top', + 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear', + 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip'] + +KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2] + +COMPLETE_DATA_INFO = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=StanExt_JOINT_NAMES, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18], + n_joints = 35, + n_keyp = 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS + ) + + +############################################################################ +# new for StanExt24 + +# ..., 'Left_eye', 'Right_eye', 'Withers', 'Throat'] # the last 4 keypoints are in the animal_pose dataset, but not StanfordExtra +StanExt_JOINT_NAMES_24 = [ + 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top', + 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top', + 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top', + 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top', + 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear', + 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip', + 'Left_eye', 'Right_eye', 'Withers', 'Throat'] + +KEYPOINT_WEIGHTS_24 = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2, 1, 1, 0, 0] + +COMPLETE_DATA_INFO_24 = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=StanExt_JOINT_NAMES_24, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23], + n_joints = 35, + n_keyp = 24, # 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS_24 + ) + + diff --git a/src/configs/dataset_path_configs.py b/src/configs/dataset_path_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c46f58a298dba5f037d0f039c91853c30ade64 --- /dev/null +++ b/src/configs/dataset_path_configs.py @@ -0,0 +1,21 @@ + + +import numpy as np +import os +import sys + +abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',)) + +# stanext dataset +# (1) path to stanext dataset +STAN_V12_ROOT_DIR = abs_barc_dir + '/datasets/StanfordExtra_V12/' +IMG_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'StanExtV12_Images') +JSON_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', "StanfordExtra_v12.json") +STAN_V12_TRAIN_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'train_stanford_StanfordExtra_v12.npy') +STAN_V12_VAL_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'val_stanford_StanfordExtra_v12.npy') +STAN_V12_TEST_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'test_stanford_StanfordExtra_v12.npy') +# (2) path to related data such as breed indices and prepared predictions for withers, throat and eye keypoints +STANEXT_RELATED_DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'stanext_related_data') + +# image crop dataset (for demo, visualization) +TEST_IMAGE_CROP_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets', 'test_image_crops') diff --git a/src/configs/dog_breeds/dog_breed_class.py b/src/configs/dog_breeds/dog_breed_class.py new file mode 100644 index 0000000000000000000000000000000000000000..282052164ec6ecb742d91d07ea564cc82cf70ab8 --- /dev/null +++ b/src/configs/dog_breeds/dog_breed_class.py @@ -0,0 +1,170 @@ + +import os +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import pandas as pd +import difflib +import json +import pickle as pkl +import csv +import numpy as np + + +# ----------------------------------------------------------------------------------------------------------------- # +class DogBreed(object): + def __init__(self, abbrev, name_akc=None, name_stanext=None, name_xlsx=None, path_akc=None, path_stanext=None, ind_in_xlsx=None, ind_in_xlsx_matrix=None, ind_in_stanext=None, clade=None): + self._abbrev = abbrev + self._name_xlsx = name_xlsx + self._name_akc = name_akc + self._name_stanext = name_stanext + self._path_stanext = path_stanext + self._additional_names = set() + if self._name_akc is not None: + self.add_akc_info(name_akc, path_akc) + if self._name_stanext is not None: + self.add_stanext_info(name_stanext, path_stanext, ind_in_stanext) + if self._name_xlsx is not None: + self.add_xlsx_info(name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade) + def add_xlsx_info(self, name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade): + assert (name_xlsx is not None) and (ind_in_xlsx is not None) and (ind_in_xlsx_matrix is not None) and (clade is not None) + self._name_xlsx = name_xlsx + self._ind_in_xlsx = ind_in_xlsx + self._ind_in_xlsx_matrix = ind_in_xlsx_matrix + self._clade = clade + def add_stanext_info(self, name_stanext, path_stanext, ind_in_stanext): + assert (name_stanext is not None) and (path_stanext is not None) and (ind_in_stanext is not None) + self._name_stanext = name_stanext + self._path_stanext = path_stanext + self._ind_in_stanext = ind_in_stanext + def add_akc_info(self, name_akc, path_akc): + assert (name_akc is not None) and (path_akc is not None) + self._name_akc = name_akc + self._path_akc = path_akc + def add_additional_names(self, name_list): + self._additional_names = self._additional_names.union(set(name_list)) + def add_text_info(self, text_height, text_weight, text_life_exp): + self._text_height = text_height + self._text_weight = text_weight + self._text_life_exp = text_life_exp + def get_datasets(self): + # all datasets in which this breed is found + datasets = set() + if self._name_akc is not None: + datasets.add('akc') + if self._name_stanext is not None: + datasets.add('stanext') + if self._name_xlsx is not None: + datasets.add('xlsx') + return datasets + def get_names(self): + # set of names for this breed + names = {self._abbrev, self._name_akc, self._name_stanext, self._name_xlsx, self._path_stanext}.union(self._additional_names) + names.discard(None) + return names + def get_names_as_pointing_dict(self): + # each name points to the abbreviation + names = self.get_names() + my_dict = {} + for name in names: + my_dict[name] = self._abbrev + return my_dict + def print_overview(self): + # print important information to get an overview of the class instance + if self._name_akc is not None: + name = self._name_akc + elif self._name_xlsx is not None: + name = self._name_xlsx + else: + name = self._name_stanext + print('----------------------------------------------------') + print('----- dog breed: ' + name ) + print('----------------------------------------------------') + print('[names]') + print(self.get_names()) + print('[datasets]') + print(self.get_datasets()) + # see https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class + print('[instance attributes]') + for attribute, value in self.__dict__.items(): + print(attribute, '=', value) + def use_dict_to_save_class_instance(self): + my_dict = {} + for attribute, value in self.__dict__.items(): + my_dict[attribute] = value + return my_dict + def use_dict_to_load_class_instance(self, my_dict): + for attribute, value in my_dict.items(): + setattr(self, attribute, value) + return + +# ----------------------------------------------------------------------------------------------------------------- # +def get_name_list_from_summary(summary): + name_from_abbrev_dict = {} + for breed in summary.values(): + abbrev = breed._abbrev + all_names = breed.get_names() + name_from_abbrev_dict[abbrev] = list(all_names) + return name_from_abbrev_dict +def get_partial_summary(summary, part): + assert part in ['xlsx', 'akc', 'stanext'] + partial_summary = {} + for key, value in summary.items(): + if (part == 'xlsx' and value._name_xlsx is not None) \ + or (part == 'akc' and value._name_akc is not None) \ + or (part == 'stanext' and value._name_stanext is not None): + partial_summary[key] = value + return partial_summary +def get_akc_but_not_stanext_partial_summary(summary): + partial_summary = {} + for key, value in summary.items(): + if value._name_akc is not None: + if value._name_stanext is None: + partial_summary[key] = value + return partial_summary + +# ----------------------------------------------------------------------------------------------------------------- # +def main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1): + with open(path_complete_abbrev_dict_v1, 'rb') as file: + complete_abbrev_dict = pkl.load(file) + with open(path_complete_summary_breeds_v1, 'rb') as file: + complete_summary_breeds_attributes_only = pkl.load(file) + + complete_summary_breeds = {} + for key, value in complete_summary_breeds_attributes_only.items(): + attributes_only = complete_summary_breeds_attributes_only[key] + complete_summary_breeds[key] = DogBreed(abbrev=attributes_only['_abbrev']) + complete_summary_breeds[key].use_dict_to_load_class_instance(attributes_only) + return complete_abbrev_dict, complete_summary_breeds + + +# ----------------------------------------------------------------------------------------------------------------- # +def load_similarity_matrix_raw(xlsx_path): + # --- LOAD EXCEL FILE FROM DOG BREED PAPER + xlsx = pd.read_excel(xlsx_path) + # create an array + abbrev_indices = {} + matrix_raw = np.zeros((168, 168)) + for ind in range(1, 169): + abbrev = xlsx[xlsx.columns[2]][ind] + abbrev_indices[abbrev] = ind-1 + for ind_col in range(0, 168): + for ind_row in range(0, 168): + matrix_raw[ind_col, ind_row] = float(xlsx[xlsx.columns[3+ind_col]][1+ind_row]) + return matrix_raw, abbrev_indices + + + +# ----------------------------------------------------------------------------------------------------------------- # +# ----------------------------------------------------------------------------------------------------------------- # +# load the (in advance created) final dict of dog breed classes +ROOT_PATH_BREED_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', 'data', 'breed_data') +path_complete_abbrev_dict_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_abbrev_dict_v2.pkl') +path_complete_summary_breeds_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_summary_breeds_v2.pkl') +COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS = main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1) +# load similarity matrix, data from: +# Parker H. G., Dreger D. L., Rimbault M., Davis B. W., Mullen A. B., Carpintero-Ramirez G., and Ostrander E. A. +# Genomic analyses reveal the influence of geographic origin, migration, and hybridization on modern dog breed +# development. Cell Reports, 4(19):697–708, 2017. +xlsx_path = os.path.join(ROOT_PATH_BREED_DATA, 'NIHMS866262-supplement-2.xlsx') +SIM_MATRIX_RAW, SIM_ABBREV_INDICES = load_similarity_matrix_raw(xlsx_path) + diff --git a/src/lifting_to_3d/inn_model_for_shape.py b/src/lifting_to_3d/inn_model_for_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab7c1f18ca603a20406092bdd7163e370d17023 --- /dev/null +++ b/src/lifting_to_3d/inn_model_for_shape.py @@ -0,0 +1,61 @@ + + +from torch import distributions +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributions import Normal +import numpy as np +import cv2 +import trimesh +from tqdm import tqdm +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import FrEIA.framework as Ff +import FrEIA.modules as Fm + + +class INNForShape(nn.Module): + def __init__(self, n_betas, n_betas_limbs, k_tot=2, betas_scale=1.0, betas_limbs_scale=0.1): + super(INNForShape, self).__init__() + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_dim = n_betas + n_betas_limbs + self.betas_scale = betas_scale + self.betas_limbs_scale = betas_limbs_scale + self.k_tot = 2 + self.model_inn = self.build_inn_network(self.n_dim, k_tot=self.k_tot) + + def subnet_fc(self, c_in, c_out): + subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), + nn.Linear(64, 64), nn.ReLU(), + nn.Linear(64, c_out)) + return subnet + + def build_inn_network(self, n_input, k_tot=12, verbose=False): + coupling_block = Fm.RNVPCouplingBlock + nodes = [Ff.InputNode(n_input, name='input')] + for k in range(k_tot): + nodes.append(Ff.Node(nodes[-1], + coupling_block, + {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, + name=F'coupling_{k}')) + nodes.append(Ff.Node(nodes[-1], + Fm.PermuteRandom, + {'seed':k}, + name=F'permute_{k}')) + nodes.append(Ff.OutputNode(nodes[-1], name='output')) + model = Ff.ReversibleGraphNet(nodes, verbose=verbose) + return model + + def forward(self, latent_rep): + shape, _ = self.model_inn(latent_rep, rev=False, jac=False) + betas = shape[:, :self.n_betas]*self.betas_scale + betas_limbs = shape[:, self.n_betas:]*self.betas_limbs_scale + return betas, betas_limbs + + def reverse(self, betas, betas_limbs): + shape = torch.cat((betas/self.betas_scale, betas_limbs/self.betas_limbs_scale), dim=1) + latent_rep, _ = self.model_inn(shape, rev=True, jac=False) + return latent_rep \ No newline at end of file diff --git a/src/lifting_to_3d/linear_model.py b/src/lifting_to_3d/linear_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c11266acefcb6bbecd8a748a44cb4915ef4da4b9 --- /dev/null +++ b/src/lifting_to_3d/linear_model.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# some code from https://raw.githubusercontent.com/weigq/3d_pose_baseline_pytorch/master/src/model.py + + +from __future__ import absolute_import +from __future__ import print_function +import torch +import torch.nn as nn + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +# from priors.vae_pose_model.vae_model import VAEmodel +from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior + + +def weight_init_dangerous(m): + # this is dangerous as it may overwrite the normalizing flow weights + if isinstance(m, nn.Linear): + nn.init.kaiming_normal(m.weight) + + +class Linear(nn.Module): + def __init__(self, linear_size, p_dropout=0.5): + super(Linear, self).__init__() + self.l_size = linear_size + + self.relu = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(p_dropout) + + self.w1 = nn.Linear(self.l_size, self.l_size) + self.batch_norm1 = nn.BatchNorm1d(self.l_size) + + self.w2 = nn.Linear(self.l_size, self.l_size) + self.batch_norm2 = nn.BatchNorm1d(self.l_size) + + def forward(self, x): + y = self.w1(x) + y = self.batch_norm1(y) + y = self.relu(y) + y = self.dropout(y) + y = self.w2(y) + y = self.batch_norm2(y) + y = self.relu(y) + y = self.dropout(y) + out = x + y + return out + + +class LinearModel(nn.Module): + def __init__(self, + linear_size=1024, + num_stage=2, + p_dropout=0.5, + input_size=16*2, + output_size=16*3): + super(LinearModel, self).__init__() + self.linear_size = linear_size + self.p_dropout = p_dropout + self.num_stage = num_stage + # input + self.input_size = input_size # 2d joints: 16 * 2 + # output + self.output_size = output_size # 3d joints: 16 * 3 + # process input to linear size + self.w1 = nn.Linear(self.input_size, self.linear_size) + self.batch_norm1 = nn.BatchNorm1d(self.linear_size) + self.linear_stages = [] + for l in range(num_stage): + self.linear_stages.append(Linear(self.linear_size, self.p_dropout)) + self.linear_stages = nn.ModuleList(self.linear_stages) + # post-processing + self.w2 = nn.Linear(self.linear_size, self.output_size) + # helpers (relu and dropout) + self.relu = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(self.p_dropout) + + def forward(self, x): + # pre-processing + y = self.w1(x) + y = self.batch_norm1(y) + y = self.relu(y) + y = self.dropout(y) + # linear layers + for i in range(self.num_stage): + y = self.linear_stages[i](y) + # post-processing + y = self.w2(y) + return y + + +class LinearModelComplete(nn.Module): + def __init__(self, + linear_size=1024, + num_stage_comb=2, + num_stage_heads=1, + num_stage_heads_pose=1, + trans_sep=False, + p_dropout=0.5, + input_size=16*2, + intermediate_size=1024, + output_info=None, + n_joints=25, + n_z=512, + add_z_to_3d_input=False, + n_segbps=64*2, + add_segbps_to_3d_input=False, + structure_pose_net='default', + fix_vae_weights=True, + nf_version=None): # 0): n_silh_enc + super(LinearModelComplete, self).__init__() + if add_z_to_3d_input: + self.n_z_to_add = n_z # 512 + else: + self.n_z_to_add = 0 + if add_segbps_to_3d_input: + self.n_segbps_to_add = n_segbps # 64 + else: + self.n_segbps_to_add = 0 + self.input_size = input_size + self.linear_size = linear_size + self.p_dropout = p_dropout + self.num_stage_comb = num_stage_comb + self.num_stage_heads = num_stage_heads + self.num_stage_heads_pose = num_stage_heads_pose + self.trans_sep = trans_sep + self.input_size = input_size + self.intermediate_size = intermediate_size + self.structure_pose_net = structure_pose_net + self.fix_vae_weights = fix_vae_weights # only relevant if structure_pose_net='vae' + self.nf_version = nf_version + if output_info is None: + pose = {'name': 'pose', 'n': n_joints*6, 'out_shape':[n_joints, 6]} + cam = {'name': 'flength', 'n': 1} + if self.trans_sep: + translation_xy = {'name': 'trans_xy', 'n': 2} + translation_z = {'name': 'trans_z', 'n': 1} + self.output_info = [pose, translation_xy, translation_z, cam] + else: + translation = {'name': 'trans', 'n': 3} + self.output_info = [pose, translation, cam] + if self.structure_pose_net == 'vae' or self.structure_pose_net == 'normflow': + global_pose = {'name': 'global_pose', 'n': 1*6, 'out_shape':[1, 6]} + self.output_info.append(global_pose) + else: + self.output_info = output_info + self.linear_combined = LinearModel(linear_size=self.linear_size, + num_stage=self.num_stage_comb, + p_dropout=p_dropout, + input_size=self.input_size + self.n_segbps_to_add + self.n_z_to_add, ###### + output_size=self.intermediate_size) + self.output_info_linear_models = [] + for ind_el, element in enumerate(self.output_info): + if element['name'] == 'pose': + num_stage = self.num_stage_heads_pose + if self.structure_pose_net == 'default': + output_size_pose_lin = element['n'] + elif self.structure_pose_net == 'vae': + # load vae decoder + self.pose_vae_model = VAEmodel() + self.pose_vae_model.initialize_with_pretrained_weights() + # define the input size of the vae decoder + output_size_pose_lin = self.pose_vae_model.latent_size + elif self.structure_pose_net == 'normflow': + # the following will automatically be initialized + self.pose_normflow_model = NormalizingFlowPrior(nf_version=self.nf_version) + output_size_pose_lin = element['n'] - 6 # no global rotation + else: + raise NotImplementedError + self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, + num_stage=num_stage, + p_dropout=p_dropout, + input_size=self.intermediate_size, + output_size=output_size_pose_lin)) + else: + if element['name'] == 'global_pose': + num_stage = self.num_stage_heads_pose + else: + num_stage = self.num_stage_heads + self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, + num_stage=num_stage, + p_dropout=p_dropout, + input_size=self.intermediate_size, + output_size=element['n'])) + element['linear_model_index'] = ind_el + self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) + + def forward(self, x): + device = x.device + # combined stage + if x.shape[1] == self.input_size + self.n_segbps_to_add + self.n_z_to_add: + y = self.linear_combined(x) + elif x.shape[1] == self.input_size + self.n_segbps_to_add: + x_mod = torch.cat((x, torch.normal(0, 1, size=(x.shape[0], self.n_z_to_add)).to(device)), dim=1) + y = self.linear_combined(x_mod) + else: + print(x.shape) + print(self.input_size) + print(self.n_segbps_to_add) + print(self.n_z_to_add) + raise ValueError + # heads + results = {} + results_trans = {} + for element in self.output_info: + linear_model = self.output_info_linear_models[element['linear_model_index']] + if element['name'] == 'pose': + if self.structure_pose_net == 'default': + results['pose'] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + normflow_z = None + elif self.structure_pose_net == 'vae': + res_lin = linear_model(y) + if self.fix_vae_weights: + self.pose_vae_model.requires_grad_(False) # let gradients flow through but don't update the parameters + res_vae = self.pose_vae_model.inference(feat=res_lin) + self.pose_vae_model.requires_grad_(True) + else: + res_vae = self.pose_vae_model.inference(feat=res_lin) + res_pose_not_glob = res_vae.reshape((-1, element['out_shape'][0], element['out_shape'][1])) + normflow_z = None + elif self.structure_pose_net == 'normflow': + normflow_z = linear_model(y)*0.1 + self.pose_normflow_model.requires_grad_(False) # let gradients flow though but don't update the parameters + res_pose_not_glob = self.pose_normflow_model.run_backwards(z=normflow_z).reshape((-1, element['out_shape'][0]-1, element['out_shape'][1])) + else: + raise NotImplementedError + elif element['name'] == 'global_pose': + res_pose_glob = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + elif element['name'] == 'trans_xy' or element['name'] == 'trans_z': + results_trans[element['name']] = linear_model(y) + else: + results[element['name']] = linear_model(y) + if self.trans_sep: + results['trans'] = torch.cat((results_trans['trans_xy'], results_trans['trans_z']), dim=1) + # prepare pose including global rotation + if self.structure_pose_net == 'vae': + # results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob), dim=1) + results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, 1:, :]), dim=1) + elif self.structure_pose_net == 'normflow': + results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, :, :]), dim=1) + # return a dictionary which contains all results + results['normflow_z'] = normflow_z + return results # this is a dictionary + + + + + +# ------------------------------------------ +# for pretraining of the 3d model only: +# (see combined_model/model_shape_v2.py) + +class Wrapper_LinearModelComplete(nn.Module): + def __init__(self, + linear_size=1024, + num_stage_comb=2, + num_stage_heads=1, + num_stage_heads_pose=1, + trans_sep=False, + p_dropout=0.5, + input_size=16*2, + intermediate_size=1024, + output_info=None, + n_joints=25, + n_z=512, + add_z_to_3d_input=False, + n_segbps=64*2, + add_segbps_to_3d_input=False, + structure_pose_net='default', + fix_vae_weights=True, + nf_version=None): + self.add_segbps_to_3d_input = add_segbps_to_3d_input + super(Wrapper_LinearModelComplete, self).__init__() + self.model_3d = LinearModelComplete(linear_size=linear_size, + num_stage_comb=num_stage_comb, + num_stage_heads=num_stage_heads, + num_stage_heads_pose=num_stage_heads_pose, + trans_sep=trans_sep, + p_dropout=p_dropout, # 0.5, + input_size=input_size, + intermediate_size=intermediate_size, + output_info=output_info, + n_joints=n_joints, + n_z=n_z, + add_z_to_3d_input=add_z_to_3d_input, + n_segbps=n_segbps, + add_segbps_to_3d_input=add_segbps_to_3d_input, + structure_pose_net=structure_pose_net, + fix_vae_weights=fix_vae_weights, + nf_version=nf_version) + def forward(self, input_vec): + # input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) + output = self.model_3d(input_vec) + return output \ No newline at end of file diff --git a/src/lifting_to_3d/utils/geometry_utils.py b/src/lifting_to_3d/utils/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a83466b7212cab58d6f4c0f88ae98a206583a3f8 --- /dev/null +++ b/src/lifting_to_3d/utils/geometry_utils.py @@ -0,0 +1,236 @@ + +import torch +from torch.nn import functional as F +import numpy as np +from torch import nn + + +def geodesic_loss(R, Rgt): + # see: Silvia tiger pose model 3d code + num_joints = R.shape[1] + RT = R.permute(0,1,3,2) + A = torch.matmul(RT.view(-1,3,3),Rgt.view(-1,3,3)) + # torch.trace works only for 2D tensors + n = A.shape[0] + po_loss = 0 + eps = 1e-7 + T = torch.sum(A[:,torch.eye(3).bool()],1) + theta = torch.clamp(0.5*(T-1), -1+eps, 1-eps) + angles = torch.acos(theta) + loss = torch.sum(angles)/(n*num_joints) + return loss + +class geodesic_loss_R(nn.Module): + def __init__(self,reduction='mean'): + super(geodesic_loss_R, self).__init__() + self.reduction = reduction + self.eps = 1e-6 + + # batch geodesic loss for rotation matrices + def bgdR(self,bRgts,bRps): + #return((bRgts - bRps)**2.).mean() + return geodesic_loss(bRgts, bRps) + + def forward(self, ypred, ytrue): + theta = geodesic_loss(ypred,ytrue) + if self.reduction == 'mean': + return torch.mean(theta) + else: + return theta + +def batch_rodrigues_numpy(theta): + """ Code adapted from spin + Convert axis-angle representation to rotation matrix. + Remark: + this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta) + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = np.linalg.norm(theta + 1e-8, ord = 2, axis = 1) + # angle = np.unsqueeze(l1norm, -1) + angle = l1norm.reshape((-1, 1)) + # normalized = np.div(theta, angle) + normalized = theta / angle + angle = angle * 0.5 + v_cos = np.cos(angle) + v_sin = np.sin(angle) + # quat = np.cat([v_cos, v_sin * normalized], dim = 1) + quat = np.concatenate([v_cos, v_sin * normalized], axis = 1) + return quat_to_rotmat_numpy(quat) + +def quat_to_rotmat_numpy(quat): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + # norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + norm_quat = norm_quat/np.linalg.norm(norm_quat, ord=2, axis=1, keepdims=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + B = quat.shape[0] + # w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + w2, x2, y2, z2 = w**2, x**2, y**2, z**2 + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + rotMat = np.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], axis=1).reshape(B, 3, 3) + return rotMat + + +def batch_rodrigues(theta): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert axis-angle representation to rotation matrix. + Remark: + this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta) + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def batch_rot2aa(Rs, epsilon=1e-7): + """ Code from: https://github.com/vchoutas/expose/blob/dffc38d62ad3817481d15fe509a93c2bb606cb8b/expose/utils/rotation_utils.py#L55 + Rs is B x 3 x 3 + void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis, + double& out_theta) + { + double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1); + c = cMathUtil::Clamp(c, -1.0, 1.0); + out_theta = std::acos(c); + if (std::abs(out_theta) < 0.00001) + { + out_axis = tVector(0, 0, 1, 0); + } + else + { + double m21 = mat(2, 1) - mat(1, 2); + double m02 = mat(0, 2) - mat(2, 0); + double m10 = mat(1, 0) - mat(0, 1); + double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10); + out_axis[0] = m21 / denom; + out_axis[1] = m02 / denom; + out_axis[2] = m10 / denom; + out_axis[3] = 0; + } + } + """ + cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1) + cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) + theta = torch.acos(cos) + m21 = Rs[:, 2, 1] - Rs[:, 1, 2] + m02 = Rs[:, 0, 2] - Rs[:, 2, 0] + m10 = Rs[:, 1, 0] - Rs[:, 0, 1] + denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon) + axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom) + axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom) + axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom) + return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1) + +def quat_to_rotmat(quat): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def rot6d_to_rotmat(rot6d): + """ Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + rot6d = rot6d.view(-1,3,2) + a1 = rot6d[:, :, 0] + a2 = rot6d[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + rotmat = torch.stack((b1, b2, b3), dim=-1) + return rotmat + +def rotmat_to_rot6d(rotmat): + """ Convert 3x3 rotation matrix to 6D rotation representation. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,3,3) Batch of corresponding rotation matrices + Output: + (B,6) Batch of 6-D rotation representations + """ + rot6d = rotmat[:, :, :2].reshape((-1, 6)) + return rot6d + + +def main(): + # rotation matrix and 6d representation + # see "On the Continuity of Rotation Representations in Neural Networks" + from pyquaternion import Quaternion + batch_size = 5 + rotmat = np.zeros((batch_size, 3, 3)) + for ind in range(0, batch_size): + rotmat[ind, :, :] = Quaternion.random().rotation_matrix + rotmat_torch = torch.Tensor(rotmat) + rot6d = rotmat_to_rot6d(rotmat_torch) + rotmat_rec = rot6d_to_rotmat(rot6d) + print('..................... 1 ....................') + print(rotmat_torch[0, :, :]) + print(rotmat_rec[0, :, :]) + print('Conversion from rotmat to rot6d and inverse are ok!') + # rotation matrix and axis angle representation + import kornia + input = torch.rand(1, 3) + output = kornia.angle_axis_to_rotation_matrix(input) + input_rec = kornia.rotation_matrix_to_angle_axis(output) + print('..................... 2 ....................') + print(input) + print(input_rec) + print('Kornia implementation for rotation_matrix_to_angle_axis is wrong!!!!') + # For non-differential conversions use scipy: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html + from scipy.spatial.transform import Rotation as R + r = R.from_matrix(rotmat[0, :, :]) + print('..................... 3 ....................') + print(r.as_matrix()) + print(r.as_rotvec()) + print(r.as_quaternion) + # one might furthermore have a look at: + # https://github.com/silviazuffi/smalst/blob/master/utils/transformations.py + + + +if __name__ == "__main__": + main() + + diff --git a/src/metrics/metrics.py b/src/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa1ae1c00bd286f55a4ede8565dc3eb619162a9 --- /dev/null +++ b/src/metrics/metrics.py @@ -0,0 +1,74 @@ +# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py + + +import torch +import torch.nn.functional as F +import numpy as np + +IMG_RES = 256 # in WLDO it is 224 + +class Metrics(): + + @staticmethod + def PCK_thresh( + pred_keypoints, gt_keypoints, + gtseg, has_seg, + thresh, idxs, biggs=False): + + pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg] + + if idxs is None: + idxs = list(range(pred_keypoints.shape[1])) + + idxs = np.array(idxs).astype(int) + + pred_keypoints = pred_keypoints[:, idxs] + gt_keypoints = gt_keypoints[:, idxs] + + if biggs: + keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES + dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1) + else: + keypoints_gt = gt_keypoints # (0 to IMG_SIZE) + dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1) + + seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1) + + hits = (dist / torch.sqrt(seg_area)) < thresh + total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1) + pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible + + return pck + + @staticmethod + def PCK( + pred_keypoints, keypoints, + gtseg, has_seg, + thresh_range=[0.15], + idxs:list=None, + biggs=False): + """Calc PCK with same method as in eval. + idxs = optional list of subset of keypoints to index from + """ + cumulative_pck = [] + for thresh in thresh_range: + pck = Metrics.PCK_thresh( + pred_keypoints, keypoints, + gtseg, has_seg, thresh, idxs, + biggs=biggs) + cumulative_pck.append(pck) + pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0) + return pck_mean + + @staticmethod + def IOU(synth_silhouettes, gt_seg, img_border_mask, mask): + for i in range(mask.shape[0]): + synth_silhouettes[i] *= mask[i] + # Do not penalize parts of the segmentation outside the img range + gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask) + intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1) + union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1) + acc_IOU_SCORE = intersection / union + if torch.isnan(acc_IOU_SCORE).sum() > 0: + import pdb; pdb.set_trace() + return acc_IOU_SCORE \ No newline at end of file diff --git a/src/priors/normalizing_flow_prior/normalizing_flow_prior.py b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf60fe51d722c31d7b045a637e1b57d4b577091 --- /dev/null +++ b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py @@ -0,0 +1,115 @@ + +from torch import distributions +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributions import Normal +import numpy as np +import cv2 +import trimesh +from tqdm import tqdm + +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import FrEIA.framework as Ff +import FrEIA.modules as Fm +from configs.barc_cfg_defaults import get_cfg_global_updated + + +class NormalizingFlowPrior(nn.Module): + def __init__(self, nf_version=None): + super(NormalizingFlowPrior, self).__init__() + # the normalizing flow network takes as input a vector of size (35-1)*6 which is + # [all joints except root joint]*6. At the moment the rotation is represented as 6D + # representation, which is actually not ideal. Nevertheless, in practice the + # results seem to be ok. + n_dim = (35 - 1) * 6 + self.param_dict = self.get_version_param_dict(nf_version) + self.model_inn = self.build_inn_network(n_dim, k_tot=self.param_dict['k_tot']) + self.initialize_with_pretrained_weights() + + def get_version_param_dict(self, nf_version): + # we had trained several version of the normalizing flow pose prior, here we just provide + # the option that was user for the cvpr 2022 paper (nf_version=3) + if nf_version == 3: + param_dict = { + 'k_tot': 2, + 'path_pretrained': get_cfg_global_updated().paths.MODELPATH_NORMFLOW, + 'subnet_fc_type': '3_64'} + else: + print(nf_version) + raise ValueError + return param_dict + + def initialize_with_pretrained_weights(self, weight_path=None): + # The normalizing flow pose prior is pretrained separately. Afterwards all weights + # are kept fixed. Here we load those pretrained weights. + if weight_path is None: + weight_path = self.param_dict['path_pretrained'] + print(' normalizing flow pose prior: loading {}..'.format(weight_path)) + pretrained_dict = torch.load(weight_path)['model_state_dict'] + self.model_inn.load_state_dict(pretrained_dict, strict=True) + + def subnet_fc(self, c_in, c_out): + if self.param_dict['subnet_fc_type']=='3_512': + subnet = nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(), + nn.Linear(512, 512), nn.ReLU(), + nn.Linear(512, c_out)) + elif self.param_dict['subnet_fc_type']=='3_64': + subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), + nn.Linear(64, 64), nn.ReLU(), + nn.Linear(64, c_out)) + return subnet + + def build_inn_network(self, n_input, k_tot=12, verbose=False): + coupling_block = Fm.RNVPCouplingBlock + nodes = [Ff.InputNode(n_input, name='input')] + for k in range(k_tot): + nodes.append(Ff.Node(nodes[-1], + coupling_block, + {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, + name=F'coupling_{k}')) + nodes.append(Ff.Node(nodes[-1], + Fm.PermuteRandom, + {'seed':k}, + name=F'permute_{k}')) + nodes.append(Ff.OutputNode(nodes[-1], name='output')) + model = Ff.ReversibleGraphNet(nodes, verbose=verbose) + return model + + def calculate_loss_from_z(self, z, type='square'): + assert type in ['square', 'neg_log_prob'] + if type == 'square': + loss = (z**2).mean() # * 0.00001 + elif type == 'neg_log_prob': + means = torch.zeros((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) + stds = torch.ones((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) + normal_distribution = Normal(means, stds) + log_prob = normal_distribution.log_prob(z) + loss = - log_prob.mean() + return loss + + def calculate_loss(self, poses_rot6d, type='square'): + assert type in ['square', 'neg_log_prob'] + poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) + z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) + loss = self.calculate_loss_from_z(z, type=type) + return loss + + def forward(self, poses_rot6d): + # from pose to latent pose representation z + # poses_rot6d has shape (bs, 34, 6) + poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) + z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) + return z + + def run_backwards(self, z): + # from latent pose representation z to pose + poses_rot6d_noglob, _ = self.model_inn(z, rev=True, jac=False) + return poses_rot6d_noglob + + + + + \ No newline at end of file diff --git a/src/priors/shape_prior.py b/src/priors/shape_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..f62ebc5d656aa6829427746d9582700db38481cc --- /dev/null +++ b/src/priors/shape_prior.py @@ -0,0 +1,40 @@ + +# some parts of the code adapted from https://github.com/benjiebob/WLDO and https://github.com/benjiebob/SMALify + +import numpy as np +import torch +import pickle as pkl + + + +class ShapePrior(torch.nn.Module): + def __init__(self, prior_path): + super(ShapePrior, self).__init__() + try: + with open(prior_path, 'r') as f: + res = pkl.load(f) + except (UnicodeDecodeError, TypeError) as e: + with open(prior_path, 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + res = u.load() + betas_mean = res['dog_cluster_mean'] + betas_cov = res['dog_cluster_cov'] + single_gaussian_inv_covs = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0])) + single_gaussian_precs = torch.tensor(np.linalg.cholesky(single_gaussian_inv_covs)).float() + single_gaussian_means = torch.tensor(betas_mean).float() + self.register_buffer('single_gaussian_precs', single_gaussian_precs) # (20, 20) + self.register_buffer('single_gaussian_means', single_gaussian_means) # (20) + use_ind_tch = torch.from_numpy(np.ones(single_gaussian_means.shape[0], dtype=bool)).float() # .to(device) + self.register_buffer('use_ind_tch', use_ind_tch) + + def forward(self, betas_smal_orig, use_singe_gaussian=False): + n_betas_smal = betas_smal_orig.shape[1] + device = betas_smal_orig.device + use_ind_tch_corrected = self.use_ind_tch * torch.cat((torch.ones_like(self.use_ind_tch[:n_betas_smal]), torch.zeros_like(self.use_ind_tch[n_betas_smal:]))) + samples = torch.cat((betas_smal_orig, torch.zeros((betas_smal_orig.shape[0], self.single_gaussian_means.shape[0]-n_betas_smal)).float().to(device)), dim=1) + mean_sub = samples - self.single_gaussian_means.unsqueeze(0) + single_gaussian_precs_corr = self.single_gaussian_precs * use_ind_tch_corrected[:, None] * use_ind_tch_corrected[None, :] + res = torch.tensordot(mean_sub, single_gaussian_precs_corr, dims = ([1], [0])) + res_final_mean_2 = torch.mean(res ** 2) + return res_final_mean_2 diff --git a/src/smal_pytorch/renderer/differentiable_renderer.py b/src/smal_pytorch/renderer/differentiable_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d76f1f34a16ee6b559e18d95ecaca4fa267b31 --- /dev/null +++ b/src/smal_pytorch/renderer/differentiable_renderer.py @@ -0,0 +1,280 @@ + +# part of the code from +# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py + +import torch +import torch.nn.functional as F +from scipy.io import loadmat +import numpy as np +# import config + +import pytorch3d +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + PerspectiveCameras, look_at_view_transform, look_at_rotation, + RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, + PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures, + DirectionalLights +) +from pytorch3d.renderer import TexturesVertex, SoftPhongShader +from pytorch3d.io import load_objs_as_meshes + +MESH_COLOR_0 = [0, 172, 223] +MESH_COLOR_1 = [172, 223, 0] + + +''' +Explanation of the shift between projection results from opendr and pytorch3d: + (0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr) + imagine you have an image of size 4: + middle of the first pixel is 0 + middle of the last pixel is 3 + => middle of the imgae would be 1.5 and not 2! + so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5) +To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps: + 1.) build camera matrix + K = np.array([[flength, 0, c_x], + [0, flength, c_y], + [0, 0, 1]], np.float) + 2.) we don't need to add extrinsics, as the mesh comes with translation (which is + added within smal_pytorch). all 3d points are already in the camera coordinate system. + -> projection reduces to p2d_proj = K*p3d + 3.) convert to pytorch3d conventions (0 in the middle of the first pixel) + p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.) +renderer.py - project_points_p3d: shows an example of what is described above, but + same focal length for the whole batch + +''' + +class SilhRenderer(torch.nn.Module): + def __init__(self, image_size, adapt_R_wldo=False): + super(SilhRenderer, self).__init__() + # see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315 + # adapt_R=True is True for all my experiments + # image_size: one number, integer + # ----- + # set mesh color + self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0)) + self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1)) + # prepare extrinsics, which in our case don't change + R = torch.Tensor(np.eye(3)).float()[None, :, :] + T = torch.Tensor(np.zeros((1, 3))).float() + if adapt_R_wldo: + R[0, 0, 0] = -1 + else: # used for all my own experiments + R[0, 0, 0] = -1 + R[0, 1, 1] = -1 + self.register_buffer('R', R) + self.register_buffer('T', T) + # prepare that part of the intrinsics which does not change either + # principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device) + # image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device) + self.img_size_scalar = image_size + self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float()) + self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float()) + # Rasterization settings for differentiable rendering, where the blur_radius + # initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable + # Renderer for Image-based 3D Reasoning', ICCV 2019 + self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4) + self.raster_settings_soft = RasterizationSettings( + image_size=image_size, # 128 + blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma, + faces_per_pixel=100) #50, + # Renderer for Image-based 3D Reasoning', body part segmentation + self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4) + self.raster_settings_soft_parts = RasterizationSettings( + image_size=image_size, # 128 + blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma, + faces_per_pixel=60) #50, + # settings for visualization renderer + self.raster_settings_vis = RasterizationSettings( + image_size=image_size, + blur_radius=0.0, + faces_per_pixel=1) + + def _get_cam(self, focal_lengths): + device = focal_lengths.device + bs = focal_lengths.shape[0] + if pytorch3d.__version__ == '0.2.5': + cameras = PerspectiveCameras(device=device, + focal_length=focal_lengths.repeat((1, 2)), + principal_point=self.principal_point.repeat((bs, 1)), + R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), + image_size=self.image_size.repeat((bs, 1))) + elif pytorch3d.__version__ == '0.6.1': + cameras = PerspectiveCameras(device=device, in_ndc=False, + focal_length=focal_lengths.repeat((1, 2)), + principal_point=self.principal_point.repeat((bs, 1)), + R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), + image_size=self.image_size.repeat((bs, 1))) + else: + print('this part depends on the version of pytorch3d, code was developed with 0.2.5') + raise ValueError + return cameras + + def _get_visualization_from_mesh(self, mesh, cameras, lights=None): + # color renderer for visualization + with torch.no_grad(): + device = mesh.device + # renderer for visualization + if lights is None: + lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) + vis_renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=self.raster_settings_vis), + shader=HardPhongShader( + device=device, + cameras=cameras, + lights=lights)) + # render image: + visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :] + return visualization + + + def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False): + tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + cameras = self._get_cam(focal_lengths) + # NEW: use the rasterizer to check vertex visibility + # see: https://github.com/facebookresearch/pytorch3d/issues/126 + # Get a rasterizer + if soft: + rasterizer = MeshRasterizer(cameras=cameras, + raster_settings=self.raster_settings_soft) + else: + rasterizer = MeshRasterizer(cameras=cameras, + raster_settings=self.raster_settings_vis) + # Get the output from rasterization + fragments = rasterizer(mesh) + # pix_to_face is of shape (N, H, W, 1) + pix_to_face = fragments.pix_to_face + # (F, 3) where F is the total number of faces across all the meshes in the batch + packed_faces = mesh.faces_packed() + # (V, 3) where V is the total number of verts across all the meshes in the batch + packed_verts = mesh.verts_packed() + vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,) + # Indices of unique visible faces + visible_faces = pix_to_face.unique() # [0] # (num_visible_faces ) + # Get Indices of unique visible verts using the vertex indices in the faces + visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3) + unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, ) + # Update visibility indicator to 1 for all visible vertices + vertex_visibility_map[unique_visible_verts_idx] = 1.0 + # since all meshes have the same amount of vertices, we can reshape the result + bs = vertices.shape[0] + vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1)) + return pix_to_face, vertex_visibility_map_resh + + + def get_torch_meshes(self, vertices, faces, color=0): + # create pytorch mesh + if color == 0: + mesh_color = self.mesh_color_0 + else: + mesh_color = self.mesh_color_1 + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + return mesh + + + def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0): + # vertices: torch.Size([bs, 3889, 3]) + # faces: torch.Size([bs, 7774, 3]), int + # focal_lengths: torch.Size([bs, 1]) + device = vertices.device + # create cameras + cameras = self._get_cam(focal_lengths) + # create pytorch mesh + if color == 0: + mesh_color = self.mesh_color_0 # blue + elif color == 1: + mesh_color = self.mesh_color_1 + elif color == 2: + MESH_COLOR_2 = [240, 250, 240] # white + mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) + elif color == 3: + # MESH_COLOR_3 = [223, 0, 172] # pink + # MESH_COLOR_3 = [245, 245, 220] # beige + MESH_COLOR_3 = [166, 173, 164] + mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device) + else: + MESH_COLOR_2 = [240, 250, 240] + mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + # render mesh (no gradients) + # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) + # lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]]) + lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]]) + visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights) + return visualization + + def project_points(self, points, focal_lengths=None, cameras=None): + # points: torch.Size([bs, n_points, 3]) + # either focal_lengths or cameras is needed: + # focal_lenghts: torch.Size([bs, 1]) + # cameras: pytorch camera, for example PerspectiveCameras() + bs = points.shape[0] + device = points.device + screen_size = self.image_size.repeat((bs, 1)) + if cameras is None: + cameras = self._get_cam(focal_lengths) + if pytorch3d.__version__ == '0.2.5': + proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission) + elif pytorch3d.__version__ == '0.6.1': + proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]] + else: + print('this part depends on the version of pytorch3d, code was developed with 0.2.5') + raise ValueError + # flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth + proj_points = torch.flip(proj_points_orig, [2]) + # --- project points 'manually' + # j_proj = project_points_p3d(image_size, focal_length, points, device) + return proj_points + + def forward(self, vertices, points, faces, focal_lengths, color=None): + # vertices: torch.Size([bs, 3889, 3]) + # points: torch.Size([bs, n_points, 3]) (or None) + # faces: torch.Size([bs, 7774, 3]), int + # focal_lengths: torch.Size([bs, 1]) + # color: if None we don't render a visualization, else it should + # either be 0 or 1 + # ---> important: results are around 0.5 pixels off compared to chumpy! + # have a look at renderer.py for an explanation + # create cameras + cameras = self._get_cam(focal_lengths) + # create pytorch mesh + if color is None or color == 0: + mesh_color = self.mesh_color_0 + else: + mesh_color = self.mesh_color_1 + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + # silhouette renderer + renderer_silh = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=self.raster_settings_soft), + shader=SoftSilhouetteShader(blend_params=self.blend_params)) + # project silhouette + silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1) + # project points + if points is None: + proj_points = None + else: + proj_points = self.project_points(points=points, cameras=cameras) + if color is not None: + # color renderer for visualization (no gradients) + visualization = self._get_visualization_from_mesh(mesh, cameras) + return silh_images, proj_points, visualization + else: + return silh_images, proj_points + + + + diff --git a/src/smal_pytorch/smal_model/batch_lbs.py b/src/smal_pytorch/smal_model/batch_lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..98e9d321cf721ac3a47504bd49843b9979a22e71 --- /dev/null +++ b/src/smal_pytorch/smal_model/batch_lbs.py @@ -0,0 +1,295 @@ +''' +Adjusted version of other PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import numpy as np + + +def batch_skew(vec, batch_size=None): + """ + vec is N x 3, batch_size is int + + returns N x 3 x 3. Skew_sym version of each matrix. + """ + device = vec.device + if batch_size is None: + batch_size = vec.shape.as_list()[0] + col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) + indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) + updates = torch.reshape( + torch.stack( + [ + -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], + vec[:, 0] + ], + dim=1), [-1]) + out_shape = [batch_size * 9] + res = torch.Tensor(np.zeros(out_shape[0])).to(device=device) + res[np.array(indices.flatten())] = updates + res = torch.reshape(res, [batch_size, 3, 3]) + + return res + + + +def batch_rodrigues(theta): + """ + Theta is Nx3 + """ + device = theta.device + batch_size = theta.shape[0] + + angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) + r = (torch.div(theta, angle)).unsqueeze(-1) + + angle = angle.unsqueeze(-1) + cos = torch.cos(angle) + sin = torch.sin(angle) + + outer = torch.matmul(r, r.transpose(1,2)) + + eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=device) + H = batch_skew(r, batch_size=batch_size) + R = cos * eyes + (1 - cos) * outer + sin * H + + return R + +def batch_lrotmin(theta): + """ + Output of this is used to compute joint-to-pose blend shape mapping. + Equation 9 in SMPL paper. + + + Args: + pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. + This includes the global rotation so K=24 + + Returns + diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., + """ + # Ignore global rotation + theta = theta[:,3:] + + Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) + lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) + + return lrotmin + +def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False): + """ + Computes absolute joint locations given pose. + + rotate_base: if True, rotates the global rotation by 90 deg in x axis. + if False, this is the original SMPL coordinate. + + Args: + Rs: N x 24 x 3 x 3 rotation vector of K joints + Js: N x 24 x 3, joint locations before posing + parent: 24 holding the parent id for each index + + Returns + new_J : `Tensor`: N x 24 x 3 location of absolute joints + A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. + """ + device = Rs.device + if rotate_base: + print('Flipping the SMPL coordinate frame!!!!') + rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile + root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) + else: + root_rotation = Rs[:, 0, :, :] + + # Now Js is N x 24 x 3 x 1 + Js = Js.unsqueeze(-1) + N = Rs.shape[0] + + def make_A(R, t): + # Rs is N x 3 x 3, ts is N x 3 x 1 + R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) + t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=device)], 1) + return torch.cat([R_homo, t_homo], 2) + + A0 = make_A(root_rotation, Js[:, 0]) + results = [A0] + for i in range(1, parent.shape[0]): + j_here = Js[:, i] - Js[:, parent[i]] + A_here = make_A(Rs[:, i], j_here) + res_here = torch.matmul( + results[parent[i]], A_here) + results.append(res_here) + + # 10 x 24 x 4 x 4 + results = torch.stack(results, dim=1) + + new_J = results[:, :, :3, 3] + + # --- Compute relative A: Skinning is based on + # how much the bone moved (not the final location of the bone) + # but (final_bone - init_bone) + # --- + Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=device)], 2) + init_bone = torch.matmul(results, Js_w0) + # Append empty 4 x 3: + init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) + A = results - init_bone + + return new_J, A + + +######################################################################################### + +def get_bone_length_scales(part_list, betas_logscale): + leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) + tail_joints = list(range(25, 32)) + ear_joints = [33, 34] + neck_joints = [15, 6] # ? + core_joints = [4, 5] # ? + mouth_joints = [16, 32] + log_scales = torch.zeros(betas_logscale.shape[0], 35).to(betas_logscale.device) + for ind, part in enumerate(part_list): + if part == 'legs_l': + log_scales[:, leg_joints] = betas_logscale[:, ind][:, None] + elif part == 'tail_l': + log_scales[:, tail_joints] = betas_logscale[:, ind][:, None] + elif part == 'ears_l': + log_scales[:, ear_joints] = betas_logscale[:, ind][:, None] + elif part == 'neck_l': + log_scales[:, neck_joints] = betas_logscale[:, ind][:, None] + elif part == 'core_l': + log_scales[:, core_joints] = betas_logscale[:, ind][:, None] + elif part == 'head_l': + log_scales[:, mouth_joints] = betas_logscale[:, ind][:, None] + else: + pass + all_scales = torch.exp(log_scales) + return all_scales[:, 1:] # don't count root + +def get_beta_scale_mask(part_list): + # which joints belong to which bodypart + leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) + tail_joints = list(range(25, 32)) + ear_joints = [33, 34] + neck_joints = [15, 6] # ? + core_joints = [4, 5] # ? + mouth_joints = [16, 32] + n_b_log = len(part_list) #betas_logscale.shape[1] # 8 # 6 + beta_scale_mask = torch.zeros(35, 3, n_b_log) # .to(betas_logscale.device) + for ind, part in enumerate(part_list): + if part == 'legs_l': + beta_scale_mask[leg_joints, [2], [ind]] = 1.0 # Leg lengthening + elif part == 'legs_f': + beta_scale_mask[leg_joints, [0], [ind]] = 1.0 # Leg fatness + beta_scale_mask[leg_joints, [1], [ind]] = 1.0 # Leg fatness + elif part == 'tail_l': + beta_scale_mask[tail_joints, [0], [ind]] = 1.0 # Tail lengthening + elif part == 'tail_f': + beta_scale_mask[tail_joints, [1], [ind]] = 1.0 # Tail fatness + beta_scale_mask[tail_joints, [2], [ind]] = 1.0 # Tail fatness + elif part == 'ears_y': + beta_scale_mask[ear_joints, [1], [ind]] = 1.0 # Ear y + elif part == 'ears_l': + beta_scale_mask[ear_joints, [2], [ind]] = 1.0 # Ear z + elif part == 'neck_l': + beta_scale_mask[neck_joints, [0], [ind]] = 1.0 # Neck lengthening + elif part == 'neck_f': + beta_scale_mask[neck_joints, [1], [ind]] = 1.0 # Neck fatness + beta_scale_mask[neck_joints, [2], [ind]] = 1.0 # Neck fatness + elif part == 'core_l': + beta_scale_mask[core_joints, [0], [ind]] = 1.0 # Core lengthening + # beta_scale_mask[core_joints, [1], [ind]] = 1.0 # Core fatness (height) + elif part == 'core_fs': + beta_scale_mask[core_joints, [2], [ind]] = 1.0 # Core fatness (side) + elif part == 'head_l': + beta_scale_mask[mouth_joints, [0], [ind]] = 1.0 # Head lengthening + elif part == 'head_f': + beta_scale_mask[mouth_joints, [1], [ind]] = 1.0 # Head fatness 0 + beta_scale_mask[mouth_joints, [2], [ind]] = 1.0 # Head fatness 1 + else: + print(part + ' not available') + raise ValueError + beta_scale_mask = torch.transpose( + beta_scale_mask.reshape(35*3, n_b_log), 0, 1) + return beta_scale_mask + +def batch_global_rigid_transformation_biggs(Rs, Js, parent, scale_factors_3x3, rotate_base = False, betas_logscale=None, opts=None): + """ + Computes absolute joint locations given pose. + + rotate_base: if True, rotates the global rotation by 90 deg in x axis. + if False, this is the original SMPL coordinate. + + Args: + Rs: N x 24 x 3 x 3 rotation vector of K joints + Js: N x 24 x 3, joint locations before posing + parent: 24 holding the parent id for each index + + Returns + new_J : `Tensor`: N x 24 x 3 location of absolute joints + A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. + """ + if rotate_base: + print('Flipping the SMPL coordinate frame!!!!') + rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile + root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) + else: + root_rotation = Rs[:, 0, :, :] + + # Now Js is N x 24 x 3 x 1 + Js = Js.unsqueeze(-1) + N = Rs.shape[0] + + Js_orig = Js.clone() + + def make_A(R, t): + # Rs is N x 3 x 3, ts is N x 3 x 1 + R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) + t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1) + return torch.cat([R_homo, t_homo], 2) + + A0 = make_A(root_rotation, Js[:, 0]) + results = [A0] + for i in range(1, parent.shape[0]): + j_here = Js[:, i] - Js[:, parent[i]] + try: + s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]]) + except: + # import pdb; pdb.set_trace() + s_par_inv = torch.max(scale_factors_3x3[:, parent[i]], 0.01*torch.eye((3))[None, :, :].to(scale_factors_3x3.device)) + rot = Rs[:, i] + s = scale_factors_3x3[:, i] + + rot_new = s_par_inv @ rot @ s + + A_here = make_A(rot_new, j_here) + res_here = torch.matmul( + results[parent[i]], A_here) + + results.append(res_here) + + # 10 x 24 x 4 x 4 + results = torch.stack(results, dim=1) + + # scale updates + new_J = results[:, :, :3, 3] + + # --- Compute relative A: Skinning is based on + # how much the bone moved (not the final location of the bone) + # but (final_bone - init_bone) + # --- + Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2) + init_bone = torch.matmul(results, Js_w0) + # Append empty 4 x 3: + init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) + A = results - init_bone + + return new_J, A \ No newline at end of file diff --git a/src/smal_pytorch/smal_model/smal_basics.py b/src/smal_pytorch/smal_model/smal_basics.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2e71ce5c5bd1d087041aed79a376eae749ad24 --- /dev/null +++ b/src/smal_pytorch/smal_model/smal_basics.py @@ -0,0 +1,82 @@ +''' +Adjusted version of other PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +''' + +import os +import pickle as pkl +import json +import numpy as np +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from configs.SMAL_configs import SMAL_DATA_DIR, SYMMETRY_INDS_FILE + +# model_dir = 'smalst/smpl_models/' +# FILE_DIR = os.path.dirname(os.path.realpath(__file__)) +model_dir = SMAL_DATA_DIR # os.path.join(FILE_DIR, '..', 'smpl_models/') +symmetry_inds_file = SYMMETRY_INDS_FILE # os.path.join(FILE_DIR, '..', 'smpl_models/symmetry_inds.json') +with open(symmetry_inds_file) as f: + symmetry_inds_dict = json.load(f) +LEFT_INDS = np.asarray(symmetry_inds_dict['left_inds']) +RIGHT_INDS = np.asarray(symmetry_inds_dict['right_inds']) +CENTER_INDS = np.asarray(symmetry_inds_dict['center_inds']) + + +def get_symmetry_indices(): + sym_dict = {'left': LEFT_INDS, + 'right': RIGHT_INDS, + 'center': CENTER_INDS} + return sym_dict + +def verify_symmetry(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS): + # shapedirs: (3889, 3, n_sh) + assert (shapedirs[center_inds, 1, :] == 0.0).all() + assert (shapedirs[right_inds, 1, :] == -shapedirs[left_inds, 1, :]).all() + return + +def from_shapedirs_to_shapedirs_half(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS, verify=False): + # shapedirs: (3889, 3, n_sh) + # shapedirs_half: (2012, 3, n_sh) + selected_inds = np.concatenate((center_inds, left_inds), axis=0) + shapedirs_half = shapedirs[selected_inds, :, :] + if verify: + verify_symmetry(shapedirs) + else: + shapedirs_half[:center_inds.shape[0], 1, :] = 0.0 + return shapedirs_half + +def from_shapedirs_half_to_shapedirs(shapedirs_half, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS): + # shapedirs_half: (2012, 3, n_sh) + # shapedirs: (3889, 3, n_sh) + shapedirs = np.zeros((center_inds.shape[0] + 2*left_inds.shape[0], 3, shapedirs_half.shape[2])) + shapedirs[center_inds, :, :] = shapedirs_half[:center_inds.shape[0], :, :] + shapedirs[left_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :] + shapedirs[right_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :] + shapedirs[right_inds, 1, :] = - shapedirs_half[center_inds.shape[0]:, 1, :] + return shapedirs + +def align_smal_template_to_symmetry_axis(v, subtract_mean=True): + # These are the indexes of the points that are on the symmetry axis + I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003] + if subtract_mean: + v = v - np.mean(v) + y = np.mean(v[I,1]) + v[:,1] = v[:,1] - y + v[I,1] = 0 + left_inds = LEFT_INDS + right_inds = RIGHT_INDS + center_inds = CENTER_INDS + v[right_inds, :] = np.array([1,-1,1])*v[left_inds, :] + try: + assert(len(left_inds) == len(right_inds)) + except: + import pdb; pdb.set_trace() + return v, left_inds, right_inds, center_inds + + + diff --git a/src/smal_pytorch/smal_model/smal_torch_new.py b/src/smal_pytorch/smal_model/smal_torch_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5562a33b97849116d827a5213e81c40ece705b70 --- /dev/null +++ b/src/smal_pytorch/smal_model/smal_torch_new.py @@ -0,0 +1,313 @@ +""" +PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +main changes compared to SMALST and WLDO: + * new model + (/ps/scratch/nrueegg/new_projects/side_packages/SMALify/new_smal_pca/results/my_tposeref_results_3/) + dogs are part of the pca to create the model + al meshes are centered around their root joint + the animals are all scaled such that their body length (butt to breast) is 1 + X_init = np.concatenate((vertices_dogs, vertices_smal), axis=0) # vertices_dogs + X = [] + for ind in range(0, X_init.shape[0]): + X_tmp, _, _, _ = align_smal_template_to_symmetry_axis(X_init[ind, :, :], subtract_mean=True) # not sure if this is necessary + X.append(X_tmp) + X = np.asarray(X) + # define points which will be used for normalization + idxs_front = [6, 16, 8, 964] # [1172, 6, 16, 8, 964] + idxs_back = [174, 2148, 175, 2149] # not in the middle, but pairs + reg_j = np.asarray(dd['J_regressor'].todense()) + # normalize the meshes such that X_frontback_dist is 1 and the root joint is in the center (0, 0, 0) + X_front = X[:, idxs_front, :].mean(axis=1) + X_back = X[:, idxs_back, :].mean(axis=1) + X_frontback_dist = np.sqrt(((X_front - X_back)**2).sum(axis=1)) + X = X / X_frontback_dist[:, None, None] + X_j0 = np.sum(X[:, reg_j[0, :]>0, :] * reg_j[0, (reg_j[0, :]>0)][None, :, None], axis=1) + X = X - X_j0[:, None, :] + * add limb length changes the same way as in WLDO + * overall scale factor is added +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import chumpy as ch +import os.path +from torch import nn +from torch.autograd import Variable +import pickle as pkl +from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation, batch_global_rigid_transformation_biggs, get_bone_length_scales, get_beta_scale_mask + +from .smal_basics import align_smal_template_to_symmetry_axis, get_symmetry_indices + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from configs.SMAL_configs import KEY_VIDS, CANONICAL_MODEL_JOINTS, IDXS_BONES_NO_REDUNDANCY, SMAL_MODEL_PATH + +from smal_pytorch.utils import load_vertex_colors + + +# There are chumpy variables so convert them to numpy. +def undo_chumpy(x): + return x if isinstance(x, np.ndarray) else x.r + +# class SMAL(object): +class SMAL(nn.Module): + def __init__(self, pkl_path=SMAL_MODEL_PATH, n_betas=None, template_name='neutral', use_smal_betas=True, logscale_part_list=None): + super(SMAL, self).__init__() + + if logscale_part_list is None: + self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] + self.betas_scale_mask = get_beta_scale_mask(part_list=self.logscale_part_list) + self.num_betas_logscale = len(self.logscale_part_list) + + self.use_smal_betas = use_smal_betas + + # -- Load SMPL params -- + try: + with open(pkl_path, 'r') as f: + dd = pkl.load(f) + except (UnicodeDecodeError, TypeError) as e: + with open(pkl_path, 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + dd = u.load() + + self.f = dd['f'] + self.register_buffer('faces', torch.from_numpy(self.f.astype(int))) + + # get the correct template (mean shape) + if template_name=='neutral': + v_template = dd['v_template'] + v = v_template + else: + raise NotImplementedError + + # Mean template vertices + self.register_buffer('v_template', torch.Tensor(v)) + # Size of mesh [Number of vertices, 3] + self.size = [self.v_template.shape[0], 3] + self.num_betas = dd['shapedirs'].shape[-1] + # symmetry indices + self.sym_ids_dict = get_symmetry_indices() + + # Shape blend shape basis + shapedir = np.reshape(undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T + shapedir.flags['WRITEABLE'] = True # not sure why this is necessary + self.register_buffer('shapedirs', torch.Tensor(shapedir)) + + # Regressor for joint locations given shape + self.register_buffer('J_regressor', torch.Tensor(dd['J_regressor'].T.todense())) + + # Pose blend shape basis + num_pose_basis = dd['posedirs'].shape[-1] + + posedirs = np.reshape(undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T + self.register_buffer('posedirs', torch.Tensor(posedirs)) + + # indices of parents for each joints + self.parents = dd['kintree_table'][0].astype(np.int32) + + # LBS weights + self.register_buffer('weights', torch.Tensor(undo_chumpy(dd['weights']))) + + + def _caclulate_bone_lengths_from_J(self, J, betas_logscale): + # NEW: calculate bone lengths: + all_bone_lengths_list = [] + for i in range(1, self.parents.shape[0]): + bone_vec = J[:, i] - J[:, self.parents[i]] + bone_length = torch.sqrt(torch.sum(bone_vec ** 2, axis=1)) + all_bone_lengths_list.append(bone_length) + all_bone_lengths = torch.stack(all_bone_lengths_list) + # some bones are pairs, it is enough to take one of the two bones + all_bone_length_scales = get_bone_length_scales(self.logscale_part_list, betas_logscale) + all_bone_lengths = all_bone_lengths.permute((1,0)) * all_bone_length_scales + + return all_bone_lengths #.permute((1,0)) + + + def caclulate_bone_lengths(self, beta, betas_logscale, shapedirs_sel=None, short=True): + nBetas = beta.shape[1] + + # 1. Add shape blend shapes + # do we use the original shapedirs or a new set of selected shapedirs? + if shapedirs_sel is None: + shapedirs_sel = self.shapedirs[:nBetas,:] + else: + assert shapedirs_sel.shape[0] == nBetas + v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + + # 2. Infer shape-dependent joint locations. + Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) + Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) + Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) + J = torch.stack([Jx, Jy, Jz], dim=2) + + # calculate bone lengths + all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale) + selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY] + + if short: + return selected_bone_lengths + else: + return all_bone_lengths + + + + def __call__(self, beta, betas_limbs, theta=None, pose=None, trans=None, del_v=None, get_skin=True, keyp_conf='red', get_all_info=False, shapedirs_sel=None): + device = beta.device + + betas_logscale = betas_limbs + # NEW: allow that rotation is given as rotation matrices instead of axis angle rotation + # theta: BSxNJointsx3 or BSx(NJoints*3) + # pose: NxNJointsx3x3 + if (theta is None) and (pose is None): + raise ValueError("Either pose (rotation matrices NxNJointsx3x3) or theta (axis angle BSxNJointsx3) must be given") + elif (theta is not None) and (pose is not None): + raise ValueError("Not both pose (rotation matrices NxNJointsx3x3) and theta (axis angle BSxNJointsx3) can be given") + + if True: # self.use_smal_betas: + nBetas = beta.shape[1] + else: + nBetas = 0 + + # 1. Add shape blend shapes + # do we use the original shapedirs or a new set of selected shapedirs? + if shapedirs_sel is None: + shapedirs_sel = self.shapedirs[:nBetas,:] + else: + assert shapedirs_sel.shape[0] == nBetas + + if nBetas > 0: + if del_v is None: + v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + else: + v_shaped = self.v_template + del_v + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + else: + if del_v is None: + v_shaped = self.v_template.unsqueeze(0) + else: + v_shaped = self.v_template + del_v + + # 2. Infer shape-dependent joint locations. + Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) + Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) + Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) + J = torch.stack([Jx, Jy, Jz], dim=2) + + # 3. Add pose blend shapes + # N x 24 x 3 x 3 + if pose is None: + Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3]) + else: + Rs = pose + # Ignore global rotation. + pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(device=device), [-1, 306]) + + v_posed = torch.reshape( + torch.matmul(pose_feature, self.posedirs), + [-1, self.size[0], self.size[1]]) + v_shaped + + #------------------------- + # new: add corrections of bone lengths to the template (before hypothetical pose blend shapes!) + # see biggs batch_lbs.py + betas_scale = torch.exp(betas_logscale @ self.betas_scale_mask.to(betas_logscale.device)) + scaling_factors = betas_scale.reshape(-1, 35, 3) + scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1) + + # 4. Get the global joint location + # self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents) + self.J_transformed, A = batch_global_rigid_transformation_biggs(Rs, J, self.parents, scale_factors_3x3, betas_logscale=betas_logscale) + + # 2-BONES. Calculate bone lengths + all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale) + # selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY] + #------------------------- + + # 5. Do skinning: + num_batch = Rs.shape[0] + + weights_t = self.weights.repeat([num_batch, 1]) + W = torch.reshape(weights_t, [num_batch, -1, 35]) + + + T = torch.reshape( + torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])), + [num_batch, -1, 4, 4]) + v_posed_homo = torch.cat( + [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=device)], 2) + v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1)) + + verts = v_homo[:, :, :3, 0] + + if trans is None: + trans = torch.zeros((num_batch,3)).to(device=device) + + verts = verts + trans[:,None,:] + + # Get joints: + joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) + joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) + joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) + joints = torch.stack([joint_x, joint_y, joint_z], dim=2) + + # New... (see https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py) + joints = torch.cat([ + joints, + verts[:, None, 1863], # end_of_nose + verts[:, None, 26], # chin + verts[:, None, 2124], # right ear tip + verts[:, None, 150], # left ear tip + verts[:, None, 3055], # left eye + verts[:, None, 1097], # right eye + ], dim = 1) + + if keyp_conf == 'blue' or keyp_conf == 'dict': + # Generate keypoints + nLandmarks = KEY_VIDS.shape[0] # 24 + j3d = torch.zeros((num_batch, nLandmarks, 3)).to(device=device) + for j in range(nLandmarks): + j3d[:, j,:] = torch.mean(verts[:, KEY_VIDS[j],:], dim=1) # translation is already added to the vertices + joints_blue = j3d + + joints_red = joints[:, :-6, :] + joints_green = joints[:, CANONICAL_MODEL_JOINTS, :] + + if keyp_conf == 'red': + relevant_joints = joints_red + elif keyp_conf == 'green': + relevant_joints = joints_green + elif keyp_conf == 'blue': + relevant_joints = joints_blue + elif keyp_conf == 'dict': + relevant_joints = {'red': joints_red, + 'green': joints_green, + 'blue': joints_blue} + else: + raise NotImplementedError + + if get_all_info: + return verts, relevant_joints, Rs, all_bone_lengths + else: + if get_skin: + return verts, relevant_joints, Rs # , v_shaped + else: + return relevant_joints + + + + + + + + + + + diff --git a/src/smal_pytorch/utils.py b/src/smal_pytorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11e48a0fe88cf27472c56cb7c9d3359984fd9b9a --- /dev/null +++ b/src/smal_pytorch/utils.py @@ -0,0 +1,13 @@ +import numpy as np + +def load_vertex_colors(obj_path): + v_colors = [] + for line in open(obj_path, "r"): + if line.startswith('#'): continue + values = line.split() + if not values: continue + if values[0] == 'v': + v_colors.append(values[4:7]) + else: + continue + return np.asarray(v_colors, dtype=np.float32) diff --git a/src/stacked_hourglass/__init__.py b/src/stacked_hourglass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5da308c10ea77f41570a7a0d417c28ad19ae9d2 --- /dev/null +++ b/src/stacked_hourglass/__init__.py @@ -0,0 +1,2 @@ +from stacked_hourglass.model import hg1, hg2, hg4, hg8 +from stacked_hourglass.predictor import HumanPosePredictor diff --git a/src/stacked_hourglass/datasets/__init__.py b/src/stacked_hourglass/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/stacked_hourglass/datasets/imgcrops.py b/src/stacked_hourglass/datasets/imgcrops.py new file mode 100644 index 0000000000000000000000000000000000000000..89face653c8d6c92fb4bf453a1ae46957ee68dff --- /dev/null +++ b/src/stacked_hourglass/datasets/imgcrops.py @@ -0,0 +1,77 @@ + + +import os +import glob +import numpy as np +import torch +import torch.utils.data as data + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.anipose_data_info import COMPLETE_DATA_INFO +from stacked_hourglass.utils.imutils import load_image +from stacked_hourglass.utils.transforms import crop, color_normalize +from stacked_hourglass.utils.pilutil import imresize +from stacked_hourglass.utils.imutils import im_to_torch +from configs.dataset_path_configs import TEST_IMAGE_CROP_ROOT_DIR +from configs.data_info import COMPLETE_DATA_INFO_24 + + +class ImgCrops(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, img_crop_folder='default', image_path=None, is_train=False, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only'): + assert is_train == False + assert do_augment == 'default' or do_augment == False + self.inp_res = inp_res + if img_crop_folder == 'default': + self.folder_imgs = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'datasets', 'test_image_crops') + else: + self.folder_imgs = img_crop_folder + name_list = glob.glob(os.path.join(self.folder_imgs, '*.png')) + glob.glob(os.path.join(self.folder_imgs, '*.jpg')) + glob.glob(os.path.join(self.folder_imgs, '*.jpeg')) + name_list = sorted(name_list) + self.test_name_list = [name.split('/')[-1] for name in name_list] + print('len(dataset): ' + str(self.__len__())) + + def __getitem__(self, index): + img_name = self.test_name_list[index] + # load image + img_path = os.path.join(self.folder_imgs, img_name) + img = load_image(img_path) # CxHxW + # prepare image (cropping and color) + img_max = max(img.shape[1], img.shape[2]) + img_padded = torch.zeros((img.shape[0], img_max, img_max)) + if img_max == img.shape[2]: + start = (img_max-img.shape[1])//2 + img_padded[:, start:start+img.shape[1], :] = img + else: + start = (img_max-img.shape[2])//2 + img_padded[:, :, start:start+img.shape[2]] = img + img = img_padded + img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) + inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + # add the following fields to make it compatible with stanext, most of them are fake + target_dict = {'index': index, 'center' : -2, 'scale' : -2, + 'breed_index': -2, 'sim_breed_index': -2, + 'ind_dataset': 1} + target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) + target_dict['silh'] = np.zeros((self.inp_res, self.inp_res)) + return inp, target_dict + + + def __len__(self): + return len(self.test_name_list) + + + + + + + + + diff --git a/src/stacked_hourglass/datasets/imgcropslist.py b/src/stacked_hourglass/datasets/imgcropslist.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c87dbb902995cf02393247f990217dbd2746f7 --- /dev/null +++ b/src/stacked_hourglass/datasets/imgcropslist.py @@ -0,0 +1,95 @@ + + +import os +import glob +import numpy as np +import math +import torch +import torch.utils.data as data + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.anipose_data_info import COMPLETE_DATA_INFO +from stacked_hourglass.utils.imutils import load_image, im_to_torch +from stacked_hourglass.utils.transforms import crop, color_normalize +from stacked_hourglass.utils.pilutil import imresize +from stacked_hourglass.utils.imutils import im_to_torch +from configs.data_info import COMPLETE_DATA_INFO_24 + + +class ImgCrops(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_list, bbox_list=None, inp_res=256, dataset_mode='keyp_only'): + # the list contains the images directly, not only their paths + self.image_list = image_list + self.bbox_list = bbox_list + self.inp_res = inp_res + self.test_name_list = [] + for ind in np.arange(0, len(self.image_list)): + self.test_name_list.append(str(ind)) + print('len(dataset): ' + str(self.__len__())) + + def __getitem__(self, index): + '''img_name = self.test_name_list[index] + # load image + img_path = os.path.join(self.folder_imgs, img_name) + img = load_image(img_path) # CxHxW''' + + # load image + '''img_hwc = self.image_list[index] + img = np.rollaxis(img_hwc, 2, 0) ''' + img = im_to_torch(self.image_list[index]) + + # import pdb; pdb.set_trace() + + # try loading bounding box + if self.bbox_list is not None: + bbox = self.bbox_list[index] + bbox_xywh = [bbox[0][0], bbox[0][1], bbox[1][0]-bbox[0][0], bbox[1][1]-bbox[0][1]] + bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] + bbox_max = max(bbox_xywh[2], bbox_xywh[3]) + bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + img_prep = crop(img, c, s, [self.inp_res, self.inp_res], rot=0) + + else: + + # prepare image (cropping and color) + img_max = max(img.shape[1], img.shape[2]) + img_padded = torch.zeros((img.shape[0], img_max, img_max)) + if img_max == img.shape[2]: + start = (img_max-img.shape[1])//2 + img_padded[:, start:start+img.shape[1], :] = img + else: + start = (img_max-img.shape[2])//2 + img_padded[:, :, start:start+img.shape[2]] = img + img = img_padded + img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) + + inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + # add the following fields to make it compatible with stanext, most of them are fake + target_dict = {'index': index, 'center' : -2, 'scale' : -2, + 'breed_index': -2, 'sim_breed_index': -2, + 'ind_dataset': 1} + target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) + target_dict['silh'] = np.zeros((self.inp_res, self.inp_res)) + return inp, target_dict + + + def __len__(self): + return len(self.image_list) + + + + + + + + + diff --git a/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb8a636d1138a58cd2265f931e2c19ef47a9220 --- /dev/null +++ b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py @@ -0,0 +1,171 @@ + +import numpy as np +import random +import copy +import time +import warnings + +from torch.utils.data import Sampler +from torch._six import int_classes as _int_classes + +class CustomPairBatchSampler(Sampler): + """Wraps another sampler to yield a mini-batch of indices. + The structure of this sampler is way to complicated because it is a shorter/simplified version of + CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept + this structure which we were using for the experiments with clade related losses. ToDo: restructure + this sampler. + Args: + data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds. + batch_size (int): Size of mini-batch. + """ + + def __init__(self, data_sampler_info, batch_size): + if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + assert batch_size%2 == 0 + self.data_sampler_info = data_sampler_info + self.batch_size = batch_size + self.n_desired_batches = int(np.floor(len(self.data_sampler_info['name_list']) / batch_size)) # 157 + + def get_description(self): + description = "\ + This sampler works only for even batch sizes. \n\ + It returns pairs of dogs of the same breed" + return description + + + def __iter__(self): + breeds_summary = self.data_sampler_info['breeds_summary'] + + breed_image_dict_orig = {} + for img_name in self.data_sampler_info['name_list']: # ['n02093859-Kerry_blue_terrier/n02093859_913.jpg', ... ] + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (breed_name in breed_image_dict_orig): + breed_image_dict_orig[breed_name] = [img_name] + else: + breed_image_dict_orig[breed_name].append(img_name) + + lengths = np.zeros((len(breed_image_dict_orig.values()))) + for ind, value in enumerate(breed_image_dict_orig.values()): + lengths[ind] = len(value) + + sim_matrix_raw = self.data_sampler_info['breeds_sim_martix_raw'] + sim_matrix_raw[sim_matrix_raw>0].shape # we have 1061 connections + + # from ind_in_sim_mat to breed_name + inverse_sim_dict = {} + for abbrev, ind in self.data_sampler_info['breeds_sim_abbrev_inds'].items(): + # breed_name might be None + breed = breeds_summary[abbrev] + breed_name = breed._name_stanext + inverse_sim_dict[ind] = {'abbrev': abbrev, + 'breed_name': breed_name} + + # similarity for relevant breeds only: + related_breeds_top_orig = {} + temp = np.arange(sim_matrix_raw.shape[0]) + for breed_name, breed_images in breed_image_dict_orig.items(): + abbrev = self.data_sampler_info['breeds_abbrev_dict'][breed_name] + related_breeds = {} + if abbrev in self.data_sampler_info['breeds_sim_abbrev_inds'].keys(): + ind_in_sim_mat = self.data_sampler_info['breeds_sim_abbrev_inds'][abbrev] + row = sim_matrix_raw[ind_in_sim_mat, :] + rel_inds = temp[row>0] + for ind in rel_inds: + rel_breed_name = inverse_sim_dict[ind]['breed_name'] + rel_abbrev = inverse_sim_dict[ind]['abbrev'] + # does this breed exist in this dataset? + if (rel_breed_name is not None) and (rel_breed_name in breed_image_dict_orig.keys()) and not (rel_breed_name==breed_name): + related_breeds[rel_breed_name] = row[ind] + related_breeds_top_orig[breed_name] = related_breeds + + breed_image_dict = copy.deepcopy(breed_image_dict_orig) + related_breeds_top = copy.deepcopy(related_breeds_top_orig) + + # clean the related_breeds_top dict such that it only contains breeds which are available + for breed_name, breed_images in breed_image_dict.items(): + if len(breed_image_dict[breed_name]) < 1: + for breed_name_rel in list(related_breeds_top[breed_name].keys()): + related_breeds_top[breed_name_rel].pop(breed_name, None) + related_breeds_top[breed_name].pop(breed_name_rel, None) + + # 1) build pairs of dogs + set_of_breeds_with_at_least_2 = set() + for breed_name, breed_images in breed_image_dict.items(): + if len(breed_images) >= 2: + set_of_breeds_with_at_least_2.add(breed_name) + + n_unused_images = len(self.data_sampler_info['name_list']) + all_dog_duos = [] + n_new_duos = 1 + while n_new_duos > 0: + for breed_name, breed_images in breed_image_dict.items(): + # shuffle image list for this specific breed (this changes the dict) + random.shuffle(breed_images) + breed_list = list(related_breeds_top.keys()) + random.shuffle(breed_list) + n_new_duos = 0 + for breed_name in breed_list: + if len(breed_image_dict[breed_name]) >= 2: + dog_a = breed_image_dict[breed_name].pop() + dog_b = breed_image_dict[breed_name].pop() + dog_duo = [dog_a, dog_b] + all_dog_duos.append({'image_names': dog_duo}) + # clean the related_breeds_top dict such that it only contains breeds which are still available + if len(breed_image_dict[breed_name]) < 1: + for breed_name_rel in list(related_breeds_top[breed_name].keys()): + related_breeds_top[breed_name_rel].pop(breed_name, None) + related_breeds_top[breed_name].pop(breed_name_rel, None) + n_new_duos += 1 + n_unused_images -= 2 + + image_name_to_ind = {} + for ind_img_name, img_name in enumerate(self.data_sampler_info['name_list']): + image_name_to_ind[img_name] = ind_img_name + + # take all images and create the batches + n_avail_2 = len(all_dog_duos) + all_batches = [] + ind_in_duos = 0 + n_imgs_used_twice = 0 + for ind_b in range(0, self.n_desired_batches): + batch_with_image_names = [] + for ind in range(int(np.floor(self.batch_size / 2))): + if ind_in_duos >= n_avail_2: + ind_rand = random.randint(0, n_avail_2-1) + batch_with_image_names.extend(all_dog_duos[ind_rand]['image_names']) + n_imgs_used_twice += 2 + else: + batch_with_image_names.extend(all_dog_duos[ind_in_duos]['image_names']) + ind_in_duos += 1 + + + batch_with_inds = [] + for image_name in batch_with_image_names: # rather a folder than name + batch_with_inds.append(image_name_to_ind[image_name]) + + all_batches.append(batch_with_inds) + + for batch in all_batches: + yield batch + + def __len__(self): + # Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not + # guarantee to show each dog exacly once. What we do instead, is returning the same amount of + # batches as we would return with a standard sampler which is not based on dog pairs. + '''if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' + return self.n_desired_batches + + + + + + + + diff --git a/src/stacked_hourglass/datasets/stanext24.py b/src/stacked_hourglass/datasets/stanext24.py new file mode 100644 index 0000000000000000000000000000000000000000..e217bf076fb63de5655fc173737ecd2e9803b1e6 --- /dev/null +++ b/src/stacked_hourglass/datasets/stanext24.py @@ -0,0 +1,301 @@ +# 24 joints instead of 20!! + + +import gzip +import json +import os +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.data_info import COMPLETE_DATA_INFO_24 +from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from stacked_hourglass.utils.misc import to_torch +from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import stacked_hourglass.datasets.utils_stanext as utils_stanext +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints +from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES +from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR + + +class StanExt(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + + # Suggested joints to use for keypoint reprojection error calculations + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): + self.V12 = V12 + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': + self.calc_seg = True + else: + self.calc_seg = False + self.val_opt = val_opt + + # create train/val split + self.img_folder = utils_stanext.get_img_dir(V12=self.V12) + self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + self.train_name_list = list(self.train_dict.keys()) # 7004 + if self.val_opt == 'test': + self.test_dict = init_test_dict + self.test_name_list = list(self.test_dict.keys()) + elif self.val_opt == 'val': + self.test_dict = init_val_dict + self.test_name_list = list(self.test_dict.keys()) + else: + raise NotImplementedError + + # stanext breed dict (contains for each name a stanext specific index) + breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') + self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) + self.train_name_list = sorted(self.train_name_list) + self.test_name_list = sorted(self.test_name_list) + random.seed(4) + random.shuffle(self.train_name_list) + random.shuffle(self.test_name_list) + if shorten_dataset_to is not None: + # sometimes it is useful to have a smaller set (validation speed, debugging) + self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] + self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] + # special case for debugging: 12 similar images + if shorten_dataset_to == 12: + my_sample = self.test_name_list[2] + for ind in range(0, 12): + self.test_name_list[ind] = my_sample + print('len(dataset): ' + str(self.__len__())) + + # add results for eyes, whithers and throat as obtained through anipose -> they are used + # as pseudo ground truth at training time. + self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') + + + def get_data_sampler_info(self): + # for custom data sampler + if self.is_train: + name_list = self.train_name_list + else: + name_list = self.test_name_list + info_dict = {'name_list': name_list, + 'stanext_breed_dict': self.breed_dict, + 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, + 'breeds_summary': COMPLETE_SUMMARY_BREEDS, + 'breeds_sim_martix_raw': SIM_MATRIX_RAW, + 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES + } + return info_dict + + + def get_breed_dict(self, breed_json_path, create_new_breed_json=False): + if create_new_breed_json: + breed_dict = {} + breed_index = 0 + for img_name in self.train_name_list: + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (folder_name in breed_dict): + breed_dict[folder_name] = { + 'breed_name': breed_name, + 'index': breed_index} + breed_index += 1 + with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) + else: + with open(breed_json_path) as json_file: breed_dict = json.load(json_file) + return breed_dict + + + def __getitem__(self, index): + + if self.is_train: + name = self.train_name_list[index] + data = self.train_dict[name] + else: + name = self.test_name_list[index] + data = self.test_dict[name] + + sf = self.scale_factor + rf = self.rot_factor + + img_path = os.path.join(self.img_folder, data['img_path']) + try: + anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + with open(anipose_res_path) as f: anipose_data = json.load(f) + anipose_thr = 0.2 + anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) + anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] + # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 + anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + # For single-person pose estimation with a centered/scaled figure + nparts = pts.size(0) + img = load_image(img_path) # CxHxW + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.calc_seg: + seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) + seg = torch.cat(3*[seg]) + + r = 0 + do_flip = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if random.random() <= 0.5: + do_flip = True + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) + c[0] = img.size(2) - c[0] + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # Generate ground truth + tpts = pts.clone() + target_weight = tpts[:, 2].clone().view(nparts, 1) + + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) + target[i], vis = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + # NEW: + '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) + target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new + target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' + + # --- Meta info + this_breed = self.breed_dict[name.split('/')[0]] # 120 + # add information about location within breed similarity matrix + folder_name = name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + abbrev = COMPLETE_ABBREV_DICT[breed_name] + try: + sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix + except: # some breeds are not in the xlsx file + sim_breed_index = -1 + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, + 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 + meta2 = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'ind_dataset': 3} + + # return different things depending on dataset_mode + if self.dataset_mode=='keyp_only': + # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg': + meta['silh'] = seg[0, :, :] + meta['name'] = name + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg_and_partseg': + # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations + meta2['silh'] = seg[0, :, :] + meta2['name'] = name + fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) + meta2['body_part_matrix'] = fake_body_part_matrix + return inp, target, meta2 + elif self.dataset_mode=='complete': + target_dict = meta + target_dict['silh'] = seg[0, :, :] + # NEW for silhouette loss + target_dict['img_border_mask'] = img_border_mask + target_dict['has_seg'] = True + if target_dict['silh'].sum() < 1: + if ((not self.is_train) and self.val_opt == 'test'): + raise ValueError + elif self.is_train: + print('had to replace training image') + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + else: + # There seem to be a few validation images without segmentation + # which would lead to nan in iou calculation + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + return inp, target_dict + else: + print('sampling error') + import pdb; pdb.set_trace() + raise ValueError + + + def __len__(self): + if self.is_train: + return len(self.train_name_list) + else: + return len(self.test_name_list) + + diff --git a/src/stacked_hourglass/datasets/utils_stanext.py b/src/stacked_hourglass/datasets/utils_stanext.py new file mode 100644 index 0000000000000000000000000000000000000000..83da8452f74ff8fb0ca95e2d8a42ba96972f684b --- /dev/null +++ b/src/stacked_hourglass/datasets/utils_stanext.py @@ -0,0 +1,114 @@ + +import os +from matplotlib import pyplot as plt +import glob +import json +import numpy as np +from scipy.io import loadmat +from csv import DictReader +from collections import OrderedDict +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.dataset_path_configs import IMG_V12_DIR, JSON_V12_DIR, STAN_V12_TRAIN_LIST_DIR, STAN_V12_VAL_LIST_DIR, STAN_V12_TEST_LIST_DIR + + +def get_img_dir(V12): + if V12: + return IMG_V12_DIR + else: + return IMG_DIR + +def get_seg_from_entry(entry): + """Given a .json entry, returns the binary mask as a numpy array""" + rle = { + "size": [entry['img_height'], entry['img_width']], + "counts": entry['seg']} + decoded = decode_RLE(rle) + return decoded + +def full_animal_visible(seg_data): + if seg_data[0, :].sum() == 0 and seg_data[seg_data.shape[0]-1, :].sum() == 0 and seg_data[:, 0].sum() == 0 and seg_data[:, seg_data.shape[1]-1].sum() == 0: + return True + else: + return False + +def load_train_and_test_lists(train_list_dir=None , test_list_dir=None): + """ returns sets containing names such as 'n02085620-Chihuahua/n02085620_5927.jpg' """ + # train data + train_list_mat = loadmat(train_list_dir) + train_list = [] + for ind in range(0, train_list_mat['file_list'].shape[0]): + name = train_list_mat['file_list'][ind, 0][0] + train_list.append(name) + # test data + test_list_mat = loadmat(test_list_dir) + test_list = [] + for ind in range(0, test_list_mat['file_list'].shape[0]): + name = test_list_mat['file_list'][ind, 0][0] + test_list.append(name) + return train_list, test_list + + + +def _filter_dict(t_list, j_dict, n_kp_min=4): + """ should only be used by load_stanext_json_as_dict() """ + out_dict = {} + for sample in t_list: + if sample in j_dict.keys(): + n_kp = np.asarray(j_dict[sample]['joints'])[:, 2].sum() + if n_kp >= n_kp_min: + out_dict[sample] = j_dict[sample] + return out_dict + +def load_stanext_json_as_dict(split_train_test=True, V12=True): + # load json into memory + if V12: + with open(JSON_V12_DIR) as infile: + json_data = json.load(infile) + # with open(JSON_V12_DIR) as infile: json_data = json.load(infile, object_pairs_hook=OrderedDict) + else: + with open(JSON_DIR) as infile: + json_data = json.load(infile) + # convert json data to a dictionary of img_path : all_data, for easy lookup + json_dict = {i['img_path']: i for i in json_data} + if split_train_test: + if V12: + train_list_numbers = np.load(STAN_V12_TRAIN_LIST_DIR) + val_list_numbers = np.load(STAN_V12_VAL_LIST_DIR) + test_list_numbers = np.load(STAN_V12_TEST_LIST_DIR) + train_list = [json_data[i]['img_path'] for i in train_list_numbers] + val_list = [json_data[i]['img_path'] for i in val_list_numbers] + test_list = [json_data[i]['img_path'] for i in test_list_numbers] + train_dict = _filter_dict(train_list, json_dict, n_kp_min=4) + val_dict = _filter_dict(val_list, json_dict, n_kp_min=4) + test_dict = _filter_dict(test_list, json_dict, n_kp_min=4) + return train_dict, test_dict, val_dict + else: + train_list, test_list = load_train_and_test_lists(train_list_dir=STAN_ORIG_TRAIN_LIST_DIR , test_list_dir=STAN_ORIG_TEST_LIST_DIR) + train_dict = _filter_dict(train_list, json_dict) + test_dict = _filter_dict(test_list, json_dict) + return train_dict, test_dict, None + else: + return json_dict + +def get_dog(json_dict, name, img_dir=None): # (json_dict, name, img_dir=IMG_DIR) + """ takes the name of a dog, and loads in all the relevant information as a dictionary: + dict_keys(['img_path', 'img_width', 'img_height', 'joints', 'img_bbox', + 'is_multiple_dogs', 'seg', 'img_data', 'seg_data']) + img_bbox: [x0, y0, width, height] """ + data = json_dict[name] + # load img + img_data = plt.imread(os.path.join(img_dir, data['img_path'])) + # load seg + seg_data = get_seg_from_entry(data) + # add to output + data['img_data'] = img_data # 0 to 255 + data['seg_data'] = seg_data # 0: bg, 1: fg + return data + + + + + diff --git a/src/stacked_hourglass/model.py b/src/stacked_hourglass/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0df09246044e8450efeb0b12f86cb9780a435a60 --- /dev/null +++ b/src/stacked_hourglass/model.py @@ -0,0 +1,308 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose +# Hourglass network inserted in the pre-activated Resnet +# Use lr=0.01 for current version +# (c) YANG, Wei + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url + + +__all__ = ['HourglassNet', 'hg'] + + +model_urls = { + 'hg1': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg1-ce125879.pth', + 'hg2': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg2-15e342d9.pth', + 'hg8': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg8-90e5d470.pth', +} + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=True) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + +class Hourglass(nn.Module): + def __init__(self, block, num_blocks, planes, depth): + super(Hourglass, self).__init__() + self.depth = depth + self.block = block + self.hg = self._make_hour_glass(block, num_blocks, planes, depth) + + def _make_residual(self, block, num_blocks, planes): + layers = [] + for i in range(0, num_blocks): + layers.append(block(planes*block.expansion, planes)) + return nn.Sequential(*layers) + + def _make_hour_glass(self, block, num_blocks, planes, depth): + hg = [] + for i in range(depth): + res = [] + for j in range(3): + res.append(self._make_residual(block, num_blocks, planes)) + if i == 0: + res.append(self._make_residual(block, num_blocks, planes)) + hg.append(nn.ModuleList(res)) + return nn.ModuleList(hg) + + def _hour_glass_forward(self, n, x): + up1 = self.hg[n-1][0](x) + low1 = F.max_pool2d(x, 2, stride=2) + low1 = self.hg[n-1][1](low1) + + if n > 1: + low2 = self._hour_glass_forward(n-1, low1) + else: + low2 = self.hg[n-1][3](low1) + low3 = self.hg[n-1][2](low2) + up2 = F.interpolate(low3, scale_factor=2) + out = up1 + up2 + return out + + def forward(self, x): + return self._hour_glass_forward(self.depth, x) + + +class HourglassNet(nn.Module): + '''Hourglass model from Newell et al ECCV 2016''' + def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + super(HourglassNet, self).__init__() + + self.inplanes = 64 + self.num_feats = 128 + self.num_stacks = num_stacks + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=True) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_residual(block, self.inplanes, 1) + self.layer2 = self._make_residual(block, self.inplanes, 1) + self.layer3 = self._make_residual(block, self.num_feats, 1) + self.maxpool = nn.MaxPool2d(2, stride=2) + self.upsample_seg = upsample_seg + self.add_partseg = add_partseg + + # build hourglass modules + ch = self.num_feats*block.expansion + hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] + for i in range(num_stacks): + hg.append(Hourglass(block, num_blocks, self.num_feats, 4)) + res.append(self._make_residual(block, self.num_feats, num_blocks)) + fc.append(self._make_fc(ch, ch)) + score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True)) + if i < num_stacks-1: + fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) + score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True)) + self.hg = nn.ModuleList(hg) + self.res = nn.ModuleList(res) + self.fc = nn.ModuleList(fc) + self.score = nn.ModuleList(score) + self.fc_ = nn.ModuleList(fc_) + self.score_ = nn.ModuleList(score_) + + if self.add_partseg: + self.hg_ps = (Hourglass(block, num_blocks, self.num_feats, 4)) + self.res_ps = (self._make_residual(block, self.num_feats, num_blocks)) + self.fc_ps = (self._make_fc(ch, ch)) + self.score_ps = (nn.Conv2d(ch, num_partseg, kernel_size=1, bias=True)) + self.ups_upsampling_ps = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + + + if self.upsample_seg: + self.ups_upsampling = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + self.ups_conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3, + bias=True) + self.ups_bn1 = nn.BatchNorm2d(32) + self.ups_conv1 = nn.Conv2d(32, 16, kernel_size=7, stride=1, padding=3, + bias=True) + self.ups_bn2 = nn.BatchNorm2d(16+2) + self.ups_conv2 = nn.Conv2d(16+2, 16, kernel_size=5, stride=1, padding=2, + bias=True) + self.ups_bn3 = nn.BatchNorm2d(16) + self.ups_conv3 = nn.Conv2d(16, 2, kernel_size=5, stride=1, padding=2, + bias=True) + + + + def _make_residual(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=True), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_fc(self, inplanes, outplanes): + bn = nn.BatchNorm2d(inplanes) + conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) + return nn.Sequential( + conv, + bn, + self.relu, + ) + + def forward(self, x_in): + out = [] + out_seg = [] + out_partseg = [] + x = self.conv1(x_in) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.maxpool(x) + x = self.layer2(x) + x = self.layer3(x) + + for i in range(self.num_stacks): + if i == self.num_stacks - 1: + if self.add_partseg: + y_ps = self.hg_ps(x) + y_ps = self.res_ps(y_ps) + y_ps = self.fc_ps(y_ps) + score_ps = self.score_ps(y_ps) + out_partseg.append(score_ps[:, :, :, :]) + y = self.hg[i](x) + y = self.res[i](y) + y = self.fc[i](y) + score = self.score[i](y) + if self.upsample_seg: + out.append(score[:, :-2, :, :]) + out_seg.append(score[:, -2:, :, :]) + else: + out.append(score) + if i < self.num_stacks-1: + fc_ = self.fc_[i](y) + score_ = self.score_[i](score) + x = x + fc_ + score_ + + if self.upsample_seg: + # PLAN: add a residual to the upsampled version of the segmentation image + # upsample predicted segmentation + seg_score = score[:, -2:, :, :] + seg_score_256 = self.ups_upsampling(seg_score) + # prepare input image + + ups_img = self.ups_conv0(x_in) + + ups_img = self.ups_bn1(ups_img) + ups_img = self.relu(ups_img) + ups_img = self.ups_conv1(ups_img) + + # import pdb; pdb.set_trace() + + ups_conc = torch.cat((seg_score_256, ups_img), 1) + + # ups_conc = self.ups_bn2(ups_conc) + ups_conc = self.relu(ups_conc) + ups_conc = self.ups_conv2(ups_conc) + + ups_conc = self.ups_bn3(ups_conc) + ups_conc = self.relu(ups_conc) + correction = self.ups_conv3(ups_conc) + + seg_final = seg_score_256 + correction + + if self.add_partseg: + partseg_final = self.ups_upsampling_ps(score_ps) + out_dict = {'out_list_kp': out, + 'out_list_seg': out, + 'seg_final': seg_final, + 'out_list_partseg': out_partseg, + 'partseg_final': partseg_final + } + return out_dict + else: + out_dict = {'out_list_kp': out, + 'out_list_seg': out, + 'seg_final': seg_final + } + return out_dict + + return out + + +def hg(**kwargs): + model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'], + num_classes=kwargs['num_classes'], upsample_seg=kwargs['upsample_seg'], + add_partseg=kwargs['add_partseg'], num_partseg=kwargs['num_partseg']) + return model + + +def _hg(arch, pretrained, progress, **kwargs): + model = hg(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def hg1(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg1', pretrained, progress, num_stacks=1, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + + +def hg2(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg2', pretrained, progress, num_stacks=2, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + +def hg4(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg4', pretrained, progress, num_stacks=4, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + +def hg8(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg8', pretrained, progress, num_stacks=8, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) diff --git a/src/stacked_hourglass/predictor.py b/src/stacked_hourglass/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..30be3b4fe816cc33018b61632c4ba120ea66dfc3 --- /dev/null +++ b/src/stacked_hourglass/predictor.py @@ -0,0 +1,119 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +from stacked_hourglass.utils.evaluation import final_preds_untransformed +from stacked_hourglass.utils.imfit import fit, calculate_fit_contain_output_area +from stacked_hourglass.utils.transforms import color_normalize, fliplr, flip_back + + +def _check_batched(images): + if isinstance(images, (tuple, list)): + return True + if images.ndimension() == 4: + return True + return False + + +class HumanPosePredictor: + def __init__(self, model, device=None, data_info=None, input_shape=None): + """Helper class for predicting 2D human pose joint locations. + + Args: + model: The model for generating joint heatmaps. + device: The computational device to use for inference. + data_info: Specifications of the data (defaults to ``Mpii.DATA_INFO``). + input_shape: The input dimensions of the model (height, width). + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = torch.device(device) + model.to(device) + self.model = model + self.device = device + + if data_info is None: + raise ValueError + # self.data_info = Mpii.DATA_INFO + else: + self.data_info = data_info + + # Input shape ordering: H, W + if input_shape is None: + self.input_shape = (256, 256) + elif isinstance(input_shape, int): + self.input_shape = (input_shape, input_shape) + else: + self.input_shape = input_shape + + def do_forward(self, input_tensor): + self.model.eval() + with torch.no_grad(): + output = self.model(input_tensor) + return output + + def prepare_image(self, image): + was_fixed_point = not image.is_floating_point() + image = torch.empty_like(image, dtype=torch.float32).copy_(image) + if was_fixed_point: + image /= 255.0 + if image.shape[-2:] != self.input_shape: + image = fit(image, self.input_shape, fit_mode='contain') + image = color_normalize(image, self.data_info.rgb_mean, self.data_info.rgb_stddev) + return image + + def estimate_heatmaps(self, images, flip=False): + is_batched = _check_batched(images) + raw_images = images if is_batched else images.unsqueeze(0) + input_tensor = torch.empty((len(raw_images), 3, *self.input_shape), + device=self.device, dtype=torch.float32) + for i, raw_image in enumerate(raw_images): + input_tensor[i] = self.prepare_image(raw_image) + heatmaps = self.do_forward(input_tensor)[-1].cpu() + if flip: + flip_input = fliplr(input_tensor) + flip_heatmaps = self.do_forward(flip_input)[-1].cpu() + heatmaps += flip_back(flip_heatmaps, self.data_info.hflip_indices) + heatmaps /= 2 + if is_batched: + return heatmaps + else: + return heatmaps[0] + + def estimate_joints(self, images, flip=False): + """Estimate human joint locations from input images. + + Images are expected to be centred on a human subject and scaled reasonably. + + Args: + images: The images to estimate joint locations for. Can be a single image or a list + of images. + flip (bool): If set to true, evaluates on flipped versions of the images as well and + averages the results. + + Returns: + The predicted human joint locations in image pixel space. + """ + is_batched = _check_batched(images) + raw_images = images if is_batched else images.unsqueeze(0) + heatmaps = self.estimate_heatmaps(raw_images, flip=flip).cpu() + # final_preds_untransformed compares the first component of shape with x and second with y + # This relates to the image Width, Height (Heatmap has shape Height, Width) + coords = final_preds_untransformed(heatmaps, heatmaps.shape[-2:][::-1]) + # Rescale coords to pixel space of specified images. + for i, image in enumerate(raw_images): + # When returning to original image space we need to compensate for the fact that we are + # used fit_mode='contain' when preparing the images for inference. + y_off, x_off, height, width = calculate_fit_contain_output_area(*image.shape[-2:], *self.input_shape) + coords[i, :, 1] *= self.input_shape[-2] / heatmaps.shape[-2] + coords[i, :, 1] -= y_off + coords[i, :, 1] *= image.shape[-2] / height + coords[i, :, 0] *= self.input_shape[-1] / heatmaps.shape[-1] + coords[i, :, 0] -= x_off + coords[i, :, 0] *= image.shape[-1] / width + if is_batched: + return coords + else: + return coords[0] diff --git a/src/stacked_hourglass/utils/__init__.py b/src/stacked_hourglass/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/stacked_hourglass/utils/evaluation.py b/src/stacked_hourglass/utils/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..b02a4d804332ae263c4fb005d36cccd967ade029 --- /dev/null +++ b/src/stacked_hourglass/utils/evaluation.py @@ -0,0 +1,188 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import math +import torch +from kornia.geometry.subpix import dsnt # kornia 0.4.0 +import torch.nn.functional as F +from .transforms import transform_preds + +__all__ = ['get_preds', 'get_preds_soft', 'calc_dists', 'dist_acc', 'accuracy', 'final_preds_untransformed', + 'final_preds', 'AverageMeter'] + +def get_preds(scores, return_maxval=False): + ''' get predictions from score maps in torch Tensor + return type: torch.LongTensor + ''' + assert scores.dim() == 4, 'Score maps should be 4-dim' + maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) + + maxval = maxval.view(scores.size(0), scores.size(1), 1) + idx = idx.view(scores.size(0), scores.size(1), 1) + 1 + + preds = idx.repeat(1, 1, 2).float() + + preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 + preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1 + + pred_mask = maxval.gt(0).repeat(1, 1, 2).float() # values > 0 + preds *= pred_mask + if return_maxval: + return preds, maxval + else: + return preds + + +def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False): + ''' get predictions from score maps in torch Tensor + predictions are made assuming a logit output map + return type: torch.LongTensor + ''' + + # New: work on logit predictions + scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1)) + # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2) + # from unnormalized to normalized see: + # from -1to1 to 0to64 + # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40) + # xs = (xs / (width - 1) - 0.5) * 2 + # ys = (ys / (height - 1) - 0.5) * 2 + + device = scores.device + + if return_maxval: + preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) + # grid_sample(input, grid, mode='bilinear', padding_mode='zeros') + gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) + gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) + + half_pad = 2 + gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0) + gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device) + ind_tot = 0 + for ind0 in [-1, 0, 1]: + for ind1 in [-1, 0, 1]: + gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1] + ind_tot +=1 + + gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :] # (120, 1, 1, 2) + gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1)) + gs_output = gs_output_all.sum(axis=1) + # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0] + # gs_output[0, :, 0] + gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1)) + + if norm_and_unnorm_coords: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds, preds_normalized, gs_output_resh + elif norm_coords: + return preds_normalized, gs_output_resh + else: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds, gs_output_resh + else: + if norm_coords: + preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) + return preds_normalized + else: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds + + +def calc_dists(preds, target, normalize): + preds = preds.float() + target = target.float() + dists = torch.zeros(preds.size(1), preds.size(0)) + for n in range(preds.size(0)): + for c in range(preds.size(1)): + if target[n,c,0] > 1 and target[n, c, 1] > 1: + dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] + else: + dists[c, n] = -1 + return dists + +def dist_acc(dist, thr=0.5): + ''' Return percentage below threshold while ignoring values with a -1 ''' + dist = dist[dist != -1] + if len(dist) > 0: + return 1.0 * (dist < thr).sum().item() / len(dist) + else: + return -1 + +def accuracy(output, target, idxs=None, thr=0.5): + ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations + First value to be returned is average accuracy across 'idxs', followed by individual accuracies + ''' + if idxs is None: + idxs = list(range(target.shape[-3])) + preds = get_preds_soft(output) # get_preds(output) + gts = get_preds(target) + norm = torch.ones(preds.size(0))*output.size(3)/10 + dists = calc_dists(preds, gts, norm) + + acc = torch.zeros(len(idxs)+1) + avg_acc = 0 + cnt = 0 + + for i in range(len(idxs)): + acc[i+1] = dist_acc(dists[idxs[i]], thr=thr) + if acc[i+1] >= 0: + avg_acc = avg_acc + acc[i+1] + cnt += 1 + + if cnt != 0: + acc[0] = avg_acc / cnt + return acc + +def final_preds_untransformed(output, res): + coords = get_preds_soft(output) # get_preds(output) # float type + + # pose-processing + for n in range(coords.size(0)): + for p in range(coords.size(1)): + hm = output[n][p] + px = int(math.floor(coords[n][p][0])) + py = int(math.floor(coords[n][p][1])) + if px > 1 and px < res[0] and py > 1 and py < res[1]: + diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) + coords[n][p] += diff.sign() * .25 + coords += 0.5 + + if coords.dim() < 3: + coords = coords.unsqueeze(0) + + coords -= 1 # Convert from 1-based to 0-based coordinates + + return coords + +def final_preds(output, center, scale, res): + coords = final_preds_untransformed(output, res) + preds = coords.clone() + + # Transform back + for i in range(coords.size(0)): + preds[i] = transform_preds(coords[i], center[i], scale[i], res) + + if preds.dim() < 3: + preds = preds.unsqueeze(0) + + return preds + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/src/stacked_hourglass/utils/finetune.py b/src/stacked_hourglass/utils/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..e7990b26a90e824f02141d7908907679f544f98c --- /dev/null +++ b/src/stacked_hourglass/utils/finetune.py @@ -0,0 +1,39 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +from torch.nn import Conv2d, ModuleList + + +def change_hg_outputs(model, indices): + """Change the output classes of the model. + + Args: + model: The model to modify. + indices: An array of indices describing the new model outputs. For example, [3, 4, None] + will modify the model to have 3 outputs, the first two of which have parameters + copied from the fourth and fifth outputs of the original model. + """ + with torch.no_grad(): + new_n_outputs = len(indices) + new_score = ModuleList() + for conv in model.score: + new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride) + new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) + for i, index in enumerate(indices): + if index is not None: + new_conv.weight[i] = conv.weight[index] + new_conv.bias[i] = conv.bias[index] + new_score.append(new_conv) + model.score = new_score + new_score_ = ModuleList() + for conv in model.score_: + new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride) + new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) + for i, index in enumerate(indices): + if index is not None: + new_conv.weight[:, i] = conv.weight[:, index] + new_conv.bias = conv.bias + new_score_.append(new_conv) + model.score_ = new_score_ diff --git a/src/stacked_hourglass/utils/imfit.py b/src/stacked_hourglass/utils/imfit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0d2e131bf3c1bd2e0c740d9c8cfd9d847f523d --- /dev/null +++ b/src/stacked_hourglass/utils/imfit.py @@ -0,0 +1,144 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +from torch.nn.functional import interpolate + + +def _resize(tensor, size, mode='bilinear'): + """Resize the image. + + Args: + tensor (torch.Tensor): The image tensor to be resized. + size (tuple of int): Size of the resized image (height, width). + mode (str): The pixel sampling interpolation mode to be used. + + Returns: + Tensor: The resized image tensor. + """ + assert len(size) == 2 + + # If the tensor is already the desired size, return it immediately. + if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]: + return tensor + + if not tensor.is_floating_point(): + dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor = _resize(tensor, size, mode) + return tensor.to(dtype) + + out_shape = (*tensor.shape[:-2], *size) + if tensor.ndimension() < 3: + raise Exception('tensor must be at least 2D') + elif tensor.ndimension() == 3: + tensor = tensor.unsqueeze(0) + elif tensor.ndimension() > 4: + tensor = tensor.view(-1, *tensor.shape[-3:]) + align_corners = None + if mode in {'linear', 'bilinear', 'trilinear'}: + align_corners = False + resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners) + return resized.view(*out_shape) + + +def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0): + """Crop the image, padding out-of-bounds regions. + + Args: + tensor (torch.Tensor): The image tensor to be cropped. + t (int): Top pixel coordinate. + l (int): Left pixel coordinate. + h (int): Height of the cropped image. + w (int): Width of the cropped image. + padding_mode (str): Padding mode (currently "constant" is the only valid option). + fill (float): Fill value to use with constant padding. + + Returns: + Tensor: The cropped image tensor. + """ + # If the _crop region is wholly within the image, simply narrow the tensor. + if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1): + return tensor[..., t:t+h, l:l+w] + + if padding_mode == 'constant': + result = torch.full((*tensor.size()[:-2], h, w), fill, + device=tensor.device, dtype=tensor.dtype) + else: + raise Exception('_crop only supports "constant" padding currently.') + + sx1 = l + sy1 = t + sx2 = l + w + sy2 = t + h + dx1 = 0 + dy1 = 0 + + if sx1 < 0: + dx1 = -sx1 + w += sx1 + sx1 = 0 + + if sy1 < 0: + dy1 = -sy1 + h += sy1 + sy1 = 0 + + if sx2 >= tensor.size(-1): + w -= sx2 - tensor.size(-1) + + if sy2 >= tensor.size(-2): + h -= sy2 - tensor.size(-2) + + # Copy the in-bounds sub-area of the _crop region into the result tensor. + if h > 0 and w > 0: + src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w) + dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w) + dst.copy_(src) + + return result + + +def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width): + ih, iw = in_height, in_width + k = min(out_width / iw, out_height / ih) + oh = round(k * ih) + ow = round(k * iw) + y_off = (out_height - oh) // 2 + x_off = (out_width - ow) // 2 + return y_off, x_off, oh, ow + + +def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0): + """Fit the image within the given spatial dimensions. + + Args: + tensor (torch.Tensor): The image tensor to be fit. + size (tuple of int): Size of the output (height, width). + fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's + `object-fit` property. + fill (float): padding value (only applicable in 'contain' mode). + + Returns: + Tensor: The resized image tensor. + """ + if fit_mode == 'fill': + return _resize(tensor, size, mode=resize_mode) + elif fit_mode == 'contain': + y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size) + resized = _resize(tensor, (oh, ow), mode=resize_mode) + result = tensor.new_full((*tensor.size()[:-2], *size), fill) + result[..., y_off:y_off + oh, x_off:x_off + ow] = resized + return result + elif fit_mode == 'cover': + ih, iw = tensor.shape[-2:] + k = max(size[-1] / iw, size[-2] / ih) + oh = round(k * ih) + ow = round(k * iw) + resized = _resize(tensor, (oh, ow), mode=resize_mode) + y_trim = (oh - size[-2]) // 2 + x_trim = (ow - size[-1]) // 2 + result = _crop(resized, y_trim, x_trim, size[-2], size[-1]) + return result + raise ValueError('Invalid fit_mode: ' + repr(fit_mode)) diff --git a/src/stacked_hourglass/utils/imutils.py b/src/stacked_hourglass/utils/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..5540728cc9f85e55b560308417c3b77d9c678a13 --- /dev/null +++ b/src/stacked_hourglass/utils/imutils.py @@ -0,0 +1,125 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np + +from .misc import to_numpy, to_torch +from .pilutil import imread, imresize +from kornia.geometry.subpix import dsnt +import torch + +def im_to_numpy(img): + img = to_numpy(img) + img = np.transpose(img, (1, 2, 0)) # H*W*C + return img + +def im_to_torch(img): + img = np.transpose(img, (2, 0, 1)) # C*H*W + img = to_torch(img).float() + if img.max() > 1: + img /= 255 + return img + +def load_image(img_path): + # H x W x C => C x H x W + return im_to_torch(imread(img_path, mode='RGB')) + +# ============================================================================= +# Helpful functions generating groundtruth labelmap +# ============================================================================= + +def gaussian(shape=(7,7),sigma=1): + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + m,n = [(ss-1.)/2. for ss in shape] + y,x = np.ogrid[-m:m+1,-n:n+1] + h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) + h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 + return to_torch(h).float() + +def draw_labelmap_orig(img, pt, sigma, type='Gaussian'): + # Draw a 2D gaussian + # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py + # maximum value of the gaussian is 1 + img = to_numpy(img) + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + return to_torch(img), 0 + + # Generate gaussian + size = 6 * sigma + 1 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + if type == 'Gaussian': + g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + elif type == 'Cauchy': + g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], img.shape[1]) + img_y = max(0, ul[1]), min(br[1], img.shape[0]) + + img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + + return to_torch(img), 1 + + + +def draw_labelmap(img, pt, sigma, type='Gaussian'): + # Draw a 2D gaussian + # real probability distribution: the sum of all values is 1 + img = to_numpy(img) + if not type == 'Gaussian': + raise NotImplementedError + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + return to_torch(img), 0 + + # Generate gaussian + # img_new = dsnt.render_gaussian2d(mean=torch.tensor([[-1, 0]]).float(), std=torch.tensor([[sigma, sigma]]).float(), size=(img.shape[0], img.shape[1]), normalized_coordinates=False) + img_new = dsnt.render_gaussian2d(mean=torch.tensor([[pt[0], pt[1]]]).float(), \ + std=torch.tensor([[sigma, sigma]]).float(), \ + size=(img.shape[0], img.shape[1]), \ + normalized_coordinates=False) + img_new = img_new[0, :, :] # this is a torch image + return img_new, 1 + + +def draw_multiple_labelmaps(out_res, pts, sigma, type='Gaussian'): + # Draw a 2D gaussian + # real probability distribution: the sum of all values is 1 + if not type == 'Gaussian': + raise NotImplementedError + + # Generate gaussians + n_pts = pts.shape[0] + imgs_new = dsnt.render_gaussian2d(mean=pts[:, :2], \ + std=torch.tensor([[sigma, sigma]]).float().repeat((n_pts, 1)), \ + size=(out_res[0], out_res[1]), \ + normalized_coordinates=False) # shape: (n_pts, out_res[0], out_res[1]) + + visibility_orig = imgs_new.sum(axis=2).sum(axis=1) # shape: (n_pts) + visibility = torch.zeros((n_pts, 1), dtype=torch.float32) + visibility[visibility_orig>=0.99999] = 1.0 + + # import pdb; pdb.set_trace() + + return imgs_new, visibility.int() \ No newline at end of file diff --git a/src/stacked_hourglass/utils/logger.py b/src/stacked_hourglass/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..8e42823a88ae20117fc5aa191f569491c102b1f3 --- /dev/null +++ b/src/stacked_hourglass/utils/logger.py @@ -0,0 +1,73 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np + +__all__ = ['Logger'] + + +class Logger: + """Log training metrics to a file.""" + def __init__(self, fpath, resume=False): + if resume: ############################################################################ + # Read header names and previously logged values. + with open(fpath, 'r') as f: + header_line = f.readline() + self.names = header_line.rstrip().split('\t') + self.numbers = {} + for _, name in enumerate(self.names): + self.numbers[name] = [] + for numbers in f: + numbers = numbers.rstrip().split('\t') + for i in range(0, len(numbers)): + self.numbers[self.names[i]].append(float(numbers[i])) + + self.file = open(fpath, 'a') + self.header_written = True + else: + self.file = open(fpath, 'w') + self.header_written = False + + def _write_line(self, field_values): + self.file.write('\t'.join(field_values) + '\n') + self.file.flush() + + def set_names(self, names): + """Set field names and write log header line.""" + assert not self.header_written, 'Log header has already been written' + self.names = names + self.numbers = {name: [] for name in self.names} + self._write_line(self.names) + self.header_written = True + + def append(self, numbers): + """Append values to the log.""" + assert self.header_written, 'Log header has not been written yet (use `set_names`)' + assert len(self.names) == len(numbers), 'Numbers do not match names' + for index, num in enumerate(numbers): + self.numbers[self.names[index]].append(num) + self._write_line(['{0:.6f}'.format(num) for num in numbers]) + + def plot(self, ax, names=None): + """Plot logged metrics on a set of Matplotlib axes.""" + names = self.names if names == None else names + for name in names: + values = self.numbers[name] + ax.plot(np.arange(len(values)), np.asarray(values)) + ax.grid(True) + ax.legend(names, loc='best') + + def plot_to_file(self, fpath, names=None, dpi=150): + """Plot logged metrics and save the resulting figure to a file.""" + import matplotlib.pyplot as plt + fig = plt.figure(dpi=dpi) + ax = fig.subplots() + self.plot(ax, names) + fig.savefig(fpath) + plt.close(fig) + del ax, fig + + def close(self): + self.file.close() diff --git a/src/stacked_hourglass/utils/misc.py b/src/stacked_hourglass/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d754c55dc2206bbb2a5cabf18c4017b5c1ee3d04 --- /dev/null +++ b/src/stacked_hourglass/utils/misc.py @@ -0,0 +1,56 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import os +import shutil + +import scipy.io +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.detach().cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array" + .format(type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor" + .format(type(ndarray))) + return ndarray + + +def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): + preds = to_numpy(preds) + filepath = os.path.join(checkpoint, filename) + torch.save(state, filepath) + scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds}) + + if snapshot and state['epoch'] % snapshot == 0: + shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch']))) + + if is_best: + shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) + scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds}) + + +def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): + preds = to_numpy(preds) + filepath = os.path.join(checkpoint, filename) + scipy.io.savemat(filepath, mdict={'preds' : preds}) + + +def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): + """Sets the learning rate to the initial LR decayed by schedule""" + if epoch in schedule: + lr *= gamma + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr diff --git a/src/stacked_hourglass/utils/pilutil.py b/src/stacked_hourglass/utils/pilutil.py new file mode 100644 index 0000000000000000000000000000000000000000..4306a31e76581cf9a7dd9901b88be1a2df2a75f0 --- /dev/null +++ b/src/stacked_hourglass/utils/pilutil.py @@ -0,0 +1,509 @@ +""" +A collection of image utilities using the Python Imaging Library (PIL). +""" + +# Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import division, print_function, absolute_import + +import numpy +from PIL import Image +from numpy import (amin, amax, ravel, asarray, arange, ones, newaxis, + transpose, iscomplexobj, uint8, issubdtype, array) + +if not hasattr(Image, 'frombytes'): + Image.frombytes = Image.fromstring + +__all__ = ['fromimage', 'toimage', 'imsave', 'imread', 'bytescale', + 'imrotate', 'imresize'] + + +def bytescale(data, cmin=None, cmax=None, high=255, low=0): + """ + Byte scales an array (image). + + Byte scaling means converting the input image to uint8 dtype and scaling + the range to ``(low, high)`` (default 0-255). + If the input image already has dtype uint8, no scaling is done. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + data : ndarray + PIL image data array. + cmin : scalar, optional + Bias scaling of small values. Default is ``data.min()``. + cmax : scalar, optional + Bias scaling of large values. Default is ``data.max()``. + high : scalar, optional + Scale max value to `high`. Default is 255. + low : scalar, optional + Scale min value to `low`. Default is 0. + + Returns + ------- + img_array : uint8 ndarray + The byte-scaled array. + + Examples + -------- + >>> img = numpy.array([[ 91.06794177, 3.39058326, 84.4221549 ], + ... [ 73.88003259, 80.91433048, 4.88878881], + ... [ 51.53875334, 34.45808177, 27.5873488 ]]) + >>> bytescale(img) + array([[255, 0, 236], + [205, 225, 4], + [140, 90, 70]], dtype=uint8) + >>> bytescale(img, high=200, low=100) + array([[200, 100, 192], + [180, 188, 102], + [155, 135, 128]], dtype=uint8) + >>> bytescale(img, cmin=0, cmax=255) + array([[91, 3, 84], + [74, 81, 5], + [52, 34, 28]], dtype=uint8) + + """ + if data.dtype == uint8: + return data + + if high > 255: + raise ValueError("`high` should be less than or equal to 255.") + if low < 0: + raise ValueError("`low` should be greater than or equal to 0.") + if high < low: + raise ValueError("`high` should be greater than or equal to `low`.") + + if cmin is None: + cmin = data.min() + if cmax is None: + cmax = data.max() + + cscale = cmax - cmin + if cscale < 0: + raise ValueError("`cmax` should be larger than `cmin`.") + elif cscale == 0: + cscale = 1 + + scale = float(high - low) / cscale + bytedata = (data - cmin) * scale + low + return (bytedata.clip(low, high) + 0.5).astype(uint8) + + +def imread(name, flatten=False, mode=None): + """ + Read an image from a file as an array. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + name : str or file object + The file name or file object to be read. + flatten : bool, optional + If True, flattens the color layers into a single gray-scale layer. + mode : str, optional + Mode to convert image to, e.g. ``'RGB'``. See the Notes for more + details. + + Returns + ------- + imread : ndarray + The array obtained by reading the image. + + Notes + ----- + `imread` uses the Python Imaging Library (PIL) to read an image. + The following notes are from the PIL documentation. + + `mode` can be one of the following strings: + + * 'L' (8-bit pixels, black and white) + * 'P' (8-bit pixels, mapped to any other mode using a color palette) + * 'RGB' (3x8-bit pixels, true color) + * 'RGBA' (4x8-bit pixels, true color with transparency mask) + * 'CMYK' (4x8-bit pixels, color separation) + * 'YCbCr' (3x8-bit pixels, color video format) + * 'I' (32-bit signed integer pixels) + * 'F' (32-bit floating point pixels) + + PIL also provides limited support for a few special modes, including + 'LA' ('L' with alpha), 'RGBX' (true color with padding) and 'RGBa' + (true color with premultiplied alpha). + + When translating a color image to black and white (mode 'L', 'I' or + 'F'), the library uses the ITU-R 601-2 luma transform:: + + L = R * 299/1000 + G * 587/1000 + B * 114/1000 + + When `flatten` is True, the image is converted using mode 'F'. + When `mode` is not None and `flatten` is True, the image is first + converted according to `mode`, and the result is then flattened using + mode 'F'. + + """ + + im = Image.open(name) + return fromimage(im, flatten=flatten, mode=mode) + + +def imsave(name, arr, format=None): + """ + Save an array as an image. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + name : str or file object + Output file name or file object. + arr : ndarray, MxN or MxNx3 or MxNx4 + Array containing image values. If the shape is ``MxN``, the array + represents a grey-level image. Shape ``MxNx3`` stores the red, green + and blue bands along the last dimension. An alpha layer may be + included, specified as the last colour band of an ``MxNx4`` array. + format : str + Image format. If omitted, the format to use is determined from the + file name extension. If a file object was used instead of a file name, + this parameter should always be used. + + Examples + -------- + Construct an array of gradient intensity values and save to file: + + >>> x = numpy.zeros((255, 255), dtype=numpy.uint8) + >>> x[:] = numpy.arange(255) + >>> imsave('gradient.png', x) + + Construct an array with three colour bands (R, G, B) and store to file: + + >>> rgb = numpy.zeros((255, 255, 3), dtype=numpy.uint8) + >>> rgb[..., 0] = numpy.arange(255) + >>> rgb[..., 1] = 55 + >>> rgb[..., 2] = 1 - numpy.arange(255) + >>> imsave('rgb_gradient.png', rgb) + + """ + im = toimage(arr, channel_axis=2) + if format is None: + im.save(name) + else: + im.save(name, format) + return + + +def fromimage(im, flatten=False, mode=None): + """ + Return a copy of a PIL image as a numpy array. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + im : PIL image + Input image. + flatten : bool + If true, convert the output to grey-scale. + mode : str, optional + Mode to convert image to, e.g. ``'RGB'``. See the Notes of the + `imread` docstring for more details. + + Returns + ------- + fromimage : ndarray + The different colour bands/channels are stored in the + third dimension, such that a grey-image is MxN, an + RGB-image MxNx3 and an RGBA-image MxNx4. + + """ + if not Image.isImageType(im): + raise TypeError("Input is not a PIL image.") + + if mode is not None: + if mode != im.mode: + im = im.convert(mode) + elif im.mode == 'P': + # Mode 'P' means there is an indexed "palette". If we leave the mode + # as 'P', then when we do `a = array(im)` below, `a` will be a 2-D + # containing the indices into the palette, and not a 3-D array + # containing the RGB or RGBA values. + if 'transparency' in im.info: + im = im.convert('RGBA') + else: + im = im.convert('RGB') + + if flatten: + im = im.convert('F') + elif im.mode == '1': + # Workaround for crash in PIL. When im is 1-bit, the call array(im) + # can cause a seg. fault, or generate garbage. See + # https://github.com/scipy/scipy/issues/2138 and + # https://github.com/python-pillow/Pillow/issues/350. + # + # This converts im from a 1-bit image to an 8-bit image. + im = im.convert('L') + + a = array(im) + return a + + +_errstr = "Mode is unknown or incompatible with input array shape." + + +def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, + mode=None, channel_axis=None): + """Takes a numpy array and returns a PIL image. + + This function is only available if Python Imaging Library (PIL) is installed. + + The mode of the PIL image depends on the array shape and the `pal` and + `mode` keywords. + + For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values + (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode + is given as 'F' or 'I' in which case a float and/or integer array is made. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Notes + ----- + For 3-D arrays, the `channel_axis` argument tells which dimension of the + array holds the channel data. + + For 3-D arrays if one of the dimensions is 3, the mode is 'RGB' + by default or 'YCbCr' if selected. + + The numpy array must be either 2 dimensional or 3 dimensional. + + """ + data = asarray(arr) + if iscomplexobj(data): + raise ValueError("Cannot convert a complex-valued array.") + shape = list(data.shape) + valid = len(shape) == 2 or ((len(shape) == 3) and + ((3 in shape) or (4 in shape))) + if not valid: + raise ValueError("'arr' does not have a suitable array shape for " + "any mode.") + if len(shape) == 2: + shape = (shape[1], shape[0]) # columns show up first + if mode == 'F': + data32 = data.astype(numpy.float32) + image = Image.frombytes(mode, shape, data32.tostring()) + return image + if mode in [None, 'L', 'P']: + bytedata = bytescale(data, high=high, low=low, + cmin=cmin, cmax=cmax) + image = Image.frombytes('L', shape, bytedata.tostring()) + if pal is not None: + image.putpalette(asarray(pal, dtype=uint8).tostring()) + # Becomes a mode='P' automagically. + elif mode == 'P': # default gray-scale + pal = (arange(0, 256, 1, dtype=uint8)[:, newaxis] * + ones((3,), dtype=uint8)[newaxis, :]) + image.putpalette(asarray(pal, dtype=uint8).tostring()) + return image + if mode == '1': # high input gives threshold for 1 + bytedata = (data > high) + image = Image.frombytes('1', shape, bytedata.tostring()) + return image + if cmin is None: + cmin = amin(ravel(data)) + if cmax is None: + cmax = amax(ravel(data)) + data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low + if mode == 'I': + data32 = data.astype(numpy.uint32) + image = Image.frombytes(mode, shape, data32.tostring()) + else: + raise ValueError(_errstr) + return image + + # if here then 3-d array with a 3 or a 4 in the shape length. + # Check for 3 in datacube shape --- 'RGB' or 'YCbCr' + if channel_axis is None: + if (3 in shape): + ca = numpy.flatnonzero(asarray(shape) == 3)[0] + else: + ca = numpy.flatnonzero(asarray(shape) == 4) + if len(ca): + ca = ca[0] + else: + raise ValueError("Could not find channel dimension.") + else: + ca = channel_axis + + numch = shape[ca] + if numch not in [3, 4]: + raise ValueError("Channel axis dimension is not valid.") + + bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax) + if ca == 2: + strdata = bytedata.tobytes() # .tostring() + shape = (shape[1], shape[0]) + elif ca == 1: + strdata = transpose(bytedata, (0, 2, 1)).tobytes() #.tostring() + shape = (shape[2], shape[0]) + elif ca == 0: + strdata = transpose(bytedata, (1, 2, 0)).tobytes() #.tostring() + shape = (shape[2], shape[1]) + else: + raise ValueError("Unexpected channel axis.") + if mode is None: + if numch == 3: + mode = 'RGB' + else: + mode = 'RGBA' + + if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']: + raise ValueError(_errstr) + + if mode in ['RGB', 'YCbCr']: + if numch != 3: + raise ValueError("Invalid array shape for mode.") + if mode in ['RGBA', 'CMYK']: + if numch != 4: + raise ValueError("Invalid array shape for mode.") + + # Here we know data and mode is correct + image = Image.frombytes(mode, shape, strdata) + return image + + +def imrotate(arr, angle, interp='bilinear'): + """ + Rotate an image counter-clockwise by angle degrees. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + arr : ndarray + Input array of image to be rotated. + angle : float + The angle of rotation. + interp : str, optional + Interpolation + + - 'nearest' : for nearest neighbor + - 'bilinear' : for bilinear + - 'lanczos' : for lanczos + - 'cubic' : for bicubic + - 'bicubic' : for bicubic + + Returns + ------- + imrotate : ndarray + The rotated array of image. + + """ + arr = asarray(arr) + func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3} + im = toimage(arr) + im = im.rotate(angle, resample=func[interp]) + return fromimage(im) + + +def imresize(arr, size, interp='bilinear', mode=None): + """ + Resize an image. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + arr : ndarray + The array of image to be resized. + size : int, float or tuple + * int - Percentage of current size. + * float - Fraction of current size. + * tuple - Size of the output image (height, width). + + interp : str, optional + Interpolation to use for re-sizing ('nearest', 'lanczos', 'bilinear', + 'bicubic' or 'cubic'). + mode : str, optional + The PIL image mode ('P', 'L', etc.) to convert `arr` before resizing. + If ``mode=None`` (the default), 2-D images will be treated like + ``mode='L'``, i.e. casting to long integer. For 3-D and 4-D arrays, + `mode` will be set to ``'RGB'`` and ``'RGBA'`` respectively. + + Returns + ------- + imresize : ndarray + The resized array of image. + + See Also + -------- + toimage : Implicitly used to convert `arr` according to `mode`. + scipy.ndimage.zoom : More generic implementation that does not use PIL. + + """ + im = toimage(arr, mode=mode) + ts = type(size) + if issubdtype(ts, numpy.signedinteger): + percent = size / 100.0 + size = tuple((array(im.size)*percent).astype(int)) + elif issubdtype(type(size), numpy.floating): + size = tuple((array(im.size)*size).astype(int)) + else: + size = (size[1], size[0]) + func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3} + imnew = im.resize(size, resample=func[interp]) + return fromimage(imnew) diff --git a/src/stacked_hourglass/utils/transforms.py b/src/stacked_hourglass/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7777e02f7a78e282c9032cb76325bafbbb16a5be --- /dev/null +++ b/src/stacked_hourglass/utils/transforms.py @@ -0,0 +1,150 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np +import torch + +from .imutils import im_to_numpy, im_to_torch +from .misc import to_torch +from .pilutil import imresize, imrotate + + +def color_normalize(x, mean, std): + if x.size(0) == 1: + x = x.repeat(3, 1, 1) + + for t, m, s in zip(x, mean, std): + t.sub_(m) + return x + + +def flip_back(flip_output, hflip_indices): + """flip and rearrange output maps""" + return fliplr(flip_output)[:, hflip_indices] + + +def shufflelr(x, width, hflip_indices): + """flip and rearrange coords""" + # Flip horizontal + x[:, 0] = width - x[:, 0] + # Change left-right parts + x = x[hflip_indices] + return x + + +def fliplr(x): + """Flip images horizontally.""" + if torch.is_tensor(x): + return torch.flip(x, [-1]) + else: + return np.ascontiguousarray(np.flip(x, -1)) + + +def get_transform(center, scale, res, rot=0): + """ + General image processing functions + """ + # Generate transformation matrix + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0, as_int=True): + # Transform pixel location to different reference + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + if as_int: + return new_pt[:2].astype(int) + 1 + else: + return new_pt[:2] + 1 + + + +def transform_preds(coords, center, scale, res): + # size = coords.size() + # coords = coords.view(-1, coords.size(-1)) + # print(coords.size()) + for p in range(coords.size(0)): + coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) + return coords + + +def crop(img, center, scale, res, rot=0, interp='bilinear'): + # import pdb; pdb.set_trace() + # mode = 'F' + + img = im_to_numpy(img) + + # Preprocessing for efficient cropping + ht, wd = img.shape[0], img.shape[1] + sf = scale * 200.0 / res[0] + if sf < 2: + sf = 1 + else: + new_size = int(np.math.floor(max(ht, wd) / sf)) + new_ht = int(np.math.floor(ht / sf)) + new_wd = int(np.math.floor(wd / sf)) + if new_size < 2: + return torch.zeros(res[0], res[1], img.shape[2]) \ + if len(img.shape) > 2 else torch.zeros(res[0], res[1]) + else: + img = imresize(img, [new_ht, new_wd], interp=interp) # , mode=mode) + center = center * 1.0 / sf + scale = scale / sf + + # Upper left point + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + # Bottom right point + br = np.array(transform(res, center, scale, res, invert=1)) + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(img.shape[1], br[0]) + old_y = max(0, ul[1]), min(img.shape[0], br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] + + if not rot == 0: + # Remove padding + new_img = imrotate(new_img, rot, interp=interp) # , mode=mode) + new_img = new_img[pad:-pad, pad:-pad] + + new_img = im_to_torch(imresize(new_img, res, interp=interp)) #, mode=mode)) + return new_img diff --git a/src/stacked_hourglass/utils/visualization.py b/src/stacked_hourglass/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..4487e7f10c348af91b3958081f6f029308440772 --- /dev/null +++ b/src/stacked_hourglass/utils/visualization.py @@ -0,0 +1,179 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import matplotlib as mpl +mpl.use('Agg') +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import torch + +# import stacked_hourglass.datasets.utils_stanext as utils_stanext +# COLORS, labels = utils_stanext.load_keypoint_labels_and_colours() +COLORS = ['#d82400', '#d82400', '#d82400', '#fcfc00', '#fcfc00', '#fcfc00', '#48b455', '#48b455', '#48b455', '#0090aa', '#0090aa', '#0090aa', '#d848ff', '#d848ff', '#fc90aa', '#006caa', '#d89000', '#d89000', '#fc90aa', '#006caa', '#ededed', '#ededed', '#a9d08e', '#a9d08e'] +RGB_MEAN = [0.4404, 0.4440, 0.4327] +RGB_STD = [0.2458, 0.2410, 0.2468] + + + +def get_img_from_fig(fig, dpi=180): + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=dpi) + buf.seek(0) + img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) + buf.close() + img = cv2.imdecode(img_arr, 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + +def save_input_image_with_keypoints(img, tpts, out_path='./test_input.png', colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD, ratio_in_out=4., threshold=0.3, print_scores=False): + """ + img has shape (3, 256, 256) and is a torch tensor + pts has shape (20, 3) and is a torch tensor + -> this function is tested with the mpii dataset and the results look ok + """ + # reverse color normalization + for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() + img_np = img.detach().cpu().numpy().transpose(1, 2, 0) + # tpts_np = tpts.detach().cpu().numpy() + # plot image + fig, ax = plt.subplots() + plt.imshow(img_np) # plt.imshow(im) + plt.gca().set_axis_off() + plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) + plt.margins(0,0) + # plot all visible keypoints + #import pdb; pdb.set_trace() + + for idx, (x, y, v) in enumerate(tpts): + if v > threshold: + x = int(x*ratio_in_out) + y = int(y*ratio_in_out) + plt.scatter([x], [y], c=[colors[idx]], marker="x", s=50) + if print_scores: + txt = '{:2.2f}'.format(v.item()) + plt.annotate(txt, (x, y)) # , c=colors[idx]) + + plt.savefig(out_path, bbox_inches='tight', pad_inches=0) + + plt.close() + return + + + +def save_input_image(img, out_path, colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD): + for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() + img_np = img.detach().cpu().numpy().transpose(1, 2, 0) + plt.imsave(out_path, img_np) + return + +###################################################################### +def get_bodypart_colors(): + # body colors + n_body = 8 + c = np.arange(1, n_body + 1) + norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max()) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.gist_rainbow) + cmap.set_array([]) + body_cols = [] + for i in range(0, n_body): + body_cols.append(cmap.to_rgba(i + 1)) + # head colors + n_blue = 5 + c = np.arange(1, n_blue + 1) + norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues) + cmap.set_array([]) + head_cols = [] + for i in range(0, n_body): + head_cols.append(cmap.to_rgba(i + 1)) + # torso colors + n_blue = 2 + c = np.arange(1, n_blue + 1) + norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens) + cmap.set_array([]) + torso_cols = [] + for i in range(0, n_body): + torso_cols.append(cmap.to_rgba(i + 1)) + return body_cols, head_cols, torso_cols +body_cols, head_cols, torso_cols = get_bodypart_colors() +tbp_dict = {'full_body': [0, 8], + 'head': [8, 13], + 'torso': [13, 15]} + +def save_image_with_part_segmentation(partseg_big, seg_big, input_image_np, ind_img, out_path_seg=None, out_path_seg_overlay=None, thr=0.3): + soft_max = torch.nn.Softmax(dim=0) + # create dit with results + tbp_dict_res = {} + for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): + partseg_tbp = partseg_big[:, tbp_dict[part][0]:tbp_dict[part][1], :, :] + segm_img_pred = soft_max((partseg_tbp[ind_img, :, :, :])) # [1, :, :] + m_v, m_i = segm_img_pred.max(axis=0) + tbp_dict_res[part] = { + 'inds': tbp_dict[part], + 'seg_probs': segm_img_pred, + 'seg_max_inds': m_i, + 'seg_max_values': m_v} + # create output_image + partseg_image = np.zeros((256, 256, 3)) + for ind_sp in range(0, 5): + # partseg_image[tbp_dict_res['head']['seg_max_inds']==ind_sp, :] = head_cols[ind_sp][0:3] + mask_a = tbp_dict_res['full_body']['seg_max_inds']==1 + mask_b = tbp_dict_res['head']['seg_max_inds']==ind_sp + partseg_image[mask_a*mask_b, :] = head_cols[ind_sp][0:3] + for ind_sp in range(0, 2): + # partseg_image[tbp_dict_res['torso']['seg_max_inds']==ind_sp, :] = torso_cols[ind_sp][0:3] + mask_a = tbp_dict_res['full_body']['seg_max_inds']==2 + mask_b = tbp_dict_res['torso']['seg_max_inds']==ind_sp + partseg_image[mask_a*mask_b, :] = torso_cols[ind_sp][0:3] + for ind_sp in range(0, 8): + if (not ind_sp == 1) and (not ind_sp == 2): # head and torso + partseg_image[tbp_dict_res['full_body']['seg_max_inds']==ind_sp, :] = body_cols[ind_sp][0:3] + partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]