Wonderplex commited on
Commit
e24fbee
1 Parent(s): fe95067

added env variable options for keys (#49)

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. sotopia_pi_generate.py +6 -2
app.py CHANGED
@@ -8,8 +8,10 @@ from utils import Environment, Agent, get_context_prompt, dialogue_history_promp
8
  from functools import cache
9
  from sotopia_pi_generate import prepare_model, generate_action
10
 
11
- with open("openai_api.key", "r") as f:
12
- os.environ["OPENAI_API_KEY"] = f.read().strip()
 
 
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
  DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
 
8
  from functools import cache
9
  from sotopia_pi_generate import prepare_model, generate_action
10
 
11
+ OPENAI_KEY_FILE="./openai_api.key"
12
+ if os.path.exists(OPENAI_KEY_FILE):
13
+ with open(OPENAI_KEY_FILE, "r") as f:
14
+ os.environ["OPENAI_API_KEY"] = f.read().strip()
15
 
16
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
17
  DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
sotopia_pi_generate.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  from typing import TypeVar
3
  from functools import cache
4
  import logging
@@ -78,8 +79,11 @@ def generate_action(
78
  @cache
79
  def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
80
  compute_type = torch.float16
81
- with open (hf_token_key_file, 'r') as f:
82
- hf_token = f.read().strip()
 
 
 
83
 
84
  if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
85
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
 
1
  import re
2
+ import os
3
  from typing import TypeVar
4
  from functools import cache
5
  import logging
 
79
  @cache
80
  def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
81
  compute_type = torch.float16
82
+ if os.path.exists(hf_token_key_file):
83
+ with open (hf_token_key_file, 'r') as f:
84
+ hf_token = f.read().strip()
85
+ else:
86
+ hf_token = os.environ["HF_TOKEN"]
87
 
88
  if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
89
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)