Spaces:
Running
Running
import gradio as gr | |
import random | |
import json | |
import re | |
import os | |
import numpy as np | |
from collections import Counter | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import functools | |
from concurrent.futures import ThreadPoolExecutor | |
import threading | |
import nltk | |
from nltk.corpus import wordnet | |
from nltk.stem import WordNetLemmatizer | |
# Add at the beginning of your script, after imports | |
import os | |
import nltk | |
# Get the current directory | |
current_dir = os.getcwd() | |
print(f"Current directory: {current_dir}") | |
# Point NLTK to the data directories in your current directory | |
nltk_data_path = os.path.join(current_dir, "nltk_data") | |
print(f"Setting NLTK data path to: {nltk_data_path}") | |
# Add the path to NLTK's search paths | |
nltk.data.path.insert(0, nltk_data_path) # Insert at position 0 to search here first | |
# Print all paths for debugging | |
print(f"NLTK will search in: {nltk.data.path}") | |
# Try to load the taggers from your local directory | |
try: | |
# Try to directly load the tagger model | |
from nltk.tag.perceptron import PerceptronTagger | |
tagger = PerceptronTagger() | |
print("Successfully loaded PerceptronTagger") | |
except Exception as e: | |
print(f"Error loading tagger: {e}") | |
# nltk.download('averaged_perceptron_tagger_eng') | |
# Add the header constant at the top of your file | |
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;"> | |
<h1>SINC: Spatial Composition of 3D Human Motions for Simultaneous Action Generation</h1> | |
<h2 style="margin: 1em 0; font-size: 2em;"> | |
<span style="font-weight: normal; font-style: italic;">ICCV 2023</span> | |
</h2> | |
<h3> | |
<a href="https://atnikos.github.io/" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>*1</sup>, | |
<a href="https://mathis.petrovich.fr/" target="_blank" rel="noopener noreferrer">Mathis Petrovich</a><sup>*1,2</sup>, | |
<br> | |
<a href="https://ps.is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>, | |
<a href="https://gulvarol.github.io/" target="_blank" rel="noopener noreferrer">Gül Varol</a><sup>2</sup> | |
</h3> | |
<h3> | |
<sup>1</sup>MPI for Intelligent Systems, Tübingen, Germany<br> | |
<sup>2</sup>LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France | |
</h3> | |
</div> | |
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center"> | |
<a href='https://arxiv.org/abs/2304.10417'><img src='https://img.shields.io/badge/Arxiv-2304.10417-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> | |
<a href='https://sinc.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> | |
</div> | |
<h2 align="center"> | |
Download | |
<a href="https://drive.google.com/drive/folders/1ks9wvNN_arrgBcd0GxN5nRLf5ASPkUgc?usp=sharing" target="_blank" rel="noopener noreferrer"> SINC synthetic data</a>, | |
if you want to train your models with spatial composition from AMASS. | |
<br> | |
The data you are exploring in this demo are | |
the data created using the | |
code <a href='https://github.com/atnikos/sinc/blob/main/create_synthetic_babel.py' target="_blank" rel="noopener noreferrer">to compose motions from AMASS in our repo.</a><sup>**</sup> | |
</h2> | |
""") | |
# Action examples | |
ACTION_EXAMPLES = [ | |
"walk forward on balance beam", "walk counterclockwise", "sit on chair", "kick a ball", "jump up", | |
"hold on to rail with right hand", "pick up an object", 'wave with the right hand', 'throw a ball', 'bow' | |
] | |
ACTION_EXAMPLES_SIMULTANEOUS = [ | |
"walk forward on balance beam while holding rail with right hand", | |
"walk counterclockwise while waving with left hand", | |
"sit on chair and wave with left hand", | |
"pick up an object while bowing", | |
"walk forward on balance beam while waving left hand" | |
] | |
# Global cache for expensive operations | |
SIMILARITY_CACHE = {} | |
SEARCH_RESULTS_CACHE = {} | |
GPT_SIMILARITY_CACHE = {} | |
GPT_SEARCH_RESULTS_CACHE = {} | |
SYNONYM_CACHE = {} | |
MAX_WORKERS = 4 # For ThreadPoolExecutor | |
# Cache for TF-IDF | |
TFIDF_VECTORIZER = None | |
TFIDF_MATRIX = None | |
MOTION_TEXTS = [] | |
MOTION_KEYS = [] | |
GPT_TEXTS = [] | |
GPT_KEYS = [] | |
# Initialize lemmatizer | |
lemmatizer = WordNetLemmatizer() | |
# Movement action word mappings - manually defined synonyms for common motion words | |
ACTION_SYNONYMS = { | |
'walk': ['move', 'stroll', 'pace', 'stride', 'wander', 'stalk', 'amble', 'saunter', 'tread', 'step'], | |
'run': ['sprint', 'jog', 'dash', 'race', 'bolt', 'scamper', 'rush', 'hurry'], | |
'jump': ['leap', 'hop', 'spring', 'bounce', 'vault', 'bound', 'skip'], | |
'turn': ['rotate', 'spin', 'twist', 'revolve', 'pivot', 'swivel', 'whirl'], | |
'wave': ['signal', 'gesture', 'flap', 'flutter', 'waggle', 'shake', 'brandish'], | |
'sit': ['perch', 'recline', 'rest', 'squat'], | |
'stand': ['rise', 'upright', 'erect', 'vertical'], | |
'throw': ['toss', 'hurl', 'fling', 'chuck', 'lob', 'pitch', 'cast'], | |
'grab': ['grasp', 'clutch', 'seize', 'grip', 'hold', 'take', 'catch'], | |
'pick': ['lift', 'raise', 'hoist', 'elevate'], | |
'kick': ['boot', 'punt', 'strike'], | |
'bow': ['bend', 'stoop', 'incline', 'nod'], | |
'dance': ['twirl', 'sway', 'shimmy', 'boogie', 'groove', 'swing'], | |
'balance': ['steady', 'stabilize', 'poise', 'equilibrium'], | |
'forward': ['ahead', 'onward', 'frontward', 'forth'], | |
'backward': ['back', 'rearward', 'reverse', 'retreat'], | |
'clockwise': ['right', 'rightward', 'rightways'], | |
'counterclockwise': ['left', 'leftward', 'leftways', 'anticlockwise'], | |
'hold': ['grip', 'grasp', 'clutch', 'clasp', 'clench', 'possess'] | |
} | |
# Build reverse mapping for faster lookups | |
REVERSE_SYNONYMS = {} | |
for word, synonyms in ACTION_SYNONYMS.items(): | |
REVERSE_SYNONYMS[word] = word # A word is its own synonym | |
for synonym in synonyms: | |
REVERSE_SYNONYMS[synonym] = word | |
def get_wordnet_pos(word): | |
"""Map POS tag to first character used by WordNet lemmatizer | |
with fallback for errors""" | |
try: | |
tag = nltk.tag.pos_tag([word])[0][1][0].upper() | |
tag_dict = {"J": wordnet.ADJ, | |
"N": wordnet.NOUN, | |
"V": wordnet.VERB, | |
"R": wordnet.ADV} | |
return tag_dict.get(tag, wordnet.NOUN) | |
except Exception as e: | |
print(f"POS tagging error for word '{word}': {e}") | |
# Default to NOUN as fallback | |
return wordnet.NOUN | |
def get_synonyms(word): | |
"""Get all synonyms for a word using WordNet and our custom action mappings""" | |
if word in SYNONYM_CACHE: | |
return SYNONYM_CACHE[word] | |
synonyms = set() | |
# Add the word itself | |
synonyms.add(word) | |
# Check our custom action mappings first (faster and more domain-specific) | |
if word in REVERSE_SYNONYMS: | |
canonical_word = REVERSE_SYNONYMS[word] | |
synonyms.add(canonical_word) | |
synonyms.update(ACTION_SYNONYMS.get(canonical_word, [])) | |
# Then check WordNet (more general but can be noisy) | |
try: | |
word_lemma = lemmatizer.lemmatize(word, get_wordnet_pos(word)) | |
for syn in wordnet.synsets(word_lemma): | |
for lemma in syn.lemmas(): | |
synonyms.add(lemma.name().lower().replace('_', ' ')) | |
except Exception as e: | |
print(f"Error getting WordNet synonyms for '{word}': {e}") | |
SYNONYM_CACHE[word] = synonyms | |
return synonyms | |
def expand_query_with_synonyms(query): | |
"""Expand a query with synonyms for each term""" | |
try: | |
words = nltk.word_tokenize(query.lower()) | |
except Exception as e: | |
print(f"Tokenization error: {e}") | |
# Fallback to simple split if tokenization fails | |
words = query.lower().split() | |
expanded_terms = [] | |
for word in words: | |
if len(word) > 2: # Only expand words with length > 2 to avoid stop words | |
synonyms = get_synonyms(word) | |
expanded_terms.extend(synonyms) | |
else: | |
expanded_terms.append(word) | |
# Join back into a space-separated string | |
return ' '.join(expanded_terms) | |
def create_example_buttons(textbox, loftexts): | |
"""Creates clickable buttons for example actions""" | |
return gr.Examples( | |
examples=loftexts, | |
inputs=textbox, | |
label="Example Actions" | |
) | |
# Load motion data | |
def load_json_dict(file_path): | |
with open(file_path, "r") as f: | |
return json.load(f) | |
# Load data at startup | |
print("Loading motion data...") | |
motion_dict = load_json_dict("for_website_v4.json") | |
motion_dict = { | |
key: value for key, value in motion_dict.items() | |
if "guide forward walk" not in value['source_annot'].lower() | |
and "guide forward walk" not in value['target_annot'].lower() | |
} | |
print("Loading GPT labels...") | |
GPT_LABELS_LIST = load_json_dict('gpt3-labels-list.json') | |
GPT_LABELS_LIST = {k: v[2] for k, v in GPT_LABELS_LIST.items()} | |
# TF-IDF based similarity implementation with synonym expansion | |
def initialize_tfidf(): | |
"""Initialize TF-IDF vectorizer and precompute matrices""" | |
global TFIDF_VECTORIZER, TFIDF_MATRIX, MOTION_TEXTS, MOTION_KEYS | |
print("Initializing TF-IDF vectorizer...") | |
# Extract text descriptions from the motion dictionary for TF-IDF | |
MOTION_TEXTS = [] | |
MOTION_KEYS = [] | |
for key, motion in motion_dict.items(): | |
# Combine source and target annotations | |
text = f"{motion['source_annot']} {motion['target_annot']}".lower() | |
MOTION_TEXTS.append(text) | |
MOTION_KEYS.append(key) | |
# Initialize the TF-IDF vectorizer | |
TFIDF_VECTORIZER = TfidfVectorizer( | |
lowercase=True, | |
stop_words='english', | |
ngram_range=(1, 2), # Include bigrams for better matching | |
max_features=20000, # Increased to accommodate synonym expansions | |
min_df=1 # Lower threshold to catch less frequent terms | |
) | |
# Fit and transform to get TF-IDF vectors | |
TFIDF_MATRIX = TFIDF_VECTORIZER.fit_transform(MOTION_TEXTS) | |
print(f"TF-IDF matrix created with shape {TFIDF_MATRIX.shape}") | |
# Also create GPT labels matrix | |
initialize_gpt_tfidf() | |
def initialize_gpt_tfidf(): | |
"""Initialize TF-IDF for GPT labels""" | |
global GPT_TEXTS, GPT_KEYS | |
print("Initializing TF-IDF for GPT labels...") | |
GPT_TEXTS = [] | |
GPT_KEYS = [] | |
for key, text in GPT_LABELS_LIST.items(): | |
GPT_TEXTS.append(text.lower()) | |
GPT_KEYS.append(key) | |
def compute_tfidf_similarity(query, top_k=10): | |
"""Compute similarity using TF-IDF vectors with synonym expansion""" | |
global TFIDF_VECTORIZER, TFIDF_MATRIX, MOTION_TEXTS, MOTION_KEYS | |
# Original query for cache key | |
original_query = query.lower().strip() | |
# Check cache first | |
cache_key = f"tfidf_{original_query}_{top_k}" | |
if cache_key in SIMILARITY_CACHE: | |
return SIMILARITY_CACHE[cache_key] | |
try: | |
# Expand query with synonyms | |
expanded_query = expand_query_with_synonyms(original_query) | |
# Transform query to TF-IDF space | |
query_vector = TFIDF_VECTORIZER.transform([expanded_query]) | |
# Compute cosine similarity between query and all texts | |
# Using matrix multiplication for sparse matrices | |
similarities = (query_vector @ TFIDF_MATRIX.T).toarray().flatten() | |
# Get indices of top_k highest similarity scores | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
# Get the corresponding entries and scores | |
top_entries = [motion_dict[MOTION_KEYS[idx]] for idx in top_indices] | |
top_scores = [similarities[idx] for idx in top_indices] | |
result = (top_entries, top_scores) | |
except Exception as e: | |
print(f"Error in TF-IDF similarity computation: {e}") | |
# Fallback to random motions if TF-IDF fails | |
result = (get_random_motions(top_k), ['NA']*top_k) | |
SIMILARITY_CACHE[cache_key] = result | |
return result | |
def compute_gpt_tfidf_similarity(query): | |
"""Compute similarity between query and GPT labels using TF-IDF with synonym expansion""" | |
global TFIDF_VECTORIZER, GPT_TEXTS, GPT_KEYS | |
# Original query for cache key | |
original_query = query.lower().strip() | |
# Check cache first | |
cache_key = f"gpt_tfidf_{original_query}" | |
if cache_key in GPT_SIMILARITY_CACHE: | |
return GPT_SIMILARITY_CACHE[cache_key] | |
try: | |
# Expand query with synonyms | |
expanded_query = expand_query_with_synonyms(original_query) | |
# Transform query and all GPT texts to TF-IDF space | |
query_vector = TFIDF_VECTORIZER.transform([expanded_query]) | |
gpt_vectors = TFIDF_VECTORIZER.transform(GPT_TEXTS) | |
# Compute cosine similarity between query and all GPT texts | |
similarities = (query_vector @ gpt_vectors.T).toarray().flatten() | |
# Get the index of highest similarity score | |
best_idx = np.argmax(similarities) | |
best_key = GPT_KEYS[best_idx] | |
best_text = GPT_LABELS_LIST[best_key] | |
best_sim = similarities[best_idx] | |
result = (best_key, best_text, best_sim) | |
except Exception as e: | |
print(f"Error in GPT TF-IDF similarity computation: {e}") | |
# Fallback to first GPT label if computation fails | |
if GPT_KEYS: | |
result = (GPT_KEYS[0], GPT_LABELS_LIST[GPT_KEYS[0]], 0.5) | |
else: | |
result = (None, None, 0) | |
GPT_SIMILARITY_CACHE[cache_key] = result | |
return result | |
# Precompile regex pattern | |
WORD_PATTERN = re.compile(r'\b\w+\b') | |
# Cache the word lists to avoid repeated tokenization | |
SOURCE_WORDS_CACHE = {} | |
TARGET_WORDS_CACHE = {} | |
def get_words(text): | |
"""Tokenize text and cache the results""" | |
if text in SOURCE_WORDS_CACHE: | |
return SOURCE_WORDS_CACHE[text] | |
words = set(WORD_PATTERN.findall(text.lower())) | |
SOURCE_WORDS_CACHE[text] = words | |
return words | |
def exact_string_search(action1, action2): | |
"""Search for exact string matches first""" | |
exact_results = [] | |
action1_lower = action1.lower().strip() | |
action2_lower = action2.lower().strip() | |
for k, v in motion_dict.items(): | |
source_lower = v["source_annot"].lower() | |
target_lower = v["target_annot"].lower() | |
# Check for exact matches in either annotation | |
cond1 = action1_lower in source_lower or action1_lower in target_lower | |
cond2 = action2_lower in source_lower or action2_lower in target_lower | |
if cond1 and cond2: | |
exact_results.append(v) | |
return exact_results | |
def search_motions_two_actions(action1, action2): | |
"""Enhanced substring search with synonym expansion""" | |
# Create a cache key for this query | |
cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}" | |
# Check if we already have results for this query | |
if cache_key in SEARCH_RESULTS_CACHE: | |
return SEARCH_RESULTS_CACHE[cache_key] | |
try: | |
# Convert actions into lists of words | |
action1_words = set(action1.lower().strip().split()) | |
action2_words = set(action2.lower().strip().split()) | |
# Expand with synonyms | |
expanded_action1_words = set() | |
for word in action1_words: | |
if len(word) > 2: # Only consider words longer than 2 chars | |
expanded_action1_words.update(get_synonyms(word)) | |
else: | |
expanded_action1_words.add(word) | |
expanded_action2_words = set() | |
for word in action2_words: | |
if len(word) > 2: # Only consider words longer than 2 chars | |
expanded_action2_words.update(get_synonyms(word)) | |
else: | |
expanded_action2_words.add(word) | |
results = [] | |
for k, v in motion_dict.items(): | |
# Get or compute tokenized words from cache | |
if v["source_annot"] not in SOURCE_WORDS_CACHE: | |
SOURCE_WORDS_CACHE[v["source_annot"]] = set(WORD_PATTERN.findall(v["source_annot"].lower())) | |
if v["target_annot"] not in TARGET_WORDS_CACHE: | |
TARGET_WORDS_CACHE[v["target_annot"]] = set(WORD_PATTERN.findall(v["target_annot"].lower())) | |
source_words = SOURCE_WORDS_CACHE[v["source_annot"]] | |
target_words = TARGET_WORDS_CACHE[v["target_annot"]] | |
# For each word in action1, check if any of its synonyms match | |
cond1 = False | |
if action1_words: # Only check if action1 has words | |
matches = 0 | |
for word in action1_words: | |
word_matches = False | |
if len(word) <= 2: # For short words, just check exact match | |
if word in source_words or word in target_words: | |
word_matches = True | |
else: # For longer words, check all synonyms | |
for syn in get_synonyms(word): | |
if syn in source_words or syn in target_words: | |
word_matches = True | |
break | |
if word_matches: | |
matches += 1 | |
# Consider a match if at least 70% of words (or their synonyms) are found | |
cond1 = (matches / len(action1_words)) >= 0.7 if action1_words else True | |
else: | |
cond1 = True | |
# For each word in action2, check if any of its synonyms match | |
cond2 = False | |
if action2_words: # Only check if action2 has words | |
matches = 0 | |
for word in action2_words: | |
word_matches = False | |
if len(word) <= 2: # For short words, just check exact match | |
if word in source_words or word in target_words: | |
word_matches = True | |
else: # For longer words, check all synonyms | |
for syn in get_synonyms(word): | |
if syn in source_words or syn in target_words: | |
word_matches = True | |
break | |
if word_matches: | |
matches += 1 | |
# Consider a match if at least 70% of words (or their synonyms) are found | |
cond2 = (matches / len(action2_words)) >= 0.7 if action2_words else True | |
else: | |
cond2 = True | |
if cond1 and cond2: | |
results.append(v) | |
except Exception as e: | |
print(f"Error in substring search: {e}") | |
results = [] | |
# Cache the results | |
SEARCH_RESULTS_CACHE[cache_key] = results | |
return results | |
def search_motions_semantic(action1, action2, top_k=10): | |
"""Semantic search using TF-IDF similarity with synonym expansion""" | |
query_text = (action1.strip() + " " + action2.strip()).strip().lower() | |
if not query_text: | |
return [], [] | |
# Check cache first | |
cache_key = f"{query_text}_{top_k}" | |
if cache_key in SEARCH_RESULTS_CACHE: | |
return SEARCH_RESULTS_CACHE[cache_key] | |
# Use TF-IDF similarity | |
return compute_tfidf_similarity(query_text, top_k) | |
def get_random_motions(n_motions): | |
all_vals = list(motion_dict.values()) | |
return random.sample(all_vals, min(n_motions, len(all_vals))) | |
def search_gpt_semantic(action, top_k=1): | |
"""Search GPT labels using TF-IDF similarity with synonym expansion""" | |
query_text = action.strip().lower() | |
if not query_text: | |
return None, None, None | |
# Check cache first | |
if query_text in GPT_SEARCH_RESULTS_CACHE: | |
return GPT_SEARCH_RESULTS_CACHE[query_text] | |
# Use TF-IDF similarity for GPT labels | |
result = compute_gpt_tfidf_similarity(query_text) | |
GPT_SEARCH_RESULTS_CACHE[query_text] = result | |
return result | |
def search_motions_combined(action1, action2, n_motions): | |
"""Improved combined search approach that prioritizes exact matches""" | |
# Create a cache key for this query | |
cache_key = f"{action1.lower().strip()}_{action2.lower().strip()}_{n_motions}" | |
# Check if we already have results for this query | |
if cache_key in SEARCH_RESULTS_CACHE: | |
return SEARCH_RESULTS_CACHE[cache_key] | |
# 1. First try exact string matches | |
exact_results = exact_string_search(action1, action2) | |
if len(exact_results) >= n_motions: | |
# If we have enough exact matches, return them | |
result = (random.sample(exact_results, n_motions), ['EXACT']*n_motions) | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
# 2. If not enough exact matches, try the enhanced substring search with synonyms | |
string_results = search_motions_two_actions(action1, action2) | |
# Filter out any results that are already in exact_results | |
string_results = [r for r in string_results if r not in exact_results] | |
# Combine exact_results with string_results | |
combined_results = list(exact_results) | |
combined_scores = ['EXACT'] * len(exact_results) | |
if len(combined_results) + len(string_results) >= n_motions: | |
# If we have enough combined results, use them | |
needed = n_motions - len(combined_results) | |
if needed > 0: | |
combined_results.extend(random.sample(string_results, needed)) | |
combined_scores.extend(['SUBSTR'] * needed) | |
result = (combined_results[:n_motions], combined_scores[:n_motions]) | |
else: | |
# 3. If still not enough, add all substring matches and then use semantic search | |
combined_results.extend(string_results) | |
combined_scores.extend(['SUBSTR'] * len(string_results)) | |
# Use semantic search for the remaining needed motions | |
needed = n_motions - len(combined_results) | |
if needed > 0: | |
sem_list, sem_score_list = search_motions_semantic(action1, action2, top_k=2*needed) | |
# Filter out duplicates | |
used_combo = {m["motion_combo"] for m in combined_results} | |
for item, score in zip(sem_list, sem_score_list): | |
if item["motion_combo"] not in used_combo: | |
combined_results.append(item) | |
combined_scores.append(score) | |
used_combo.add(item["motion_combo"]) | |
if len(combined_results) == n_motions: | |
break | |
# Still short? Fill with random | |
if len(combined_results) < n_motions: | |
needed2 = n_motions - len(combined_results) | |
rnd = get_random_motions(needed2) | |
for r in rnd: | |
if r["motion_combo"] not in used_combo: | |
combined_results.append(r) | |
combined_scores.append('RANDOM') | |
used_combo.add(r["motion_combo"]) | |
if len(combined_results) == n_motions: | |
break | |
result = (combined_results[:n_motions], combined_scores[:n_motions]) | |
# Cache the results | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
def safe_video_update(motion_data, semantic_score, visible=True): | |
"""Optimized video update with match type display""" | |
# Prepare the annotation text based on the match type | |
if semantic_score == 'EXACT': | |
match_info = "Exact Match" | |
elif semantic_score == 'SUBSTR': | |
match_info = "Substring Match" | |
elif semantic_score == 'RANDOM': | |
match_info = "Random Result" | |
else: | |
# For semantic matches, round to 2 decimal places | |
ssim = str(round(semantic_score, 2)) if semantic_score != 'NA' else '' | |
match_info = f"Semantic Match (sim: {ssim})" | |
actual_annot = f"{motion_data['annotation']} | {match_info}" | |
return [ | |
gr.update(value=url, visible=visible) | |
for url in (motion_data["motion_combo"], | |
motion_data["motion_a"], | |
motion_data["motion_b"]) | |
] + [gr.update(value=actual_annot, visible=visible)] | |
def update_videos(motions, n_visible, semantic_scores): | |
"""Update video components with motion data, with parallel video processing""" | |
updates = [] | |
if not motions: | |
updates.append(gr.update(value='incompatible combination', visible=True)) | |
remaining = 7 | |
for _ in range(remaining): | |
updates.extend([ | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False) | |
]) | |
else: | |
try: | |
# Prepare all updates in parallel using ThreadPoolExecutor | |
with ThreadPoolExecutor(max_workers=min(8, n_visible)) as executor: | |
# Submit all video update tasks | |
future_updates = [ | |
executor.submit(safe_video_update, motion, semantic_scores[jj], True) | |
for jj, motion in enumerate(motions[:n_visible]) | |
] | |
# Collect all updates as they complete | |
for future in future_updates: | |
updates.extend(future.result()) | |
remaining = 8 - len(motions[:n_visible]) | |
for _ in range(remaining): | |
updates.extend([ | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False) | |
]) | |
except Exception as e: | |
print(f"Error updating videos: {e}") | |
# Fallback if parallel processing fails | |
updates = [] | |
for i in range(8): | |
if i < len(motions[:n_visible]): | |
motion = motions[i] | |
score = semantic_scores[i] | |
# Handle different score types | |
if score == 'EXACT': | |
match_info = "Exact Match" | |
elif score == 'SUBSTR': | |
match_info = "Substring Match" | |
elif score == 'RANDOM': | |
match_info = "Random Result" | |
else: | |
# For semantic matches, round to 2 decimal places | |
ssim = str(round(score, 2)) if score != 'NA' else '' | |
match_info = f"Semantic Match (sim: {ssim})" | |
actual_annot = f"{motion['annotation']} | {match_info}" | |
updates.extend([ | |
gr.update(value=motion["motion_combo"], visible=True), | |
gr.update(value=motion["motion_a"], visible=True), | |
gr.update(value=motion["motion_b"], visible=True), | |
gr.update(value=actual_annot, visible=True) | |
]) | |
else: | |
updates.extend([ | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False) | |
]) | |
return updates | |
def parse_gpt_labels(text): | |
"""Parse GPT labels from text""" | |
if text.startswith("Answer: "): | |
text = text[len("Answer: "):] # Remove the "Answer: " prefix | |
return text.split("\n") # Split by newline | |
def failure_update(message, n_motions=None): | |
"""Create UI updates for failure cases""" | |
updates = [] | |
# For the first motion: hide videos and display the message in the text box | |
updates.append(gr.update(value=None, visible=False)) # video_combo for motion 1 | |
updates.append(gr.update(value=None, visible=False)) # video_a for motion 1 | |
updates.append(gr.update(value=None, visible=False)) # video_b for motion 1 | |
updates.append(gr.update(value=message, visible=True)) # annotation text for motion 1 | |
# For the remaining 7 motions, hide all components | |
for _ in range(7): | |
updates.extend([ | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False) | |
]) | |
return updates | |
def handle_interaction(action1, action2, n_motions): | |
"""Handle user interaction with caching for faster responses""" | |
# Create a cache key for the entire interaction | |
cache_key = f"interaction_{action1.strip().lower()}_{action2.strip().lower()}_{n_motions}" | |
# Check if we have cached results for this interaction | |
if cache_key in SEARCH_RESULTS_CACHE: | |
return SEARCH_RESULTS_CACHE[cache_key] | |
try: | |
if not action1.strip() and not action2.strip(): | |
# Both empty => random | |
motions = get_random_motions(n_motions) | |
result = update_videos(motions, n_motions, ['NA'] * len(motions)) | |
else: | |
# Process GPT labels in parallel | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
# Submit tasks for processing both actions in parallel | |
if action1 in GPT_LABELS_LIST: | |
future_act1 = executor.submit(lambda: parse_gpt_labels(GPT_LABELS_LIST[action1])) | |
else: | |
future_act1 = executor.submit(search_gpt_semantic, action1, 1) | |
if action2 in GPT_LABELS_LIST: | |
future_act2 = executor.submit(lambda: parse_gpt_labels(GPT_LABELS_LIST[action2])) | |
else: | |
future_act2 = executor.submit(search_gpt_semantic, action2, 1) | |
# Get results | |
try: | |
if action1 in GPT_LABELS_LIST: | |
gpt_act1 = future_act1.result() | |
else: | |
best_key, best_text, best_sim = future_act1.result() | |
if not best_text: | |
result = failure_update("Action 1 not recognized.") | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
gpt_act1 = parse_gpt_labels(best_text) | |
if action2 in GPT_LABELS_LIST: | |
gpt_act2 = future_act2.result() | |
else: | |
best_key, best_text, best_sim = future_act2.result() | |
if not best_text: | |
result = failure_update("Action 2 not recognized.") | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
gpt_act2 = parse_gpt_labels(best_text) | |
except Exception as e: | |
print(f"Error processing GPT labels: {e}") | |
result = failure_update("Error processing actions. Please try again.") | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
# Check for conflicts | |
if bool(set(gpt_act1) & set(gpt_act2)): | |
failure_message = "Incompatible action pair. Please select actions that are not conflicting." | |
result = failure_update(failure_message) | |
else: | |
motions, sem_mot_scores = search_motions_combined(action1, action2, n_motions) | |
result = update_videos(motions, n_motions, sem_mot_scores) | |
except Exception as e: | |
print(f"Error in handle_interaction: {e}") | |
result = failure_update("An error occurred. Please try again.") | |
# Cache the result | |
SEARCH_RESULTS_CACHE[cache_key] = result | |
return result | |
# Custom CSS | |
CUSTOM_CSS = """ | |
button.compact-button { | |
width: auto !important; /* Let the button shrink to fit text */ | |
min-width: unset !important; /* Remove any forced min-width */ | |
padding: 4px 8px !important; | |
font-size: 20px !important; | |
line-height: 1 !important; | |
} | |
""" | |
# Build the Gradio UI | |
with gr.Blocks(css=CUSTOM_CSS) as demo: | |
gr.HTML(WEBSITE) | |
with gr.Tabs(): | |
with gr.Tab("SINC-Synth exploration"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
action1_textbox = gr.Textbox( | |
label="Action 1", | |
placeholder="Select an action or type the first action, e.g. 'walk'", | |
) | |
create_example_buttons(action1_textbox, ACTION_EXAMPLES[:5]) | |
with gr.Column(): | |
action2_textbox = gr.Textbox( | |
label="Action 2", | |
placeholder="Select an action or type the second action, e.g. 'wave'" | |
) | |
create_example_buttons(action2_textbox, ACTION_EXAMPLES[5:]) | |
with gr.Column(): | |
n_motions_radio = gr.Radio( | |
choices=[2, 4, 6, 8], | |
label="Number of motions to be shown from the SINC-Synthetic data", | |
value=2, | |
show_label=True, | |
container=True, | |
) | |
with gr.Row(): | |
search_button = gr.Button("Search", | |
elem_classes=["compact-button"]) | |
random_button = gr.Button("Random", | |
elem_classes=["compact-button"]) | |
# up to 8 motions | |
motion_components = [] | |
videos_per_row = 2 | |
max_motions = 8 | |
num_rows = (max_motions + videos_per_row - 1) // videos_per_row # Ceiling division | |
for i in range(num_rows): | |
with gr.Row(): | |
for j in range(videos_per_row): | |
motion_index = i * videos_per_row + j | |
if motion_index >= max_motions: | |
break | |
with gr.Column(): | |
video_combo = gr.Video( | |
label=f"Motion {motion_index + 1}", | |
visible=False, | |
width=640, | |
height=512 | |
) | |
with gr.Row(): | |
video_a = gr.Video( | |
label="Motion A", | |
visible=False, | |
width=320, | |
height=256 | |
) | |
video_b = gr.Video( | |
label="Motion B", | |
visible=False, | |
width=320, | |
height=256 | |
) | |
text = gr.Textbox( | |
visible=False, | |
interactive=False | |
) | |
motion_components.extend([video_combo, video_a, video_b, text]) | |
search_button.click( | |
fn=handle_interaction, | |
inputs=[action1_textbox, action2_textbox, n_motions_radio], | |
outputs=motion_components | |
) | |
random_button.click( | |
fn=lambda n: handle_interaction("", "", n), | |
inputs=[n_motions_radio], | |
outputs=motion_components | |
) | |
gr.HTML((""" | |
<div style='text-align: center; margin-top: 20px; font-size: 16px;'> | |
<p><sup>**</sup>Our data in the official paper are using on the fly compositions, | |
which means than are not computed and filtered offline. This is a minimally | |
processed version of ~124k motions ranging between 3-7 seconds.</p> | |
<p>Made with ❤️ by Nikos Athanasiou</p> | |
</div> | |
""") | |
) | |
with gr.Tab("Simultaneous Motion Generation with SINC model"): | |
gr.HTML("<h2>Motion Generation from Text [TBD. Currenly under construction.]</h2>") | |
with gr.Row(): | |
text_input_gen = gr.Textbox( | |
label="Motion Description", | |
placeholder="Describe the motion, e.g. 'A person walking forward while waving'" | |
) | |
create_example_buttons(text_input_gen, ACTION_EXAMPLES_SIMULTANEOUS) | |
generate_button = gr.Button("Generate Motion", | |
elem_classes=["compact-button"]) | |
with gr.Row(): | |
output_video = gr.Video( | |
label="Generated Motion", | |
visible=True, | |
width=320, | |
height=180 | |
) | |
def generate_motion(text): | |
# Placeholder function - replace with actual model inference | |
# Return None instead of a string path to avoid schema conversion issues | |
return None | |
generate_button.click( | |
fn=generate_motion, | |
inputs=[text_input_gen], | |
outputs=[output_video] | |
) | |
# Initialize TF-IDF at startup | |
initialize_tfidf() | |
# Precompute synonyms for common action words | |
print("Precomputing synonyms for common action words...") | |
for action in ACTION_SYNONYMS: | |
get_synonyms(action) | |
# Video prefetching | |
def prefetch_videos(): | |
"""Prefetch some common videos to warm up the cache""" | |
print("Prefetching common videos...") | |
try: | |
# Get a small set of common videos to prefetch | |
random_motions = get_random_motions(4) | |
common_actions = [("walk", "wave"), ("sit", "bow"), ("jump", "throw")] | |
with ThreadPoolExecutor(max_workers=8) as executor: | |
futures = [] | |
# Add random motions to prefetch list | |
for motion in random_motions: | |
futures.append(executor.submit( | |
lambda m: (m["motion_combo"], m["motion_a"], m["motion_b"]), | |
motion | |
)) | |
# Add common action combinations | |
for act1, act2 in common_actions: | |
motions, _ = search_motions_combined(act1, act2, 2) | |
if motions: | |
for motion in motions: | |
futures.append(executor.submit( | |
lambda m: (m["motion_combo"], m["motion_a"], m["motion_b"]), | |
motion | |
)) | |
# Wait for all prefetch operations to complete | |
for future in futures: | |
future.result() | |
print("Video prefetching complete") | |
except Exception as e: | |
print(f"Error in video prefetching: {e}") | |
# Start prefetching in a separate thread to not block startup | |
threading.Thread(target=prefetch_videos).start() | |
# Print ready message | |
print("Demo ready! Optimized code running with exact matching prioritized over synonym-enhanced TF-IDF similarity.") | |
# Launch the demo | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |