lfoppiano commited on
Commit
01b5fcd
1 Parent(s): 3865d62

use models

Browse files
Files changed (1) hide show
  1. streamlit_app.py +8 -4
streamlit_app.py CHANGED
@@ -19,6 +19,10 @@ from document_qa.document_qa_engine import DocumentQAEngine
19
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
20
  from grobid_client_generic import GrobidClientGeneric
21
 
 
 
 
 
22
  if 'rqa' not in st.session_state:
23
  st.session_state['rqa'] = {}
24
 
@@ -117,17 +121,17 @@ def clear_memory():
117
  # @st.cache_resource
118
  def init_qa(model, api_key=None):
119
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
120
- if model == 'chatgpt-3.5-turbo':
121
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
122
  if api_key:
123
- chat = ChatOpenAI(model_name="gpt-3.5-turbo",
124
  temperature=0,
125
  openai_api_key=api_key,
126
  frequency_penalty=0.1)
127
  embeddings = OpenAIEmbeddings(openai_api_key=api_key)
128
 
129
  else:
130
- chat = ChatOpenAI(model_name="gpt-3.5-turbo",
131
  temperature=0,
132
  frequency_penalty=0.1)
133
  embeddings = OpenAIEmbeddings()
@@ -241,7 +245,7 @@ with st.sidebar:
241
  # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
242
  st.session_state['rqa'][model] = init_qa(model)
243
 
244
- elif model == 'chatgpt-3.5-turbo' and model not in st.session_state['api_keys']:
245
  if 'OPENAI_API_KEY' not in os.environ:
246
  api_key = st.text_input('OpenAI API Key', type="password")
247
  st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
 
19
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
20
  from grobid_client_generic import GrobidClientGeneric
21
 
22
+ OPENAI_MODELS = ['chatgpt-3.5-turbo',
23
+ "gpt-4",
24
+ "gpt-4-1106-preview"]
25
+
26
  if 'rqa' not in st.session_state:
27
  st.session_state['rqa'] = {}
28
 
 
121
  # @st.cache_resource
122
  def init_qa(model, api_key=None):
123
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
124
+ if model in OPENAI_MODELS:
125
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
126
  if api_key:
127
+ chat = ChatOpenAI(model_name=model,
128
  temperature=0,
129
  openai_api_key=api_key,
130
  frequency_penalty=0.1)
131
  embeddings = OpenAIEmbeddings(openai_api_key=api_key)
132
 
133
  else:
134
+ chat = ChatOpenAI(model_name=model,
135
  temperature=0,
136
  frequency_penalty=0.1)
137
  embeddings = OpenAIEmbeddings()
 
245
  # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
246
  st.session_state['rqa'][model] = init_qa(model)
247
 
248
+ elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
249
  if 'OPENAI_API_KEY' not in os.environ:
250
  api_key = st.text_input('OpenAI API Key', type="password")
251
  st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")