Spaces:
Runtime error
Runtime error
from typing import Tuple, List | |
import nltk | |
import numpy as np | |
import torch | |
from sklearn.cluster import KMeans | |
from constants import STYLE_INDEX, STRUCT_INDEX | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
""" | |
Self-segmentation technique taken from Prompt Mixing: https://github.com/orpatashnik/local-prompt-mixing | |
""" | |
class Segmentor: | |
def __init__(self, prompt: str, object_nouns: List[str], num_segments: int = 5, res: int = 32): | |
self.prompt = prompt | |
self.num_segments = num_segments | |
self.resolution = res | |
self.object_nouns = object_nouns | |
tokenized_prompt = nltk.word_tokenize(prompt) | |
forbidden_words = [word.upper() for word in ["photo", "image", "picture"]] | |
self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) | |
if pos[:2] == 'NN' and word.upper() not in forbidden_words] | |
def update_attention(self, attn, is_cross): | |
res = int(attn.shape[2] ** 0.5) | |
if is_cross: | |
if res == 16: | |
self.cross_attention_32 = attn | |
elif res == 32: | |
self.cross_attention_64 = attn | |
else: | |
if res == 32: | |
self.self_attention_32 = attn | |
elif res == 64: | |
self.self_attention_64 = attn | |
def __call__(self, *args, **kwargs): | |
clusters = self.cluster() | |
cluster2noun = self.cluster2noun(clusters) | |
return cluster2noun | |
def cluster(self, res: int = 32): | |
np.random.seed(1) | |
self_attn = self.self_attention_32 if res == 32 else self.self_attention_64 | |
style_attn = self_attn[STYLE_INDEX].mean(dim=0).cpu().numpy() | |
style_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(style_attn) | |
style_clusters = style_kmeans.labels_.reshape(res, res) | |
struct_attn = self_attn[STRUCT_INDEX].mean(dim=0).cpu().numpy() | |
struct_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(struct_attn) | |
struct_clusters = struct_kmeans.labels_.reshape(res, res) | |
return style_clusters, struct_clusters | |
def cluster2noun(self, clusters, cross_attn, attn_index): | |
result = {} | |
res = int(cross_attn.shape[2] ** 0.5) | |
nouns_indices = [index for (index, word) in self.nouns] | |
cross_attn = cross_attn[attn_index].mean(dim=0).reshape(res, res, -1) | |
nouns_maps = cross_attn.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] | |
normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1) | |
for i in range(nouns_maps.shape[-1]): | |
curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1) | |
normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() | |
max_score = 0 | |
all_scores = [] | |
for c in range(self.num_segments): | |
cluster_mask = np.zeros_like(clusters) | |
cluster_mask[clusters == c] = 1 | |
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
all_scores.append(max(scores)) | |
max_score = max(max(scores), max_score) | |
all_scores.remove(max_score) | |
mean_score = sum(all_scores) / len(all_scores) | |
for c in range(self.num_segments): | |
cluster_mask = np.zeros_like(clusters) | |
cluster_mask[clusters == c] = 1 | |
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > 1.4 * mean_score else "BG" | |
return result | |
def create_mask(self, clusters, cross_attention, attn_index): | |
cluster2noun = self.cluster2noun(clusters, cross_attention, attn_index) | |
mask = clusters.copy() | |
obj_segments = [c for c in cluster2noun if cluster2noun[c][1] in self.object_nouns] | |
for c in range(self.num_segments): | |
mask[clusters == c] = 1 if c in obj_segments else 0 | |
return torch.from_numpy(mask).to("cuda") | |
def get_object_masks(self) -> Tuple[torch.Tensor]: | |
clusters_style_32, clusters_struct_32 = self.cluster(res=32) | |
clusters_style_64, clusters_struct_64 = self.cluster(res=64) | |
mask_style_32 = self.create_mask(clusters_style_32, self.cross_attention_32, STYLE_INDEX) | |
mask_struct_32 = self.create_mask(clusters_struct_32, self.cross_attention_32, STRUCT_INDEX) | |
mask_style_64 = self.create_mask(clusters_style_64, self.cross_attention_64, STYLE_INDEX) | |
mask_struct_64 = self.create_mask(clusters_struct_64, self.cross_attention_64, STRUCT_INDEX) | |
return mask_style_32, mask_struct_32, mask_style_64, mask_struct_64 | |