FinanceModel / app.py
4lli39421's picture
Update app.py
2d835c9 verified
import streamlit as st
import pandas as pd
import torch
import requests
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
HF_TOKEN = os.getenv("Allie") or "<your_token_here>"
if HF_TOKEN:
login(HF_TOKEN)
# Define model map
model_map = {
"InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
"FinLLaMA": {"id": "us4/fin-llama3.1-8b", "local": False},
"FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
"Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True},
"FinGPT (LoRA)": {"id": "FinGPT/fingpt-mt_llama2-7b_lora", "local": True} # Placeholder, special handling below
}
# Load question list
@st.cache_data
def load_questions():
df = pd.read_csv("questions.csv")
return df["Question"].dropna().tolist()
# Load local models
@st.cache_resource
def load_local_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="auto",
use_auth_token=HF_TOKEN
)
return model, tokenizer
# Prompt template
PROMPT_TEMPLATE = (
"You are FinGPT, a highly knowledgeable and reliable financial assistant.\n"
"Explain the following finance/tax/controlling question clearly, including formulas, examples, and reasons why it matters.\n"
"\n"
"Question: {question}\n"
"Answer:"
)
# Local generation
def query_local_model(model_id, prompt):
model, tokenizer = load_local_model(model_id)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=400,
temperature=0.7,
top_p=0.9,
top_k=40,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remote HF inference
def query_remote_model(model_id, prompt):
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 400}}
response = requests.post(
f"https://api-inference.huggingface.co/models/{model_id}",
headers=headers,
json=payload
)
result = response.json()
return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "ERROR")
# Route to appropriate model
def query_model(model_entry, question):
prompt = PROMPT_TEMPLATE.format(question=question)
if model_entry["id"] == "FinGPT/fingpt-mt_llama2-7b_lora":
return "⚠️ FinGPT (LoRA) integration requires manual loading with PEFT and is not available via HF API."
elif model_entry["local"]:
return query_local_model(model_entry["id"], prompt)
else:
return query_remote_model(model_entry["id"], prompt)
# === UI ===
st.set_page_config(page_title="Finanzmodell Tester", layout="centered")
st.title("📊 Finanzmodell Vergleichs-Interface")
questions = load_questions()
question_choice = st.selectbox("Wähle eine Frage", questions)
model_choice = st.selectbox("Wähle ein Modell", list(model_map.keys()))
if st.button("Antwort generieren"):
with st.spinner("Antwort wird generiert..."):
model_entry = model_map[model_choice]
try:
answer = query_model(model_entry, question_choice)
except Exception as e:
answer = f"[Fehler: {str(e)}]"
st.text_area("💬 Antwort des Modells:", value=answer, height=400, disabled=True)