|
import sqlite3 |
|
import logging |
|
import gradio as gr |
|
import time |
|
import contextlib |
|
import os |
|
import pandas as pd |
|
import click |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
DISABLE_TIMEER = os.environ.get("DISABLE_TIMER", False) |
|
DEBUG = os.environ.get("DEBUG", False) |
|
|
|
|
|
@contextlib.contextmanager |
|
def timer(timer_name="timer", pbar=None, pos=0): |
|
if DISABLE_TIMEER: |
|
return |
|
|
|
start = time.time() |
|
yield |
|
end = time.time() |
|
if pbar is not None: |
|
pbar.display(f"Time taken in [{timer_name}]: {end - start:.2f}", pos=pos) |
|
else: |
|
logger.info(f"Time taken in [{timer_name}]: {end - start:.2f}") |
|
|
|
|
|
def get_tables_with_name_and_schema(cursor): |
|
|
|
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table'") |
|
|
|
|
|
table_names = [] |
|
table_schemas = [] |
|
for result in cursor.fetchall(): |
|
table_name, table_schema = result["name"], result["sql"] |
|
table_names.append(table_name) |
|
table_schemas.append(table_schema) |
|
|
|
return table_names, table_schemas |
|
|
|
|
|
def load_rows(cursor, table_name): |
|
print(f"Loading table: {table_name}") |
|
pos_table_name = table_name + "_pos_extension" |
|
cursor.execute( |
|
f""" |
|
SELECT {table_name}.region_id, {table_name}.phrases, {pos_table_name}.nouns, {pos_table_name}.noun_chunks |
|
FROM {table_name} |
|
JOIN {pos_table_name} ON {table_name}.region_id = {pos_table_name}.region_id |
|
""" |
|
+ ("LIMIT 10" if DEBUG else "") |
|
) |
|
|
|
rows = cursor.fetchall() |
|
logger.info(f"Finished loading table: {table_name} with {len(rows)} rows") |
|
return rows |
|
|
|
|
|
def dict_factory(cursor, row): |
|
|
|
d = {} |
|
for idx, col in enumerate(cursor.description): |
|
d[col[0]] = row[idx] |
|
return d |
|
|
|
|
|
DB = "tmp/annotation_db/objects365-local/annotations.db" |
|
|
|
|
|
@click.command() |
|
@click.option("--db", help="Path to the database file", default=DB) |
|
def main(db): |
|
def load_db(db): |
|
if not os.path.exists(db): |
|
raise ValueError(f"Database file {db} does not exist.") |
|
conn = sqlite3.connect(db) |
|
|
|
conn.row_factory = dict_factory |
|
|
|
cursor = conn.cursor() |
|
table_names, _ = get_tables_with_name_and_schema(cursor) |
|
table_names = list(filter(lambda x: not x.endswith("_extension"), table_names)) |
|
|
|
logger.info(f"Table Names: {table_names}") |
|
|
|
rows_ls = [] |
|
for table_name in table_names: |
|
with timer(): |
|
rows = load_rows(cursor, table_name) |
|
rows_ls.append(rows) |
|
|
|
conn.close() |
|
return table_names, rows_ls |
|
|
|
class DataFrameWithBatchSlider: |
|
def __init__(self, rows, num_samples, batch_size=10): |
|
self.rows = rows |
|
|
|
with gr.Row(): |
|
self.num_samples = gr.Textbox( |
|
lines=1, value=str(num_samples), label="Number of samples", interactive=False |
|
) |
|
self.batch_idx = gr.Slider( |
|
minimum=0, maximum=num_samples, step=batch_size, value=0, label="batch_idx", interactive=True |
|
) |
|
self.batch_size = gr.Slider( |
|
minimum=1, maximum=num_samples, step=1, value=batch_size, label="batch_size", interactive=True |
|
) |
|
|
|
self.data_frame = gr.DataFrame(pd.DataFrame(rows[0 : 0 + batch_size])) |
|
|
|
self.update_slider(self.batch_idx) |
|
self.update_slider(self.batch_size) |
|
|
|
def update_data_frame(self, batch_idx, batch_size): |
|
new_rows = self.rows[batch_idx : batch_idx + batch_size] |
|
return pd.DataFrame(new_rows) |
|
|
|
def update_slider(self, obj): |
|
|
|
handle = obj.change(lambda value: gr.update(value=value), inputs=[obj], outputs=[obj]) |
|
|
|
if obj is self.batch_size: |
|
handle.then(lambda step: gr.update(step=step), inputs=[obj], outputs=[self.batch_idx]) |
|
handle.then( |
|
self.update_data_frame, |
|
inputs=[self.batch_idx, self.batch_size], |
|
outputs=[self.data_frame], |
|
) |
|
|
|
with gr.Blocks() as app: |
|
db_tb = gr.Textbox(lines=1, value=db, label="Input database path") |
|
|
|
table_names, rows_ls = load_db(db) |
|
for table_name, rows in zip(table_names, rows_ls): |
|
with gr.Accordion(label=table_name): |
|
num_samples = len(rows) |
|
DataFrameWithBatchSlider(rows, num_samples) |
|
|
|
app.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|