NLP / app.py
krs426's picture
Update app.py
b655ef1 verified
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
# Load optimized model (use pruned, quantized, or distilled model here)
model = DistilBertForSequenceClassification.from_pretrained('path_to_optimized_model')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# Function to generate response
def get_response(query):
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1)
return prediction.item() # Return predicted class index
# Streamlit UI
st.title("Conversational Financial Assistant")
st.write("Ask me any financial question!")
# Input box
user_input = st.text_input("Enter your question:")
# Generate response when input is provided
if user_input:
prediction = get_response(user_input)
response = f"Predicted Class: {prediction}" # You can map this to a response based on your classes
st.write(response)