Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import os | |
import argparse | |
from PIL import Image | |
# Add the path to the thirdparty/SeeSR directory to the Python path | |
sys.path.append(os.path.abspath("./thirdparty/SeeSR")) | |
import torch | |
from torchvision import transforms | |
from ram.models.ram_lora import ram | |
from ram import inference_ram as inference | |
def load_ram_model(ram_model_path: str, dape_model_path: str): | |
""" | |
Load the RAM model with the given paths. | |
Args: | |
ram_model_path (str): Path to the pretrained RAM model. | |
dape_model_path (str): Path to the pretrained DAPE model. | |
Returns: | |
torch.nn.Module: Loaded RAM model. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the RAM model | |
tag_model = ram(pretrained=ram_model_path, pretrained_condition=dape_model_path, image_size=384, vit="swin_l") | |
tag_model.eval() | |
return tag_model.to(device) | |
def generate_caption(image_path: str, tag_model) -> str: | |
""" | |
Generate a caption for a degraded image using the RAM model. | |
Args: | |
image_path (str): Path to the degraded input image. | |
tag_model (torch.nn.Module): Preloaded RAM model. | |
Returns: | |
str: Generated caption for the image. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define image transformations | |
tensor_transforms = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
ram_transforms = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Load and preprocess the image | |
image = Image.open(image_path).convert("RGB") | |
image_tensor = tensor_transforms(image).unsqueeze(0).to(device) | |
image_tensor = ram_transforms(image_tensor) | |
# Generate caption using the RAM model | |
caption = inference(image_tensor, tag_model) | |
return caption[0] | |
def process_images_in_directory(input_dir: str, output_file: str, tag_model): | |
""" | |
Process all images in a directory, generate captions using the RAM model, | |
and save the captions to a file. | |
Args: | |
input_dir (str): Path to the directory containing input images. | |
output_file (str): Path to the file where captions will be saved. | |
tag_model (torch.nn.Module): Preloaded RAM model. | |
""" | |
# Open the output file for writing captions | |
with open(output_file, "w") as f: | |
# Iterate through all files in the input directory | |
for filename in os.listdir(input_dir): | |
# Construct the full path to the image file | |
image_path = os.path.join(input_dir, filename) | |
# Check if the file is an image | |
if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
try: | |
# Generate a caption for the image | |
caption = generate_caption(image_path, tag_model) | |
print(f"Generated caption for {filename}: {caption}") | |
# Write the caption to the output file | |
f.write(f"{filename}: {caption}\n") | |
print(f"Processed {filename}: {caption}") | |
except Exception as e: | |
print(f"Error processing {filename}: {e}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Generate captions for images using RAM and DAPE models.") | |
parser.add_argument("--input_dir", type=str, default="data/val", help="Path to the directory containing input images.") | |
parser.add_argument("--output_file", type=str, default="data/val_captions.txt", help="Path to the file where captions will be saved.") | |
parser.add_argument("--ram_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/ram_swin_large_14m.pth", help="Path to the pretrained RAM model.") | |
parser.add_argument("--dape_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/DAPE.pth", help="Path to the pretrained DAPE model.") | |
args = parser.parse_args() | |
# Load the RAM model once | |
tag_model = load_ram_model(args.ram_model, args.dape_model) | |
# Process images in the directory | |
process_images_in_directory(args.input_dir, args.output_file, tag_model) |