Triad: Dense Cross-Modal Feature Learning

image/png

I built Triad to explore dense feature correspondences between video, audio and text modalities - focusing on learning fine-grained, localized relationships rather than just global alignment. The goal was to create a model that could ground features between specific image regions, audio segments, and text spans simultaneously.

This is a very early research checkpoint for dense multi-modal learning, with lots of room for improvement and experimentation. The current model was trained on a subset of AudioSet (~400k videos, ~20% of the entire dataset) and CC3M (~2M image-text pairs) for just one epoch, so while it shows promising behavior, it's definitely not state-of-the-art yet.

TL:DR - The model embeds semantic concepts on dense features (patches, audio features, text spans) instead of global embeddings containing semantic concepts. The embedding of a patch that contains a cat in a image has high cosine similarity with the word "cat" and an audio segment of a cat meow.

What Makes This Interesting?

Unlike models that learn global alignment between modalities (think CLIP, ImageBind), Triad learns to map specific parts of each modality to each other. This means it can:

  • Locate which parts of an image correspond to particular words or sounds
  • Ground audio segments to relevant visual regions
  • Connect text descriptions to precise areas in images
  • (Potentially) Learn transitive audio-text relationships through the shared visual space

What's Next?

I've got lots of ideas for making this better - longer training, playing with the architecture, investigating some interesting behaviors I've noticed and solving that massive issue of dealing with text, audio features that do not exist in the visual features.

I'm actively looking to push this research further and super interested in tackling more multimodal learning problems. Feel free to reach out if you're working in this space!

Inference

Triad Model

The model can process image, audio, and text inputs - either individually or together.

Installation & Loading

from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import torch
import json
import sys
from pathlib import Path

def load_model(path="SajayR/Triad", device="cpu"):
    model_path = hf_hub_download(repo_id=path, filename="model.safetensors")
    model_config = hf_hub_download(repo_id=path, filename="config.json")
    model_arch = hf_hub_download(repo_id=path, filename="hf_model.py")
    
    sys.path.append(str(Path(model_arch).parent))
    from hf_model import Triad
    
    model = Triad(**json.load(open(model_config)))
    weights = load_file(model_path)
    model.load_state_dict(weights)
    return model.to(device)

# Initialize model
model = load_model()  # Use load_model(device="cuda") for GPU

Single Modality Examples

Image Input

You can provide images as file paths or tensors:

# From file path
output = model(image="path/to/image.jpg")
output['visual_feats'].shape  # torch.Size([1, 256, 512])

# From tensor (already pre-processed)
from torchvision import transforms
from PIL import Image

# Load and preprocess image
image = Image.open("path/to/image.jpg").convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image)  # Shape: [3, 224, 224]

# Pass to model
output = model(image=image_tensor)
output['visual_feats'].shape  # torch.Size([1, 256, 512])

Audio Input

# Audio only - returns audio features (B, N_segments, D) 
# Currently is trained for audio features of 1 seconds each. Longer audio sequences could have worse performance
audio = torch.randn(1, 16331)  # Raw audio waveform
output = model(audio=audio)
output['audio_feats'].shape  # torch.Size([1, 50, 512])

Text Input

# Text only - returns text features (B, N_tokens, D)
text_list = ["a man riding a bicycle"]
output = model(text_list=text_list)
output['text_feats'].shape  # torch.Size([1, 5, 512])

Batch Processing

The model now supports batch processing for image inputs:

Batch of Image Paths

# Process a batch of image paths
image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg"]
output = model(image=image_paths)
output['visual_feats'].shape  # torch.Size([3, 256, 512])

Batch of Image Tensors

# Process a batch of image tensors
import torch
from torchvision import transforms
from PIL import Image

# Create a transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load and preprocess images
images = []
for path in ["image1.jpg", "image2.jpg", "image3.jpg"]:
    img = Image.open(path).convert('RGB')
    images.append(transform(img))

# Stack into a batch
batch = torch.stack(images)  # Shape: [3, 3, 224, 224]

# Process the batch
output = model(image=batch)
output['visual_feats'].shape  # torch.Size([3, 256, 512])

Multi-Modal Examples

Image and Audio Together

# Process image and audio together
output = model(
    audio=audio,
    image="path/to/image.jpg"
)

print(output.keys())  # dict_keys(['visual_feats', 'audio_feats', 'vis_audio_sim_matrix'])

# Output shapes:
# - audio_feats: [1, 50, 512]      # (batch, audio_segments, features)
# - visual_feats: [1, 256, 512]    # (batch, image_patches, features) 
# - vis_audio_sim_matrix: [1, 50, 256]  # (batch, audio_segments, image_patches)

The similarity matrix shows the correspondence between each audio segment and image patch.

Output Key Reference

Depending on which modalities you provide, the model returns different outputs:

  • visual_feats: (B, 256, 512) # When you pass an image
  • audio_feats: (B, 50, 512) # When you pass audio
  • text_feats: (B, N_tokens, 512) # When you pass text
  • vis_text_sim_matrix: (B, N_tokens, 256) # When you pass both image and text
  • vis_audio_sim_matrix: (B, 50, 256) # When you pass both image and audio
  • text_audio_sim_matrix: (B, N_tokens, 50) # When you pass both text and audio

Where:

  • B = batch size
  • 256 = number of image patches
  • 50 = number of audio segments
  • N_tokens = variable length of text tokens
  • 512 = embedding dimension
Downloads last month
31
Safetensors
Model size
248M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support