Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import torch | |
import transformers | |
import huggingface_hub | |
import datetime | |
import json | |
import shutil | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
# To suppress the following warning: | |
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD'] | |
HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD'] | |
MODE = os.environ['MODE'] # 'debug' or 'prod' | |
MODEL_NAME = 'liujch1998/vera' | |
DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset" | |
DATA_DIR = 'data' | |
DATA_PATH = os.path.join(DATA_DIR, 'data.jsonl') | |
try: | |
shutil.rmtree(DATA_DIR) | |
except: | |
pass | |
repo = huggingface_hub.Repository( | |
local_dir=DATA_DIR, | |
clone_from=DATASET_REPO_URL, | |
token=HF_TOKEN_UPLOAD, | |
repo_type='dataset', | |
) | |
repo.git_pull() | |
class Interactive: | |
def __init__(self): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD) | |
if MODE == 'debug': | |
return | |
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto') | |
self.model.D = self.model.shared.embedding_dim | |
self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device) | |
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) | |
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) | |
self.model.eval() | |
self.t = self.model.shared.weight[32097, 0].item() | |
def run(self, statement): | |
if MODE == 'debug': | |
return { | |
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
'statement': statement, | |
'logit': 0.0, | |
'logit_calibrated': 0.0, | |
'score': 0.5, | |
'score_calibrated': 0.5, | |
} | |
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device) | |
with torch.no_grad(): | |
output = self.model(input_ids) | |
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) | |
hidden = last_hidden_state[0, -1, :] # (D) | |
logit = self.linear(hidden).squeeze(-1) # () | |
logit_calibrated = logit / self.t | |
score = logit.sigmoid() | |
score_calibrated = logit_calibrated.sigmoid() | |
return { | |
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
'statement': statement, | |
'logit': logit.item(), | |
'logit_calibrated': logit_calibrated.item(), | |
'score': score.item(), | |
'score_calibrated': score_calibrated.item(), | |
} | |
def runs(self, statements): | |
if MODE == 'debug': | |
return [{ | |
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
'statement': statement, | |
'logit': 0.0, | |
'logit_calibrated': 0.0, | |
'score': 0.5, | |
'score_calibrated': 0.5, | |
} for _ in statements] | |
tok = self.tokenizer.batch_encode_plus(statements, return_tensors='pt', padding='longest') | |
input_ids = tok.input_ids.to(device) | |
attention_mask = tok.attention_mask.to(device) | |
with torch.no_grad(): | |
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
last_indices = attention_mask.sum(dim=1, keepdim=True) - 1 # (B, 1) | |
last_indices = last_indices.unsqueeze(-1).expand(-1, -1, self.model.D) # (B, 1, D) | |
last_hidden_state = output.last_hidden_state.to(device) # (B, L, D) | |
hidden = last_hidden_state.gather(dim=1, index=last_indices).squeeze(1) # (B, D) | |
logits = self.linear(hidden).squeeze(-1) # (B) | |
logits_calibrated = logits / self.t | |
scores = logits.sigmoid() | |
scores_calibrated = logits_calibrated.sigmoid() | |
return [{ | |
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
'statement': statement, | |
'logit': logit.item(), | |
'logit_calibrated': logit_calibrated.item(), | |
'score': score.item(), | |
'score_calibrated': score_calibrated.item(), | |
} for statement, logit, logit_calibrated, score, score_calibrated in zip(statements, logits, logits_calibrated, scores, scores_calibrated)] | |
interactive = Interactive() | |
# def predict(statement, do_save=True): | |
# output_raw = interactive.run(statement) | |
# output = { | |
# 'True': output_raw['score_calibrated'], | |
# 'False': 1 - output_raw['score_calibrated'], | |
# } | |
# if do_save: | |
# with open(DATA_PATH, 'a') as f: | |
# json.dump(output_raw, f, ensure_ascii=False) | |
# f.write('\n') | |
# commit_url = repo.push_to_hub() | |
# print('Logged statement to dataset:') | |
# print('Commit URL:', commit_url) | |
# print(output_raw) | |
# print() | |
# return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.') | |
# def record_feedback(output_raw, feedback, do_save=True): | |
# if do_save: | |
# output_raw.update({ 'feedback': feedback }) | |
# with open(DATA_PATH, 'a') as f: | |
# json.dump(output_raw, f, ensure_ascii=False) | |
# f.write('\n') | |
# commit_url = repo.push_to_hub() | |
# print('Logged feedback to dataset:') | |
# print('Commit URL:', commit_url) | |
# print(output_raw) | |
# print() | |
# return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.') | |
# def record_feedback_agree(output_raw, do_save=True): | |
# return record_feedback(output_raw, 'agree', do_save) | |
# def record_feedback_disagree(output_raw, do_save=True): | |
# return record_feedback(output_raw, 'disagree', do_save) | |
def predict(statements, do_saves): | |
output_raws = interactive.runs(list(statements)) # statements is a tuple, but tokenizer takes a list | |
outputs = [{ | |
'True': output_raw['score_calibrated'], | |
'False': 1 - output_raw['score_calibrated'], | |
} for output_raw in output_raws] | |
print('Logging statements to dataset:') | |
for output_raw, do_save in zip(output_raws, do_saves): | |
if do_save: | |
print(output_raw) | |
with open(DATA_PATH, 'a') as f: | |
json.dump(output_raw, f, ensure_ascii=False) | |
f.write('\n') | |
if any(do_saves): | |
commit_url = repo.push_to_hub() | |
# print('Commit URL:', commit_url) | |
# print() | |
return outputs, output_raws, \ | |
[gr.update(visible=False) for _ in statements], \ | |
[gr.update(visible=True) for _ in statements], \ | |
[gr.update(visible=True) for _ in statements], \ | |
[gr.update(value='Please provide your feedback before trying out another statement.') for _ in statements] | |
def record_feedback(output_raws, feedback, do_saves): | |
print('Logged feedbacks to dataset:') | |
for output_raw, do_save in zip(output_raws, do_saves): | |
if do_save: | |
print(output_raw) | |
output_raw.update({ 'feedback': feedback }) | |
with open(DATA_PATH, 'a') as f: | |
json.dump(output_raw, f, ensure_ascii=False) | |
f.write('\n') | |
if any(do_saves): | |
commit_url = repo.push_to_hub() | |
# print('Commit URL:', commit_url) | |
# print() | |
return [gr.update(visible=True) for _ in output_raws], \ | |
[gr.update(visible=False) for _ in output_raws], \ | |
[gr.update(visible=False) for _ in output_raws], \ | |
[gr.update(value='Thanks for your feedback! Now you can enter another statement.') for _ in output_raws] | |
def record_feedback_agree(output_raws, do_saves): | |
return record_feedback(output_raws, 'agree', do_saves) | |
def record_feedback_disagree(output_raws, do_saves): | |
return record_feedback(output_raws, 'disagree', do_saves) | |
examples = [ | |
# openbookqa | |
'If a person walks in the opposite direction of a compass arrow they are walking south.', | |
'If a person walks in the opposite direction of a compass arrow they are walking north.', | |
# arc_easy | |
'A pond is different from a lake because ponds are smaller and shallower.', | |
'A pond is different from a lake because ponds have moving water.', | |
# arc_hard | |
'Hunting strategies are more likely to be learned rather than inherited.', | |
'A spotted coat is more likely to be learned rather than inherited.', | |
# ai2_science_elementary | |
'Photosynthesis uses carbon from the air to make food for plants.', | |
'Respiration uses carbon from the air to make food for plants.', | |
# ai2_science_middle | |
'The barometer measures atmospheric pressure.', | |
'The thermometer measures atmospheric pressure.', | |
# commonsenseqa | |
'People aim to complete a job at work.', | |
'People aim to kill animals at work.', | |
# qasc | |
'Climate is generally described in terms of local weather conditions.', | |
'Climate is generally described in terms of forests.', | |
# physical_iqa | |
'ice box will turn into a cooler if you add water to it.', | |
'ice box will turn into a cooler if you add soda to it.', | |
# social_iqa | |
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very aggressive and talkative person.', | |
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very quiet person.', | |
# winogrande_xl | |
'Sarah was a much better surgeon than Maria so Maria always got the easier cases.', | |
'Sarah was a much better surgeon than Maria so Sarah always got the easier cases.', | |
# com2sense_paired | |
'If you want a quick snack, getting one banana would be a good choice generally.', | |
'If you want a snack, getting twenty bananas would be a good choice generally.', | |
# sciq | |
'Each specific polypeptide has a unique linear sequence of amino acids.', | |
'Each specific polypeptide has a unique linear sequence of fatty acids.', | |
# quarel | |
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because wet floor has more resistance.', | |
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because marble floor has more resistance.', | |
# quartz | |
'If less waters falls on an area of land it will cause less plants to grow in that area.', | |
'If less waters falls on an area of land it will cause more plants to grow in that area.', | |
# cycic_mc | |
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the park on January 20.', | |
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the financial district on January 20.', | |
# comve_a | |
'Summer in North America is great for swimming, boating, and fishing.', | |
'Summer in North America is great for skiing, snowshoeing, and making a snowman.', | |
# csqa2 | |
'Gas is always capable of turning into liquid under high pressure.', | |
'Cotton candy is sometimes made out of cotton.', | |
# symkd_anno | |
'James visits a famous landmark. As a result, James learns about the world.', | |
'Cliff and Andrew enter the castle. But before, Cliff needed to have been a student at the school.', | |
# gengen_anno | |
'Generally, bar patrons are capable of taking care of their own drinks.', | |
'Generally, ocean currents have little influence over storm intensity.', | |
# 'If A sits next to B and B sits next to C, then A must sit next to C.', | |
# 'If A sits next to B and B sits next to C, then A might not sit next to C.', | |
] | |
# input_statement = gr.Dropdown(choices=examples, label='Statement:') | |
# input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False) | |
# output = gr.outputs.Label(num_top_classes=2) | |
# description = '''This is a demo for Vera, a commonsense statement verification model. Under development. | |
# β οΈ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app!''' | |
# gr.Interface( | |
# fn=predict, | |
# inputs=[input_statement, input_model], | |
# outputs=output, | |
# title="Vera", | |
# description=description, | |
# ).launch() | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown( | |
'''# Vera | |
This is a demo for Vera, a commonsense statement verification model. Under development. | |
β οΈ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! | |
''' | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
do_save = gr.Checkbox( | |
value=True, | |
label="Store data", | |
info="You agree to the storage of your input data for research and development purposes:") | |
statement = gr.Textbox(placeholder='Enter a commonsense statement here, or select an example from below', label='Statement', interactive=True) | |
submit = gr.Button(value='Submit', variant='primary', visible=True) | |
with gr.Column(scale=1): | |
output = gr.Label(num_top_classes=2, interactive=False) | |
output_raw = gr.JSON(visible=False) | |
with gr.Row(): | |
feedback_agree = gr.Button(value='π Agree', variant='secondary', visible=False) | |
feedback_disagree = gr.Button(value='π Disagree', variant='secondary', visible=False) | |
feedback_ack = gr.Markdown(value='', visible=True, interactive=False) | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
fn=predict, | |
inputs=[statement], | |
outputs=[output, output_raw, statement, submit, feedback_agree, feedback_disagree, feedback_ack], | |
examples_per_page=100, | |
cache_examples=False, | |
run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit() | |
) | |
submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16) | |
# statement.submit(predict, inputs=[statement], outputs=[output, output_raw]) | |
feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16) | |
feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack], batch=True, max_batch_size=16) | |
demo.queue(concurrency_count=1).launch(debug=True) | |
# Concurrency, Batching | |
# Theme, CSS | |