Spaces:
Runtime error
Runtime error
Refactored HF token ID access to a common function
Browse files- pages/005_LLM_Models.py +2 -3
- src/common.py +8 -0
- src/models.py +7 -5
pages/005_LLM_Models.py
CHANGED
@@ -11,7 +11,6 @@ if st_setup('LLM Models'):
|
|
11 |
HFLlamaChatModel.load_configs()
|
12 |
|
13 |
SESSION_KEY_CHAT_SERVER = 'chat_server'
|
14 |
-
HF_AUTH_KEY_SECRET = 'hf_token'
|
15 |
button_count = 0
|
16 |
|
17 |
|
@@ -56,7 +55,7 @@ if st_setup('LLM Models'):
|
|
56 |
with chat_container:
|
57 |
with st.chat_message("user"):
|
58 |
st.write(prompt)
|
59 |
-
chat_model = HFLlamaChatModel.
|
60 |
-
response = chat_model(prompt
|
61 |
with st.chat_message("assistant"):
|
62 |
st.write(response)
|
|
|
11 |
HFLlamaChatModel.load_configs()
|
12 |
|
13 |
SESSION_KEY_CHAT_SERVER = 'chat_server'
|
|
|
14 |
button_count = 0
|
15 |
|
16 |
|
|
|
55 |
with chat_container:
|
56 |
with st.chat_message("user"):
|
57 |
st.write(prompt)
|
58 |
+
chat_model = HFLlamaChatModel.for_name(st.session_state[SESSION_KEY_CHAT_SERVER])
|
59 |
+
response = chat_model(prompt)
|
60 |
with st.chat_message("assistant"):
|
61 |
st.write(response)
|
src/common.py
CHANGED
@@ -1,5 +1,13 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
|
4 |
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
5 |
config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import streamlit as st
|
3 |
|
4 |
|
5 |
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
6 |
config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
|
7 |
+
|
8 |
+
def hf_api_token() -> str:
|
9 |
+
#TODO: Need to consider how to make this more generic to look for a token elsewhere
|
10 |
+
token = st.secrets['hf_token']
|
11 |
+
if token is None:
|
12 |
+
raise ValueError('No HF access token found in streamlit secrets')
|
13 |
+
return token
|
src/models.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
import requests
|
4 |
from typing import List
|
5 |
|
6 |
-
from src.common import config_dir
|
7 |
|
8 |
|
9 |
class HFLlamaChatModel:
|
@@ -16,13 +16,13 @@ class HFLlamaChatModel:
|
|
16 |
configs = json.load(f)['models']
|
17 |
cls.models = []
|
18 |
for cfg in configs:
|
19 |
-
if cls.
|
20 |
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
|
21 |
|
22 |
@classmethod
|
23 |
-
def
|
24 |
for m in cls.models:
|
25 |
-
if m.name ==
|
26 |
return m
|
27 |
|
28 |
@classmethod
|
@@ -38,10 +38,12 @@ class HFLlamaChatModel:
|
|
38 |
|
39 |
def __call__(self,
|
40 |
query: str,
|
41 |
-
auth_token: str,
|
42 |
system_prompt: str = None,
|
43 |
max_new_tokens: str = 256,
|
44 |
temperature: float = 1.0):
|
|
|
|
|
45 |
headers = {"Authorization": f"Bearer {auth_token}"}
|
46 |
api_url = f"https://api-inference.huggingface.co/models/{self.id}"
|
47 |
if system_prompt is None:
|
|
|
3 |
import requests
|
4 |
from typing import List
|
5 |
|
6 |
+
from src.common import config_dir, hf_api_token
|
7 |
|
8 |
|
9 |
class HFLlamaChatModel:
|
|
|
16 |
configs = json.load(f)['models']
|
17 |
cls.models = []
|
18 |
for cfg in configs:
|
19 |
+
if cls.for_name(cfg['name']) is None:
|
20 |
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
|
21 |
|
22 |
@classmethod
|
23 |
+
def for_name(cls, name: str):
|
24 |
for m in cls.models:
|
25 |
+
if m.name == name:
|
26 |
return m
|
27 |
|
28 |
@classmethod
|
|
|
38 |
|
39 |
def __call__(self,
|
40 |
query: str,
|
41 |
+
auth_token: str = None,
|
42 |
system_prompt: str = None,
|
43 |
max_new_tokens: str = 256,
|
44 |
temperature: float = 1.0):
|
45 |
+
if auth_token is None:
|
46 |
+
auth_token = hf_api_token() # Attempt look up if not provided
|
47 |
headers = {"Authorization": f"Bearer {auth_token}"}
|
48 |
api_url = f"https://api-inference.huggingface.co/models/{self.id}"
|
49 |
if system_prompt is None:
|