|
import streamlit as st |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
import torch |
|
|
|
|
|
model = DistilBertForSequenceClassification.from_pretrained('path_to_optimized_model') |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
|
|
|
|
def get_response(query): |
|
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True) |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
prediction = torch.argmax(logits, dim=-1) |
|
return prediction.item() |
|
|
|
|
|
st.title("Conversational Financial Assistant") |
|
st.write("Ask me any financial question!") |
|
|
|
|
|
user_input = st.text_input("Enter your question:") |
|
|
|
|
|
if user_input: |
|
prediction = get_response(user_input) |
|
response = f"Predicted Class: {prediction}" |
|
st.write(response) |
|
|