File size: 3,663 Bytes
c2d8087
2d835c9
c2d8087
98ae1f9
777ca73
98ae1f9
78cc306
c2d8087
2d835c9
98ae1f9
 
777ca73
2d835c9
777ca73
2d835c9
 
 
 
 
777ca73
 
2d835c9
 
 
 
 
 
 
c2d8087
98ae1f9
 
777ca73
 
78cc306
2d835c9
98ae1f9
777ca73
c2d8087
 
2d835c9
 
 
 
 
 
 
 
78cc306
2d835c9
98ae1f9
 
 
78cc306
 
2d835c9
78cc306
2d835c9
 
78cc306
 
 
 
 
2d835c9
c2d8087
2d835c9
98ae1f9
2d835c9
 
98ae1f9
 
 
 
 
2d835c9
 
 
 
 
 
 
 
 
98ae1f9
 
2d835c9
98ae1f9
2d835c9
 
 
c2d8087
2d835c9
 
 
c2d8087
2d835c9
 
 
c2d8087
2d835c9
c2d8087
2d835c9
 
 
78cc306
7e807e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import pandas as pd
import torch
import requests
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

HF_TOKEN = os.getenv("Allie") or "<your_token_here>"
if HF_TOKEN:
    login(HF_TOKEN)

# Define model map
model_map = {
    "InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
    "FinLLaMA": {"id": "us4/fin-llama3.1-8b", "local": False},
    "FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
    "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True},
    "FinGPT (LoRA)": {"id": "FinGPT/fingpt-mt_llama2-7b_lora", "local": True}  # Placeholder, special handling below
}

# Load question list
@st.cache_data
def load_questions():
    df = pd.read_csv("questions.csv")
    return df["Question"].dropna().tolist()

# Load local models
@st.cache_resource
def load_local_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        device_map="auto",
        use_auth_token=HF_TOKEN
    )
    return model, tokenizer

# Prompt template
PROMPT_TEMPLATE = (
    "You are FinGPT, a highly knowledgeable and reliable financial assistant.\n"
    "Explain the following finance/tax/controlling question clearly, including formulas, examples, and reasons why it matters.\n"
    "\n"
    "Question: {question}\n"
    "Answer:"
)

# Local generation
def query_local_model(model_id, prompt):
    model, tokenizer = load_local_model(model_id)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=400,
        temperature=0.7,
        top_p=0.9,
        top_k=40,
        repetition_penalty=1.2,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Remote HF inference
def query_remote_model(model_id, prompt):
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    payload = {"inputs": prompt, "parameters": {"max_new_tokens": 400}}
    response = requests.post(
        f"https://api-inference.huggingface.co/models/{model_id}",
        headers=headers,
        json=payload
    )
    result = response.json()
    return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "ERROR")

# Route to appropriate model
def query_model(model_entry, question):
    prompt = PROMPT_TEMPLATE.format(question=question)
    if model_entry["id"] == "FinGPT/fingpt-mt_llama2-7b_lora":
        return "⚠️ FinGPT (LoRA) integration requires manual loading with PEFT and is not available via HF API."
    elif model_entry["local"]:
        return query_local_model(model_entry["id"], prompt)
    else:
        return query_remote_model(model_entry["id"], prompt)

# === UI ===
st.set_page_config(page_title="Finanzmodell Tester", layout="centered")
st.title("📊 Finanzmodell Vergleichs-Interface")

questions = load_questions()
question_choice = st.selectbox("Wähle eine Frage", questions)
model_choice = st.selectbox("Wähle ein Modell", list(model_map.keys()))

if st.button("Antwort generieren"):
    with st.spinner("Antwort wird generiert..."):
        model_entry = model_map[model_choice]
        try:
            answer = query_model(model_entry, question_choice)
        except Exception as e:
            answer = f"[Fehler: {str(e)}]"
        st.text_area("💬 Antwort des Modells:", value=answer, height=400, disabled=True)