creativity / app.py
liujch1998's picture
Initial commit
25f66ac
raw
history blame
4.99 kB
import gradio as gr
import datetime
import json
import requests
from constants import *
def process(query_type, index_desc, **kwargs):
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
index = INDEX_BY_DESC[index_desc]
data = {
'source': 'hf' if not DEBUG else 'hf-dev',
'timestamp': timestamp,
'query_type': query_type,
'index': index,
}
data.update(kwargs)
print(json.dumps(data))
if API_URL is None:
raise ValueError(f'API_URL envvar is not set!')
try:
response = requests.post(API_URL, json=data, timeout=10)
except requests.exceptions.Timeout:
raise ValueError('Web request timed out. Please try again later.')
except requests.exceptions.RequestException as e:
raise ValueError(f'Web request error: {e}')
if response.status_code == 200:
result = response.json()
else:
raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
if DEBUG:
print(result)
return result
def creativity(index_desc, query):
result = process('creativity', index_desc, query=query)
latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
if 'error' in result:
ci = result['error']
ngram_len = NGRAM_LEN_DEFAULT
html = ''
return latency, ci, ngram_len, html
rs = result['rs']
tokens = result['tokens']
highlighteds_by_n = {}
uniqueness_by_n = {}
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
highlighteds = [False] * len(tokens)
last_r = 0
for l, r in enumerate(rs):
if r - l < n:
continue
for i in range(max(last_r, l), r):
highlighteds[i] = True
last_r = r
uniqueness = sum([1 for h in highlighteds if not h]) / len(highlighteds)
highlighteds_by_n[n] = highlighteds
uniqueness_by_n[n] = uniqueness
ci = sum(uniqueness_by_n.values()) / len(uniqueness_by_n)
ci = f'{ci:.2%}'
ngram_len = NGRAM_LEN_DEFAULT
html = ''
highlighted = highlighteds_by_n[ngram_len]
line_len = 0
for i, (token, highlighted) in enumerate(zip(tokens, highlighteds)):
if line_len >= 100 and token.startswith('Ġ') and token != 'Ċ':
html += '<br/>'
line_len = 0
color = '0, 0, 255, 0.5'
if token == 'Ċ':
disp_token = '\\n'
is_linebreak = True
else:
disp_token = token.replace('Ġ', '&nbsp;')
is_linebreak = False
if highlighted:
html += f'<span id="hldoc-token-{i}" style="background-color: rgba{color};" class="background-color: rgba{color};">{disp_token}</span>'
else:
html += disp_token
if is_linebreak:
html += '<br/>'
line_len = 0
else:
line_len += len(token)
html = '<div><p id="hldoc" style="font-size: 16px;">' + html.strip(' ') + '</p></div>'
return latency, ci, ngram_len, html
with gr.Blocks() as demo:
with gr.Column():
gr.HTML(
'''<h1 text-align="center">Creativity Index</h1>
<p style='font-size: 16px;'>Compute the <a href="">Creativity Index</a> of a piece of text.</p>
<p style='font-size: 16px;'>The computed Creativity Index is based on verbatim match and is supported by <a href="https://infini-gram.io">infini-gram</a>.</p>
'''
)
with gr.Row():
with gr.Column(scale=1, min_width=240):
index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[0])
with gr.Column(scale=3):
creativity_query = gr.Textbox(placeholder='Enter a piece of text here', label='Query', interactive=True, lines=10)
with gr.Row():
creativity_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
creativity_submit = gr.Button(value='Submit', variant='primary', visible=True)
creativity_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
with gr.Column(scale=4):
creativity_ci = gr.Label(value='', label='Creativity Index')
creativity_ngram_len = gr.Slider(minimum=NGRAM_LEN_MIN, maximum=NGRAM_LEN_MAX, value=NGRAM_LEN_DEFAULT, step=1, label='Length of n-gram')
creativity_html = gr.HTML(value='', label='Coverage')
creativity_clear.add([creativity_query, creativity_latency, creativity_ci, creativity_html])
creativity_submit.click(creativity, inputs=[index_desc, creativity_query], outputs=[creativity_latency, creativity_ci, creativity_ngram_len, creativity_html], api_name=False)
demo.queue(
default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT,
max_size=MAX_SIZE,
api_open=False,
).launch(
max_threads=MAX_THREADS,
debug=DEBUG,
show_api=False,
)