Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
from typing import * | |
from flask import Flask | |
from flask import request | |
from sftp import SpanPredictor, Span | |
parser = ArgumentParser() | |
parser.add_argument('model', metavar='MODEL_PATH', type=str) | |
parser.add_argument('-p', metavar='PORT', type=int, default=7749) | |
parser.add_argument('-d', metavar='DEVICE', type=int, default=-1) | |
args = parser.parse_args() | |
template = open('tools/demo/flask_template.html').read() | |
predictor = SpanPredictor.from_path(args.model, cuda_device=args.d) | |
app = Flask(__name__) | |
default_sentence = '因为 आरजू です vegan , هي купил soja .' | |
def visualized_prediction(inputs: List[str], prediction: Span, prefix=''): | |
spans = list() | |
span2event = [[] for _ in inputs] | |
for event_idx, event in enumerate(prediction): | |
for arg_idx, arg in enumerate(event): | |
for token_idx in range(arg.start_idx, arg.end_idx+1): | |
span2event[token_idx].append((event_idx, arg_idx)) | |
for token_idx, token in enumerate(inputs): | |
class_labels = ' '.join( | |
['token'] + [f'{prefix}-arg-{event_idx}-{arg_idx}' for event_idx, arg_idx in span2event[token_idx]] | |
) | |
spans.append(f'<span id="{prefix}-token-{token_idx}" class="{class_labels}" style="background-color">{token} </span>\n') | |
for event_idx, event in enumerate(prediction): | |
spans[event.start_idx] = ( | |
f'<span class="highlight bottom blue" ' | |
f' onmouseenter="highlight_args({event_idx}, \'{prefix}\')" onmouseleave="cancel_highlight(\'{prefix}\')">' | |
'<span class="highlight__content" align="center">' | |
f'<span class="event" id="{prefix}-event-{event_idx}">' | |
+ spans[event.start_idx] | |
) | |
spans[event.end_idx] += f'</span></span><span class="highlight__label"><center>{event.label}</center></span>' | |
arg_tips = [] | |
for arg_idx, arg in enumerate(event): | |
arg_tips.append(f'<span class="{prefix}-arg-{event_idx}-{arg_idx}">{arg.label}</span>') | |
if len(arg_tips) > 0: | |
arg_tips = '<br>'.join(arg_tips) | |
spans[event.end_idx] += f'<span class="highlight__tooltip">{arg_tips}</span>\n' | |
spans[event.end_idx] += '\n</span>' | |
return( | |
'<div class="passage model__content__summary highlight-container highlight-container--bottom-labels">\n' + | |
'\n'.join(spans) + '\n</div>' | |
) | |
def structured_prediction(inputs, prediction): | |
ret = list() | |
for event in prediction: | |
event_text, event_label = ' '.join(inputs[event.start_idx: event.end_idx+1]), event.label | |
ret.append(f'<li class="list-group-item list-group-item-info">' | |
f'<strong>{event_label}</strong>: {event_text}</li>') | |
for arg in event: | |
arg_text = ' '.join(inputs[arg.start_idx: arg.end_idx+1]) | |
ret.append( | |
f'<li class="list-group-item"> <strong>{arg.label}</strong>: {arg_text}</li>' | |
) | |
content = '\n'.join(ret) | |
return '\n<ul class="list-group">\n' + content + '\n</ul>' | |
def sftp(): | |
ret = template | |
tokens = request.args.get('sentence') | |
if tokens is not None: | |
ret = ret.replace('DEFAULT_SENTENCE', tokens) | |
sentences = tokens.split('\n') | |
model_outputs = predictor.predict_batch_sentences(sentences, max_tokens=512) | |
vis_pred, str_pred = list(), list() | |
for sent_idx, output in enumerate(model_outputs): | |
vis_pred.append(visualized_prediction(output.sentence, output.span, f'sent{sent_idx}')) | |
str_pred.append(structured_prediction(output.sentence, output.span)) | |
ret = ret.replace('VISUALIZED_PREDICTION', '<hr>'.join(vis_pred)) | |
ret = ret.replace('STRUCTURED_PREDICTION', '<hr>'.join(str_pred)) | |
else: | |
ret = ret.replace('DEFAULT_SENTENCE', default_sentence) | |
ret = ret.replace('VISUALIZED_PREDICTION', '') | |
ret = ret.replace('STRUCTURED_PREDICTION', '') | |
return ret | |
app.run(port=args.p) | |