Geraldine commited on
Commit
5ed8ff3
·
verified ·
1 Parent(s): 971c673

Update pages/app_api_completion.py

Browse files
Files changed (1) hide show
  1. pages/app_api_completion.py +128 -128
pages/app_api_completion.py CHANGED
@@ -1,128 +1,128 @@
1
- import requests
2
- import json
3
- import os
4
- import streamlit as st
5
- from clients import OllamaClient, NvidiaClient, GroqClient
6
-
7
- st.set_page_config(
8
- page_title="QA Inference Streamlit App using Ollama, Nvidia and Groq APIs"
9
- )
10
-
11
-
12
- # Cache the header of the app to prevent re-rendering on each load
13
- @st.cache_resource
14
- def display_app_header():
15
- """Display the header of the Streamlit app."""
16
- st.title("QA Inference with Ollama & Nvidia & Groq as LLMs providers")
17
- st.subheader("ChatBot based on provider's OpenAI-like APIs and clients")
18
-
19
-
20
- # Display the header of the app
21
- display_app_header()
22
-
23
- # UI sidebar ##########################################
24
- st.sidebar.subheader("Models")
25
-
26
- # LLM
27
- llm_providers = {
28
- "Local Ollama": "ollama",
29
- "Cloud Nvidia": "nvidia",
30
- "Cloud Groq": "groq",
31
- }
32
- llm_provider = st.sidebar.radio(
33
- "Choose your LLM Provider", llm_providers.keys(), key="llm_provider"
34
- )
35
- if llm_provider == "Local Ollama":
36
- ollama_list_models = OllamaClient().list_models()
37
- if ollama_list_models:
38
- ollama_models = [x["name"] for x in ollama_list_models["models"]]
39
- ollama_llm = st.sidebar.radio(
40
- "Select your Ollama model", ollama_models, key="ollama_llm"
41
- ) # retrive with st.session_state["ollama_llm"]
42
- else:
43
- st.sidebar.error('Ollama is not running')
44
- elif llm_provider == "Cloud Nvidia":
45
- if nvidia_api_token := st.sidebar.text_input("Enter your Nvidia API Key"):
46
- st.sidebar.info("Nvidia authentification ok")
47
- nvidia_list_models = NvidiaClient().list_models() # api_key is not needed to list the available models
48
- nvidia_models = [x["id"] for x in nvidia_list_models["data"]]
49
- nvidia_llm = st.sidebar.radio(
50
- "Select your Nvidia LLM", nvidia_models, key="nvidia_llm"
51
- )
52
- else:
53
- st.sidebar.warning("You must enter your Nvidia API key")
54
- elif llm_provider == "Cloud Groq":
55
- if groq_api_token := st.sidebar.text_input("Enter your Groq API Key"):
56
- st.sidebar.info("Groq authentification ok")
57
- groq_list_models = GroqClient(api_key=groq_api_token).list_models()
58
- groq_models = [x["id"] for x in groq_list_models["data"]]
59
- groq_llm = st.sidebar.radio("Choose your Groq LLM", groq_models, key="groq_llm")
60
- else:
61
- st.sidebar.warning("You must enter your Groq API key")
62
-
63
- # LLM parameters
64
- st.sidebar.subheader("Parameters")
65
- max_tokens = st.sidebar.number_input("Token numbers", value=1024, key="max_tokens")
66
- temperature = st.sidebar.slider(
67
- "Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1, key="temperature"
68
- )
69
- top_p = st.sidebar.slider(
70
- "Top P", min_value=0.0, max_value=1.0, value=0.7, step=0.1, key="top_p"
71
- )
72
-
73
-
74
- # LLM response function ########################################
75
- def get_llm_response(provider, prompt):
76
- options = dict(
77
- max_tokens=st.session_state["max_tokens"],
78
- top_p=st.session_state["top_p"],
79
- temperature=st.session_state["temperature"],
80
- )
81
- if provider == "ollama":
82
- return OllamaClient(
83
- api_key="ollama",
84
- model=st.session_state["ollama_llm"],
85
- ).api_chat_completion(
86
- prompt, **options
87
- ) # or .client_chat_completion(prompt,**options)
88
- elif provider == "nvidia":
89
- return NvidiaClient(
90
- api_key=nvidia_api_token,
91
- model=st.session_state["nvidia_llm"],
92
- ).api_chat_completion(
93
- prompt, **options
94
- ) # or .client_chat_completion(prompt,**options)
95
- elif provider == "groq":
96
- return GroqClient(
97
- api_key=groq_api_token,
98
- model=st.session_state["groq_llm"],
99
- ).api_chat_completion(
100
- prompt, **options
101
- ) # or .client_chat_completion(prompt,**options)
102
-
103
-
104
- # UI main #####################################################
105
- # Initialize chat history
106
- if "messages" not in st.session_state:
107
- st.session_state.messages = []
108
-
109
- # Display chat messages from history on app rerun
110
- for message in st.session_state.messages:
111
- with st.chat_message(message["role"]):
112
- st.markdown(message["content"])
113
-
114
- # React to user input
115
- if prompt := st.chat_input("What is up?"):
116
- # Display user message in chat message container
117
- with st.chat_message("user"):
118
- st.markdown(prompt)
119
- # Add user message to chat history
120
- st.session_state.messages.append({"role": "user", "content": prompt})
121
-
122
- response = f"Echo: {prompt}"
123
- # Display assistant response in chat message container
124
- with st.chat_message("assistant"):
125
- response = get_llm_response(llm_providers[st.session_state["llm_provider"]], prompt)
126
- st.markdown(response)
127
- # Add assistant response to chat history
128
- st.session_state.messages.append({"role": "assistant", "content": response})
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import streamlit as st
5
+ from clients import OllamaClient, NvidiaClient, GroqClient
6
+
7
+ st.set_page_config(
8
+ page_title="QA Inference Streamlit App using Ollama, Nvidia and Groq APIs"
9
+ )
10
+
11
+
12
+ # Cache the header of the app to prevent re-rendering on each load
13
+ @st.cache_resource
14
+ def display_app_header():
15
+ """Display the header of the Streamlit app."""
16
+ st.title("QA Inference with Ollama & Nvidia & Groq as LLMs providers")
17
+ st.subheader("ChatBot based on provider's OpenAI-like APIs and clients")
18
+
19
+
20
+ # Display the header of the app
21
+ display_app_header()
22
+
23
+ # UI sidebar ##########################################
24
+ st.sidebar.subheader("Models")
25
+
26
+ # LLM
27
+ llm_providers = {
28
+ "Local Ollama": "ollama",
29
+ "Cloud Nvidia": "nvidia",
30
+ "Cloud Groq": "groq",
31
+ }
32
+ llm_provider = st.sidebar.radio(
33
+ "Choose your LLM Provider", llm_providers.keys(), key="llm_provider"
34
+ )
35
+ if llm_provider == "Local Ollama":
36
+ ollama_list_models = OllamaClient().list_models()
37
+ if ollama_list_models:
38
+ ollama_models = [x["name"] for x in ollama_list_models["models"]]
39
+ ollama_llm = st.sidebar.radio(
40
+ "Select your Ollama model", ollama_models, key="ollama_llm"
41
+ ) # retrive with st.session_state["ollama_llm"]
42
+ else:
43
+ st.sidebar.error('Ollama is not running')
44
+ elif llm_provider == "Cloud Nvidia":
45
+ if nvidia_api_token := st.sidebar.text_input("Enter your Nvidia API Key", type="password"):
46
+ st.sidebar.info("Nvidia authentification ok")
47
+ nvidia_list_models = NvidiaClient().list_models() # api_key is not needed to list the available models
48
+ nvidia_models = [x["id"] for x in nvidia_list_models["data"]]
49
+ nvidia_llm = st.sidebar.radio(
50
+ "Select your Nvidia LLM", nvidia_models, key="nvidia_llm"
51
+ )
52
+ else:
53
+ st.sidebar.warning("You must enter your Nvidia API key")
54
+ elif llm_provider == "Cloud Groq":
55
+ if groq_api_token := st.sidebar.text_input("Enter your Groq API Key", type="password"):
56
+ st.sidebar.info("Groq authentification ok")
57
+ groq_list_models = GroqClient(api_key=groq_api_token).list_models()
58
+ groq_models = [x["id"] for x in groq_list_models["data"]]
59
+ groq_llm = st.sidebar.radio("Choose your Groq LLM", groq_models, key="groq_llm")
60
+ else:
61
+ st.sidebar.warning("You must enter your Groq API key")
62
+
63
+ # LLM parameters
64
+ st.sidebar.subheader("Parameters")
65
+ max_tokens = st.sidebar.number_input("Token numbers", value=1024, key="max_tokens")
66
+ temperature = st.sidebar.slider(
67
+ "Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1, key="temperature"
68
+ )
69
+ top_p = st.sidebar.slider(
70
+ "Top P", min_value=0.0, max_value=1.0, value=0.7, step=0.1, key="top_p"
71
+ )
72
+
73
+
74
+ # LLM response function ########################################
75
+ def get_llm_response(provider, prompt):
76
+ options = dict(
77
+ max_tokens=st.session_state["max_tokens"],
78
+ top_p=st.session_state["top_p"],
79
+ temperature=st.session_state["temperature"],
80
+ )
81
+ if provider == "ollama":
82
+ return OllamaClient(
83
+ api_key="ollama",
84
+ model=st.session_state["ollama_llm"],
85
+ ).api_chat_completion(
86
+ prompt, **options
87
+ ) # or .client_chat_completion(prompt,**options)
88
+ elif provider == "nvidia":
89
+ return NvidiaClient(
90
+ api_key=nvidia_api_token,
91
+ model=st.session_state["nvidia_llm"],
92
+ ).api_chat_completion(
93
+ prompt, **options
94
+ ) # or .client_chat_completion(prompt,**options)
95
+ elif provider == "groq":
96
+ return GroqClient(
97
+ api_key=groq_api_token,
98
+ model=st.session_state["groq_llm"],
99
+ ).api_chat_completion(
100
+ prompt, **options
101
+ ) # or .client_chat_completion(prompt,**options)
102
+
103
+
104
+ # UI main #####################################################
105
+ # Initialize chat history
106
+ if "messages" not in st.session_state:
107
+ st.session_state.messages = []
108
+
109
+ # Display chat messages from history on app rerun
110
+ for message in st.session_state.messages:
111
+ with st.chat_message(message["role"]):
112
+ st.markdown(message["content"])
113
+
114
+ # React to user input
115
+ if prompt := st.chat_input("What is up?"):
116
+ # Display user message in chat message container
117
+ with st.chat_message("user"):
118
+ st.markdown(prompt)
119
+ # Add user message to chat history
120
+ st.session_state.messages.append({"role": "user", "content": prompt})
121
+
122
+ response = f"Echo: {prompt}"
123
+ # Display assistant response in chat message container
124
+ with st.chat_message("assistant"):
125
+ response = get_llm_response(llm_providers[st.session_state["llm_provider"]], prompt)
126
+ st.markdown(response)
127
+ # Add assistant response to chat history
128
+ st.session_state.messages.append({"role": "assistant", "content": response})