|
|
|
|
|
|
| import torch |
| import torch.nn as nn |
| from clip import clip |
|
|
|
|
| def article(name): |
| return "an" if name[0] in "aeiou" else "a" |
|
|
|
|
| def processed_name(name, rm_dot=False): |
| |
| |
| res = name.replace("_", " ").replace("/", " or ").lower() |
| if rm_dot: |
| res = res.rstrip(".") |
| return res |
|
|
|
|
| single_template = ["a photo of a {}."] |
|
|
| multiple_templates = [ |
| "There is {article} {} in the scene.", |
| "There is the {} in the scene.", |
| "a photo of {article} {} in the scene.", |
| "a photo of the {} in the scene.", |
| "a photo of one {} in the scene.", |
| "itap of {article} {}.", |
| "itap of my {}.", |
| "itap of the {}.", |
| "a photo of {article} {}.", |
| "a photo of my {}.", |
| "a photo of the {}.", |
| "a photo of one {}.", |
| "a photo of many {}.", |
| "a good photo of {article} {}.", |
| "a good photo of the {}.", |
| "a bad photo of {article} {}.", |
| "a bad photo of the {}.", |
| "a photo of a nice {}.", |
| "a photo of the nice {}.", |
| "a photo of a cool {}.", |
| "a photo of the cool {}.", |
| "a photo of a weird {}.", |
| "a photo of the weird {}.", |
| "a photo of a small {}.", |
| "a photo of the small {}.", |
| "a photo of a large {}.", |
| "a photo of the large {}.", |
| "a photo of a clean {}.", |
| "a photo of the clean {}.", |
| "a photo of a dirty {}.", |
| "a photo of the dirty {}.", |
| "a bright photo of {article} {}.", |
| "a bright photo of the {}.", |
| "a dark photo of {article} {}.", |
| "a dark photo of the {}.", |
| "a photo of a hard to see {}.", |
| "a photo of the hard to see {}.", |
| "a low resolution photo of {article} {}.", |
| "a low resolution photo of the {}.", |
| "a cropped photo of {article} {}.", |
| "a cropped photo of the {}.", |
| "a close-up photo of {article} {}.", |
| "a close-up photo of the {}.", |
| "a jpeg corrupted photo of {article} {}.", |
| "a jpeg corrupted photo of the {}.", |
| "a blurry photo of {article} {}.", |
| "a blurry photo of the {}.", |
| "a pixelated photo of {article} {}.", |
| "a pixelated photo of the {}.", |
| "a black and white photo of the {}.", |
| "a black and white photo of {article} {}.", |
| "a plastic {}.", |
| "the plastic {}.", |
| "a toy {}.", |
| "the toy {}.", |
| "a plushie {}.", |
| "the plushie {}.", |
| "a cartoon {}.", |
| "the cartoon {}.", |
| "an embroidered {}.", |
| "the embroidered {}.", |
| "a painting of the {}.", |
| "a painting of a {}.", |
| ] |
|
|
|
|
| openimages_rare_unseen = ['Aerial photography', |
| 'Aircraft engine', |
| 'Ale', |
| 'Aloe', |
| 'Amphibian', |
| 'Angling', |
| 'Anole', |
| 'Antique car', |
| 'Arcade game', |
| 'Arthropod', |
| 'Assault rifle', |
| 'Athletic shoe', |
| 'Auto racing', |
| 'Backlighting', |
| 'Bagpipes', |
| 'Ball game', |
| 'Barbecue chicken', |
| 'Barechested', |
| 'Barquentine', |
| 'Beef tenderloin', |
| 'Billiard room', |
| 'Billiards', |
| 'Bird of prey', |
| 'Black swan', |
| 'Black-and-white', |
| 'Blond', |
| 'Boating', |
| 'Bonbon', |
| 'Bottled water', |
| 'Bouldering', |
| 'Bovine', |
| 'Bratwurst', |
| 'Breadboard', |
| 'Briefs', |
| 'Brisket', |
| 'Brochette', |
| 'Calabaza', |
| 'Camera operator', |
| 'Canola', |
| 'Childbirth', |
| 'Chordophone', |
| 'Church bell', |
| 'Classical sculpture', |
| 'Close-up', |
| 'Cobblestone', |
| 'Coca-cola', |
| 'Combat sport', |
| 'Comics', |
| 'Compact car', |
| 'Computer speaker', |
| 'Cookies and crackers', |
| 'Coral reef fish', |
| 'Corn on the cob', |
| 'Cosmetics', |
| 'Crocodilia', |
| 'Digital camera', |
| 'Dishware', |
| 'Divemaster', |
| 'Dobermann', |
| 'Dog walking', |
| 'Domestic rabbit', |
| 'Domestic short-haired cat', |
| 'Double-decker bus', |
| 'Drums', |
| 'Electric guitar', |
| 'Electric piano', |
| 'Electronic instrument', |
| 'Equestrianism', |
| 'Equitation', |
| 'Erinaceidae', |
| 'Extreme sport', |
| 'Falafel', |
| 'Figure skating', |
| 'Filling station', |
| 'Fire apparatus', |
| 'Firearm', |
| 'Flatbread', |
| 'Floristry', |
| 'Forklift truck', |
| 'Freight transport', |
| 'Fried food', |
| 'Fried noodles', |
| 'Frigate', |
| 'Frozen yogurt', |
| 'Frying', |
| 'Full moon', |
| 'Galleon', |
| 'Glacial landform', |
| 'Gliding', |
| 'Go-kart', |
| 'Goats', |
| 'Grappling', |
| 'Great white shark', |
| 'Gumbo', |
| 'Gun turret', |
| 'Hair coloring', |
| 'Halter', |
| 'Headphones', |
| 'Heavy cruiser', |
| 'Herding', |
| 'High-speed rail', |
| 'Holding hands', |
| 'Horse and buggy', |
| 'Horse racing', |
| 'Hound', |
| 'Hunting knife', |
| 'Hurdling', |
| 'Inflatable', |
| 'Jackfruit', |
| 'Jeans', |
| 'Jiaozi', |
| 'Junk food', |
| 'Khinkali', |
| 'Kitesurfing', |
| 'Lawn game', |
| 'Leaf vegetable', |
| 'Lechon', |
| 'Lifebuoy', |
| 'Locust', |
| 'Lumpia', |
| 'Luxury vehicle', |
| 'Machine tool', |
| 'Medical imaging', |
| 'Melee weapon', |
| 'Microcontroller', |
| 'Middle ages', |
| 'Military person', |
| 'Military vehicle', |
| 'Milky way', |
| 'Miniature Poodle', |
| 'Modern dance', |
| 'Molluscs', |
| 'Monoplane', |
| 'Motorcycling', |
| 'Musical theatre', |
| 'Narcissus', |
| 'Nest box', |
| 'Newsagent\'s shop', |
| 'Nile crocodile', |
| 'Nordic skiing', |
| 'Nuclear power plant', |
| 'Orator', |
| 'Outdoor shoe', |
| 'Parachuting', |
| 'Pasta salad', |
| 'Peafowl', |
| 'Pelmeni', |
| 'Perching bird', |
| 'Performance car', |
| 'Personal water craft', |
| 'Pit bull', |
| 'Plant stem', |
| 'Pork chop', |
| 'Portrait photography', |
| 'Primate', |
| 'Procyonidae', |
| 'Prosciutto', |
| 'Public speaking', |
| 'Racewalking', |
| 'Ramen', |
| 'Rear-view mirror', |
| 'Residential area', |
| 'Ribs', |
| 'Rice ball', |
| 'Road cycling', |
| 'Roller skating', |
| 'Roman temple', |
| 'Rowing', |
| 'Rural area', |
| 'Sailboat racing', |
| 'Scaled reptile', |
| 'Scuba diving', |
| 'Senior citizen', |
| 'Shallot', |
| 'Shinto shrine', |
| 'Shooting range', |
| 'Siberian husky', |
| 'Sledding', |
| 'Soba', |
| 'Solar energy', |
| 'Sport climbing', |
| 'Sport utility vehicle', |
| 'Steamed rice', |
| 'Stemware', |
| 'Sumo', |
| 'Surfing Equipment', |
| 'Team sport', |
| 'Touring car', |
| 'Toy block', |
| 'Trampolining', |
| 'Underwater diving', |
| 'Vegetarian food', |
| 'Wallaby', |
| 'Water polo', |
| 'Watercolor paint', |
| 'Whiskers', |
| 'Wind wave', |
| 'Woodwind instrument', |
| 'Yakitori', |
| 'Zeppelin'] |
|
|
|
|
| def build_openset_label_embedding(categories=None): |
| if categories is None: |
| categories = openimages_rare_unseen |
| |
| model, _ = clip.load("ViT-B-16.pt") |
| templates = multiple_templates |
|
|
| run_on_gpu = torch.cuda.is_available() |
|
|
| with torch.no_grad(): |
| openset_label_embedding = [] |
| for category in categories: |
| texts = [ |
| template.format( |
| processed_name(category, rm_dot=True), article=article(category) |
| ) |
| for template in templates |
| ] |
| texts = [ |
| "This is " + text if text.startswith("a") or text.startswith("the") else text |
| for text in texts |
| ] |
| texts = clip.tokenize(texts) |
| if run_on_gpu: |
| texts = texts.cuda() |
| model = model.cuda() |
| text_embeddings = model.encode_text(texts) |
| text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
| text_embedding = text_embeddings.mean(dim=0) |
| text_embedding /= text_embedding.norm() |
| openset_label_embedding.append(text_embedding) |
| openset_label_embedding = torch.stack(openset_label_embedding, dim=1) |
| if run_on_gpu: |
| openset_label_embedding = openset_label_embedding.cuda() |
|
|
| openset_label_embedding = openset_label_embedding.t() |
| return openset_label_embedding, categories |
|
|
|
|
|
|
|
|
|
|