Spaces:
Runtime error
Runtime error
Add source code
Browse files- __init__.py +5 -0
- answer_list.json +1 -0
- configs/retrieval.yaml +73 -0
- configs/vqa.yaml +78 -0
- data/__init__.py +5 -0
- data/retrieval_datamodule.py +188 -0
- data/retrieval_dataset.py +149 -0
- data/transforms.py +139 -0
- data/vqa_datamodules.py +206 -0
- data/vqa_dataset.py +115 -0
- finetune_retrieval.py +400 -0
- finetune_vqa.py +204 -0
- images/COCO_val2014_000000026348.jpg +0 -0
- images/COCO_val2014_000000057222.jpg +0 -0
- images/COCO_val2014_000000111207.jpg +0 -0
- images/COCO_val2014_000000159269.jpg +0 -0
- images/COCO_val2014_000000184359.jpg +0 -0
- images/COCO_val2014_000000407072.jpg +0 -0
- images/COCO_val2014_000000473994.jpg +0 -0
- images/COCO_val2014_000000552075.jpg +0 -0
- model.py +666 -0
- requirements.txt +5 -0
- utils.py +127 -0
- vqa_data.json +1 -0
__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
answer_list.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["net", "pitcher", "orange", "yes", "white", "skiing", "red", "frisbee", "brushing teeth", "no", "black and white", "skateboard", "1", "blue", "green", "motorcycle", "gray", "2", "purse", "skis", "poles", "surfboard", "dog", "on", "office", "large", "very big", "laptop", "vent", "computer", "black", "bear", "3", "wii", "glasses", "tree", "eating", "log", "5", "raft", "left", "living room", "pink", "right", "railing", "grass", "wire", "10 years", "knife", "cake", "banana", "chef", "vanilla", "4", "outdoor", "mustard", "bun", "clouds", "dock", "brown", "silver", "refrigerator", "square", "teddy", "elm", "stripes", "baseball", "catcher", "beer", "bottom", "north", "nike", "yellow and white", "morning", "elephant", "red and white", "propeller", "tan", "wall", "rolex", "clock", "table", "0", "wood", "christmas", "spinach", "thick", "bag", "leaves", "necklace", "6", "bathroom", "shower", "towel", "solid", "referee", "wilson", "8:00", "e", "24", "hat", "grazing", "sheep", "10", "tag", "spanish", "hot dog", "plate", "lunch", "butter", "peppers", "onions", "very", "mayonnaise", "mayo", "sweet potato", "pig", "sweet", "flowers", "floral", "yellow", "window", "7", "pizza", "car", "cargo", "stairs", "abstract", "rug", "baseball cap", "texting", "pole", "crosswalk", "nothing", "urban", "bus", "light", "afternoon", "boat", "cheese", "paper", "real", "sun", "birthday", "words", "inside", "shadows", "tomato", "evergreen", "100 feet", "shingles", "trees", "building", "hay", "ski pole", "patterned", "walking", "ice", "laundry", "pepsi", "good", "1:50", "purple", "13", "africa", "teddy bears", "socks", "giraffe", "soccer", "blue and yellow", "zebras", "cupcake", "broccoli", "soldier", "parking lot", "cows", "herding", "on table", "fish", "nightstand", "50", "overcast", "cross", "toaster oven", "tile", "11:55", "red and yellow", "nowhere", "hair dryer", "truck", "11", "people", "rectangle", "hot dogs", "party", "12:55", "apron", "kitchen", "cooking", "ring", "1 way", "stop", "neither", "many", "female", "brushing", "tie", "tennis racket", "knife and fork", "restaurant", "cat", "bed", "sand", "ocean", "cold", "kites", "cumulus", "standing", "male", "star", "tracks", "chocolate", "round", "fork and knife", "yankees", "pictures", "dots", "bird", "parrot", "red white and blue", "man", "metal", "fence", "snowboarding", "pine", "snow", "shorts", "swim", "wine", "brick", "no parking", "children", "beef", "phone", "english", "cell phone", "pink and yellow", "clear", "watermelon", "bedroom", "fork", "cow", "rackets", "tennis rackets", "8", "collar", "tennis", "1950s", "playing tennis", "skirt", "30", "polka dot", "beach", "horse", "grill", "african american", "down", "street", "in air", "sweater", "yellow and blue", "park", "backyard", "spectators", "parasailing", "31", "river", "55", "shadow", "winter", "chicken", "tea", "evening", "dusk", "ski resort", "helmet", "penne", "bench", "resting", "elephants", "southwest", "usa", "cars", "town", "bananas", "umbrella", "container", "woman", "on counter", "salad", "striped", "motel", "vertical", "oranges", "hot sauce", "bottle", "juice", "eyes", "ground", "backpack", "black and yellow", "forward", "jackets", "1 on right", "green and yellow", "playing baseball", "riding", "sitting", "carrot", "basket", "seagull", "ski poles", "p", "parking", "street light", "mets", "strap", "bike", "riding bike", "poodle", "shoes", "carpet", "lettuce", "food", "1 foot", "roses", "mountains", "scissors", "camera", "beige", "beard", "cutting", "baby", "tape", "watch", "never", "taking picture", "eggs", "syrup", "sandwich", "water skiing", "microphone", "back", "bears", "donuts", "w", "sky", "double decker", "england", "surfing", "running", "shirt", "barn", "weather vane", "white and blue", "fishing", "bridge", "los angeles", "open", "red sox", "bat", "plane", "white and green", "transportation", "sunny", "bus stop", "city", "brown and white", "bicycle", "crow", "magazines", "daisy", "14", "old", "curtains", "jumped", "snowboard", "dinosaur", "racing", "asphalt", "court", "plastic", "circle", "red and blue", "zebra", "12", "biplane", "shallow", "brazil", "logo", "2:20", "electric", "night time", "motion", "toothbrushes", "orange and white", "66", "spoon", "toyota", "tennis shoes", "46", "second", "no 1", "iphone", "friend", "apple", "carnation", "15", "tiger", "glove", "airplane", "bow", "air france", "passengers", "tv", "on building", "3:55", "victorian", "steeple", "happy", "skateboarding", "fruit", "cutting board", "cantaloupe", "kiwi", "sliced", "heart", "water", "rainy", "carrots", "giraffes", "eat", "ramp", "lab", "field", "horizontal", "birds", "home", "shrimp", "12 feet", "girl", "modern", "turtle", "dell", "boots", "sunglasses", "black and orange", "yellow and black", "gloves", "hp", "desk", "both", "sign", "on street", "2000", "cirrus", "to dry", "ceiling", "fluorescent", "up", "9", "boys", "playing soccer", "american", "passenger", "turn", "palm", "no train", "wedding", "branch", "parrots", "air force", "on tracks", "small", "tank", "dirty", "france", "honda", "2.00", "whale", "vase", "flying", "professional", "driving", "tissue", "protest", "corona", "for balance", "twin", "clothes", "t shirt", "window sill", "wild", "noon", "caution", "spring", "raining", "cane", "school", "windsurfing", "parachute", "black and red", "25", "background", "toaster", "planes", "yellow and red", "spatula", "10:10", "ivory", "train", "welcome", "highway", "off", "on track", "electricity", "italy", "dinner", "sink", "squares", "5 ft", "parked", "store", "dress", "signs", "meow", "football", "rugby", "stainless steel", "la", "dirt", "blue and white", "klm", "house", "unknown", "ford", "reading", "chair", "mountain", "alive", "water skis", "picture", "parade", "slippers", "trailer", "boating", "holding it", "shade", "cloth", "6:20", "candle", "hose", "hand", "3:25", "on sidewalk", "poster", "downhill", "68", "reflection", "summer", "pickles", "halloween", "bats", "london", "zoo", "surfer", "racket", "flickr", "cutting hair", "strawberries", "mushroom", "teddy bear", "big", "suitcase", "veggie", "pepper", "houses", "70", "toshiba", "triangle", "boxes", "photograph", "smoke", "engine", "camel", "sidewalk", "left 1", "red and green", "4:35", "on couch", "candy", "minnie mouse", "homemade", "mouse", "box", "movie", "45", "strawberry", "fridge", "full", "vegetables", "bright", "play", "remote", "pond", "savannah", "celery", "concrete", "semi", "dump", "scania", "safety", "posing", "fabric", "laying", "couch", "blueberries", "handle", "pipe", "stick", "parmesan", "steak", "chain link", "catch", "barbed wire", "mozzarella", "soda", "fire hydrant", "cat food", "pepperoni", "lot", "licking", "red and black", "clay", "tennis court", "jumping", "potatoes", "toothbrush", "kite", "not at all", "flying kite", "broken", "black and silver", "lap", "outside", "44", "delta", "greyhound", "ring finger", "talking on phone", "bad", "kettle", "35", "motorcycles", "produce", "comfort", "steering wheel", "18", "humans", "coffee", "white and brown", "fall", "bread", "cherry", "4:30", "flag", "night", "lamp", "cucumber", "can't see", "porcelain", "oval", "museum", "rain", "sprinkles", "20", "kids", "bracelet", "sneakers", "mask", "mickey mouse", "twins", "very high", "costume", "cabbage", "paint", "lighting", "young", "air conditioner", "wooden", "board", "someone", "beets", "16", "day time", "4 inches", "lights", "ladder", "glass", "ferris wheel", "fries", "steamed", "shepherd", "cotton", "suit", "goatee", "on his head", "print", "happy birthday", "forks", "travel", "maple", "200", "oil", "jeans", "can", "chopsticks", "on wall", "construction", "mack", "36", "chinese", "moped", "festival", "gas", "throwing", "circus", "wires", "not possible", "plates", "sugar", "in", "women's", "door", "no man", "volleyball", "serving", "ponytail", "business", "decoration", "santa", "flat", "barrel", "12:15", "candles", "atv", "free", "hair", "waffle", "ball", "stop sign", "wetsuit", "very deep", "swimsuit", "green and black", "foreground", "stands", "china airlines", "flower", "300", "lobster", "on bench", "plaster", "phones", "sailboat", "apples", "road", "recently", "cones", "cactus", "rice", "vegetarian", "donut", "ketchup", "police", "mirror", "rock", "meat", "blinds", "cell phones", "china", "rust", "7:25", "stone", "vans", "middle", "eagle", "9:30", "ping pong", "microwave", "gmc", "umbrellas", "wrist", "cuddling", "laughing", "boy", "next to toilet", "tabby", "petting", "south", "40", "name tag", "checkered", "name", "slow", "cardboard", "windows", "croissant", "plain", "cookie", "on ground", "low", "water bottle", "goggles", "turkey", "pull", "shut", "kite flying", "bowl", "smile", "in bowl", "bush", "cloudy", "top left", "skateboarder", "coca cola", "pan", "drinking", "short", "floor", "thanksgiving", "radio", "drink", "on toilet", "bike rack", "bleachers", "train tracks", "horses", "far", "top", "toilet", "in water", "private", "nature", "checkers", "commercial", "stroller", "power", "stuffed animals", "uniforms", "japan", "liquor", "faucet", "green and orange", "corn", "sub", "white and yellow", "mercedes", "in sky", "tarp", "indian", "counter", "multicolored", "polar", "go", "now", "no number", "swimming", "bridle", "cowboy", "union station", "salt and pepper", "olives", "pizza cutter", "british airways", "nighttime", "domestic", "trolley", "australia", "tiles", "pug", "wicker", "british", "us airways express", "burton", "christmas tree", "napkin", "writing", "rocks", "hello kitty", "lacoste", "gold", "fan", "skateboards", "day", "on floor", "2008", "dark", "flying kites", "rural", "olympics", "bmw", "34", "factory", "denim", "typing", "for fun", "steel", "watching tv", "chevron", "driver", "baggage claim", "grapes", "f", "angels", "roof", "handlebars", "train station", "public", "oak", "sleeping", "canada", "on runway", "air canada", "on top", "tired", "blonde", "cups", "little", "adidas", "10 feet", "white and gray", "leaf", "fisheye", "forest", "war", "octagon", "raspberry", "helmets", "united states", "29", "noodles", "van", "long", "traveling", "luggage", "airport", "single", "pitching", "dugout", "garbage", "in street", "happiness", "cigarette", "on tower", "antelope", "graffiti", "skating", "on road", "curved", "red light", "washington", "ski lift", "athletics", "brace", "squatting", "catching", "batter", "batting", "game", "towards", "33", "sliding", "makeup", "japanese", "person", "pirates", "plaid", "rose", "daytime", "keyboard", "surfboards", "hummingbird", "ollie", "11:30", "clock tower", "5:55", "san francisco", "stopping", "tags", "samsung", "computers", "cabinets", "talking", "cage", "asparagus", "5 years", "hanger", "adult", "rabbit", "empty", "softball", "1st", "playing", "chairs", "farm", "cross country", "dump truck", "women", "snowboarder", "tall", "monkey", "mantle", "fire", "books", "quilt", "cessna", "chandelier", "dunkin donuts", "beans", "relish", "no flag", "parking meter", "spots", "ducks", "sandals", "doughnut", "lighthouse", "yacht", "german shepherd", "in middle", "raw", "chain", "2 feet", "pedestal", "sauerkraut", "bagels", "mutt", "dog and cat", "race", "poor", "cat and dog", "station", "printer", "daisies", "front", "gravel", "rear", "grassy", "pigeons", "dogs", "in car", "life", "wii remotes", "suv", "leather", "bottom right", "peace", "facebook", "blanket", "fountain", "frisbees", "12:30", "am", "scooter", "going", "analog", "america", "pitbull", "relaxing", "paddle boarding", "white and pink", "shampoo", "alps", "ride", "side", "mane", "on desk", "on chair", "2012", "multi", "straight", "big ben", "closed", "frosted", "3 feet", "waves", "buoy", "life vest", "trash can", "medium", "boxer", "very tall", "yamaha", "sunlight", "hit ball", "dry", "coke", "gym", "orange and black", "center", "rope", "flip flops", "4th of july", "siamese", "crafts", "color", "italian", "playing frisbee", "skate park", "orange juice", "windowsill", "corgi", "thumb", "peanut butter", "pie", "toast", "no hat", "benches", "diamond", "blender", "avocado", "television", "speakers", "pony", "baseball field", "pavement", "sydney", "not there", "diamonds", "4 feet", "goalie", "soccer ball", "runway", "video game", "gaming", "casual", "green and white", "toilet brush", "working", "pickup", "girls", "remotes", "pasta", "hood", "braves", "skier", "motorola", "17", "b", "100", "diet coke", "hospital", "wagon", "milk", "ferry", "rainbow", "on bed", "toward", "1:30", "19", "security", "herself", "mercedes benz", "supreme", "thin", "platform", "gray and red", "thai", "storage", "thailand", "swan", "peach", "10:05", "dome", "chiquita", "2:00", "mountain dew", "23", "knives", "street sign", "on beach", "playing wii", "using laptop", "stickers", "yogurt", "on grass", "9:50", "9:45", "sweat", "gatorade", "umpire", "37", "transport", "desktop", "desserts", "main", "boston", "fell", "top right", "case", "asleep", "over", "9:55", "grapefruit", "breakfast", "headphones", "freight", "cup", "sweatband", "nobody", "lamps", "9:25", "scarf", "on fridge", "main st", "moving", "confused", "fresh", "kiting", "blue jay", "flats", "long time", "chihuahua", "ceramic", "mushrooms", "on plate", "human", "power lines", "hotel", "map", "earring", "boarding", "display", "warm", "napkins", "brown and black", "broom", "basketball", "papers", "holding baby", "sad", "kickstand", "60", "shoulder", "sleep", "footprints", "tunnel", "1990", "hats", "6 inches", "ham", "bacon", "church", "53", "pineapple", "at camera", "red bull", "pilot", "tattoo", "work", "polar bear", "taking off", "website", "22", "4:00", "coffee maker", "fast", "fur", "rubber", "tongs", "german", "germany", "3 inches", "toy", "3:20", "calm", "pots", "balloons", "fruits", "9:20", "drawer", "oven", "soup", "stove", "heels", "wind", "island", "blood", "leg", "theater", "tennis racquet", "21", "gothic", "2:35", "wii remote", "turning", "20 feet", "pink and black", "ears", "fun", "wreath", "to right", "child", "fly", "head", "drywall", "shorter", "pier", "feeding giraffe", "in vase", "burger", "easter", "onion", "uniform", "remote control", "guitar", "time", "verizon", "tomatoes", "ship", "tulips", "glaze", "on suitcase", "tent", "1:45", "market", "bnsf", "bandana", "still", "don't know", "piano", "mouth", "run", "sparrow", "throw", "lines", "vest", "1950", "jet", "sepia", "2015", "busy", "lighter", "dessert", "bending", "75", "finch", "pastries", "outdoors", "bakery", "clean", "ipod", "tablecloth", "cigarettes", "looking at phone", "in front", "food truck", "face", "swinging", "safari", "500", "volkswagen", "2010", "shape", "shelves", "riding horses", "2016", "behind bus", "towels", "lemon", "straw", "bamboo", "5 feet", "hardwood", "oregon", "schnauzer", "organic", "h", "kid", "meter", "61", "charging", "bald", "caucasian", "man on left", "stand", "27", "dining room", "sandwiches", "32", "apartment", "tower", "virgin", "out", "white and red", "2:05", "i don't know", "chains", "legs", "age", "goats", "s", "congratulations", "dresser", "camper", "half", "silverware", "decorative", "hawaiian", "petting horse", "wheel", "florida", "reds", "washington dc", "moon", "conference", "screen", "controller", "robin", "men", "protection", "roll", "harley davidson", "coal", "mustache", "smiling", "pedestrians", "88", "me", "tray", "males", "monitor", "bell", "landscape", "club", "toothpick", "seagulls", "bowtie", "lake", "steam", "surf", "baseball glove", "blinders", "woods", "stuffed", "sunbathing", "shearing", "dad", "mixer", "pot", "blending", "identification", "owl", "wine glass", "on bike", "billabong", "new york", "yarn", "tube", "tennis ball", "2:55", "ice cream", "chevrolet", "shirt and tie", "taking selfie", "blue and green", "he isn't", "cutting cake", "east", "setting", "brewers", "riding bikes", "7 eleven", "stars", "jockey", "jacket", "standing still", "book", "gray and white", "pen", "red white blue", "above", "alaska", "tongue", "feathers", "k", "camping", "pasture", "corner", "away", "ski", "texas", "fire truck", "sailboats", "jump", "walk", "spray paint", "loading", "united", "1000", "brushing his teeth", "roman numerals", "garlic", "surprise", "3rd", "first", "side of road", "dodgers", "airplanes", "unsure", "russian", "wet", "skyscraper", "5 star", "brushing her teeth", "blankets", "natural", "across street", "smartphone", "duck", "sausage", "paris", "newspaper", "pants", "spices", "pillow", "to left", "snowboards", "colgate", "on elephant", "string", "horns", "2:40", "men's", "cobblestone", "regular", "staring", "28", "barber shop", "linoleum", "grind", "cut", "x", "above sink", "above stove", "dishes", "dalmatian", "watching", "glazed", "5:25", "j", "messy", "wallet", "tuna", "toasted", "grilled", "french", "green and blue", "sunflowers", "to catch frisbee", "wool", "sprint", "no grass", "cabinet", "shell", "foil", "bottles", "bar", "king", "paper towels", "friends", "beagle", "school bus", "laptops", "snowing", "cement", "pc", "accident", "stuffed animal", "wakeboard", "balance", "in suitcase", "white and black", "nikon", "cleats", "on sink", "pool", "mom", "downtown", "asian", "heater", "bathing", "193", "against wall", "canopy", "jungle", "berries", "military", "pickle", "clams", "seafood", "in box", "boats", "tables", "lizard", "lemonade", "m", "soft", "illinois", "country", "for sale", "arm", "listening", "curly", "play tennis", "hands", "cereal", "blue and red", "robe", "around neck", "red and silver", "soap", "trains", "throwing frisbee", "smoking", "india", "headband", "not very", "westin", "serve", "bicycles", "can't tell", "to catch ball", "visibility", "ana", "reins", "rodeo", "boot", "on horse", "12:35", "riding motorcycle", "mexico", "mother", "african", "left and right", "button", "earrings", "blackberry", "cell", "10:00", "harness", "pillows", "vegetable", "tablet", "fern", "cats", "golden retriever", "goat", "tractor", "valentine's day", "hearts", "khaki", "man on right", "mcdonald's", "player", "arriving", "husky", "on skateboard", "vases", "coat", "beanie", "coming", "granite", "shopping cart", "it's raining", "sports", "leash", "balls", "blurry", "baseball bat", "team", "mango", "mug", "eiffel tower", "worms", "trash", "robot", "show", "terrier", "painting", "rooster", "42", "jones", "state farm", "balloon", "trunk", "coach", "t", "playing game", "fireplace", "behind clouds", "uphill", "motocross", "sony", "magazine", "kitesurfing", "catching frisbee", "catch frisbee", "bud light", "drive", "fighting", "1 on left", "very old", "hallway", "lexus", "wii controller", "9:15", "fast food", "5:45", "catholic", "muffin", "traffic light", "band", "button up", "grocery", "shelf", "2:25", "honey", "plants", "oars", "foggy", "nathan's", "cord", "yard", "48", "donut shop", "chimney", "calico", "suits", "sideways", "animals", "black and blue", "bikini", "photographer", "700", "queen", "1:00", "12:05", "horseback riding", "awake", "bunny", "12:00", "continental", "flamingo", "rye", "family", "lots", "owner", "stew", "palm tree", "cruise ship", "56", "design", "ny", "far right", "tire", "younger", "biking", "at&t", "giants", "marshmallows", "caramel", "polo", "emirates", "salon", "focus", "on motorcycle", "magnets", "mat", "ivy", "cakes", "chrome", "bob", "asia", "graduation", "cauliflower", "in snow", "c", "rough", "vacation", "air", "windy", "victoria", "4:45", "trick", "coconut", "labrador", "on left", "yellow and green", "butterfly", "fake", "on napkin", "bricks", "wine glasses", "detroit", "man's", "parsley", "art", "subway", "wave", "placemat", "hydrant", "sofa", "pigeon", "riding elephant", "all", "branches", "plant", "to eat", "zucchini", "feta", "neon", "mouse pad", "cloud", "toilet paper", "pumpkin", "rowing", "toronto", "handicap", "seeds", "fly kite", "chicago", "marble", "frame", "150", "rocky", "give way", "sauce", "it's not", "control", "high chair", "playstation", "xbox", "not likely", "roman", "land", "1:35", "lifeguard", "on pizza", "size", "bull", "dandelions", "equestrian", "goose", "8 feet", "recessed", "statue", "index", "phillies", "strike", "mirrors", "pointing", "farmer", "collie", "motorbike", "lanes", "bikes", "biker", "arrows", "gas station", "logs", "smaller", "desert", "yield", "flags", "stool", "kitten", "doll", "daffodils", "letters", "dishwasher", "first base", "nuts", "2013", "persian", "swim trunks", "deep", "o", "doubles", "toothpicks", "in field", "wristband", "wheels", "baking", "4:15", "11:00", "ear", "2007", "51", "chevy", "using computer", "frog", "storm", "boogie board", "hungry", "by window", "ambulance", "pigtails", "audi", "microsoft", "on man", "cannot tell", "stained glass", "hugging", "laying down", "3:00", "taxi", "pedestrian", "landing", "numbers", "38", "stones", "on tree", "clocks", "new", "picnic", "fog", "buffalo", "under armour", "cocker spaniel", "orioles", "no sign", "telling time", "bags", "golden gate", "cover", "castle", "canoe", "selfie", "cream", "floating", "indoor", "antique", "aluminum", "silver and black", "cast iron", "peas", "sun hat", "on right", "swiss", "flour", "under sink", "fashion", "fedora", "shells", "1 hour", "puppy", "in stands", "not here", "motor", "thousands", "120", "sail", "butt", "mexican", "dead end", "paddle", "bathing suit", "shop", "onion rings", "boxing", "birthday cake", "chalk", "scenery", "style", "nissan", "sticker", "on rack", "1 4", "woman's", "surprised", "north face", "squash", "not sure", "email", "spotted", "seat", "himself", "circles", "san diego", "kia", "mattress", "obama", "lamb", "american flag", "climbing", "skull and crossbones", "roast beef", "visor", "herd", "double", "52", "high", "stagecoach", "cart", "feeding", "eaten", "cone", "11:15", "smoothie", "golf", "colorado", "electronics", "5:15", "bowling", "players", "ketchup and mustard", "styrofoam", "6 feet", "hawk", "cheddar", "12:28", "arabic", "12:25", "12:10", "shower curtain", "army", "salmon", "10:40", "hanging", "whole", "behind fence", "bars", "moss", "no dog", "traffic", "10:25", "r", "countryside", "machine", "directions", "cooked", "aa", "6:45", "4 way", "stripe", "brand", "baseball player", "bunk", "coleslaw", "fishing boat", "at table", "europe", "dead", "arch", "scrambled", "clothing", "closet", "egg", "suitcases", "indoors", "coffee pot", "tires", "lilies", "cafe", "9:35", "teal", "toothpaste", "in background", "tarmac", "painted", "sunset", "orange and yellow", "oar", "peaches", "zebra and giraffe", "ladybug", "20 ft", "sesame seeds", "hills", "2:30", "stucco", "tail", "couple", "kawasaki", "smooth", "powdered sugar", "pedestrian crossing", "french fries", "picnic table", "teeth", "ribbon", "saddle", "15 feet", "earbuds", "on train", "39", "curb", "tow", "shark", "white and orange", "6:25", "gravy", "fork and spoon", "pooping", "curtain", "lime", "skull", "crossing", "speed limit", "peacock", "boredom", "neck", "hit", "dragon", "tissues", "basil", "waving", "blue team", "rectangles", "helicopter", "mud", "us", "balcony", "red and gray", "firefighter", "sunflower", "wallpaper", "best buy", "11:20", "public market center", "seattle", "bookshelf", "looking", "1 inch", "harley", "urinal", "cartoon", "t shirt and jeans", "navy", "fedex", "rays", "deck", "coaster", "1:20", "50 feet", "4:20", "us open", "looking at camera", "600", "national express", "white house", "5:00", "jp morgan", "palm trees", "tub", "pens", "soldiers", "2 people", "animal", "speaker", "hamburger", "spaghetti", "green beans", "it isn't", "10:20", "buildings", "on shelf", "baseball uniform", "tiled", "orange and blue", "90", "north america", "arrow", "news", "tropicana", "formal", "in grass", "thumbs up", "clip", "gate", "tennis player", "lilac", "pastry", "nose", "pacifier", "11:35", "different teams", "cardinals", "exhaust", "hauling", "on tray", "bagel", "huge", "out of focus", "cook", "wheat", "photo", "ghost", "sedan", "qatar", "zig zag", "lanyard", "pink and white", "sesame", "space", "no clock", "warning", "snowy", "tater tots", "tropical", "grandfather", "mac", "magnet", "photoshop", "pajamas", "350", "casserole", "4:55", "pelican", "2009", "clydesdale", "tow truck", "belt", "west", "omelet", "heavy", "crown", "in corner", "hexagon", "mound", "iris", "g", "12:45", "2:15", "3:10", "drawing", "only", "little girl", "washing", "nokia", "windsor", "2 men", "parmesan cheese", "on woman", "freezer", "icing", "venice", "dairy", "several", "concentration", "3:15", "no smoking", "kayak", "frosting", "jetblue", "thoroughbred", "parakeet", "shoe", "skeleton", "britain", "ties", "in sink", "patio", "bank", "camouflage", "privacy", "bib", "blue and gray", "looking out window", "falling", "bucket", "cupcakes", "throw ball", "garden", "almonds", "ducati", "ireland", "plastic wrap", "starbucks", "all way", "bark", "home plate", "base", "dog food", "toys", "blue and orange", "1 in front", "foot", "dc", "california", "towing", "cheesecake", "bushes", "bow tie", "millions", "down street", "2011", "police officer", "windmill", "taking pictures", "street name", "cleaning", "on pole", "russia", "main street", "catch ball", "mario", "pirate", "track", "garage", "7:10", "they aren't", "mother and child", "tents", "fancy", "tattoos", "alcohol", "2:45", "wheelchair", "money", "top hat", "willow", "cd", "brushing hair", "pancake", "80", "listening to music", "green and red", "barrier", "vests", "hiking", "tank top", "lufthansa", "student", "menu", "forehand", "wii controllers", "acer", "wall st", "hundreds", "water ski", "furniture", "paisley", "pizza hut", "baseball game", "hill", "prom", "1 world", "tiara", "students", "information", "hazy", "nasa", "canon", "bird feeder", "crane", "dr pepper", "logitech", "2:10", "all of them", "utensils", "telephone", "converse", "bone", "jeep", "nursing", "krispy kreme", "cameraman", "pee", "ranch", "polka dots", "railroad crossing", "shirts", "feeder", "above toilet", "unclear", "below", "43", "spoons", "calendar", "vaio", "fox", "mint", "after", "spiderman", "lg", "concert", "on rock", "fluffy", "gray and black", "coats", "lady", "dodge", "easyjet", "pearl", "bunt", "flat screen", "10:30", "music", "polar bears", "riding horse", "lift", "angry", "cookies", "3:45", "buttons", "hot", "cute", "behind", "dole", "in motion", "26", "pans", "love", "winnie pooh", "pear", "copyright", "2 hours", "snowsuit", "kissing", "backhand", "to get to other side", "metro", "swans", "very fast", "can't see it", "nintendo", "direction", "waiting", "mohawk", "st patrick's day", "rail", "hoodie", "feet", "swirls", "muffins", "4:05", "106", "10:55", "coins", "mitt", "game controller", "room", "adults", "urinals", "cameras", "marker", "upright", "brass", "sled", "teacher", "conductor", "farmers market", "toiletries", "blue and black", "soccer field", "banana peel", "sprite", "doughnuts", "bank of america", "on his face", "heat", "emergency", "ski slope", "hard", "41", "6:00", "in his hand", "cluttered", "dog show", "on boat", "grizzly", "drums", "not", "in hand", "easy", "400", "under table", "d", "hitting ball", "photography", "intersection", "backwards", "crocs", "marina", "chips", "bible", "harry potter", "hawaii", "fanta", "half full", "carriage", "curious", "12:50", "black white", "geese", "pork", "mailbox", "l", "sidecar", "poop", "wings", "penguin", "to see", "pocket", "steps", "cubs", "junk", "deer", "ottoman", "salt", "condiments", "1:55", "post", "bulldog", "notebook", "no cat", "champagne", "jets", "knee pads", "throw frisbee", "drinks", "leopard", "taller", "cooler", "bundt", "monday", "grape", "wine tasting", "under", "baskets", "santa hat", "chest", "sewing", "on car", "sony ericsson", "peeing", "for photo", "tour", "few", "singapore", "fireman", "fire extinguisher", "wildebeest", "lemons", "peanuts", "babies", "wiimote", "guitar hero", "slide", "stopped", "library", "multi colored", "blue and pink", "choppy", "sailing", "brush", "grinding", "jelly", "dairy queen", "shaking hands", "ge", "tigers", "tokyo", "philadelphia", "ski boots", "buses", "11:45", "collage", "pink and blue", "jesus", "singles", "iron", "coffee table", "2 years", "don't walk", "classroom", "on water", "potato salad", "posts", "harbor", "residential", "joshua", "uk", "burgers", "deli", "kicking", "lace", "overalls", "vehicles", "ram", "dancing", "47", "shed", "lid", "he's not", "fans", "amtrak", "space shuttle", "ostrich", "bathtub", "kneeling", "2:50", "mall", "yellow and orange", "gazebo", "wax", "slow down", "lays", "hammer time", "octopus", "crib", "banana split", "broadway", "pottery", "wavy", "farmers", "holding phone", "on phone", "squirrel", "wax paper", "tusks", "dining", "packing", "kangaroo", "dawn", "defense", "powdered", "thomas", "budweiser", "back left", "stir fry", "beijing", "11:10", "tripod", "wide", "slope", "black and gray", "planter", "chili", "siblings", "kayaking", "captivity", "opaque", "rack", "panda", "doorway", "wheelie", "pelicans", "genetics", "not in service", "volvo", "dachshund", "v", "on laptop", "western", "gone", "birthday party", "parking garage", "tying tie", "blueberry", "scale", "notes", "train car", "man made", "stability", "lily", "lying down", "pacific", "high heels", "pare", "checkerboard", "partly cloudy", "cool", "n", "toilets", "tree branch", "copper", "cycling", "5:50", "870", "shopping", "7:05", "zipper", "holding umbrella", "batman", "lotion", "1:25", "black and brown", "playing video game", "girl on right", "legos", "drinking water", "burrito", "plow", "jet ski", "spiral", "ibm", "tools", "flashlight", "cherries", "maple leaf", "mountainous", "under tree", "vines", "sushi", "baker", "snake", "globe", "target", "john", "pomeranian", "tuxedo", "hockey", "sleeve", "leaning", "wireless", "11:05", "compaq", "do not enter", "radish", "1:05", "dim", "advertisement", "movement", "model", "hammock", "swing", "sheet", "google", "boardwalk", "right 1", "haircut", "ankle", "3:30", "exit", "csx", "tim hortons", "lego", "cucumbers", "angel", "12:20", "racquet", "behind woman", "potato", "egg salad", "controllers", "recliner", "upside down", "mosaic", "before", "antenna", "3:50", "10:15", "lion", "camo", "fighter", "silver and red", "dirt bike", "playing video games", "used", "crates", "horizontally", "plunger", "refrigerators", "radiator", "stork", "in basket", "cap", "living", "married", "briefcase", "bottom left", "30 mph", "ascending", "flip phone", "101", "11:50", "gun", "arizona", "foam", "serious", "y", "close up", "pancakes", "heineken", "paw", "cnn", "comforter", "sheets", "8:35", "driveway", "fair", "cleaner", "1 year", "delivery", "commuter", "apple and banana", "chase", "72", "safe", "trucks", "trunks", "spider", "64", "slacks", "meeting", "7:00", "skiers", "shaved", "carrot cake", "holding", "surfers", "giraffe and zebra", "7:45", "mississippi", "seaweed", "black and pink", "horse racing", "orchid", "rv", "tourist", "above door", "leaving", "pitch", "crest", "miami", "asics", "flood", "bus station", "take off", "amazon", "practice", "entering", "diesel", "pm", "wetsuits", "remodeling", "porch", "7:35", "tie dye", "baked", "life jacket", "cylinder", "grilled cheese", "meatballs", "paddling", "banana bread", "monster", "smiley face", "not high", "keys", "dreadlocks", "kitchenaid", "straight ahead", "badminton", "long sleeve", "sheepdog", "5:18", "end", "on shore", "scratching", "oriental", "5:05", "alligator", "city bus", "purple and white", "10:50", "each other", "weeds", "tinkerbell", "rottweiler", "apartments", "snowflakes", "stop light", "sweatshirt", "shore", "bidet", "switzerland", "stretching", "tv stand", "boundaries", "65", "bronze", "jar", "middle 1", "54", "skate", "easton", "turn right", "raspberries", "singing", "on bus", "carnations", "descending", "classic", "suspenders", "not long", "8:50", "father", "anniversary", "hsbc", "very long", "space needle", "skatepark", "fruit salad", "kenmore", "no water", "8:05", "db", "baby's breath", "shelter", "1980", "no left turn", "washington monument", "ham and cheese", "10 inches", "8:55", "savory", "6:35", "indians", "9:05", "fires", "pipes", "donkey", "cds", "mitsubishi", "tell time", "outfield", "christian", "puma", "parking meters", "cranes", "flip", "wine bottle", "stadium", "mouthwash", "heinz", "distance", "macaroni", "on plane", "triumph", "more", "4:50", "single engine", "disney", "on stove", "shih tzu", "fried", "to hit ball", "in her hand", "sunrise", "2nd", "elmo", "kite string", "suzuki", "traffic lights", "blt", "i", "hitting", "htc", "healthy", "current", "star alliance", "stomach", "watch tv", "tulip", "5:10", "right side", "4:40", "ginger", "on sign", "cushion", "5:30", "learning", "pencil", "maroon", "food processor", "5:40", "dog bed", "michigan", "close", "license plate", "crows", "right hand", "normal", "green and brown", "1.00", "000", "1:40", "wing", "american airlines", "kodak", "mural", "sniffing", "1:15", "behind bench", "cardinal", "no light", "warmth", "paved", "skyscrapers", "swinging bat", "watermark", "in cup", "pizza box", "dough", "hiding", "goal", "no plate", "shower head", "ripe", "1:10", "1 in back", "older", "nest", "multiple", "cinnamon", "bin", "new orleans", "colored", "enclosure", "bride", "on dresser", "star wars", "in back", "triangles", "over easy", "cilantro", "statues", "sticks", "formica", "roundabout", "bowls", "ahead", "years", "drain", "veggies", "no shirt", "taking photo", "tugboat", "broke", "59", "cadillac", "prince", "left side", "1 in middle", "10:45", "drying", "11:25", "silk", "conference room", "buoys", "pockets", "daffodil", "6:40", "walgreens", "4 ft", "6:05", "virgin atlantic", "12:40", "digital", "ups", "westjet", "bikers", "us air force", "limes", "comcast", "dip", "7:55", "man in middle", "bus driver", "soon", "futon", "selling", "braid", "mariners", "wisconsin", "99", "citizen", "broccoli and carrots", "grocery store", "us airways", "49", "bored", "red velvet", "hotel room", "qantas", "tam", "korean air", "10:35", "whirlpool", "coffee cup", "hilly", "9:12", "whipped cream", "video", "finger", "competition", "hollywood", "sas", "backward", "beads", "cosmo", "10:08", "jal", "6:30", "100 year party ct", "hispanic", "in cabbage town", "opponent", "woodpecker", "visilab", "mt airy", "crosstown", "freightliner"]
|
configs/retrieval.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hidden_size: &hidden_size 768
|
2 |
+
vocab_size: &vocab_size 30522
|
3 |
+
type_vocab_size: &type_vocab_size 2
|
4 |
+
max_position_embeddings: &max_position_embeddings 512
|
5 |
+
pad_token_id: &pad_token_id 0
|
6 |
+
embed_size: &embed_size 256
|
7 |
+
|
8 |
+
seed: 42
|
9 |
+
world_size: 1
|
10 |
+
device: "cuda"
|
11 |
+
dist_url: "env://"
|
12 |
+
output_path: "./examples/albef/outputs/retrieval_output.pt"
|
13 |
+
|
14 |
+
datamodule_args:
|
15 |
+
train_files: ["./examples/albef/data_files/coco_train.json"]
|
16 |
+
test_files: ["./examples/albef/data_files/coco_test.json"]
|
17 |
+
image_root: "./examples/albef/data_files/coco"
|
18 |
+
batch_size: 32
|
19 |
+
num_workers: 8
|
20 |
+
|
21 |
+
vision_encoder_args:
|
22 |
+
hidden_size: *hidden_size
|
23 |
+
image_size: 384
|
24 |
+
patch_size: 16
|
25 |
+
num_hidden_layers: 12
|
26 |
+
num_attention_heads: 12
|
27 |
+
mlp_dim: 3072
|
28 |
+
dropout: 0.0
|
29 |
+
attention_dropout: 0.0
|
30 |
+
layer_norm_eps: 1e-6
|
31 |
+
|
32 |
+
text_encoder_args:
|
33 |
+
vocab_size: *vocab_size
|
34 |
+
hidden_size: *hidden_size
|
35 |
+
type_vocab_size: *type_vocab_size
|
36 |
+
max_position_embeddings: *max_position_embeddings
|
37 |
+
pad_token_id: *pad_token_id
|
38 |
+
num_hidden_layers: 6
|
39 |
+
num_attention_heads: 12
|
40 |
+
intermediate_size: 3072
|
41 |
+
layer_norm_eps: 1e-12
|
42 |
+
dropout: 0.0
|
43 |
+
|
44 |
+
multimodal_encoder_args:
|
45 |
+
hidden_size: *hidden_size
|
46 |
+
num_hidden_layers: 6
|
47 |
+
num_attention_heads: 12
|
48 |
+
intermediate_size: 3072
|
49 |
+
layer_norm_eps: 1e-12
|
50 |
+
|
51 |
+
projection_args:
|
52 |
+
in_features: *hidden_size
|
53 |
+
out_features: *embed_size
|
54 |
+
|
55 |
+
similarity_args:
|
56 |
+
embed_size: *embed_size
|
57 |
+
queue_size: 65536
|
58 |
+
temp: 0.07
|
59 |
+
|
60 |
+
training_args:
|
61 |
+
log_every_n_steps: 100
|
62 |
+
alpha: 0.4
|
63 |
+
weight_decay: 0.02
|
64 |
+
lr: 1e-5
|
65 |
+
min_lr: 1e-6
|
66 |
+
max_epochs: 5
|
67 |
+
step_size: 100
|
68 |
+
warmup_steps: 1
|
69 |
+
checkpoint_root: "./examples/albef/checkpoints"
|
70 |
+
|
71 |
+
eval_args:
|
72 |
+
log_every_n_steps: 100
|
73 |
+
k_test: 256
|
configs/vqa.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hidden_size: &hidden_size 768
|
2 |
+
vocab_size: &vocab_size 30522
|
3 |
+
type_vocab_size: &type_vocab_size 2
|
4 |
+
max_position_embeddings: &max_position_embeddings 512
|
5 |
+
pad_token_id: &pad_token_id 0
|
6 |
+
|
7 |
+
seed: 42
|
8 |
+
world_size: 1
|
9 |
+
device: "cuda"
|
10 |
+
dist_url: "env://"
|
11 |
+
output_root: "./examples/albef/outputs"
|
12 |
+
|
13 |
+
datamodule_args:
|
14 |
+
train_files: ["./examples/albef/data_files/vqa_train.json", "./examples/albef/data_files/vg_qa.json", "./examples/albef/data_files/vqa_val.json"]
|
15 |
+
test_files: ["./examples/albef/data_files/vqa_test.json"]
|
16 |
+
answer_list: "./examples/albef/data_files/answer_list.json"
|
17 |
+
vqa_root: "./examples/albef/data_files/coco"
|
18 |
+
vg_root: "./examples/albef/data_files/visual_genome"
|
19 |
+
batch_size: 32
|
20 |
+
num_workers: 8
|
21 |
+
|
22 |
+
vision_encoder_args:
|
23 |
+
hidden_size: *hidden_size
|
24 |
+
image_size: 384
|
25 |
+
patch_size: 16
|
26 |
+
num_hidden_layers: 12
|
27 |
+
num_attention_heads: 12
|
28 |
+
mlp_dim: 3072
|
29 |
+
dropout: 0.0
|
30 |
+
attention_dropout: 0.0
|
31 |
+
layer_norm_eps: 1e-6
|
32 |
+
|
33 |
+
text_encoder_args:
|
34 |
+
vocab_size: *vocab_size
|
35 |
+
hidden_size: *hidden_size
|
36 |
+
type_vocab_size: *type_vocab_size
|
37 |
+
max_position_embeddings: *max_position_embeddings
|
38 |
+
pad_token_id: *pad_token_id
|
39 |
+
num_hidden_layers: 6
|
40 |
+
num_attention_heads: 12
|
41 |
+
intermediate_size: 3072
|
42 |
+
layer_norm_eps: 1e-12
|
43 |
+
dropout: 0.0
|
44 |
+
|
45 |
+
multimodal_encoder_args:
|
46 |
+
hidden_size: *hidden_size
|
47 |
+
num_hidden_layers: 6
|
48 |
+
num_attention_heads: 12
|
49 |
+
intermediate_size: 3072
|
50 |
+
layer_norm_eps: 1e-12
|
51 |
+
|
52 |
+
text_embeddings_args:
|
53 |
+
hidden_size: *hidden_size
|
54 |
+
vocab_size: *vocab_size
|
55 |
+
pad_token_id: *pad_token_id
|
56 |
+
max_position_embeddings: *max_position_embeddings
|
57 |
+
type_vocab_size: *type_vocab_size
|
58 |
+
layer_norm_eps: 1e-12
|
59 |
+
|
60 |
+
prediction_head_args:
|
61 |
+
hidden_size: *hidden_size
|
62 |
+
vocab_size: *vocab_size
|
63 |
+
layer_norm_eps: 1e-12
|
64 |
+
|
65 |
+
training_args:
|
66 |
+
log_every_n_steps: 100
|
67 |
+
alpha: 0.4
|
68 |
+
weight_decay: 0.02
|
69 |
+
lr: 2e-5
|
70 |
+
min_lr: 1e-6
|
71 |
+
max_epochs: 8
|
72 |
+
step_size: 100
|
73 |
+
warmup_steps: 4
|
74 |
+
checkpoint_root: "./examples/albef/checkpoints"
|
75 |
+
|
76 |
+
eval_args:
|
77 |
+
log_every_n_steps: 100
|
78 |
+
k_test: 128
|
data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
data/retrieval_datamodule.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import List, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from data.retrieval_dataset import (
|
11 |
+
ImageToTextRetrievalDataset,
|
12 |
+
RetrievalTrainingDataset,
|
13 |
+
TextToImageRetrievalDataset,
|
14 |
+
)
|
15 |
+
from data.transforms import (
|
16 |
+
ALBEFTextTransform,
|
17 |
+
testing_image_transform,
|
18 |
+
training_image_transform,
|
19 |
+
)
|
20 |
+
from pytorch_lightning import LightningDataModule
|
21 |
+
from torch import Tensor
|
22 |
+
from torch.nn.utils.rnn import pad_sequence
|
23 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
24 |
+
|
25 |
+
|
26 |
+
class RetrievalDataModule(LightningDataModule):
|
27 |
+
"""
|
28 |
+
The Data Module for Retrieval task.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
train_files (List[str]): The paths to training json files.
|
32 |
+
test_files (List[str]): The paths to testing json files.
|
33 |
+
image_root (str): The path to image data directory.
|
34 |
+
batch_size (int): The sampling batch size.
|
35 |
+
num_workers (int): The number of workers for the distributed mode.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
train_files: List[str],
|
41 |
+
test_files: List[str],
|
42 |
+
image_root: str,
|
43 |
+
batch_size: int,
|
44 |
+
num_workers: int,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.train_dataset = RetrievalTrainingDataset(
|
48 |
+
train_files,
|
49 |
+
image_root,
|
50 |
+
training_image_transform(),
|
51 |
+
ALBEFTextTransform(truncate=True, max_seq_len=30, add_end_token=False),
|
52 |
+
)
|
53 |
+
|
54 |
+
self.image_dataset = ImageToTextRetrievalDataset(
|
55 |
+
test_files,
|
56 |
+
image_root,
|
57 |
+
testing_image_transform(),
|
58 |
+
)
|
59 |
+
|
60 |
+
self.text_dataset = TextToImageRetrievalDataset(
|
61 |
+
test_files,
|
62 |
+
ALBEFTextTransform(
|
63 |
+
truncate=True,
|
64 |
+
pad_to_max_seq_len=True,
|
65 |
+
max_seq_len=30,
|
66 |
+
add_end_token=False,
|
67 |
+
),
|
68 |
+
)
|
69 |
+
|
70 |
+
self.batch_size = batch_size
|
71 |
+
self.num_workers = num_workers
|
72 |
+
|
73 |
+
def _get_sampler(
|
74 |
+
self,
|
75 |
+
dataset: Dataset,
|
76 |
+
shuffle: bool,
|
77 |
+
is_distributed: bool,
|
78 |
+
num_tasks: int,
|
79 |
+
global_rank: int,
|
80 |
+
) -> Optional[DistributedSampler]:
|
81 |
+
# do not return a sampler if is not in distributed mode
|
82 |
+
# a default RandomSampler is used in this case
|
83 |
+
if not is_distributed:
|
84 |
+
return None
|
85 |
+
|
86 |
+
return DistributedSampler(
|
87 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
88 |
+
)
|
89 |
+
|
90 |
+
def train_dataloader(
|
91 |
+
self,
|
92 |
+
is_distributed: bool = False,
|
93 |
+
num_tasks: int = 0,
|
94 |
+
global_rank: int = 0,
|
95 |
+
drop_last: bool = True,
|
96 |
+
) -> DataLoader:
|
97 |
+
"""
|
98 |
+
DataLoader Outputs:
|
99 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
100 |
+
text (Tensor): Tensor of shape (B, L) of text inputs.
|
101 |
+
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
|
102 |
+
idx (Tensor): Tensor of shape (B) of image identifiers.
|
103 |
+
"""
|
104 |
+
sampler = self._get_sampler(
|
105 |
+
dataset=self.train_dataset,
|
106 |
+
shuffle=True,
|
107 |
+
is_distributed=is_distributed,
|
108 |
+
num_tasks=num_tasks,
|
109 |
+
global_rank=global_rank,
|
110 |
+
)
|
111 |
+
shuffle = sampler is None
|
112 |
+
return DataLoader(
|
113 |
+
self.train_dataset,
|
114 |
+
batch_size=self.batch_size,
|
115 |
+
num_workers=self.num_workers,
|
116 |
+
pin_memory=True,
|
117 |
+
sampler=sampler,
|
118 |
+
shuffle=shuffle,
|
119 |
+
collate_fn=retrieval_train_collate_fn,
|
120 |
+
drop_last=drop_last,
|
121 |
+
)
|
122 |
+
|
123 |
+
def image_dataloader(
|
124 |
+
self,
|
125 |
+
drop_last: bool = False,
|
126 |
+
) -> DataLoader:
|
127 |
+
"""
|
128 |
+
DataLoader Outputs:
|
129 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
130 |
+
"""
|
131 |
+
return DataLoader(
|
132 |
+
self.image_dataset,
|
133 |
+
batch_size=self.batch_size,
|
134 |
+
num_workers=self.num_workers,
|
135 |
+
pin_memory=True,
|
136 |
+
sampler=None,
|
137 |
+
shuffle=False,
|
138 |
+
collate_fn=None,
|
139 |
+
drop_last=drop_last,
|
140 |
+
)
|
141 |
+
|
142 |
+
def text_dataloader(
|
143 |
+
self,
|
144 |
+
drop_last: bool = False,
|
145 |
+
) -> DataLoader:
|
146 |
+
"""
|
147 |
+
DataLoader Outputs:
|
148 |
+
text (Tensor): Tensor of shape (B, L) of text inputs.
|
149 |
+
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
|
150 |
+
"""
|
151 |
+
return DataLoader(
|
152 |
+
self.text_dataset,
|
153 |
+
batch_size=self.batch_size,
|
154 |
+
num_workers=self.num_workers,
|
155 |
+
pin_memory=True,
|
156 |
+
sampler=None,
|
157 |
+
shuffle=False,
|
158 |
+
collate_fn=text_collate_fn,
|
159 |
+
drop_last=drop_last,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def retrieval_train_collate_fn(
|
164 |
+
batch: List[Tuple[Tensor, Tensor, int]]
|
165 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
166 |
+
image_list = []
|
167 |
+
text_list = []
|
168 |
+
idx_list = []
|
169 |
+
for image, text, idx in batch:
|
170 |
+
image_list.append(image)
|
171 |
+
text_list.append(text)
|
172 |
+
idx_list.append(idx)
|
173 |
+
images = torch.stack(image_list, dim=0)
|
174 |
+
text = pad_sequence(text_list, batch_first=True)
|
175 |
+
text_atts = (text != 0).type(torch.long)
|
176 |
+
idx = Tensor(idx_list).type(torch.long)
|
177 |
+
return (
|
178 |
+
images,
|
179 |
+
text,
|
180 |
+
text_atts,
|
181 |
+
idx,
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def text_collate_fn(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
|
186 |
+
text = pad_sequence(batch, batch_first=True)
|
187 |
+
text_atts = (text != 0).type(torch.long)
|
188 |
+
return text, text_atts
|
data/retrieval_dataset.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from typing import Callable, List, Tuple, Union
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
|
16 |
+
class RetrievalTrainingDataset(Dataset):
|
17 |
+
"""
|
18 |
+
Create the training dataset for Retrieval task.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
ann_file (List[str]): The paths to training annotation json files.
|
22 |
+
image_root (str): The path to image data directory.
|
23 |
+
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
|
24 |
+
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
|
25 |
+
|
26 |
+
Dataset Outputs:
|
27 |
+
image (Tensor): Transformed image input tensor of shape (C, H, W).
|
28 |
+
caption (Tensor): Transformed text token input ids.
|
29 |
+
idx (int): The unique identifier for the image.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
ann_file: List[str],
|
35 |
+
image_root: str,
|
36 |
+
image_transform: Callable[[Image.Image], Tensor],
|
37 |
+
text_transform: Callable[[Union[List[str], str]], Tensor],
|
38 |
+
) -> None:
|
39 |
+
self.ann = []
|
40 |
+
for f in ann_file:
|
41 |
+
self.ann += json.load(open(f, "r"))
|
42 |
+
|
43 |
+
self.image_root = image_root
|
44 |
+
self.image_transform = image_transform
|
45 |
+
self.text_transform = text_transform
|
46 |
+
|
47 |
+
self.idx = {} # map str image_id from dataset to int ids
|
48 |
+
i = 0
|
49 |
+
for ann in self.ann:
|
50 |
+
image_id = ann["image_id"]
|
51 |
+
if image_id not in self.idx.keys():
|
52 |
+
self.idx[image_id] = i
|
53 |
+
i += 1
|
54 |
+
|
55 |
+
def __len__(self) -> int:
|
56 |
+
return len(self.ann)
|
57 |
+
|
58 |
+
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]:
|
59 |
+
ann = self.ann[index]
|
60 |
+
image_path = os.path.join(self.image_root, ann["image"])
|
61 |
+
image = Image.open(image_path).convert("RGB")
|
62 |
+
image = self.image_transform(image)
|
63 |
+
caption = self.text_transform(ann["caption"])
|
64 |
+
return image, caption, self.idx[ann["image_id"]]
|
65 |
+
|
66 |
+
|
67 |
+
class ImageToTextRetrievalDataset(Dataset):
|
68 |
+
"""
|
69 |
+
Create the dataset for Image-to-Text Retrieval task.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
ann_file (List[str]): The paths to annotation json files.
|
73 |
+
image_root (str): The path to image data directory.
|
74 |
+
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
|
75 |
+
|
76 |
+
Dataset Outputs:
|
77 |
+
image (Tensor): Transformed image input tensor of shape (C, H, W).
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
ann_file: List[str],
|
83 |
+
image_root: str,
|
84 |
+
image_transform: Callable[[Image.Image], Tensor],
|
85 |
+
) -> None:
|
86 |
+
self.image_root = image_root
|
87 |
+
self.image_transform = image_transform
|
88 |
+
|
89 |
+
self.ann = []
|
90 |
+
self.images = [] # paths to all images in the dataset
|
91 |
+
self.image_to_text = {} # map image ids to text ids for evaluation
|
92 |
+
for f in ann_file:
|
93 |
+
self.ann += json.load(open(f, "r"))
|
94 |
+
|
95 |
+
text_id = 0
|
96 |
+
for image_id, ann in enumerate(self.ann):
|
97 |
+
self.images.append(ann["image"])
|
98 |
+
num_text = len(ann["caption"])
|
99 |
+
self.image_to_text[image_id] = list(range(text_id, text_id + num_text))
|
100 |
+
text_id += num_text
|
101 |
+
|
102 |
+
def __len__(self) -> int:
|
103 |
+
return len(self.images)
|
104 |
+
|
105 |
+
def __getitem__(self, index: int) -> Tensor:
|
106 |
+
image_path = os.path.join(self.image_root, self.images[index])
|
107 |
+
image = Image.open(image_path).convert("RGB")
|
108 |
+
image = self.image_transform(image)
|
109 |
+
return image
|
110 |
+
|
111 |
+
|
112 |
+
class TextToImageRetrievalDataset(Dataset):
|
113 |
+
"""
|
114 |
+
Create the dataset for Text-to-Image Retrieval task.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
ann_file (List[str]): The paths to annotation json files.
|
118 |
+
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
|
119 |
+
|
120 |
+
Dataset Outputs:
|
121 |
+
text (Tensor): Transformed text token input ids.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
ann_file: List[str],
|
127 |
+
text_transform: Callable[[Union[List[str], str]], Tensor],
|
128 |
+
) -> None:
|
129 |
+
self.text_transform = text_transform
|
130 |
+
|
131 |
+
self.ann = []
|
132 |
+
self.text = [] # all text strings in the dataset
|
133 |
+
self.text_to_image = {} # map text ids to image ids for evaluation
|
134 |
+
for f in ann_file:
|
135 |
+
self.ann += json.load(open(f, "r"))
|
136 |
+
|
137 |
+
text_id = 0
|
138 |
+
for image_id, ann in enumerate(self.ann):
|
139 |
+
for caption in ann["caption"]:
|
140 |
+
self.text.append(caption)
|
141 |
+
self.text_to_image[text_id] = image_id
|
142 |
+
text_id += 1
|
143 |
+
|
144 |
+
def __len__(self) -> int:
|
145 |
+
return len(self.text)
|
146 |
+
|
147 |
+
def __getitem__(self, index: int) -> Tensor:
|
148 |
+
text = self.text_transform(self.text[index])
|
149 |
+
return text
|
data/transforms.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import re
|
8 |
+
from typing import List, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate
|
13 |
+
from torchvision import transforms
|
14 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
15 |
+
|
16 |
+
# mean and standard deviation from the ALBEF repo:
|
17 |
+
# https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16
|
18 |
+
MEAN = (0.48145466, 0.4578275, 0.40821073)
|
19 |
+
STD_DEV = (0.26862954, 0.26130258, 0.27577711)
|
20 |
+
|
21 |
+
|
22 |
+
class ALBEFTextTransform:
|
23 |
+
"""
|
24 |
+
Remove punctuations and trailing spaces in input text and transform it into
|
25 |
+
a Tensor of token ids using BERTTokenizer.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
pretrained_tokenizer (str): Pretrained tokenizer to use.
|
29 |
+
Default: "bert-base-uncased"
|
30 |
+
do_pre_process (bool): Whether to pre-process input text.
|
31 |
+
Defaults to True.
|
32 |
+
truncate (bool): Whether to truncate input text to max_seq_length.
|
33 |
+
Defaults to False.
|
34 |
+
pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length.
|
35 |
+
add_end_token (bool): Whether to add the end-of-sentence token.
|
36 |
+
Defaults to True.
|
37 |
+
max_seq_len (int): The max sequence length after truncating or padding.
|
38 |
+
Defaults to 25.
|
39 |
+
cls_token_id (int): Value to represent the start of each text.
|
40 |
+
Defaults to 101, Hugging Face's BERT cls token id.
|
41 |
+
sep_token_id (int): Value to represent the end of each text.
|
42 |
+
Defaults to 102, Hugging Face's BERT sep token id.
|
43 |
+
pad_token_id (int): Value with which to pad each text so that all texts are the same length.
|
44 |
+
Defaults to 0, Hugging Face's BERT pad token id.
|
45 |
+
|
46 |
+
Inputs:
|
47 |
+
text (Union[List[str], str]): Input text to transform.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
pretrained_tokenizer: str = "bert-base-uncased",
|
53 |
+
do_pre_process: bool = True,
|
54 |
+
truncate: bool = False,
|
55 |
+
pad_to_max_seq_len: bool = False,
|
56 |
+
add_end_token: bool = True,
|
57 |
+
max_seq_len: int = 25,
|
58 |
+
cls_token_id: int = 101,
|
59 |
+
sep_token_id: int = 102,
|
60 |
+
pad_token_id: int = 0,
|
61 |
+
):
|
62 |
+
self.do_pre_process = do_pre_process
|
63 |
+
self.cls_token_id = cls_token_id
|
64 |
+
self.sep_token_id = sep_token_id
|
65 |
+
self.pad_token_id = pad_token_id
|
66 |
+
self.add_end_token = add_end_token
|
67 |
+
|
68 |
+
self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer)
|
69 |
+
self.transform = Sequential(
|
70 |
+
Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
|
71 |
+
ToTensor(padding_value=self.pad_token_id),
|
72 |
+
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
|
73 |
+
if pad_to_max_seq_len
|
74 |
+
else torch.nn.Identity(),
|
75 |
+
)
|
76 |
+
|
77 |
+
def pre_process(self, text: str) -> str:
|
78 |
+
text = (
|
79 |
+
re.sub(
|
80 |
+
r"([,.'!?\"()*#:;~])",
|
81 |
+
"",
|
82 |
+
text,
|
83 |
+
)
|
84 |
+
.replace("-", " ")
|
85 |
+
.replace("/", " ")
|
86 |
+
)
|
87 |
+
text = text.rstrip(" ")
|
88 |
+
|
89 |
+
return text
|
90 |
+
|
91 |
+
def __call__(self, text: Union[List[str], str]) -> torch.Tensor:
|
92 |
+
if self.do_pre_process:
|
93 |
+
if isinstance(text, str):
|
94 |
+
text = self.pre_process(text)
|
95 |
+
else:
|
96 |
+
text = [self.pre_process(t) for t in text]
|
97 |
+
tokens = self.tokenizer(text)["input_ids"]
|
98 |
+
if not self.add_end_token and tokens[-1] == self.sep_token_id:
|
99 |
+
tokens = tokens[:-1]
|
100 |
+
input_ids = self.transform(tokens)
|
101 |
+
|
102 |
+
return input_ids
|
103 |
+
|
104 |
+
|
105 |
+
def training_image_transform(
|
106 |
+
image_size: int = 384,
|
107 |
+
scale: Tuple[float, float] = (0.5, 1.0),
|
108 |
+
image_interpolation=transforms.InterpolationMode.BICUBIC,
|
109 |
+
mean: Tuple[float, float, float] = MEAN,
|
110 |
+
std_dev: Tuple[float, float, float] = STD_DEV,
|
111 |
+
) -> transforms.Compose:
|
112 |
+
return transforms.Compose(
|
113 |
+
[
|
114 |
+
transforms.RandomResizedCrop(
|
115 |
+
image_size, scale=scale, interpolation=image_interpolation
|
116 |
+
),
|
117 |
+
transforms.RandomHorizontalFlip(),
|
118 |
+
transforms.RandAugment(2, 7),
|
119 |
+
transforms.ToTensor(),
|
120 |
+
transforms.Normalize(mean, std_dev),
|
121 |
+
]
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
def testing_image_transform(
|
126 |
+
image_size: int = 384,
|
127 |
+
image_interpolation=transforms.InterpolationMode.BICUBIC,
|
128 |
+
mean: Tuple[float, float, float] = MEAN,
|
129 |
+
std_dev: Tuple[float, float, float] = STD_DEV,
|
130 |
+
) -> transforms.Compose:
|
131 |
+
return transforms.Compose(
|
132 |
+
[
|
133 |
+
transforms.Resize(
|
134 |
+
(image_size, image_size), interpolation=image_interpolation
|
135 |
+
),
|
136 |
+
transforms.ToTensor(),
|
137 |
+
transforms.Normalize(mean, std_dev),
|
138 |
+
]
|
139 |
+
)
|
data/vqa_datamodules.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import List, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from data.transforms import (
|
11 |
+
ALBEFTextTransform,
|
12 |
+
testing_image_transform,
|
13 |
+
training_image_transform,
|
14 |
+
)
|
15 |
+
from data.vqa_dataset import VQADataset
|
16 |
+
from pytorch_lightning import LightningDataModule
|
17 |
+
from torch import Tensor
|
18 |
+
from torch.nn.utils.rnn import pad_sequence
|
19 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
20 |
+
|
21 |
+
|
22 |
+
class VQADataModule(LightningDataModule):
|
23 |
+
"""
|
24 |
+
The Data Module for Visual Question Answering task.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
train_files (List[str]): The paths to training json files.
|
28 |
+
test_files (List[str]): The paths to testing json files.
|
29 |
+
answer_list (str): The path to the answers list.
|
30 |
+
vqa_root (str): The path to vqa data directory.
|
31 |
+
vg_root (str): The path to vg data directory.
|
32 |
+
batch_size (int): The sampling batch size.
|
33 |
+
num_workers (int): The number of workers for the distributed mode.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
train_files: List[str],
|
39 |
+
test_files: List[str],
|
40 |
+
answer_list: str,
|
41 |
+
vqa_root: str,
|
42 |
+
vg_root: str,
|
43 |
+
batch_size: int,
|
44 |
+
num_workers: int,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.train_dataset = VQADataset(
|
48 |
+
train_files,
|
49 |
+
vqa_root,
|
50 |
+
vg_root,
|
51 |
+
image_transform=training_image_transform(),
|
52 |
+
question_transform=ALBEFTextTransform(
|
53 |
+
truncate=True, max_seq_len=25, add_end_token=False
|
54 |
+
),
|
55 |
+
answer_transform=ALBEFTextTransform(do_pre_process=False),
|
56 |
+
split="train",
|
57 |
+
)
|
58 |
+
|
59 |
+
self.test_dataset = VQADataset(
|
60 |
+
test_files,
|
61 |
+
vqa_root,
|
62 |
+
vg_root,
|
63 |
+
image_transform=testing_image_transform(),
|
64 |
+
question_transform=ALBEFTextTransform(add_end_token=False),
|
65 |
+
answer_transform=ALBEFTextTransform(do_pre_process=False),
|
66 |
+
split="test",
|
67 |
+
answer_list=answer_list,
|
68 |
+
)
|
69 |
+
|
70 |
+
self.batch_size = batch_size
|
71 |
+
self.num_workers = num_workers
|
72 |
+
|
73 |
+
def _get_sampler(
|
74 |
+
self,
|
75 |
+
dataset: VQADataset,
|
76 |
+
shuffle: bool,
|
77 |
+
is_distributed: bool,
|
78 |
+
num_tasks: int,
|
79 |
+
global_rank: int,
|
80 |
+
) -> Optional[DistributedSampler]:
|
81 |
+
if not is_distributed:
|
82 |
+
return None
|
83 |
+
|
84 |
+
return DistributedSampler(
|
85 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
86 |
+
)
|
87 |
+
|
88 |
+
def train_dataloader(
|
89 |
+
self,
|
90 |
+
is_distributed: bool = False,
|
91 |
+
num_tasks: int = 0,
|
92 |
+
global_rank: int = 0,
|
93 |
+
drop_last: bool = True,
|
94 |
+
) -> DataLoader:
|
95 |
+
"""
|
96 |
+
DataLoader Outputs:
|
97 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
98 |
+
questions (Tensor): Tensor of shape (B, L) of question inputs.
|
99 |
+
question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
|
100 |
+
answers (Tensor): Tensor of shape (N, M) of answer inputs.
|
101 |
+
N >= B because a vqa sample can have multiple answers.
|
102 |
+
answer_atts (Tensor): Tensor of shape (N, M) of answer attention mask.
|
103 |
+
weights (Tensor): Tensor of shape (N) of answer weights.
|
104 |
+
ans_lengths (List[int]): List of length B and sum N where
|
105 |
+
ans_lengths[i] = number of answers for images[i] and questions[i].
|
106 |
+
"""
|
107 |
+
sampler = self._get_sampler(
|
108 |
+
dataset=self.train_dataset,
|
109 |
+
shuffle=True,
|
110 |
+
is_distributed=is_distributed,
|
111 |
+
num_tasks=num_tasks,
|
112 |
+
global_rank=global_rank,
|
113 |
+
)
|
114 |
+
shuffle = sampler is None
|
115 |
+
return DataLoader(
|
116 |
+
self.train_dataset,
|
117 |
+
batch_size=self.batch_size,
|
118 |
+
num_workers=self.num_workers,
|
119 |
+
pin_memory=True,
|
120 |
+
sampler=sampler,
|
121 |
+
shuffle=shuffle,
|
122 |
+
collate_fn=vqa_train_collate_fn,
|
123 |
+
drop_last=drop_last,
|
124 |
+
)
|
125 |
+
|
126 |
+
def test_dataloader(
|
127 |
+
self,
|
128 |
+
is_distributed: bool = False,
|
129 |
+
num_tasks: int = 0,
|
130 |
+
global_rank: int = 0,
|
131 |
+
drop_last=False,
|
132 |
+
) -> DataLoader:
|
133 |
+
"""
|
134 |
+
DataLoader Outputs:
|
135 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
136 |
+
questions (Tensor): Tensor of shape (B, L) of question inputs.
|
137 |
+
question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
|
138 |
+
question_ids (List): List of length B of question ids.
|
139 |
+
"""
|
140 |
+
sampler = self._get_sampler(
|
141 |
+
dataset=self.test_dataset,
|
142 |
+
shuffle=False,
|
143 |
+
is_distributed=is_distributed,
|
144 |
+
num_tasks=num_tasks,
|
145 |
+
global_rank=global_rank,
|
146 |
+
)
|
147 |
+
return DataLoader(
|
148 |
+
self.test_dataset,
|
149 |
+
batch_size=self.batch_size,
|
150 |
+
num_workers=self.num_workers,
|
151 |
+
pin_memory=True,
|
152 |
+
sampler=sampler,
|
153 |
+
shuffle=False,
|
154 |
+
collate_fn=vqa_test_collate_fn,
|
155 |
+
drop_last=drop_last,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
def vqa_train_collate_fn(
|
160 |
+
batch: List[Tuple[Tensor, Tensor, List[Tensor], List[float]]]
|
161 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[int]]:
|
162 |
+
image_list = []
|
163 |
+
question_list = []
|
164 |
+
answer_list = []
|
165 |
+
weight_list = []
|
166 |
+
ans_lengths = []
|
167 |
+
for image, question, answer, weights in batch:
|
168 |
+
image_list.append(image)
|
169 |
+
question_list.append(question)
|
170 |
+
answer_list += answer
|
171 |
+
weight_list += weights
|
172 |
+
ans_lengths.append(len(answer))
|
173 |
+
images = torch.stack(image_list, dim=0)
|
174 |
+
questions = pad_sequence(question_list, batch_first=True)
|
175 |
+
question_atts = (questions != 0).type(torch.long)
|
176 |
+
answers = pad_sequence(answer_list, batch_first=True)
|
177 |
+
answer_atts = (answers != 0).type(torch.long)
|
178 |
+
weights = torch.Tensor(weight_list)
|
179 |
+
return (
|
180 |
+
images,
|
181 |
+
questions,
|
182 |
+
question_atts,
|
183 |
+
answers,
|
184 |
+
answer_atts,
|
185 |
+
weights,
|
186 |
+
ans_lengths,
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
def vqa_test_collate_fn(
|
191 |
+
batch: List[Tuple[Tensor, Tensor, int]]
|
192 |
+
) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
|
193 |
+
image_list, question_list, question_ids = [], [], []
|
194 |
+
for image, question, question_id in batch:
|
195 |
+
image_list.append(image)
|
196 |
+
question_list.append(question)
|
197 |
+
question_ids.append(question_id)
|
198 |
+
images = torch.stack(image_list, dim=0)
|
199 |
+
questions = pad_sequence(question_list, batch_first=True)
|
200 |
+
question_atts = (questions != 0).type(torch.long)
|
201 |
+
return (
|
202 |
+
images,
|
203 |
+
questions,
|
204 |
+
question_atts,
|
205 |
+
question_ids,
|
206 |
+
)
|
data/vqa_dataset.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from typing import Callable, List, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
from torch import Tensor
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
|
17 |
+
|
18 |
+
class VQADataset(Dataset):
|
19 |
+
"""
|
20 |
+
Create the dataset for VQA task.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
ann_file (List[str]): The paths to annotation json files.
|
24 |
+
vqa_root (str): The path to vqa data directory.
|
25 |
+
vg_root (str): The path to vg data directory.
|
26 |
+
image_transform (Callable[[Image.Image], Tensor]): image data transform.
|
27 |
+
question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
|
28 |
+
answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
|
29 |
+
split (str): Indicates train or test. Default is train.
|
30 |
+
answer_list (str): The path to the answers list. Required for test split.
|
31 |
+
|
32 |
+
Dataset Outputs:
|
33 |
+
if split is train:
|
34 |
+
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
35 |
+
question (Tensor): Transformed question token input ids.
|
36 |
+
answers (List[Tensor]): List of transformed answers token input ids.
|
37 |
+
answer_weights (List[float]): List of answer weights.
|
38 |
+
answer_weights[i] is proportional to the number of occurences of answers[i]
|
39 |
+
if split is test:
|
40 |
+
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
41 |
+
question (Tensor): Transformed text token input ids.
|
42 |
+
question_id (int): The question sample id.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
ann_file: List[str],
|
48 |
+
vqa_root: str,
|
49 |
+
vg_root: str,
|
50 |
+
image_transform: Callable[[Image.Image], Tensor],
|
51 |
+
question_transform: Callable[[Union[List[str], str]], Tensor],
|
52 |
+
answer_transform: Callable[[Union[List[str], str]], Tensor],
|
53 |
+
split: str = "train",
|
54 |
+
answer_list: str = None,
|
55 |
+
) -> None:
|
56 |
+
self.ann = []
|
57 |
+
for f in ann_file:
|
58 |
+
self.ann += json.load(open(f, "r"))
|
59 |
+
|
60 |
+
self.vqa_root = vqa_root
|
61 |
+
self.vg_root = vg_root
|
62 |
+
self.image_transform = image_transform
|
63 |
+
self.question_transform = question_transform
|
64 |
+
self.answer_transform = answer_transform
|
65 |
+
self.split = split
|
66 |
+
|
67 |
+
if split == "test":
|
68 |
+
self.answer_list = json.load(open(answer_list, "r"))
|
69 |
+
self.answer_input_ids = self.answer_transform(self.answer_list)
|
70 |
+
self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
|
71 |
+
|
72 |
+
def __len__(self) -> int:
|
73 |
+
return len(self.ann)
|
74 |
+
|
75 |
+
def __getitem__(
|
76 |
+
self, index: int
|
77 |
+
) -> Union[
|
78 |
+
Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
|
79 |
+
]:
|
80 |
+
ann = self.ann[index]
|
81 |
+
|
82 |
+
image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
|
83 |
+
image_path = os.path.join(image_root, ann["image"])
|
84 |
+
image = Image.open(image_path).convert("RGB")
|
85 |
+
image = self.image_transform(image)
|
86 |
+
question = self.question_transform(ann["question"])
|
87 |
+
|
88 |
+
if self.split == "test":
|
89 |
+
return image, question, ann["question_id"]
|
90 |
+
|
91 |
+
elif self.split == "train":
|
92 |
+
if ann["dataset"] == "vqa":
|
93 |
+
# Each VQA sample question has a list of answers (with potential repeats)
|
94 |
+
# answer_weight[answer] = count(answer) / len(answers for the question)
|
95 |
+
answer_weights = {}
|
96 |
+
for answer in ann["answer"]:
|
97 |
+
if answer in answer_weights.keys():
|
98 |
+
answer_weights[answer] += 1 / len(ann["answer"])
|
99 |
+
else:
|
100 |
+
answer_weights[answer] = 1 / len(ann["answer"])
|
101 |
+
|
102 |
+
answers = list(answer_weights.keys())
|
103 |
+
answer_weights = list(answer_weights.values())
|
104 |
+
|
105 |
+
elif ann["dataset"] == "vg":
|
106 |
+
# A VG sample question has one answer so assign it a constant weight (0.5)
|
107 |
+
answers = [ann["answer"]]
|
108 |
+
answer_weights = [0.5]
|
109 |
+
|
110 |
+
answers = list(self.answer_transform(answers))
|
111 |
+
|
112 |
+
return image, question, answers, answer_weights
|
113 |
+
|
114 |
+
else:
|
115 |
+
raise ValueError("dataset split should be train or test")
|
finetune_retrieval.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import datetime
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
|
13 |
+
import ruamel.yaml as yaml
|
14 |
+
import torch
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
import torch.distributed as dist
|
17 |
+
from data.retrieval_datamodule import RetrievalDataModule
|
18 |
+
from model import albef_model_for_retrieval
|
19 |
+
from torch.optim import AdamW
|
20 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
21 |
+
from utils import (
|
22 |
+
add_weight_decay,
|
23 |
+
get_rank,
|
24 |
+
get_world_size,
|
25 |
+
init_distributed_mode,
|
26 |
+
is_dist_avail_and_initialized,
|
27 |
+
is_main_process,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def train(model, datamodule, args, device):
|
32 |
+
model.train()
|
33 |
+
|
34 |
+
model_without_ddp = model.module if is_dist_avail_and_initialized() else model
|
35 |
+
|
36 |
+
optimizer_params = add_weight_decay(model, args["weight_decay"])
|
37 |
+
optimizer = AdamW(optimizer_params, lr=args["lr"])
|
38 |
+
scheduler = CosineAnnealingWarmRestarts(
|
39 |
+
optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
|
40 |
+
)
|
41 |
+
|
42 |
+
step_size = args["step_size"]
|
43 |
+
warmup_steps = args["warmup_steps"]
|
44 |
+
warmup_iterations = warmup_steps * step_size
|
45 |
+
|
46 |
+
data_loader = datamodule.train_dataloader(
|
47 |
+
is_distributed=is_dist_avail_and_initialized(),
|
48 |
+
num_tasks=get_world_size(),
|
49 |
+
global_rank=get_rank(),
|
50 |
+
)
|
51 |
+
|
52 |
+
start_time = time.time()
|
53 |
+
|
54 |
+
for epoch in range(args["max_epochs"]):
|
55 |
+
if epoch > 0:
|
56 |
+
scheduler.step(epoch + warmup_steps)
|
57 |
+
|
58 |
+
for batch, (image, text, text_atts, idx) in enumerate(data_loader):
|
59 |
+
if epoch > 0:
|
60 |
+
alpha = args["alpha"]
|
61 |
+
else:
|
62 |
+
alpha = args["alpha"] * min(1, batch / len(data_loader))
|
63 |
+
|
64 |
+
image = image.to(device, non_blocking=True)
|
65 |
+
text = text.to(device)
|
66 |
+
text_atts = text_atts.to(device)
|
67 |
+
idx = idx.to(device, non_blocking=True)
|
68 |
+
loss = model(image, text, text_atts, idx, alpha, is_train=True)
|
69 |
+
|
70 |
+
optimizer.zero_grad()
|
71 |
+
loss.backward()
|
72 |
+
optimizer.step()
|
73 |
+
|
74 |
+
if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
|
75 |
+
scheduler.step(batch // step_size)
|
76 |
+
|
77 |
+
if batch % args["log_every_n_steps"] == 0:
|
78 |
+
total_time = time.time() - start_time
|
79 |
+
time_str = "time {},".format(
|
80 |
+
datetime.timedelta(seconds=int(total_time))
|
81 |
+
)
|
82 |
+
epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
|
83 |
+
batch_str = "batch {}/{},".format(batch, len(data_loader))
|
84 |
+
loss_str = "loss {}".format(loss.item())
|
85 |
+
print(time_str, epoch_str, batch_str, loss_str)
|
86 |
+
|
87 |
+
if is_main_process():
|
88 |
+
save_obj = {
|
89 |
+
"model": model_without_ddp.state_dict(),
|
90 |
+
"optimizer": optimizer.state_dict(),
|
91 |
+
"lr_scheduler": scheduler.state_dict(),
|
92 |
+
"epoch": epoch,
|
93 |
+
}
|
94 |
+
torch.save(
|
95 |
+
save_obj,
|
96 |
+
os.path.join(
|
97 |
+
args["checkpoint_root"], "retrieval_checkpoint_%02d.pt" % epoch
|
98 |
+
),
|
99 |
+
)
|
100 |
+
|
101 |
+
if is_dist_avail_and_initialized():
|
102 |
+
dist.barrier()
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def encode_text(model, text_dataloader, device):
|
108 |
+
text_embeds = []
|
109 |
+
text_feats = []
|
110 |
+
text_atts = []
|
111 |
+
for text, text_att in text_dataloader:
|
112 |
+
text = text.to(device)
|
113 |
+
text_att = text_att.to(device)
|
114 |
+
text_embed, text_feat = model(
|
115 |
+
text=text, text_atts=text_att, input_type="text", is_train=False
|
116 |
+
)
|
117 |
+
text_embeds.append(text_embed)
|
118 |
+
text_feats.append(text_feat)
|
119 |
+
text_atts.append(text_att)
|
120 |
+
text_embeds = torch.cat(text_embeds, dim=0)
|
121 |
+
text_feats = torch.cat(text_feats, dim=0)
|
122 |
+
text_atts = torch.cat(text_atts, dim=0)
|
123 |
+
return text_embeds, text_feats, text_atts
|
124 |
+
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def encode_image(model, image_dataloader, device):
|
128 |
+
image_embeds = []
|
129 |
+
image_feats = []
|
130 |
+
for image in image_dataloader:
|
131 |
+
image = image.to(device)
|
132 |
+
image_embed, image_feat = model(image=image, input_type="image", is_train=False)
|
133 |
+
image_embeds.append(image_embed)
|
134 |
+
image_feats.append(image_feat)
|
135 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
136 |
+
image_feats = torch.cat(image_feats, dim=0)
|
137 |
+
return image_embeds, image_feats
|
138 |
+
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
def image_to_text(
|
142 |
+
model,
|
143 |
+
image_embeds,
|
144 |
+
text_embeds,
|
145 |
+
text_atts,
|
146 |
+
sims_matrix,
|
147 |
+
num_images,
|
148 |
+
num_text,
|
149 |
+
device,
|
150 |
+
args,
|
151 |
+
):
|
152 |
+
start_time = time.time()
|
153 |
+
world_size = get_world_size()
|
154 |
+
rank = get_rank()
|
155 |
+
step = sims_matrix.size(0) // world_size + 1
|
156 |
+
start = rank * step
|
157 |
+
end = min(sims_matrix.size(0), start + step)
|
158 |
+
k = args["k_test"]
|
159 |
+
|
160 |
+
image_to_text_scores = torch.full((num_images, num_text), -100.0).to(device)
|
161 |
+
for i, sims in enumerate(sims_matrix[start:end]):
|
162 |
+
_, topk_idx = sims.topk(k, dim=0)
|
163 |
+
|
164 |
+
score = model(
|
165 |
+
image=image_embeds[start + i].repeat(k, 1, 1),
|
166 |
+
text=text_embeds[topk_idx],
|
167 |
+
text_atts=text_atts[topk_idx],
|
168 |
+
input_type="multimodal",
|
169 |
+
is_train=False,
|
170 |
+
)
|
171 |
+
image_to_text_scores[start + i, topk_idx] = score
|
172 |
+
|
173 |
+
if i % args["log_every_n_steps"] == 0:
|
174 |
+
total_time = time.time() - start_time
|
175 |
+
time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
|
176 |
+
batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
|
177 |
+
print("image to text retrieval", time_str, batch_str)
|
178 |
+
return image_to_text_scores
|
179 |
+
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def text_to_image(
|
183 |
+
model,
|
184 |
+
image_embeds,
|
185 |
+
text_embeds,
|
186 |
+
text_atts,
|
187 |
+
sims_matrix,
|
188 |
+
num_images,
|
189 |
+
num_text,
|
190 |
+
device,
|
191 |
+
args,
|
192 |
+
):
|
193 |
+
start_time = time.time()
|
194 |
+
world_size = get_world_size()
|
195 |
+
rank = get_rank()
|
196 |
+
step = sims_matrix.size(0) // world_size + 1
|
197 |
+
start = rank * step
|
198 |
+
end = min(sims_matrix.size(0), start + step)
|
199 |
+
k = args["k_test"]
|
200 |
+
|
201 |
+
text_to_image_scores = torch.full((num_text, num_images), -100.0).to(device)
|
202 |
+
for i, sims in enumerate(sims_matrix[start:end]):
|
203 |
+
_, topk_idx = sims.topk(k, dim=0)
|
204 |
+
score = model(
|
205 |
+
image=image_embeds[topk_idx],
|
206 |
+
text=text_embeds[start + i].repeat(k, 1, 1),
|
207 |
+
text_atts=text_atts[start + i].repeat(k, 1, 1),
|
208 |
+
input_type="multimodal",
|
209 |
+
is_train=False,
|
210 |
+
)
|
211 |
+
text_to_image_scores[start + i, topk_idx] = score
|
212 |
+
|
213 |
+
if i % args["log_every_n_steps"] == 0:
|
214 |
+
total_time = time.time() - start_time
|
215 |
+
time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
|
216 |
+
batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
|
217 |
+
print("text to image retrieval", time_str, batch_str)
|
218 |
+
return text_to_image_scores
|
219 |
+
|
220 |
+
|
221 |
+
@torch.no_grad()
|
222 |
+
def evaluation(model, datamodule, args, device):
|
223 |
+
model.eval()
|
224 |
+
|
225 |
+
text_loader = datamodule.text_dataloader()
|
226 |
+
image_loader = datamodule.image_dataloader()
|
227 |
+
num_images = len(datamodule.image_dataset)
|
228 |
+
num_text = len(datamodule.text_dataset)
|
229 |
+
|
230 |
+
text_embeds, text_feats, text_atts = encode_text(model, text_loader, device)
|
231 |
+
image_embeds, image_feats = encode_image(model, image_loader, device)
|
232 |
+
|
233 |
+
sims_matrix = image_feats @ text_feats.t()
|
234 |
+
image_to_text_scores = image_to_text(
|
235 |
+
model,
|
236 |
+
image_embeds,
|
237 |
+
text_embeds,
|
238 |
+
text_atts,
|
239 |
+
sims_matrix,
|
240 |
+
num_images,
|
241 |
+
num_text,
|
242 |
+
device,
|
243 |
+
args,
|
244 |
+
)
|
245 |
+
|
246 |
+
sims_matrix = sims_matrix.t()
|
247 |
+
text_to_image_scores = text_to_image(
|
248 |
+
model,
|
249 |
+
image_embeds,
|
250 |
+
text_embeds,
|
251 |
+
text_atts,
|
252 |
+
sims_matrix,
|
253 |
+
num_images,
|
254 |
+
num_text,
|
255 |
+
device,
|
256 |
+
args,
|
257 |
+
)
|
258 |
+
|
259 |
+
if is_dist_avail_and_initialized():
|
260 |
+
dist.barrier()
|
261 |
+
torch.distributed.all_reduce(
|
262 |
+
image_to_text_scores, op=torch.distributed.ReduceOp.SUM
|
263 |
+
)
|
264 |
+
torch.distributed.all_reduce(
|
265 |
+
text_to_image_scores, op=torch.distributed.ReduceOp.SUM
|
266 |
+
)
|
267 |
+
|
268 |
+
return image_to_text_scores.cpu(), text_to_image_scores.cpu()
|
269 |
+
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def itm_eval(
|
273 |
+
image_to_text_scores,
|
274 |
+
text_to_image_scores,
|
275 |
+
image_to_text_mapping,
|
276 |
+
text_to_image_mapping,
|
277 |
+
):
|
278 |
+
# Images to Text
|
279 |
+
ranks = torch.zeros(image_to_text_scores.size(0))
|
280 |
+
for index, score in enumerate(image_to_text_scores):
|
281 |
+
inds = torch.flip(torch.argsort(score), dims=[0])
|
282 |
+
rank = 1e10
|
283 |
+
# each image has multiple text mappings
|
284 |
+
# check retrieved inds with each ground truth mappping i
|
285 |
+
for i in image_to_text_mapping[index]:
|
286 |
+
tmp = torch.where(inds == i)[0][0]
|
287 |
+
if tmp < rank:
|
288 |
+
rank = tmp
|
289 |
+
ranks[index] = rank
|
290 |
+
|
291 |
+
# Compute metrics
|
292 |
+
tr1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
|
293 |
+
tr5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
|
294 |
+
tr10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
|
295 |
+
|
296 |
+
# Text to Images
|
297 |
+
ranks = torch.zeros(text_to_image_scores.size(0))
|
298 |
+
for index, score in enumerate(text_to_image_scores):
|
299 |
+
inds = torch.flip(torch.argsort(score), dims=[0])
|
300 |
+
ranks[index] = torch.where(inds == text_to_image_mapping[index])[0][0]
|
301 |
+
|
302 |
+
# Compute metrics
|
303 |
+
ir1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
|
304 |
+
ir5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
|
305 |
+
ir10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
|
306 |
+
|
307 |
+
tr_mean = (tr1 + tr5 + tr10) / 3
|
308 |
+
ir_mean = (ir1 + ir5 + ir10) / 3
|
309 |
+
r_mean = (tr_mean + ir_mean) / 2
|
310 |
+
|
311 |
+
eval_result = {
|
312 |
+
"txt_r1": tr1,
|
313 |
+
"txt_r5": tr5,
|
314 |
+
"txt_r10": tr10,
|
315 |
+
"txt_r_mean": tr_mean,
|
316 |
+
"img_r1": ir1,
|
317 |
+
"img_r5": ir5,
|
318 |
+
"img_r10": ir10,
|
319 |
+
"img_r_mean": ir_mean,
|
320 |
+
"r_mean": r_mean,
|
321 |
+
}
|
322 |
+
return eval_result
|
323 |
+
|
324 |
+
|
325 |
+
@torch.no_grad()
|
326 |
+
def format_output(
|
327 |
+
image_to_text_scores,
|
328 |
+
text_to_image_scores,
|
329 |
+
image_dataset,
|
330 |
+
text_dataset,
|
331 |
+
):
|
332 |
+
image_to_text_output = {}
|
333 |
+
for index, score in enumerate(image_to_text_scores):
|
334 |
+
image = image_dataset.images[index]
|
335 |
+
top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
|
336 |
+
top10_text = [text_dataset.text[i] for i in top10_ids]
|
337 |
+
image_to_text_output[index] = {
|
338 |
+
"image": image,
|
339 |
+
"output": top10_text,
|
340 |
+
}
|
341 |
+
text_to_image_output = {}
|
342 |
+
for index, score in enumerate(text_to_image_scores):
|
343 |
+
text = text_dataset.text[index]
|
344 |
+
top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
|
345 |
+
top10_images = [image_dataset.images[i] for i in top10_ids]
|
346 |
+
text_to_image_output[index] = {
|
347 |
+
"text": text,
|
348 |
+
"output": top10_images,
|
349 |
+
}
|
350 |
+
return image_to_text_output, text_to_image_output
|
351 |
+
|
352 |
+
|
353 |
+
def main():
|
354 |
+
parser = argparse.ArgumentParser()
|
355 |
+
parser.add_argument("--config", default="./examples/albef/configs/retrieval.yaml")
|
356 |
+
args = parser.parse_args()
|
357 |
+
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
|
358 |
+
|
359 |
+
init_distributed_mode(config)
|
360 |
+
device = torch.device(config["device"])
|
361 |
+
|
362 |
+
seed = config["seed"] + get_rank()
|
363 |
+
torch.manual_seed(seed)
|
364 |
+
random.seed(seed)
|
365 |
+
cudnn.benchmark = True
|
366 |
+
|
367 |
+
datamodule = RetrievalDataModule(**config["datamodule_args"])
|
368 |
+
model = albef_model_for_retrieval(config, pretrained=True)
|
369 |
+
model = model.to(device)
|
370 |
+
if is_dist_avail_and_initialized():
|
371 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
372 |
+
model, device_ids=[config["gpu"]]
|
373 |
+
)
|
374 |
+
|
375 |
+
train(model, datamodule, config["training_args"], device)
|
376 |
+
image_to_text_scores, text_to_image_scores = evaluation(
|
377 |
+
model, datamodule, config["eval_args"], device
|
378 |
+
)
|
379 |
+
val_result = itm_eval(
|
380 |
+
image_to_text_scores,
|
381 |
+
text_to_image_scores,
|
382 |
+
datamodule.image_dataset.image_to_text,
|
383 |
+
datamodule.text_dataset.text_to_image,
|
384 |
+
)
|
385 |
+
image_to_text_output, text_to_image_output = format_output(
|
386 |
+
image_to_text_scores,
|
387 |
+
text_to_image_scores,
|
388 |
+
datamodule.image_dataset,
|
389 |
+
datamodule.text_dataset,
|
390 |
+
)
|
391 |
+
result = {
|
392 |
+
"image_to_text_output": image_to_text_output,
|
393 |
+
"text_to_image_output": text_to_image_output,
|
394 |
+
**val_result,
|
395 |
+
}
|
396 |
+
torch.save(result, config["output_path"])
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
main()
|
finetune_vqa.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import datetime
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
|
13 |
+
import ruamel.yaml as yaml
|
14 |
+
import torch
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
import torch.distributed as dist
|
17 |
+
from data.vqa_datamodules import VQADataModule
|
18 |
+
from model import albef_model_for_vqa
|
19 |
+
from torch.optim import AdamW
|
20 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
21 |
+
|
22 |
+
from utils import (
|
23 |
+
add_weight_decay,
|
24 |
+
get_rank,
|
25 |
+
get_world_size,
|
26 |
+
init_distributed_mode,
|
27 |
+
is_dist_avail_and_initialized,
|
28 |
+
is_main_process,
|
29 |
+
save_result,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def train(model, datamodule, args, device):
|
34 |
+
model_without_ddp = model.module if is_dist_avail_and_initialized() else model
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
optimizer_params = add_weight_decay(model, args["weight_decay"])
|
38 |
+
optimizer = AdamW(optimizer_params, lr=args["lr"])
|
39 |
+
scheduler = CosineAnnealingWarmRestarts(
|
40 |
+
optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
|
41 |
+
)
|
42 |
+
|
43 |
+
step_size = args["step_size"]
|
44 |
+
warmup_steps = args["warmup_steps"]
|
45 |
+
warmup_iterations = warmup_steps * step_size
|
46 |
+
|
47 |
+
data_loader = datamodule.train_dataloader(
|
48 |
+
is_distributed=is_dist_avail_and_initialized(),
|
49 |
+
num_tasks=get_world_size(),
|
50 |
+
global_rank=get_rank(),
|
51 |
+
)
|
52 |
+
|
53 |
+
start_time = time.time()
|
54 |
+
|
55 |
+
for epoch in range(args["max_epochs"]):
|
56 |
+
if is_dist_avail_and_initialized():
|
57 |
+
data_loader.sampler.set_epoch(epoch)
|
58 |
+
|
59 |
+
if epoch > 0:
|
60 |
+
scheduler.step(epoch + warmup_steps)
|
61 |
+
|
62 |
+
for batch, (
|
63 |
+
images,
|
64 |
+
questions,
|
65 |
+
questions_atts,
|
66 |
+
answers,
|
67 |
+
answers_atts,
|
68 |
+
ans_weights,
|
69 |
+
ans_lengths,
|
70 |
+
) in enumerate(data_loader):
|
71 |
+
if epoch > 0:
|
72 |
+
alpha = args["alpha"]
|
73 |
+
else:
|
74 |
+
alpha = args["alpha"] * min(1, batch / len(data_loader))
|
75 |
+
|
76 |
+
images = images.to(device, non_blocking=True)
|
77 |
+
questions = questions.to(device)
|
78 |
+
questions_atts = questions_atts.to(device)
|
79 |
+
answers = answers.to(device)
|
80 |
+
answers_atts = answers_atts.to(device)
|
81 |
+
ans_weights = ans_weights.to(device)
|
82 |
+
|
83 |
+
loss = model(
|
84 |
+
images,
|
85 |
+
questions,
|
86 |
+
questions_atts,
|
87 |
+
answers,
|
88 |
+
answers_atts,
|
89 |
+
ans_weights=ans_weights,
|
90 |
+
ans_lengths=ans_lengths,
|
91 |
+
alpha=alpha,
|
92 |
+
is_train=True,
|
93 |
+
)
|
94 |
+
|
95 |
+
optimizer.zero_grad()
|
96 |
+
loss.backward()
|
97 |
+
optimizer.step()
|
98 |
+
|
99 |
+
if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
|
100 |
+
scheduler.step(batch // step_size)
|
101 |
+
|
102 |
+
if batch % args["log_every_n_steps"] == 0:
|
103 |
+
total_time = time.time() - start_time
|
104 |
+
time_str = "time {},".format(
|
105 |
+
datetime.timedelta(seconds=int(total_time))
|
106 |
+
)
|
107 |
+
epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
|
108 |
+
batch_str = "batch {}/{},".format(batch, len(data_loader))
|
109 |
+
loss_str = "loss {}".format(loss.item())
|
110 |
+
print(time_str, epoch_str, batch_str, loss_str)
|
111 |
+
|
112 |
+
if is_main_process():
|
113 |
+
save_obj = {
|
114 |
+
"model": model_without_ddp.state_dict(),
|
115 |
+
"optimizer": optimizer.state_dict(),
|
116 |
+
"scheduler": scheduler.state_dict(),
|
117 |
+
"epoch": epoch,
|
118 |
+
}
|
119 |
+
torch.save(
|
120 |
+
save_obj,
|
121 |
+
os.path.join(args["checkpoint_root"], "vqa_checkpoint_%02d.pt" % epoch),
|
122 |
+
)
|
123 |
+
|
124 |
+
if is_dist_avail_and_initialized():
|
125 |
+
dist.barrier()
|
126 |
+
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def evaluation(model, datamodule, args, device):
|
130 |
+
model.eval()
|
131 |
+
|
132 |
+
result = []
|
133 |
+
|
134 |
+
answer_list = datamodule.test_dataset.answer_list
|
135 |
+
answer_input_ids = datamodule.test_dataset.answer_input_ids.to(device)
|
136 |
+
answer_atts = datamodule.test_dataset.answer_attention_mask.to(device)
|
137 |
+
data_loader = datamodule.test_dataloader(
|
138 |
+
is_distributed=is_dist_avail_and_initialized(),
|
139 |
+
num_tasks=get_world_size(),
|
140 |
+
global_rank=get_rank(),
|
141 |
+
)
|
142 |
+
|
143 |
+
start_time = time.time()
|
144 |
+
|
145 |
+
for batch, (img, ques, ques_atts, ques_ids) in enumerate(data_loader):
|
146 |
+
img = img.to(device, non_blocking=True)
|
147 |
+
ques = ques.to(device)
|
148 |
+
ques_atts = ques_atts.to(device)
|
149 |
+
|
150 |
+
topk_ids, topk_probs = model(
|
151 |
+
img,
|
152 |
+
ques,
|
153 |
+
ques_atts,
|
154 |
+
answer_input_ids,
|
155 |
+
answer_atts,
|
156 |
+
k=args["k_test"],
|
157 |
+
is_train=False,
|
158 |
+
)
|
159 |
+
|
160 |
+
for ques_id, topk_id, topk_prob in zip(ques_ids, topk_ids, topk_probs):
|
161 |
+
_, pred = topk_prob.max(dim=0)
|
162 |
+
result.append(
|
163 |
+
{"question_id": ques_id, "answer": answer_list[topk_id[pred]]}
|
164 |
+
)
|
165 |
+
|
166 |
+
if batch % args["log_every_n_steps"] == 0:
|
167 |
+
total_time = time.time() - start_time
|
168 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
169 |
+
print(
|
170 |
+
"time {}, batch {}/{}".format(total_time_str, batch, len(data_loader))
|
171 |
+
)
|
172 |
+
|
173 |
+
return result
|
174 |
+
|
175 |
+
|
176 |
+
def main():
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
parser.add_argument("--config", default="./examples/albef/configs/vqa.yaml")
|
179 |
+
args = parser.parse_args()
|
180 |
+
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
|
181 |
+
|
182 |
+
init_distributed_mode(config)
|
183 |
+
device = torch.device(config["device"])
|
184 |
+
|
185 |
+
seed = config["seed"] + get_rank()
|
186 |
+
torch.manual_seed(seed)
|
187 |
+
random.seed(seed)
|
188 |
+
cudnn.benchmark = True
|
189 |
+
|
190 |
+
datamodule = VQADataModule(**config["datamodule_args"])
|
191 |
+
model = albef_model_for_vqa(config, pretrained=True)
|
192 |
+
model = model.to(device)
|
193 |
+
if is_dist_avail_and_initialized():
|
194 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
195 |
+
model, device_ids=[config["gpu"]]
|
196 |
+
)
|
197 |
+
|
198 |
+
train(model, datamodule, config["training_args"], device)
|
199 |
+
result = evaluation(model, datamodule, config["eval_args"], device)
|
200 |
+
save_result(result, config["output_root"], "vqa_output")
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
main()
|
images/COCO_val2014_000000026348.jpg
ADDED
images/COCO_val2014_000000057222.jpg
ADDED
images/COCO_val2014_000000111207.jpg
ADDED
images/COCO_val2014_000000159269.jpg
ADDED
images/COCO_val2014_000000184359.jpg
ADDED
images/COCO_val2014_000000407072.jpg
ADDED
images/COCO_val2014_000000473994.jpg
ADDED
images/COCO_val2014_000000552075.jpg
ADDED
model.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import copy
|
8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn, Tensor
|
13 |
+
from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder
|
14 |
+
from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity
|
15 |
+
from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder
|
16 |
+
from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
|
17 |
+
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
|
18 |
+
from torchmultimodal.modules.losses.albef import (
|
19 |
+
CausalLanguageModelingLoss,
|
20 |
+
ImageTextContrastiveLoss,
|
21 |
+
)
|
22 |
+
from torchmultimodal.utils.attention import get_causal_attention_mask
|
23 |
+
from torchmultimodal.utils.common import momentum_update, remove_grad
|
24 |
+
|
25 |
+
|
26 |
+
_ALBEF_PRETRAINED_URLS = {
|
27 |
+
"vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt",
|
28 |
+
"retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class PredictionHead(nn.Module):
|
33 |
+
"""
|
34 |
+
Predict the following token autoregressively.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_size (int): The number of different tokens the prediction_head can predict.
|
38 |
+
hidden_size (int): The hidden size of the prediction_head.
|
39 |
+
layer_norm_eps (float): The epsilon used by the prediction_head normalization layer.
|
40 |
+
transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head.
|
41 |
+
|
42 |
+
Inputs:
|
43 |
+
hidden_states (Tensor): The hidden states of preceding tokens.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Tensor: Prediction scores for the following token.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
vocab_size: int = 30522,
|
52 |
+
hidden_size: int = 768,
|
53 |
+
layer_norm_eps: float = 1e-12,
|
54 |
+
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
|
55 |
+
) -> None:
|
56 |
+
super().__init__()
|
57 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
58 |
+
self.transform_act_fn = transform_act_fn
|
59 |
+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
60 |
+
self.decoder = nn.Linear(hidden_size, vocab_size)
|
61 |
+
|
62 |
+
def forward(self, hidden_states: Tensor) -> Tensor:
|
63 |
+
hidden_states = self.dense(hidden_states)
|
64 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
65 |
+
hidden_states = self.layer_norm(hidden_states)
|
66 |
+
hidden_states = self.decoder(hidden_states)
|
67 |
+
return hidden_states
|
68 |
+
|
69 |
+
|
70 |
+
class ALBEFDecoder(nn.Module):
|
71 |
+
"""
|
72 |
+
Generate the prediction scores for answers from image and question hidden states.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings.
|
76 |
+
multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder.
|
77 |
+
prediction_head (PredictionHead): Instantiated PredictionHead.
|
78 |
+
|
79 |
+
Inputs:
|
80 |
+
input_ids (Tensor of shape (batch_size, seq_len)):
|
81 |
+
Input ids for input text tokens.
|
82 |
+
attention_mask (Tensor of shape (batch_size, seq_len)):
|
83 |
+
Input attention mask to avoid performing attention on padding token indices.
|
84 |
+
encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)):
|
85 |
+
The encoder hidden states.
|
86 |
+
encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)):
|
87 |
+
The attention mask for encoder hidden states.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
Tensor: Prediction scores for answers.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
text_embeddings: BERTTextEmbeddings,
|
96 |
+
multimodal_encoder: ALBEFMultimodalEncoder,
|
97 |
+
prediction_head: PredictionHead,
|
98 |
+
) -> None:
|
99 |
+
super().__init__()
|
100 |
+
self.text_embeddings = text_embeddings
|
101 |
+
self.multimodal_encoder = multimodal_encoder
|
102 |
+
self.prediction_head = prediction_head
|
103 |
+
|
104 |
+
def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor:
|
105 |
+
"""
|
106 |
+
Apply a causal mask in addition to the padding mask and make the mask broadcastable,
|
107 |
+
such that future and masked tokens are ignored.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
attention_mask (Tensor):
|
111 |
+
Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
extended_attention_mask (Tensor):
|
115 |
+
The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``.
|
116 |
+
"""
|
117 |
+
device = attention_mask.device
|
118 |
+
batch_size, seq_length = attention_mask.shape
|
119 |
+
causal_mask = get_causal_attention_mask(seq_length).to(device)
|
120 |
+
causal_mask = causal_mask.repeat(batch_size, 1).view(
|
121 |
+
batch_size, seq_length, seq_length
|
122 |
+
)
|
123 |
+
extended_attention_mask = (
|
124 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
125 |
+
)
|
126 |
+
extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype)
|
127 |
+
return extended_attention_mask
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
input_ids: Tensor,
|
132 |
+
attention_mask: Tensor,
|
133 |
+
encoder_hidden_states: Tensor,
|
134 |
+
encoder_attention_mask: Tensor,
|
135 |
+
) -> Tensor:
|
136 |
+
hidden_states = self.text_embeddings(input_ids)
|
137 |
+
attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask)
|
138 |
+
decoder_output = self.multimodal_encoder(
|
139 |
+
hidden_states=hidden_states,
|
140 |
+
attention_mask=attention_mask,
|
141 |
+
encoder_hidden_states=encoder_hidden_states,
|
142 |
+
encoder_attention_mask=encoder_attention_mask,
|
143 |
+
)
|
144 |
+
prediction_scores = self.prediction_head(decoder_output)
|
145 |
+
return prediction_scores
|
146 |
+
|
147 |
+
|
148 |
+
class ALBEFModelForVQA(nn.Module):
|
149 |
+
"""
|
150 |
+
ALBEF Model for VQA finetuning and inference.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
model (ALBEFModel): Instantiated ALBEFModel.
|
154 |
+
answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder.
|
155 |
+
loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss.
|
156 |
+
|
157 |
+
Inputs:
|
158 |
+
image (Tensor of shape (B, C, H, W)): Image features.
|
159 |
+
question (Tensor of shape (B, L)): Question text features.
|
160 |
+
question_atts (Tensor of shape (B, L)): Question attention mask.
|
161 |
+
answers (Tensor of shape (N, M)): Answer text features.
|
162 |
+
answers_atts (Tensor of shape (N, M)): Answer attention mask.
|
163 |
+
ans_weights (Optional[Tensor] of shape (N)): Weights for each answer.
|
164 |
+
Required if is_train is True.
|
165 |
+
ans_lengths (Optional[List[int]] of length B): Number of answers for each question.
|
166 |
+
ans_lengths should sum to N.
|
167 |
+
Required if is_train is True.
|
168 |
+
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
|
169 |
+
Required if is_train is True.
|
170 |
+
k (Optional[int]): The number of answers to return for inference.
|
171 |
+
Required if is_train is False.
|
172 |
+
is_train (Optional[bool]): Whether the model is in training.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
is_train is True:
|
176 |
+
Tensor: The masked language modeling loss for input.
|
177 |
+
is_train is False:
|
178 |
+
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
model: ALBEFModel,
|
184 |
+
answer_decoder: ALBEFDecoder,
|
185 |
+
loss: CausalLanguageModelingLoss,
|
186 |
+
) -> None:
|
187 |
+
super().__init__()
|
188 |
+
self.model = model
|
189 |
+
self.answer_decoder = answer_decoder
|
190 |
+
self.loss = loss
|
191 |
+
self.answer_decoder_m = copy.deepcopy(self.answer_decoder)
|
192 |
+
remove_grad(
|
193 |
+
self.answer_decoder_m
|
194 |
+
) # remove gradient for the momentum decoder model
|
195 |
+
|
196 |
+
def _train_forward(
|
197 |
+
self,
|
198 |
+
image: Tensor,
|
199 |
+
question: Tensor,
|
200 |
+
question_atts: Tensor,
|
201 |
+
answers: Tensor,
|
202 |
+
answers_atts: Tensor,
|
203 |
+
ans_weights: Tensor,
|
204 |
+
ans_lengths: List[int],
|
205 |
+
alpha: float,
|
206 |
+
) -> Tensor:
|
207 |
+
"""
|
208 |
+
Forward step for training. Encode the inputs with the ALBEFModel.
|
209 |
+
Generate pseudo-targets using answer_decoder_m (momentum decoder model).
|
210 |
+
Generate answer predictions using answer_decoder.
|
211 |
+
Compute masked language modeling loss of the predictions using answers as labels,
|
212 |
+
pseudo-targets as soft-labels, and alpha as their interpolation value.
|
213 |
+
|
214 |
+
Inputs:
|
215 |
+
image (Tensor of shape (B, C, H, W)): Image features.
|
216 |
+
question (Tensor of shape (B, L)): Question text features.
|
217 |
+
question_atts (Tensor of shape (B, L)): Question attention mask.
|
218 |
+
answers (Tensor of shape (N, M)): Answer text features.
|
219 |
+
answers_atts (Tensor of shape (N, M)): Answer attention mask.
|
220 |
+
ans_weights (Tensor of shape (N)): Weights for each answer.
|
221 |
+
ans_lengths (List[int] of length B): Number of answers for each question.
|
222 |
+
ans_lengths should sum to N.
|
223 |
+
alpha (float): The interpolation value between clm_loss and loss_distill.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Tensor: The masked language modeling loss for input.
|
227 |
+
"""
|
228 |
+
# get image-question embeddings from the ALBEFModel and format it to match the ans_lengths
|
229 |
+
encoder_outputs = self.model(image, question, question_atts)
|
230 |
+
(
|
231 |
+
encoder_hidden_states,
|
232 |
+
encoder_hidden_states_m,
|
233 |
+
encoder_attention_mask,
|
234 |
+
) = self._encoder_hidden_states(
|
235 |
+
encoder_outputs.multimodal_embeddings,
|
236 |
+
encoder_outputs.multimodal_embeddings_m,
|
237 |
+
question_atts,
|
238 |
+
ans_lengths,
|
239 |
+
)
|
240 |
+
|
241 |
+
# use the momentum model to generate pseudo-targets
|
242 |
+
with torch.no_grad():
|
243 |
+
momentum_update(
|
244 |
+
self.answer_decoder, self.answer_decoder_m, self.model.momentum
|
245 |
+
)
|
246 |
+
prediction_scores_m = self.answer_decoder_m(
|
247 |
+
input_ids=answers,
|
248 |
+
attention_mask=answers_atts,
|
249 |
+
encoder_hidden_states=encoder_hidden_states_m,
|
250 |
+
encoder_attention_mask=encoder_attention_mask,
|
251 |
+
)
|
252 |
+
|
253 |
+
# generate answer predictions
|
254 |
+
prediction_scores = self.answer_decoder(
|
255 |
+
input_ids=answers,
|
256 |
+
attention_mask=answers_atts,
|
257 |
+
encoder_hidden_states=encoder_hidden_states,
|
258 |
+
encoder_attention_mask=encoder_attention_mask,
|
259 |
+
)
|
260 |
+
|
261 |
+
# compute masked language modeling loss from the prediction scores
|
262 |
+
labels = answers.masked_fill(answers == 0, self.loss.mask_token_id)
|
263 |
+
loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha)
|
264 |
+
loss = ans_weights * loss
|
265 |
+
loss = loss.sum() / image.size(0)
|
266 |
+
return loss
|
267 |
+
|
268 |
+
def _eval_forward(
|
269 |
+
self,
|
270 |
+
image: Tensor,
|
271 |
+
question: Tensor,
|
272 |
+
question_atts: Tensor,
|
273 |
+
answers: Tensor,
|
274 |
+
answer_atts: Tensor,
|
275 |
+
k: int = 128,
|
276 |
+
) -> Tuple[Tensor, Tensor]:
|
277 |
+
"""
|
278 |
+
Forward step for evaluation. Encode the inputs with the ALBEFModel.
|
279 |
+
Generate answer autoregressively using the decoder, starting with the [CLS] token.
|
280 |
+
Compute the answer ids and their perspective probabilities of the top k predictions.
|
281 |
+
|
282 |
+
Inputs:
|
283 |
+
image (Tensor of shape (B, C, H, W)): Image features.
|
284 |
+
question (Tensor of shape (B, L)): Question text features.
|
285 |
+
question_atts (Tensor of shape (B, L)): Question attention mask.
|
286 |
+
answers (Tensor of shape (N, M)): Answer text features.
|
287 |
+
answer_atts (Tensor of shape (N, M)): Answer attention mask.
|
288 |
+
k (int): The number of answers to return for inference.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
|
292 |
+
"""
|
293 |
+
# get multimodal embeddings from the ALBEFModel and
|
294 |
+
# feed it to the decoder as cross attention
|
295 |
+
encoder_outputs = self.model(image, question, question_atts)
|
296 |
+
|
297 |
+
# use cls token as the decoder's initial input token
|
298 |
+
num_ques = question.size(0)
|
299 |
+
start_ids = answers[0, 0].repeat(num_ques, 1)
|
300 |
+
atts = torch.ones(start_ids.shape).to(image.device)
|
301 |
+
|
302 |
+
# auto-regressively generates the answer
|
303 |
+
prediction_scores = self.answer_decoder(
|
304 |
+
input_ids=start_ids,
|
305 |
+
attention_mask=atts,
|
306 |
+
encoder_hidden_states=encoder_outputs.multimodal_embeddings,
|
307 |
+
encoder_attention_mask=question_atts,
|
308 |
+
)
|
309 |
+
|
310 |
+
logits = prediction_scores[:, 0, :]
|
311 |
+
answer_first_token = answers[:, 1]
|
312 |
+
prob_first_token = F.softmax(logits, dim=1).index_select(
|
313 |
+
dim=1, index=answer_first_token
|
314 |
+
)
|
315 |
+
topk_probs, topk_ids = prob_first_token.topk(k, dim=1)
|
316 |
+
|
317 |
+
input_ids = []
|
318 |
+
input_atts = []
|
319 |
+
for topk_id in topk_ids:
|
320 |
+
input_ids.append(answers.index_select(dim=0, index=topk_id))
|
321 |
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
322 |
+
input_ids = torch.cat(input_ids)
|
323 |
+
input_atts = torch.cat(input_atts)
|
324 |
+
targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id)
|
325 |
+
|
326 |
+
question_states = encoder_outputs.multimodal_embeddings.repeat_interleave(
|
327 |
+
k, dim=0
|
328 |
+
)
|
329 |
+
question_atts = question_atts.repeat_interleave(k, dim=0)
|
330 |
+
|
331 |
+
prediction_scores = self.answer_decoder(
|
332 |
+
input_ids=input_ids,
|
333 |
+
attention_mask=input_atts,
|
334 |
+
encoder_hidden_states=question_states,
|
335 |
+
encoder_attention_mask=question_atts,
|
336 |
+
)
|
337 |
+
|
338 |
+
answer_loss = self.loss(targets_ids, prediction_scores)
|
339 |
+
answer_loss = answer_loss.view(input_ids.size(0), -1)
|
340 |
+
|
341 |
+
# topk_prob: first token probability
|
342 |
+
topk_probs = topk_probs.view(-1, 1)
|
343 |
+
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)
|
344 |
+
|
345 |
+
# re-calculate log probabilities for the answer sequences using chain rule
|
346 |
+
log_probs_sum = log_probs.sum(1)
|
347 |
+
log_probs_sum = log_probs_sum.view(num_ques, k)
|
348 |
+
|
349 |
+
topk_probs = F.softmax(log_probs_sum, dim=-1)
|
350 |
+
|
351 |
+
# get top-k after re-ranking
|
352 |
+
topk_probs, rerank_id = topk_probs.topk(k, dim=1)
|
353 |
+
topk_ids = torch.gather(topk_ids, 1, rerank_id)
|
354 |
+
|
355 |
+
return topk_ids, topk_probs
|
356 |
+
|
357 |
+
def _encoder_hidden_states(
|
358 |
+
self,
|
359 |
+
multimodal_embeds: Tensor,
|
360 |
+
multimodal_embeds_m: Tensor,
|
361 |
+
question_atts: Tensor,
|
362 |
+
ans_lengths: List[int],
|
363 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
364 |
+
"""
|
365 |
+
Repeat each image-question input, repeat its embedding and mask to match the number of answers it has.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
multimodal_embeds (Tensor): Image-question embeddings.
|
369 |
+
multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model.
|
370 |
+
question_atts (Tensor): Question attention mask.
|
371 |
+
ans_lengths (List[int]): The number of answers each image-question input has.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
encoder_hidden_states (Tensor): Image-question embeddings after the repetition.
|
375 |
+
encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition.
|
376 |
+
encoder_attention_mask (Tensor): Question attention mask after the repetition.
|
377 |
+
"""
|
378 |
+
encoder_hidden_states = []
|
379 |
+
encoder_attention_mask = []
|
380 |
+
for b, n in enumerate(ans_lengths):
|
381 |
+
encoder_hidden_states += [multimodal_embeds[b]] * n
|
382 |
+
encoder_attention_mask += [question_atts[b]] * n
|
383 |
+
encoder_hidden_states = torch.stack(encoder_hidden_states)
|
384 |
+
encoder_attention_mask = torch.stack(encoder_attention_mask)
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
encoder_hidden_states_m = []
|
388 |
+
for b, n in enumerate(ans_lengths):
|
389 |
+
encoder_hidden_states_m += [multimodal_embeds_m[b]] * n
|
390 |
+
encoder_hidden_states_m = torch.stack(encoder_hidden_states_m)
|
391 |
+
|
392 |
+
return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
image: Tensor,
|
397 |
+
question: Tensor,
|
398 |
+
question_atts: Tensor,
|
399 |
+
answers: Tensor,
|
400 |
+
answers_atts: Tensor,
|
401 |
+
ans_weights: Optional[Tensor] = None,
|
402 |
+
ans_lengths: Optional[List[int]] = None,
|
403 |
+
alpha: Optional[float] = 0.0,
|
404 |
+
k: Optional[int] = 128,
|
405 |
+
is_train: Optional[bool] = True,
|
406 |
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
407 |
+
if is_train:
|
408 |
+
return self._train_forward(
|
409 |
+
image,
|
410 |
+
question,
|
411 |
+
question_atts,
|
412 |
+
answers,
|
413 |
+
answers_atts,
|
414 |
+
ans_weights,
|
415 |
+
ans_lengths,
|
416 |
+
alpha,
|
417 |
+
)
|
418 |
+
else:
|
419 |
+
return self._eval_forward(
|
420 |
+
image,
|
421 |
+
question,
|
422 |
+
question_atts,
|
423 |
+
answers,
|
424 |
+
answers_atts,
|
425 |
+
k,
|
426 |
+
)
|
427 |
+
|
428 |
+
|
429 |
+
class ALBEFModelForRetrieval(nn.Module):
|
430 |
+
"""
|
431 |
+
ALBEF Model for Retrieval finetuning and inference.
|
432 |
+
In training mode, the forward step computes image-text contrastive loss and
|
433 |
+
image-text matching loss.
|
434 |
+
In evaluation mode, the forward step takes 3 types of input:
|
435 |
+
image: encode image input, project and normalize the embeddings.
|
436 |
+
text: encode text input, project and normalize the embeddings.
|
437 |
+
multimodal: create multimodal embeddings from image and text
|
438 |
+
embeddings, and compute image-text matching scores.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity.
|
442 |
+
itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss.
|
443 |
+
hidden_size (int): Dimensionality of encoder outputs.
|
444 |
+
|
445 |
+
Inputs:
|
446 |
+
image (Optional[Tensor] of shape (B, C, H, W)): Image features.
|
447 |
+
Required if is_train is True.
|
448 |
+
Required if input_type is "image" or "multimodal".
|
449 |
+
text (Optional[Tensor] of shape (B, L)): Text features.
|
450 |
+
Required if is_train is True.
|
451 |
+
Required if input_type is "text" or "multimodal".
|
452 |
+
text_atts (Tensor of shape (B, L)): Text attention mask.
|
453 |
+
Required if is_train is True.
|
454 |
+
Required if input_type is "text" or "multimodal".
|
455 |
+
idx (Tensor of shape (B)): Identifier for each image sample.
|
456 |
+
Required if is_train is True.
|
457 |
+
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
|
458 |
+
Default is 0.
|
459 |
+
input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type.
|
460 |
+
Required if is_train is False.
|
461 |
+
is_train (Optional[bool]): Whether the model is in training.
|
462 |
+
Default is True.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
is_train is True:
|
466 |
+
Tensor: The sum of itc loss and itm loss.
|
467 |
+
is_train is False:
|
468 |
+
input_type is "image":
|
469 |
+
Tuple[Tensor, Tensor]: Image embeddings and projected image features.
|
470 |
+
input_type is "text":
|
471 |
+
Tuple[Tensor, Tensor]: Text embeddings and projected text features.
|
472 |
+
input_type is "multimodal"
|
473 |
+
Tensor: Scores for the retrieval task.
|
474 |
+
"""
|
475 |
+
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
model_with_similarity: ALBEFModelWithSimilarity,
|
479 |
+
itc_loss: ImageTextContrastiveLoss,
|
480 |
+
hidden_size: int,
|
481 |
+
) -> None:
|
482 |
+
super().__init__()
|
483 |
+
self.model_with_similarity = model_with_similarity
|
484 |
+
self.itc_loss = itc_loss
|
485 |
+
self.itm_head = nn.Linear(hidden_size, 2)
|
486 |
+
|
487 |
+
def _train_forward(
|
488 |
+
self,
|
489 |
+
image: Tensor,
|
490 |
+
text: Tensor,
|
491 |
+
text_atts: Tensor,
|
492 |
+
idx: Tensor,
|
493 |
+
alpha: float,
|
494 |
+
) -> Tensor:
|
495 |
+
encoder_output = self.model_with_similarity(image, text, text_atts, idx)
|
496 |
+
|
497 |
+
# compute image-text contrastive loss
|
498 |
+
similarity_outputs = encoder_output.similarity
|
499 |
+
similarity_targets = encoder_output.sim_targets
|
500 |
+
itc_loss = self.itc_loss(
|
501 |
+
similarity_outputs.sim_i2t,
|
502 |
+
similarity_outputs.sim_t2i,
|
503 |
+
similarity_outputs.sim_i2t_m,
|
504 |
+
similarity_outputs.sim_t2i_m,
|
505 |
+
similarity_targets,
|
506 |
+
alpha,
|
507 |
+
)
|
508 |
+
|
509 |
+
# compute image-text matching loss
|
510 |
+
pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :]
|
511 |
+
neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :]
|
512 |
+
vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0)
|
513 |
+
vl_output = self.itm_head(vl_embeddings)
|
514 |
+
itm_labels = torch.cat(
|
515 |
+
[
|
516 |
+
torch.ones(pos_embeddings.size(0), dtype=torch.long),
|
517 |
+
torch.zeros(neg_embeddings.size(0), dtype=torch.long),
|
518 |
+
],
|
519 |
+
dim=0,
|
520 |
+
).to(vl_embeddings.device)
|
521 |
+
itm_loss = F.cross_entropy(vl_output, itm_labels)
|
522 |
+
|
523 |
+
loss = itc_loss + itm_loss
|
524 |
+
return loss
|
525 |
+
|
526 |
+
def _encode_image(
|
527 |
+
self,
|
528 |
+
image: Tensor,
|
529 |
+
) -> Tuple[Tensor, Tensor]:
|
530 |
+
image_embed = self.model_with_similarity.albef_model.vision_encoder(image)
|
531 |
+
image_feat = F.normalize(
|
532 |
+
self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1
|
533 |
+
)
|
534 |
+
return image_embed, image_feat
|
535 |
+
|
536 |
+
def _encode_text(
|
537 |
+
self,
|
538 |
+
text: Tensor,
|
539 |
+
text_atts: Tensor,
|
540 |
+
) -> Tuple[Tensor, Tensor]:
|
541 |
+
text_embed = self.model_with_similarity.albef_model.text_encoder(
|
542 |
+
text, text_atts
|
543 |
+
).last_hidden_state
|
544 |
+
text_feat = F.normalize(
|
545 |
+
self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1
|
546 |
+
)
|
547 |
+
return text_embed, text_feat
|
548 |
+
|
549 |
+
def _image_text_matching_score(
|
550 |
+
self,
|
551 |
+
image: Tensor,
|
552 |
+
text: Tensor,
|
553 |
+
text_atts: Tensor,
|
554 |
+
) -> Tensor:
|
555 |
+
multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder(
|
556 |
+
text,
|
557 |
+
text_atts,
|
558 |
+
image,
|
559 |
+
)
|
560 |
+
score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1]
|
561 |
+
return score
|
562 |
+
|
563 |
+
def _eval_forward(
|
564 |
+
self,
|
565 |
+
input_type: str,
|
566 |
+
image: Optional[Tensor],
|
567 |
+
text: Optional[Tensor],
|
568 |
+
text_atts: Optional[Tensor],
|
569 |
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
570 |
+
if input_type == "image":
|
571 |
+
assert image is not None, "image input tensor cannot be None"
|
572 |
+
return self._encode_image(image)
|
573 |
+
|
574 |
+
elif input_type == "text":
|
575 |
+
assert (
|
576 |
+
text is not None and text_atts is not None
|
577 |
+
), "text and text attention mask cannot be None"
|
578 |
+
return self._encode_text(text, text_atts)
|
579 |
+
|
580 |
+
elif input_type == "multimodal":
|
581 |
+
assert (
|
582 |
+
image is not None and text is not None and text_atts is not None
|
583 |
+
), "image embeddings, text embeddings, and text attention mask cannot be None"
|
584 |
+
return self._image_text_matching_score(image, text, text_atts)
|
585 |
+
|
586 |
+
else:
|
587 |
+
raise ValueError("input_type must be image, text, or multimodal")
|
588 |
+
|
589 |
+
def forward(
|
590 |
+
self,
|
591 |
+
image: Optional[Tensor] = None,
|
592 |
+
text: Optional[Tensor] = None,
|
593 |
+
text_atts: Optional[Tensor] = None,
|
594 |
+
idx: Optional[Tensor] = None,
|
595 |
+
alpha: Optional[Tensor] = 0.0,
|
596 |
+
input_type: Optional[str] = None,
|
597 |
+
is_train: Optional[bool] = True,
|
598 |
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
599 |
+
if is_train:
|
600 |
+
return self._train_forward(
|
601 |
+
image,
|
602 |
+
text,
|
603 |
+
text_atts,
|
604 |
+
idx,
|
605 |
+
alpha,
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
return self._eval_forward(
|
609 |
+
input_type,
|
610 |
+
image,
|
611 |
+
text,
|
612 |
+
text_atts,
|
613 |
+
)
|
614 |
+
|
615 |
+
|
616 |
+
def albef_model_for_vqa(
|
617 |
+
config: Dict[str, Any], pretrained: bool = False
|
618 |
+
) -> ALBEFModelForVQA:
|
619 |
+
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
|
620 |
+
text_encoder = bert_text_encoder(**config["text_encoder_args"])
|
621 |
+
question_multimodal_encoder = ALBEFMultimodalEncoder(
|
622 |
+
**config["multimodal_encoder_args"]
|
623 |
+
)
|
624 |
+
text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"])
|
625 |
+
answer_multimodal_encoder = ALBEFMultimodalEncoder(
|
626 |
+
**config["multimodal_encoder_args"]
|
627 |
+
)
|
628 |
+
prediction_head = PredictionHead(**config["prediction_head_args"])
|
629 |
+
albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder)
|
630 |
+
decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head)
|
631 |
+
loss = CausalLanguageModelingLoss()
|
632 |
+
model = ALBEFModelForVQA(albef_model, decoder, loss)
|
633 |
+
|
634 |
+
if pretrained:
|
635 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
636 |
+
_ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu"
|
637 |
+
)
|
638 |
+
model.load_state_dict(checkpoint)
|
639 |
+
return model
|
640 |
+
|
641 |
+
|
642 |
+
def albef_model_for_retrieval(
|
643 |
+
config: Dict[str, Any], pretrained: bool = False
|
644 |
+
) -> ALBEFModelForRetrieval:
|
645 |
+
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
|
646 |
+
text_encoder = bert_text_encoder(**config["text_encoder_args"])
|
647 |
+
multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"])
|
648 |
+
vision_proj = nn.Linear(**config["projection_args"])
|
649 |
+
text_proj = nn.Linear(**config["projection_args"])
|
650 |
+
|
651 |
+
albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder)
|
652 |
+
albef_model_with_sim = ALBEFModelWithSimilarity(
|
653 |
+
albef_model, vision_proj, text_proj, **config["similarity_args"]
|
654 |
+
)
|
655 |
+
itc_loss = ImageTextContrastiveLoss()
|
656 |
+
|
657 |
+
model = ALBEFModelForRetrieval(
|
658 |
+
albef_model_with_sim, itc_loss, config["hidden_size"]
|
659 |
+
)
|
660 |
+
|
661 |
+
if pretrained:
|
662 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
663 |
+
_ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu"
|
664 |
+
)
|
665 |
+
model.load_state_dict(checkpoint)
|
666 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.6.0.66
|
2 |
+
pytorch-lightning==1.6.0
|
3 |
+
Pillow==9.0.1
|
4 |
+
ruamel_yaml==0.17.21
|
5 |
+
transformers==4.24.0
|
utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# This source code is licensed under the BSD-style license found in the
|
11 |
+
# LICENSE file in the root directory of this source tree.
|
12 |
+
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.distributed as dist
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
def setup_for_distributed(is_master):
|
22 |
+
"""
|
23 |
+
This function disables printing when not in master process
|
24 |
+
"""
|
25 |
+
import builtins as __builtin__
|
26 |
+
|
27 |
+
builtin_print = __builtin__.print
|
28 |
+
|
29 |
+
def print(*args, **kwargs):
|
30 |
+
force = kwargs.pop("force", False)
|
31 |
+
if is_master or force:
|
32 |
+
builtin_print(*args, **kwargs)
|
33 |
+
|
34 |
+
__builtin__.print = print
|
35 |
+
|
36 |
+
|
37 |
+
def is_dist_avail_and_initialized():
|
38 |
+
if not dist.is_available():
|
39 |
+
return False
|
40 |
+
if not dist.is_initialized():
|
41 |
+
return False
|
42 |
+
return True
|
43 |
+
|
44 |
+
|
45 |
+
def get_world_size():
|
46 |
+
if not is_dist_avail_and_initialized():
|
47 |
+
return 1
|
48 |
+
return dist.get_world_size()
|
49 |
+
|
50 |
+
|
51 |
+
def get_rank():
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return 0
|
54 |
+
return dist.get_rank()
|
55 |
+
|
56 |
+
|
57 |
+
def is_main_process():
|
58 |
+
return get_rank() == 0
|
59 |
+
|
60 |
+
|
61 |
+
def init_distributed_mode(args):
|
62 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
63 |
+
args["rank"] = int(os.environ["RANK"])
|
64 |
+
args["world_size"] = int(os.environ["WORLD_SIZE"])
|
65 |
+
args["gpu"] = int(os.environ["LOCAL_RANK"])
|
66 |
+
elif "SLURM_PROCID" in os.environ:
|
67 |
+
args["rank"] = int(os.environ["SLURM_PROCID"])
|
68 |
+
args["gpu"] = args["rank"] % torch.cuda.device_count()
|
69 |
+
else:
|
70 |
+
print("Not using distributed mode")
|
71 |
+
args["distributed"] = False
|
72 |
+
return
|
73 |
+
|
74 |
+
args["distributed"] = True
|
75 |
+
|
76 |
+
torch.cuda.set_device(args["gpu"])
|
77 |
+
args["dist_backend"] = "nccl"
|
78 |
+
print(
|
79 |
+
"| distributed init (rank {}): {}".format(args["rank"], args["dist_url"]),
|
80 |
+
flush=True,
|
81 |
+
)
|
82 |
+
torch.distributed.init_process_group(
|
83 |
+
backend=args["dist_backend"],
|
84 |
+
init_method=args["dist_url"],
|
85 |
+
world_size=args["world_size"],
|
86 |
+
rank=args["rank"],
|
87 |
+
)
|
88 |
+
torch.distributed.barrier()
|
89 |
+
setup_for_distributed(args["rank"] == 0)
|
90 |
+
|
91 |
+
|
92 |
+
def save_result(result, directory, file_name):
|
93 |
+
rank_path = os.path.join(directory, "{}_rank_{}.json".format(file_name, get_rank()))
|
94 |
+
main_path = os.path.join(directory, "{}.json".format(file_name))
|
95 |
+
json.dump(result, open(rank_path, "w"))
|
96 |
+
|
97 |
+
if is_dist_avail_and_initialized():
|
98 |
+
dist.barrier()
|
99 |
+
|
100 |
+
if is_main_process():
|
101 |
+
result = []
|
102 |
+
for rank in range(get_world_size()):
|
103 |
+
rank_path = os.path.join(
|
104 |
+
directory, "{}_rank_{}.json".format(file_name, rank)
|
105 |
+
)
|
106 |
+
rank_res = json.load(open(rank_path, "r"))
|
107 |
+
result += rank_res
|
108 |
+
json.dump(result, open(main_path, "w"))
|
109 |
+
|
110 |
+
if is_dist_avail_and_initialized():
|
111 |
+
dist.barrier()
|
112 |
+
|
113 |
+
|
114 |
+
def add_weight_decay(model: nn.Module, weight_decay: float) -> None:
|
115 |
+
decay = []
|
116 |
+
no_decay = []
|
117 |
+
for name, param in model.named_parameters():
|
118 |
+
if not param.requires_grad:
|
119 |
+
continue # skip weight_decay for momentum models
|
120 |
+
if len(param.shape) == 1 or name.endswith(".bias"):
|
121 |
+
no_decay.append(param)
|
122 |
+
else:
|
123 |
+
decay.append(param)
|
124 |
+
return [
|
125 |
+
{"params": no_decay, "weight_decay": 0.0},
|
126 |
+
{"params": decay, "weight_decay": weight_decay},
|
127 |
+
]
|
vqa_data.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"image": "images/COCO_val2014_000000184359.jpg", "question": "Is this a train station?", "answers": ["no", "no", "no", "no", "no", "no", "no", "no", "no", "no"]}, {"image": "images/COCO_val2014_000000407072.jpg", "question": "Was this photo taken at night?", "answers": ["yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes"]}, {"image": "images/COCO_val2014_000000111207.jpg", "question": "How many photos in one?", "answers": ["2", "2", "2", "2", "2", "2", "2", "2", "2", "2"]}, {"image": "images/COCO_val2014_000000057222.jpg", "question": "How many bears are there?", "answers": ["2", "3", "3", "4", "2", "2", "3", "3", "2", "3"]}, {"image": "images/COCO_val2014_000000159269.jpg", "question": "What time of the day it is?", "answers": ["evening", "evening", "dusk", "sunset", "sunset", "dusk", "morning", "dusk", "evening", "4 pm"]}, {"image": "images/COCO_val2014_000000026348.jpg", "question": "What color is the refrigerator handle?", "answers": ["white", "white", "white", "white", "white", "white", "white", "white", "white", "white"]}, {"image": "images/COCO_val2014_000000473994.jpg", "question": "What does this animal eat?", "answers": ["meat", "dog food", "dog food", "dog food", "dog food", "dog food", "frisbee", "dog food", "frisbee", "dog food"]}, {"image": "images/COCO_val2014_000000552075.jpg", "question": "Who is wearing a hat?", "answers": ["no one", "woman", "no one", "nobody", "no one", "nobody", "no", "nobody", "nobody", "man"]}]
|