File size: 1,717 Bytes
6f9ea8e
 
 
 
 
 
 
 
 
1b06ea8
 
6f9ea8e
 
9f71c9f
6f9ea8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""Job-coach-document-testing-site.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1ECyP45v5tfkn0I-0W24qWGQotWyqPIQS
"""

#! pip install gradio
#! pip install transformers
import gradio as gr

#import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import pipeline

def QA_function(context, queries):
    #tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

    model = "bert-large-uncased-whole-word-masking-finetuned-squad"

    question_answerer = pipeline("question-answering", model = model)
    Answered=[]
    NotAnswered=[]
    queries=queries.split("?")
    queries.pop(-1)
    for query in queries:
      query.strip()
      result = question_answerer(question = query, context=context)
      if result['score'] > 0.01:
        Answered.append(query +"?   Answer: " + result['answer'])
      else:
        NotAnswered.append(query)

    #print("Question Answered: ")
    #for answer in Answered:
      #print(answer)
    # print("Question Not Answered: ")
    # for answer in NotAnswered:
    #   print(answer)
    result1=''',
    '''.join(Answered) 
    result1='''Question Answered: 
    '''+ result1
    result2=''',
    '''.join(NotAnswered) 
    result2='''
    Question Not Answered: 
    '''+ result2

    return result1 + result2

title = "Testing demo"
description = "A testing site for Job Coach Documents"

gradio_ui = gr.Interface(QA_function, [gr.inputs.Textbox(lines=10, label="Context"), gr.inputs.Textbox(lines=10, label="Question")], gr.outputs.Textbox(label="Answer"))
gradio_ui.launch(debug=True)