File size: 1,056 Bytes
b655ef1
 
 
ad4e0b9
b655ef1
 
 
ad4e0b9
b655ef1
 
 
 
 
 
 
ad4e0b9
b655ef1
 
 
ad4e0b9
b655ef1
 
ad4e0b9
b655ef1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

# Load optimized model (use pruned, quantized, or distilled model here)
model = DistilBertForSequenceClassification.from_pretrained('path_to_optimized_model')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Function to generate response
def get_response(query):
    inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1)
    return prediction.item()  # Return predicted class index

# Streamlit UI
st.title("Conversational Financial Assistant")
st.write("Ask me any financial question!")

# Input box
user_input = st.text_input("Enter your question:")

# Generate response when input is provided
if user_input:
    prediction = get_response(user_input)
    response = f"Predicted Class: {prediction}"  # You can map this to a response based on your classes
    st.write(response)