File size: 1,640 Bytes
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Utility functions for interacting with the SQLite database."""
import io
import logging
import sqlite3
from typing import Any, List, Optional

import torch


def select_tensors(
        db_path: str,
        table_name: str,
        keys: List[str] = ['layer', 'pooling_method', 'tensor_dim', 'tensor'],
        sql_where: Optional[str] = None,
        ) -> List[Any]:
    """Select and return all tensors from the specified SQLite database and table.

    Args:
        db_path (str): Path to the SQLite database file.
        table_name (str): Name of the table to query.
        keys (List[str]): List of keys to select from the database.
        sql_where (str): Optional SQL WHERE clause to filter results.

    Returns:
        List[Any]: A list of tensors retrieved from the database.
    """
    if 'tensor' not in keys:
        logging.warning("'tensor' key should be included to retrieve tensors; automatically adding it.")
        keys.append('tensor')
    final_results = []
    with sqlite3.connect(db_path) as connection:
        cursor = connection.cursor()
        query = f'SELECT {", ".join(keys)} FROM {table_name}'
        if sql_where:
            assert sql_where.strip().lower().startswith('where'), "sql_where should start with 'WHERE'"
            query += f' {sql_where}'
        cursor.execute(query)
        results = cursor.fetchall()
        for row in results:
            result_item = {key: value for key, value in zip(keys, row)}
            result_item['tensor'] = torch.load(io.BytesIO(result_item['tensor']), map_location='cpu')
            final_results.append(result_item)
    return final_results