|
import os |
|
import huggingface_hub |
|
import streamlit as st |
|
from vllm import LLM, SamplingParams |
|
|
|
huggingface_hub.login(token=os.getenv("HF_TOKEN")) |
|
llm = LLM(model="InvestmentResearchAI/LLM-ADE-small-v0.1.0") |
|
tok = llm.get_tokenizer() |
|
tok.eos_token = '<|eot_id|>' |
|
|
|
|
|
template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
|
|
You are a helpful financial assistant that answers the user as accurately, truthfully, and concisely as possible.<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
|
|
{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
|
""" |
|
|
|
|
|
def get_response(prompt): |
|
try: |
|
prompts = [template.format(user_message=prompt)] |
|
sampling_params = SamplingParams(temperature=0.3, top_p=0.95) |
|
outputs = llm.generate(prompts, sampling_params) |
|
for output in outputs: |
|
return output.outputs[0].text |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
def main(): |
|
st.title("LLM-ADE 9B Demo") |
|
|
|
input_text = st.text_area("Enter your text here:", value="", height=200) |
|
if st.button("Generate"): |
|
if input_text: |
|
with st.spinner('Generating response...'): |
|
response_text = get_response(input_text) |
|
st.write(response_text) |
|
else: |
|
st.warning("Please enter some text to generate a response.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|