File size: 1,915 Bytes
f0a38b0
 
 
 
 
 
 
 
 
 
 
02940d1
 
 
f0a38b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from task import tasks_config
from transformers import pipeline


def review_training_choices(choice):
    if choice == "Use Pipeline":
        return gr.Row(visible=True)
    else:
        return gr.Row(visible=False)

def task_dropdown_choices():
    return [(task["name"], task_id)
                                 for task_id, task in tasks_config.items()]
def handle_task_change(task):
    visibility = task == "question-answering"
    models = tasks_config[task]["config"]["models"]
    model_choices = [(model, model) for model in models]
    return gr.update(visible=visibility), gr.Dropdown(
        choices=model_choices,
        label="Model",
        allow_custom_value=True,
        interactive=True
    ), gr.Dropdown(info=tasks_config[task]["info"])


def test_pipeline(task, model=None, prompt=None, context=None):
    # configure additional options for each model
    options = {"ner": {"grouped_entities": True}, "question-answering": {},
               "text-generation": {}, "fill-mask": {}, "summarization": {}}
    # configure pipeline
    test = pipeline(task, model=model, **
                    options[task]) if model else pipeline(task, **options[task])
    # call pipeline
    if task == "question-answering":
        if not context:
            return "Context is required"
        else:
            result = test(question=prompt, context=context)
    else:
        result = test(prompt)

    # generated ouput based on task and return
    output_mapping = {
        "text-generation": lambda x: x[0]["generated_text"],
        "fill-mask": lambda x: x[0]["sequence"],
        "summarization": lambda x: x[0]["summary_text"],
        "ner": lambda x: "\n".join(f"{k}={v}" for item in x for k, v in item.items() if k not in ["start", "end", "index"]).rstrip("\n"),
        "question-answering": lambda x: x
    }

    return gr.TextArea(output_mapping[task](result))