|
import gradio as gr |
|
|
|
|
|
entry_count = None |
|
metadata = None |
|
|
|
def init_demo(): |
|
import json |
|
import numpy as np |
|
|
|
global metadata |
|
with open("metadata.json") as f: |
|
metadata = json.load(f) |
|
|
|
|
|
with open("metaclip/entry_counts_400m.json") as f: |
|
entry_count_json = json.load(f) |
|
global entry_count |
|
entry_count = np.array([entry_count_json[entry] for entry in metadata], dtype=np.uint64) |
|
|
|
|
|
def curation(text): |
|
import sys |
|
sys.path.append("./") |
|
from metaclip.substr_matching import substr_matching |
|
from metaclip.balancing import balance_sampling |
|
|
|
t = 20000 |
|
entry_count[entry_count < t] = t |
|
entry_prob = t / entry_count |
|
|
|
matched_entry_ids = substr_matching(text, metadata) |
|
curation_prob = min(entry_prob[matched_entry_ids].sum(), 1.0) |
|
curated = balance_sampling(matched_entry_ids, entry_prob) |
|
|
|
return f"curation_prob={curation_prob:.3f}, curated={curated}" |
|
|
|
|
|
init_demo() |
|
|
|
demo = gr.Interface(fn=curation, inputs="text", outputs="text") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_api=False) |
|
|