In [1]:
!pip install -q transformers datasets

In [2]:
import valohai

valohai.prepare(
    step='train-model',
    image='pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime',    
    default_parameters={        
        'epochs': 10,
        'model': 'google/mt5-small',
    }
)
output_path = valohai.outputs().path('model')



In [3]:
import torch

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch_device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [4]:
! nvidia-smi

Mon Mar 27 07:02:29 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06   Driver Version: 470.129.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:05:00.0 Off |                  Off |
| 30%   31C    P8    15W / 300W |      3MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
from datasets import load_dataset

train_data = load_dataset('wikisql', split='train+validation')
test_data = load_dataset('wikisql', split='test')

Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)


In [6]:
def format_dataset(example):
 return {'input': 'translate to SQL: ' + example['question'] + ' table ID: ' + ', '.join(str(x) for x in example['table']['header']), 'target': example['sql']['human_readable']}

In [7]:
train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)

Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-1ea43016a8276f85.arrow


In [8]:
test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)

Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-b9e3da7e258b7aa5.arrow


In [9]:
!pip install sentencepiece
!pip install protobuf==3.20.*

Collecting protobuf==3.20.*
  Downloading protobuf-3.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
[K     |â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.0 MB 4.4 MB/s eta 0:00:01
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 4.22.1
    Uninstalling protobuf-4.22.1:
      Successfully uninstalled protobuf-4.22.1
Successfully installed protobuf-3.20.3


In [10]:
CKPT = valohai.parameters("model").value
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained(CKPT)
model = T5ForConditionalGeneration.from_pretrained(CKPT).to(torch_device)

  "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.
Downloading pytorch_model.bin: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.20G/1.20G [00:16<00:00, 72.6MB/s]
Downloading (â€¦)neration_config.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 147/147 [00:00<00:00, 31.2kB/s]


In [11]:
# map article and summary len to dict as well as if sample is longer than 512 tokens
def map_to_length(x):
  x["input_len"] = len(tokenizer(x["input"]).input_ids)
  x["input_longer_256"] = int(x["input_len"] > 256)
  x["input_longer_128"] = int(x["input_len"] > 128)
  x["input_longer_64"] = int(x["input_len"] > 64)
  x["out_len"] = len(tokenizer(x["target"]).input_ids)
  x["out_longer_256"] = int(x["out_len"] > 256)
  x["out_longer_128"] = int(x["out_len"] > 128)
  x["out_longer_64"] = int(x["out_len"] > 64)
  return x

sample_size = 10000
data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)

def compute_and_print_stats(x):
  if len(x["input_len"]) == sample_size:
    print(
        "Input Mean: {}, %-Input > 256:{},  %-Input > 128:{}, %-Input > 64:{} Output Mean:{}, %-Output > 256:{}, %-Output > 128:{}, %-Output > 64:{}".format(
            sum(x["input_len"]) / sample_size,
            sum(x["input_longer_256"]) / sample_size,
            sum(x["input_longer_128"]) / sample_size,
            sum(x["input_longer_64"]) / sample_size,   
            sum(x["out_len"]) / sample_size,
            sum(x["out_longer_256"]) / sample_size,
            sum(x["out_longer_128"]) / sample_size,
            sum(x["out_longer_64"]) / sample_size,
        )
    )

output = data_stats.map(
  compute_and_print_stats, 
  batched=True,
  batch_size=-1,
)    

                                                                                

Input Mean: 47.4798, %-Input > 256:0.0,  %-Input > 128:0.001, %-Input > 64:0.0684 Output Mean:19.4288, %-Output > 256:0.0, %-Output > 128:0.0002, %-Output > 64:0.0004




In [12]:
# tokenize the examples
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], pad_to_max_length=True, max_length=100, truncation=True)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target'], pad_to_max_length=True, max_length=100, truncation=True)

    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }

    return encodings 

In [13]:
train_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)
test_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)

columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']

train_data.set_format(type='torch', columns=columns)
test_data.set_format(type='torch', columns=columns)

                                                                   

In [14]:
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments

In [15]:
# set training arguments - Feel free to adapt it
training_args = Seq2SeqTrainingArguments(
    output_dir=output_path,
    per_device_train_batch_size=16,
    num_train_epochs=valohai.parameters("epochs").value,
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    evaluation_strategy="epoch",
    do_train=True,
    do_eval=True,
    logging_steps=500,
    save_strategy="epoch",
    #save_steps=1000,
    #eval_steps=1000,
    overwrite_output_dir=True,
    save_total_limit=1,
    load_best_model_at_end=True,
    push_to_hub=False
    #fp16=True, 
)

In [16]:
! pip install -q rouge_score

In [17]:
from datasets import load_metric
rouge = load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

  


In [18]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=test_data,
)

In [19]:
trainer.evaluate()

{'eval_loss': 42.09397506713867,
 'eval_rouge2_precision': 0.002,
 'eval_rouge2_recall': 0.0009,
 'eval_rouge2_fmeasure': 0.0012,
 'eval_runtime': 77.1,
 'eval_samples_per_second': 205.94,
 'eval_steps_per_second': 12.879}

In [20]:
trainer.train()



Epoch,Training Loss,Validation Loss,Rouge2 Precision,Rouge2 Recall,Rouge2 Fmeasure
1,0.1032,0.051379,0.901,0.8173,0.8497
2,0.0658,0.038024,0.9174,0.8382,0.8693
3,0.0547,0.033012,0.923,0.8441,0.875
4,0.0459,0.030169,0.9286,0.8473,0.88
5,0.0401,0.02873,0.9308,0.8498,0.8824
6,0.0393,0.027651,0.9318,0.8507,0.8833
7,0.036,0.027332,0.9329,0.852,0.8846
8,0.0335,0.026453,0.9331,0.8523,0.8849
9,0.0328,0.026168,0.9342,0.8531,0.8858
10,0.0323,0.026122,0.9343,0.8531,0.8859


TrainOutput(global_step=40490, training_loss=0.2895770631857524, metrics={'train_runtime': 7927.7437, 'train_samples_per_second': 81.708, 'train_steps_per_second': 5.107, 'total_flos': 6.689509761024e+16, 'train_loss': 0.2895770631857524, 'epoch': 10.0})

In [21]:
trainer.save_model(output_path)
tokenizer.save_pretrained(output_path)

('/valohai/outputs/model/tokenizer_config.json',
 '/valohai/outputs/model/special_tokens_map.json',
 '/valohai/outputs/model/spiece.model',
 '/valohai/outputs/model/added_tokens.json',
 '/valohai/outputs/model/tokenizer.json')

In [22]:
CKPT = output_path

tokenizer = AutoTokenizer.from_pretrained(CKPT, local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained(CKPT, local_files_only=True).to(torch_device)

In [23]:
!pip install sentencepiece
!pip install pandasql
!pip install python-Levenshtein
!pip install sacremoses

Collecting pandasql
  Downloading pandasql-0.7.3.tar.gz (26 kB)
Collecting sqlalchemy
  Downloading SQLAlchemy-2.0.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
[K     |â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2.7 MB 6.9 MB/s eta 0:00:01
Collecting greenlet!=0.4.17
  Downloading greenlet-2.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (566 kB)
[K     |â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 566 kB 111.8 MB/s eta 0:00:01
Building wheels for collected packages: pandasql
  Building wheel for pandasql (setup.py) ... [?25ldone
[?25h  Created wheel for pandasql: filename=pandasql-0.7.3-py3-none-any.whl size=26782 sha256=110b83989487b7b983fb80e3ede92a519027d1ddfd6988e2012175878ee93522
  Stored in directory: /root/.cache/pip/wheels/5c/4b/ec/41f4e116c8053c3654e2c2a47c62b4fca34cc67ef7b55deb7f
Successfully built pandasql
Installing collected

In [24]:
import Levenshtein
import re
from collections import Counter

#Get columns in query
def get_columns_name_in_query(query):
  cols_from_select = get_cols_name_for_select(query) 
  cols_from_where = get_cols_name_for_where(query)
  return list(set(cols_from_select + cols_from_where))

#Translate query in natural language from italian to english (input: string; output: string)
def translate2en(query):
  translated = model_t.generate(**tokenizer_t(query, return_tensors="pt", padding=True))
  query = [tokenizer_t.decode(t, skip_special_tokens=True) for t in translated]
  return query

# Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string)
def replace_nonalphanumeric_chars_with_us(l):
  well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l]
  return well_defined

# Adjust column name using columns name from original table (input: column name in SQL query (string), 
#list of columns names from table (string); output: corrected column name (if needed) (string))
def adjust_col_name(col_name, columns_available): 
  columns_available = [x.upper() for x in columns_available]
  if col_name.upper() in set(columns_available):
    return col_name
  else:
    max = -100
    most_similar_column = 'column123456789011'
    for col in columns_available:      
      score = -Levenshtein.distance(col_name, col)               
      if score > max:
        most_similar_column = col  
        max = score           
    return most_similar_column

def min_positive(a,b):
  if (b < a) and (b > 0): return b
  else: return a

#Return corrected syntax for aggregator operators (input: string; output: string)
#USE only for wikisql dataset
def aggregator_parser(query): 
  query = query.upper() 
  if query.find('SELECT MAX') > -1:
    end = min_positive(query.find('FROM'), query.find(','))    
    adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
    return adjusted_query
  elif query.find('SELECT COUNT') > -1:
    end = min_positive(query.find('FROM'), query.find(','))
    adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ')
    return adjusted_query
  elif query.find('SELECT MIN') > -1:
    end = min_positive(query.find('FROM'), query.find(','))
    adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
    return adjusted_query
  #elif query.find('SELECT DISTINCT') > -1:
   #end = query.find('FROM')
    #adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ')
    #return adjusted_query
  else: 
    return query

#Return columns name from SELECT operator (input: string; output: list)
def get_cols_name_for_select(query):
  query = query.upper()  
  if query.find('SELECT DISTINCT') > -1:
    end = query.find('FROM')
    cols = query[15:end-1].split(',')
  elif query.find('SELECT MAX') > -1:
    end = query.find('FROM')
    cols = query[10:end-1].split(',')  
  elif query.find('SELECT MIN') > -1:
    end = query.find('FROM')
    cols = query[10:end-1].split(',')     
  elif query.find('SELECT COUNT') > -1:
    end = query.find('FROM')
    cols = query[13:end-1].split(',')    
  elif query.find('SELECT') > -1:
    end = query.find('FROM')
    cols = query[7:end-1].split(',')    
  else:  
    cols = ['']    
  return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols]

def get_indexes(l):
  ops = []
  idx = []
  for i in range(len(l)):
    if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']:
      idx.append(i)
  return idx

def add_spaces_cmp_operators(string):
  ops = ['=', '>', '<', '>=', '<=', '<>']
  for op in ops:
    string = string.replace(op, ' ' + op + ' ') 
  return ' '.join(string.split())

#Check if string and add quotes (input: string; output: string)
#USE only for wikisql dataset
def add_quotes_to_string(query):
  query = query.upper()
  if query.find('WHERE') > 0:
    query_list = query.split(' ')
    query_list = [x.replace(' ','') for x in query_list]
    query_list[:] = filter(None, query_list)  
    idx_list = get_indexes(query_list)  
    idx_list.append(len(query_list))  
    subs = []
    for i in range(len(idx_list)):
      if i % 2 == 0:
        b = idx_list[i] + 1
        e = idx_list[i+1] - 1
        if b != e:
          s = ''
          for ix in range(b,e + 1):          
            s = s + query_list[ix] + ' ' 
          s = s[:-1]   
        else:
          s = query_list[b]     
        if not(s.isnumeric()):
          s = "'" + s + "'"
        subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s))  
    subs = subs[::-1]       
    for i in range(len(subs)):
      e = subs[i]
      if e[0] == e[1]:
        query_list[e[0]] = e[2]
      else:
        query_list[e[0]] = e[2]
        idx = e[1]
        while idx > e[0]:
          query_list.pop(idx)
          idx = idx - 1
    final_query = ''
    for word in query_list:
      final_query = final_query + word + ' '     
    return final_query[:-1]
  else:
    return query

#Get values from where clause (input: string; output: list)
def get_values_for_query_filter(query):
  query = query.upper()
  if query.find('WHERE') > 0:
    query_list = query.split(' ')
    query_list = [x.replace(' ','') for x in query_list]
    query_list[:] = filter(None, query_list)  
    idx_list = get_indexes(query_list)  
    idx_list.append(len(query_list))  
    subs = []
    for i in range(len(idx_list)):
      if i % 2 == 0:
        b = idx_list[i] + 1
        e = idx_list[i+1] - 1
        if b != e:
          s = ''
          for ix in range(b,e + 1):          
            s = s + query_list[ix] + ' ' 
          s = s[:-1]   
        else:
          s = query_list[b]        
        subs.append(s.replace("'",""))
  return subs


# Get columns name after where (input: string, output: list)
def get_cols_name_for_where(query):
  query = query.upper()
  subs = []  
  if query.find('WHERE') > 0:
    query_list = query.split(' ')
    query_list = [x.replace(' ','') for x in query_list]
    query_list[:] = filter(None, query_list)  
    idx_list = get_indexes(query_list)  
    #idx_list.append(len(query_list))
    idx_list.insert(0, query_list.index('WHERE'))      
    for i in range(len(idx_list)-1):
      if i % 2 == 0:     
        b = idx_list[i] + 1
        e = idx_list[i+1] - 1
        if b != e:
          s = ''
          for ix in range(b,e + 1):          
            s = s + query_list[ix] + ' ' 
          s = s[:-1]   
        else:
          s = query_list[b]
        subs.append(s)    
  return subs   

def check_if_number(s):
  try:
    a = float(s)
    return True
  except:
    return False

#Correct missing compare operator (input: string; output: string)
#T5 seems to have problem with '<' operator so if there is none this is used.
def check_if_correct_cmp_operators(query):
  query = query.upper()
  if query.find('WHERE') > 0:
    query = add_spaces_cmp_operators(query)
    query_list = query.split(' ')
    w = query_list.index('WHERE')
    cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE']
    s = 0
    for op in cmp_operators:
      s = s + query_list.count(op)
    if s == 0:      
      if check_if_number(query_list[-1]):
        query_list.insert(len(query_list)-1,'<')
      else:
        query_list.insert(len(query_list)-1,'=')
      return ' '.join(query_list)
    else:
      return query
  else: return query
    


#Correct SQL syntax using info from table (input: string, list; ouput:string)
#Use only for wikisql dataset
def correct_query(query, table_columns):  
    query = check_if_correct_cmp_operators(query)
    query = add_spaces_cmp_operators(query)    
  #try: 
    query = aggregator_parser(query) 
  #except: pass 
  #try: 
    query = add_quotes_to_string(query) 
  #except: pass 
  #try:
    cols_name = get_columns_name_in_query(query)      
    for col in cols_name:    
      corrected_col = adjust_col_name(col, table_columns)      
      query = query.replace(col, corrected_col)
  #except: pass
    return query

def correct_mispelling(question, query):  
  query = query.upper()
  if query.find('WHERE') > 0:
    question = question.upper()
    corrections = []
    values = get_values_for_query_filter(query)
    for value in values:    
      l = len(value.split(' '))
      tokens = question.replace('  ', ' ').split(' ')
      l_gram = ''
      max = -100
      for i in range(0, len(tokens)-l+1, 1):
        filter = ' '.join(tokens[i:i+l]).strip('.,?')
        #filter = re.sub(r"[,.;@#?!&$]+\ *", " ", filter).strip()    
        score = -Levenshtein.distance(value, filter)        
        if score > max:
          max = score
          correct_filter = filter        
      corrections.append([value, correct_filter])    
    for corr in corrections:
      query = query.replace(corr[0], corr[1])
  return query

In [25]:
def translate_to_sql(text):
    inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(torch_device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)

    return tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
test_data = load_dataset('wikisql', split='test')

print(len(test_data))
n =10000

count = 0
correct_samples = 0
for i in range(0,n,1):
  #print('processed', 100*(i+1)/n,'%')  
  question = 'translate to SQL: ' + test_data[i]['question'] + ' table ID: ' + ', '.join(str(x) for x in test_data[i]['table']['header'])   
  sql = translate_to_sql(question)
  #print(sql, test_data[i]['question'])
  #output = correct_query(sql, test_data[i]['table']['header'])  
  #output = correct_mispelling(test_data[i]['question'], output)
  #target = correct_query(test_data[i]['sql']['human_readable'], test_data[i]['table']['header'])
  try:     
    output = correct_query(sql, test_data[i]['table']['header'])
    output = correct_mispelling(test_data[i]['question'], output)
    target = correct_query(test_data[i]['sql']['human_readable'], test_data[i]['table']['header'])
    #output = sql
    #target = test_data[i]['sql']['human_readable']
    correct_samples = correct_samples + 1
    if output.lower() == target.lower():
      count = count + 1     
    else:
      #print(question)
      #print(output)   
      #print(target)      
      pass
    if i % 50 == 0:
        print(count/correct_samples, 100*(i+1)/n,'%')   
  except Exception as err:
    #print(f"Unexpected {err=}, {type(err)=}")
    #print('---Error-- ')  
    #print(sql) 
    #print(test_data[i]['sql']['human_readable'])
    #print(test_data[i]['table']['header'])
    pass
  #output = translate_to_sql(question)
  #target = test_data[i]['sql']['human_readable']
  #print(question)
  #print(output)  
  #print(target) 
print(count/n)
print(count/correct_samples)
print(correct_samples)



15878
0.0 0.01 %
0.6595744680851063 0.51 %
0.7263157894736842 1.01 %
0.6928571428571428 1.51 %
0.6878306878306878 2.01 %
0.7058823529411765 2.51 %
0.7142857142857143 3.01 %
0.6795252225519288 3.51 %
0.6909090909090909 4.01 %
0.7020785219399538 4.51 %
0.7 5.01 %
0.7 5.51 %
0.6972318339100346 6.01 %
0.6990445859872612 6.51 %
0.7071005917159763 7.01 %
0.7158620689655173 7.51 %
0.7174193548387097 8.01 %
0.7127272727272728 8.51 %
0.7177142857142857 9.01 %
0.7096424702058505 9.51 %
0.7057613168724279 10.01 %
0.7045009784735812 10.51 %
0.7052238805970149 11.01 %
0.6978609625668449 11.51 %
0.7028181041844578 12.01 %
0.7 12.51 %
0.7021276595744681 13.01 %
0.7025796661608498 13.51 %
0.706140350877193 14.01 %
0.7080394922425952 14.51 %
0.7100954979536153 15.01 %
0.7058047493403694 15.51 %
0.7049808429118773 16.01 %
0.7054455445544554 16.51 %
0.7063063063063063 17.01 %
0.7046117921774664 17.51 %
0.7060830017055145 18.01 %
0.7037037037037037 18.51 %
0.7002152852529602 19.01 %
0.7017819706498952 19.

In [None]:
with valohai.logger() as logger:
    logger.log('accuracy', count/correct_samples)
    

In [None]:
print(count)
print(correct_samples)