File size: 3,834 Bytes
244b0b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import json
import pyzipper
from typing import Dict, Any, Optional
from voyager import Index, Space, StorageDataType

class DataManager:
    def __init__(self, faces_path: str = "data/faces.json", 
                 performers_zip: str = "data/persons.zip",
                 facenet_index_path: str = "data/face_facenet.voy",
                 arc_index_path: str = "data/face_arc.voy"):
        """
        Initialize the data manager.
        
        Parameters:
        faces_path: Path to the faces.json file
        performers_zip: Path to the performers zip file
        facenet_index_path: Path to the facenet index file
        arc_index_path: Path to the arc index file
        """
        self.faces_path = faces_path
        self.performers_zip = performers_zip
        self.facenet_index_path = facenet_index_path
        self.arc_index_path = arc_index_path
        
        # Initialize indices
        self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
        self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
        
        # Load data
        self.faces = {}
        self.performer_db = {}
        self.load_data()
    
    def load_data(self):
        """Load all data from files"""
        self._load_faces()
        self._load_performer_db()
        self._load_indices()
    
    def _load_faces(self):
        """Load faces from JSON file"""
        try:
            with open(self.faces_path, 'r') as f:
                self.faces = json.load(f)
        except Exception as e:
            print(f"Error loading faces: {e}")
            self.faces = {}
    
    def _load_performer_db(self):
        """Load performer database from encrypted zip file"""
        try:
            with pyzipper.AESZipFile(self.performers_zip) as zf:
                password = os.getenv("VISAGE_KEY", "").encode('ascii')
                zf.setpassword(password)
                self.performer_db = json.loads(zf.read('performers.json'))
        except Exception as e:
            print(f"Error loading performer database: {e}")
            self.performer_db = {}
    
    def _load_indices(self):
        """Load face recognition indices"""
        try:
            with open(self.arc_index_path, 'rb') as f:
                self.index_arc = self.index_arc.load(f)
            
            with open(self.facenet_index_path, 'rb') as f:
                self.index_facenet = self.index_facenet.load(f)
        except Exception as e:
            print(f"Error loading indices: {e}")
    
    def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
        """
        Get performer information from the database
        
        Parameters:
        stash_id: Stash ID of the performer
        confidence: Confidence score (0-1)
        
        Returns:
        Dictionary with performer information or None if not found
        """
        performer = self.performer_db.get(stash_id, [])
        if not performer:
            return None
        
        confidence_int = int(confidence * 100)
        return {  
            'id': stash_id,
            "name": performer['name'],
            "confidence": confidence_int,
            'image': performer['image'],
            'country': performer['country'],
            'hits': 1,
            'distance': confidence_int,
            'performer_url': f"https://stashdb.org/performers/{stash_id}"
        }
    
    def query_facenet_index(self, embedding, limit):
        """Query the facenet index with an embedding"""
        return self.index_facenet.query(embedding, limit)
    
    def query_arc_index(self, embedding, limit):
        """Query the arc index with an embedding"""
        return self.index_arc.query(embedding, limit)