File size: 1,685 Bytes
1a9dcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
from hashlib import sha256
from typing import List, Tuple, Dict, Any
import math
import re

EXTRACTION_PROMPT = "All attempted answers, correct and incorrect"

def regex_compare(a: str, b: str) -> bool:
    """
    Compare all alphanum chars in a and b
    """
    a_chars = "".join(re.findall(r'\w', a))
    b_chars = "".join(re.findall(r'\w', b))
    return a_chars == b_chars or a_chars in b_chars

def print_info(db_connection):
    tables = db_connection.execute("SHOW TABLES").fetchall()
    # Iterate over each table and print its name and columns
    for table in tables:
        table_name = table[0]
        print(f"Table: {table_name}")
        
        # Get the columns for this table
        columns = db_connection.execute(f"DESCRIBE {table_name}").fetchall()
        
        # Print the column details
        for column in columns:
            print(f"  - {column[0]} ({column[1]})")  # column[0] is the column name, column[1] is the data type

        print()  # Add a blank line between tables for readability

def query_format_models(models: List[str]) -> str:
    """
    Format model names for the SQL query `WHERE <this_model> IN <models>
    """
    return "('" + "','".join(["completions-"+m for m in models]) + "')"

def get_completions(db_connector, query: str, **query_kwargs) -> pd.DataFrame:
    """
    If model has multiple completions, use only first.
    """
    df = db_connector.sql(query.format(**query_kwargs)).df()
    df = df.groupby(["prompt_id", "model", "solution", "prompt"]).agg({"completion":"first"}).reset_index()
    return df

def sha256_hash(text: str) -> str:
    return sha256(bytes(text, "utf-8")).hexdigest()