Spaces:
Sleeping
Sleeping
File size: 1,486 Bytes
691730f 6d2322e 22dadd4 1070f39 22dadd4 6b5d959 6d2322e 691730f 6b5d959 0f9a053 6d2322e 691730f 6b5d959 1070f39 fd50112 e89aaa7 fd50112 e89aaa7 fd50112 6b5d959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
# Load model and tokenizer
model_name = "NinaMwangi/T5_finbot"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# Load dataset
dataset = load_dataset("virattt/financial-qa-10K")["train"]
# Function to retrieve context
def get_context_for_question(question):
for item in dataset:
if item["question"].strip().lower() == question.strip().lower():
return item["context"]
return "No relevant context found."
# Predict function
def generate_answer(question):
context = get_context_for_question(question)
prompt = f"Q: {question} Context: {context} A:"
inputs = tokenizer(
prompt,
return_tensors="tf",
padding="max_length",
truncation=True,
max_length=256
)
outputs = model.generate(
**inputs,
max_new_tokens=64,
num_beams=4,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return answer
# Interface
interface = gr.Interface(
fn=generate_answer,
inputs=gr.Textbox(lines=2, placeholder="Ask a finance question..."),
outputs="text",
title="Finance QA Chatbot",
description="Built using a fine-tuned T5 Transformer. Ask a finance-related question and get an accurate, concise answer."
)
interface.launch()
|