alfraser commited on
Commit
5117e0a
·
1 Parent(s): 33de646

Refactored HF token ID access to a common function

Browse files
Files changed (3) hide show
  1. pages/005_LLM_Models.py +2 -3
  2. src/common.py +8 -0
  3. src/models.py +7 -5
pages/005_LLM_Models.py CHANGED
@@ -11,7 +11,6 @@ if st_setup('LLM Models'):
11
  HFLlamaChatModel.load_configs()
12
 
13
  SESSION_KEY_CHAT_SERVER = 'chat_server'
14
- HF_AUTH_KEY_SECRET = 'hf_token'
15
  button_count = 0
16
 
17
 
@@ -56,7 +55,7 @@ if st_setup('LLM Models'):
56
  with chat_container:
57
  with st.chat_message("user"):
58
  st.write(prompt)
59
- chat_model = HFLlamaChatModel.get_model(st.session_state[SESSION_KEY_CHAT_SERVER])
60
- response = chat_model(prompt, st.secrets[HF_AUTH_KEY_SECRET])
61
  with st.chat_message("assistant"):
62
  st.write(response)
 
11
  HFLlamaChatModel.load_configs()
12
 
13
  SESSION_KEY_CHAT_SERVER = 'chat_server'
 
14
  button_count = 0
15
 
16
 
 
55
  with chat_container:
56
  with st.chat_message("user"):
57
  st.write(prompt)
58
+ chat_model = HFLlamaChatModel.for_name(st.session_state[SESSION_KEY_CHAT_SERVER])
59
+ response = chat_model(prompt)
60
  with st.chat_message("assistant"):
61
  st.write(response)
src/common.py CHANGED
@@ -1,5 +1,13 @@
1
  import os
 
2
 
3
 
4
  data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
5
  config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
 
 
 
 
 
 
 
 
1
  import os
2
+ import streamlit as st
3
 
4
 
5
  data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
6
  config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
7
+
8
+ def hf_api_token() -> str:
9
+ #TODO: Need to consider how to make this more generic to look for a token elsewhere
10
+ token = st.secrets['hf_token']
11
+ if token is None:
12
+ raise ValueError('No HF access token found in streamlit secrets')
13
+ return token
src/models.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import requests
4
  from typing import List
5
 
6
- from src.common import config_dir
7
 
8
 
9
  class HFLlamaChatModel:
@@ -16,13 +16,13 @@ class HFLlamaChatModel:
16
  configs = json.load(f)['models']
17
  cls.models = []
18
  for cfg in configs:
19
- if cls.get_model(cfg['name']) is None:
20
  cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
21
 
22
  @classmethod
23
- def get_model(cls, model: str):
24
  for m in cls.models:
25
- if m.name == model:
26
  return m
27
 
28
  @classmethod
@@ -38,10 +38,12 @@ class HFLlamaChatModel:
38
 
39
  def __call__(self,
40
  query: str,
41
- auth_token: str,
42
  system_prompt: str = None,
43
  max_new_tokens: str = 256,
44
  temperature: float = 1.0):
 
 
45
  headers = {"Authorization": f"Bearer {auth_token}"}
46
  api_url = f"https://api-inference.huggingface.co/models/{self.id}"
47
  if system_prompt is None:
 
3
  import requests
4
  from typing import List
5
 
6
+ from src.common import config_dir, hf_api_token
7
 
8
 
9
  class HFLlamaChatModel:
 
16
  configs = json.load(f)['models']
17
  cls.models = []
18
  for cfg in configs:
19
+ if cls.for_name(cfg['name']) is None:
20
  cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
21
 
22
  @classmethod
23
+ def for_name(cls, name: str):
24
  for m in cls.models:
25
+ if m.name == name:
26
  return m
27
 
28
  @classmethod
 
38
 
39
  def __call__(self,
40
  query: str,
41
+ auth_token: str = None,
42
  system_prompt: str = None,
43
  max_new_tokens: str = 256,
44
  temperature: float = 1.0):
45
+ if auth_token is None:
46
+ auth_token = hf_api_token() # Attempt look up if not provided
47
  headers = {"Authorization": f"Bearer {auth_token}"}
48
  api_url = f"https://api-inference.huggingface.co/models/{self.id}"
49
  if system_prompt is None: