RelightVid / misc_utils /video_ptp_utils.py
aleafy's picture
Start fresh
0a63786
import torch
from modules.damo_text_to_video.unet_sd import UNetSD
from misc_utils.train_utils import instantiate_from_config
from omegaconf import OmegaConf
from modules.damo_text_to_video.text_model import FrozenOpenCLIPEmbedder
from typing import List, Tuple, Union
from diff_match_patch import diff_match_patch
import difflib
from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete
def get_models_of_damo_model(
unet_config: str,
unet_ckpt: str,
vae_config: str,
vae_ckpt: str,
text_model_ckpt: str,
):
vae_conf = OmegaConf.load(vae_config)
unet_conf = OmegaConf.load(unet_config)
vae = instantiate_from_config(vae_conf.vae)
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'))
vae = vae.half().cuda()
unet = UNetSD(**unet_conf.model.model_cfg)
unet.load_state_dict(torch.load(unet_ckpt, map_location='cpu'))
unet = unet.half().cuda()
text_model = FrozenOpenCLIPEmbedder(version=text_model_ckpt, layer='penultimate')
text_model = text_model.half().cuda()
return vae, unet, text_model
def compute_diff_old(old_sentence: str, new_sentence: str) -> List[Tuple[Union[Text, Edit, Insert, Delete], str, str]]:
dmp = diff_match_patch()
diff = dmp.diff_main(old_sentence, new_sentence)
dmp.diff_cleanupSemantic(diff)
result = []
i = 0
while i < len(diff):
op, data = diff[i]
if op == 0: # Equal
# result.append((Text, data, data))
result.append(Text(text=data))
elif op == -1: # Delete
if i + 1 < len(diff) and diff[i + 1][0] == 1: # If next operation is Insert
result.append(Edit(old=data, new=diff[i + 1][1])) # Append as Edit operation
i += 1 # Skip next operation because we've handled it here
else:
result.append(Delete(text=data))
elif op == 1: # Insert
if i == 0 or diff[i - 1][0] != -1: # If previous operation wasn't Delete
result.append(Insert(text=data))
i += 1
return result
def compute_diff(old_sentence: str, new_sentence: str) -> List[Union[Text, Edit, Insert, Delete]]:
differ = difflib.Differ()
diff = list(differ.compare(old_sentence.split(), new_sentence.split()))
result = []
i = 0
while i < len(diff):
if diff[i][0] == ' ': # Equal
equal_words = [diff[i][2:]]
while i + 1 < len(diff) and diff[i + 1][0] == ' ':
i += 1
equal_words.append(diff[i][2:])
result.append(Text(text=' '.join(equal_words)))
elif diff[i][0] == '-': # Delete
deleted_words = [diff[i][2:]]
while i + 1 < len(diff) and diff[i + 1][0] == '-':
i += 1
deleted_words.append(diff[i][2:])
result.append(Delete(text=' '.join(deleted_words)))
elif diff[i][0] == '+': # Insert
inserted_words = [diff[i][2:]]
while i + 1 < len(diff) and diff[i + 1][0] == '+':
i += 1
inserted_words.append(diff[i][2:])
result.append(Insert(text=' '.join(inserted_words)))
i += 1
# Post-process to merge adjacent inserts and deletes into edits
i = 0
while i < len(result) - 1:
if isinstance(result[i], Delete) and isinstance(result[i+1], Insert):
result[i:i+2] = [Edit(old=result[i].text, new=result[i+1].text)]
elif isinstance(result[i], Insert) and isinstance(result[i+1], Delete):
result[i:i+2] = [Edit(old=result[i+1].text, new=result[i].text)]
else:
i += 1
return result