|
import concurrent.futures |
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional, Set, Tuple |
|
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
from ctm.configs import BaseConsciousnessTuringMachineConfig |
|
from ctm.processors import BaseProcessor |
|
from ctm.supervisors import BaseSupervisor |
|
|
|
|
|
class BaseConsciousnessTuringMachine(object): |
|
def __init__(self, ctm_name: Optional[str] = None) -> None: |
|
super().__init__() |
|
if ctm_name: |
|
self.config = BaseConsciousnessTuringMachineConfig.from_ctm( |
|
ctm_name |
|
) |
|
else: |
|
self.config = BaseConsciousnessTuringMachineConfig() |
|
self.processor_list: List[Dict[str, Any]] = [] |
|
self.processor_group_map: Dict[str, str] = defaultdict(str) |
|
self.load_ctm() |
|
|
|
def __call__( |
|
self, |
|
query: str, |
|
text: Optional[str] = None, |
|
image: Optional[Any] = None, |
|
audio: Optional[Any] = None, |
|
video_frames: Optional[Any] = None, |
|
) -> Tuple[str, float]: |
|
return self.forward(query, text, image, audio, video_frames) |
|
|
|
def add_processor( |
|
self, processor_name: str, group_name: Optional[str] = 'default_group' |
|
) -> None: |
|
processor_instance = BaseProcessor(processor_name) |
|
self.processor_list.append( |
|
{ |
|
"processor_name": processor_name, |
|
"processor_instance": processor_instance, |
|
} |
|
) |
|
if group_name: |
|
self.processor_group_map[processor_name] = group_name |
|
|
|
def add_supervisor(self, supervisor_name: str) -> None: |
|
supervisor_instance = BaseSupervisor(supervisor_name) |
|
self.supervisor: Dict[str, Any] = { |
|
"supervisor_name": supervisor_name, |
|
"supervisor_instance": supervisor_instance, |
|
} |
|
|
|
@staticmethod |
|
def ask_processor( |
|
processor: Dict[str, Any], |
|
query: str, |
|
text: Optional[str] = None, |
|
image: Optional[Any] = None, |
|
audio: Optional[Any] = None, |
|
video_frames: Optional[Any] = None, |
|
) -> Dict[str, Any]: |
|
processor_instance = processor["processor_instance"] |
|
processor_name = processor["processor_name"] |
|
print(processor_name) |
|
gist, score = processor_instance.ask( |
|
query=query, |
|
text=text, |
|
image=image, |
|
audio=audio, |
|
video_frames=video_frames, |
|
) |
|
return {"name": processor_name, "gist": gist, "score": score} |
|
|
|
def ask_processors( |
|
self, |
|
query: str, |
|
text: Optional[str] = None, |
|
image: Optional[Any] = None, |
|
audio: Optional[Any] = None, |
|
video_frames: Optional[Any] = None, |
|
) -> Dict[str, Dict[str, Any]]: |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
futures = [ |
|
executor.submit( |
|
self.ask_processor, |
|
processor, |
|
query, |
|
text, |
|
image, |
|
audio, |
|
video_frames, |
|
) |
|
for processor in self.processor_list |
|
] |
|
results = [ |
|
future.result() |
|
for future in concurrent.futures.as_completed(futures) |
|
] |
|
|
|
output: Dict[str, Dict[str, Any]] = {} |
|
for result in results: |
|
output[result["name"]] = { |
|
"gist": result["gist"], |
|
"score": result["score"], |
|
} |
|
|
|
assert len(output) == len(self.processor_list) |
|
return output |
|
|
|
def uptree_competition( |
|
self, processor_output: Dict[str, Dict[str, Any]] |
|
) -> Dict[str, Any]: |
|
|
|
gists: List[str] = [] |
|
scores: List[float] = [] |
|
names: List[str] = [] |
|
|
|
for name, info in processor_output.items(): |
|
gists.append(info["gist"]) |
|
scores.append(info["score"]) |
|
names.append(name) |
|
|
|
|
|
unique_groups: Set[str] = set(self.processor_group_map.values()) |
|
|
|
|
|
best_processor_by_group: Dict[str, Tuple[Optional[str], float]] = { |
|
group: ( |
|
None, |
|
float("-inf"), |
|
) |
|
for group in unique_groups |
|
} |
|
|
|
|
|
for name, score in zip(names, scores): |
|
group = self.processor_group_map.get(name, "") |
|
if score > best_processor_by_group[group][1]: |
|
best_processor_by_group[group] = (name, score) |
|
|
|
|
|
best_overall: Tuple[Optional[str], float] = max( |
|
best_processor_by_group.values(), key=lambda x: x[1] |
|
) |
|
best_name: Optional[str] = best_overall[0] |
|
|
|
if best_name is None: |
|
raise ValueError( |
|
"No valid processor found." |
|
) |
|
|
|
index: int = names.index( |
|
best_name |
|
) |
|
|
|
winning_info: Dict[str, Any] = { |
|
"name": best_name, |
|
"gist": gists[index], |
|
"score": scores[index], |
|
} |
|
return winning_info |
|
|
|
def ask_supervisor( |
|
self, query: str, processor_info: Dict[str, Any] |
|
) -> Tuple[str, float]: |
|
final_answer, score = self.supervisor["supervisor_instance"].ask( |
|
query, processor_info["gist"] |
|
) |
|
return final_answer, score |
|
|
|
def downtree_broadcast(self, winning_output: Dict[str, str]) -> None: |
|
winning_processor_name = winning_output["name"] |
|
winning_processor_gist = winning_output["gist"] |
|
for processor in self.processor_list: |
|
if processor["processor_name"] != winning_processor_name: |
|
processor["processor_instance"].update_info( |
|
winning_processor_gist |
|
) |
|
return |
|
|
|
def calc_processor_sim( |
|
self, processor_output: Dict[str, Dict[str, str]] |
|
) -> Any: |
|
processor_gists = [info["gist"] for info in processor_output.values()] |
|
tfidf_vectorizer = TfidfVectorizer() |
|
tfidf_matrix = tfidf_vectorizer.fit_transform(processor_gists) |
|
cosine_sim = cosine_similarity(tfidf_matrix, tfidf_matrix) |
|
return cosine_sim |
|
|
|
def link_form(self, processor_output: Dict[str, Dict[str, str]]) -> None: |
|
sim = self.calc_processor_sim(processor_output) |
|
print(sim) |
|
|
|
|
|
link_threshold = 0.5 |
|
for i in range(len(sim)): |
|
for j in range(i + 1, len(sim)): |
|
if sim[i][j] > 0.5: |
|
processor1_name = list(processor_output.keys())[i] |
|
processor2_name = list(processor_output.keys())[j] |
|
|
|
|
|
group1 = self.processor_group_map[processor1_name] |
|
group2 = self.processor_group_map[processor2_name] |
|
|
|
group1_count = sum( |
|
[ |
|
1 |
|
for group in self.processor_group_map.values() |
|
if group == group1 |
|
] |
|
) |
|
group2_count = sum( |
|
[ |
|
1 |
|
for group in self.processor_group_map.values() |
|
if group == group2 |
|
] |
|
) |
|
|
|
group_name = ( |
|
group1 if group1_count > group2_count else group2 |
|
) |
|
self.processor_group_map[processor1_name] = group_name |
|
self.processor_group_map[processor2_name] = group_name |
|
return |
|
|
|
def processor_fuse( |
|
self, infos: List[str], scores: List[float] |
|
) -> Tuple[List[str], List[float]]: |
|
return infos, scores |
|
|
|
def forward( |
|
self, |
|
query: str, |
|
text: Optional[str] = None, |
|
image: Optional[Any] = None, |
|
audio: Optional[Any] = None, |
|
video_frames: Optional[Any] = None, |
|
) -> Tuple[str, float]: |
|
answer_threshold = 0.5 |
|
max_iter = 3 |
|
|
|
for i in range(max_iter): |
|
print("start the {}-th iteration".format(i + 1)) |
|
processor_output = self.ask_processors( |
|
query=query, |
|
text=text, |
|
image=image, |
|
audio=audio, |
|
video_frames=video_frames, |
|
) |
|
winning_output = self.uptree_competition(processor_output) |
|
answer, score = self.ask_supervisor(query, winning_output) |
|
if score > answer_threshold: |
|
break |
|
else: |
|
self.downtree_broadcast(winning_output) |
|
self.link_form(processor_output) |
|
return answer, score |
|
|
|
def load_ctm(self) -> None: |
|
for ( |
|
group_name, |
|
processor_list, |
|
) in self.config.groups_of_processors.items(): |
|
for processor_name in processor_list: |
|
self.add_processor(processor_name, group_name=group_name) |
|
self.add_supervisor(self.config.supervisor) |
|
|