| import gradio as gr |
| import torch |
| from transformers import pipeline |
| import re |
|
|
| |
| pipeline = pipeline(model="jeevana/GenerativeQnASystem", max_new_tokens=60) |
|
|
| def predict(input): |
| print("pipeline object", pipeline) |
| prediction = pipeline(input) |
| prediction = prediction[0].get("generated_text") |
| print("1:::", prediction) |
| prediction = prediction[len(input):] |
| pattern = re.compile(r'\bAnswer:|\bAnswer\b', re.IGNORECASE) |
|
|
| |
| result = pattern.sub('', prediction) |
|
|
| return result.strip() |
|
|
|
|
| app = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Question", lines=3)], |
| outputs=[gr.Textbox(label="Answer", lines=6)], |
| title="Generative QnA System", |
| description="Generative QnA with GPT2" |
| ) |
| app.launch(share=True, debug=True) |
|
|
|
|
| |