File size: 1,486 Bytes
691730f
6d2322e
 
22dadd4
1070f39
22dadd4
 
 
 
6b5d959
6d2322e
691730f
6b5d959
0f9a053
6d2322e
 
 
 
691730f
6b5d959
1070f39
fd50112
 
e89aaa7
fd50112
 
 
 
 
 
 
 
 
 
 
 
 
 
e89aaa7
fd50112
6b5d959
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# Load model and tokenizer
model_name = "NinaMwangi/T5_finbot"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

# Load dataset
dataset = load_dataset("virattt/financial-qa-10K")["train"]

# Function to retrieve context
def get_context_for_question(question):
    for item in dataset:
        if item["question"].strip().lower() == question.strip().lower():
            return item["context"]
    return "No relevant context found."

# Predict function
def generate_answer(question):
    context = get_context_for_question(question)
    prompt = f"Q: {question} Context: {context} A:"

    inputs = tokenizer(
        prompt,
        return_tensors="tf",
        padding="max_length",
        truncation=True,
        max_length=256
    )

    outputs = model.generate(
        **inputs,
        max_new_tokens=64,
        num_beams=4,
        early_stopping=True
    )

    answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return answer

# Interface
interface = gr.Interface(
    fn=generate_answer,
    inputs=gr.Textbox(lines=2, placeholder="Ask a finance question..."),
    outputs="text",
    title="Finance QA Chatbot",
    description="Built using a fine-tuned T5 Transformer. Ask a finance-related question and get an accurate, concise answer."
)

interface.launch()