pseudotensor commited on
Commit
df5eeb7
1 Parent(s): 935bf6f

Update with h2oGPT hash 23aaa9c9839867b3f0c86e7722cc7fbdae414fc4

Browse files
Files changed (3) hide show
  1. src/db_utils.py +54 -0
  2. src/gpt_langchain.py +2 -51
  3. src/gradio_runner.py +5 -2
src/db_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+
3
+ from enums import LangChainMode
4
+
5
+
6
+ def set_userid(db1s, requests_state1, get_userid_auth):
7
+ db1 = db1s[LangChainMode.MY_DATA.value]
8
+ assert db1 is not None and len(db1) == length_db1()
9
+ if not db1[1]:
10
+ db1[1] = get_userid_auth(requests_state1)
11
+ if not db1[2]:
12
+ username1 = None
13
+ if 'username' in requests_state1:
14
+ username1 = requests_state1['username']
15
+ db1[2] = username1
16
+
17
+
18
+ def set_userid_direct(db1s, userid, username):
19
+ db1 = db1s[LangChainMode.MY_DATA.value]
20
+ db1[1] = userid
21
+ db1[2] = username
22
+
23
+
24
+ def get_userid_direct(db1s):
25
+ return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
26
+
27
+
28
+ def get_username_direct(db1s):
29
+ return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
30
+
31
+
32
+ def get_dbid(db1):
33
+ return db1[1]
34
+
35
+
36
+ def set_dbid(db1):
37
+ # can only call this after function called so for specific user, not in gr.State() that occurs during app init
38
+ assert db1 is not None and len(db1) == length_db1()
39
+ if db1[1] is None:
40
+ # uuid in db is used as user ID
41
+ db1[1] = str(uuid.uuid4())
42
+
43
+
44
+ def length_db1():
45
+ # For MyData:
46
+ # 0: db
47
+ # 1: userid and dbid
48
+ # 2: username
49
+
50
+ # For others:
51
+ # 0: db
52
+ # 1: dbid
53
+ # 2: None
54
+ return 3
src/gpt_langchain.py CHANGED
@@ -37,6 +37,8 @@ from langchain.tools import PythonREPLTool
37
  from langchain.tools.json.tool import JsonSpec
38
  from tqdm import tqdm
39
 
 
 
40
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
41
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
42
  have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
@@ -4655,57 +4657,6 @@ def get_sources_answer(query, docs, answer, scores, show_rank,
4655
  return ret, extra
4656
 
4657
 
4658
- def set_userid(db1s, requests_state1, get_userid_auth):
4659
- db1 = db1s[LangChainMode.MY_DATA.value]
4660
- assert db1 is not None and len(db1) == length_db1()
4661
- if not db1[1]:
4662
- db1[1] = get_userid_auth(requests_state1)
4663
- if not db1[2]:
4664
- username1 = None
4665
- if 'username' in requests_state1:
4666
- username1 = requests_state1['username']
4667
- db1[2] = username1
4668
-
4669
-
4670
- def set_userid_direct(db1s, userid, username):
4671
- db1 = db1s[LangChainMode.MY_DATA.value]
4672
- db1[1] = userid
4673
- db1[2] = username
4674
-
4675
-
4676
- def get_userid_direct(db1s):
4677
- return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
4678
-
4679
-
4680
- def get_username_direct(db1s):
4681
- return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
4682
-
4683
-
4684
- def get_dbid(db1):
4685
- return db1[1]
4686
-
4687
-
4688
- def set_dbid(db1):
4689
- # can only call this after function called so for specific user, not in gr.State() that occurs during app init
4690
- assert db1 is not None and len(db1) == length_db1()
4691
- if db1[1] is None:
4692
- # uuid in db is used as user ID
4693
- db1[1] = str(uuid.uuid4())
4694
-
4695
-
4696
- def length_db1():
4697
- # For MyData:
4698
- # 0: db
4699
- # 1: userid and dbid
4700
- # 2: username
4701
-
4702
- # For others:
4703
- # 0: db
4704
- # 1: dbid
4705
- # 2: None
4706
- return 3
4707
-
4708
-
4709
  def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
4710
  dbs=None,
4711
  load_db_if_exists=None, db_type=None,
 
37
  from langchain.tools.json.tool import JsonSpec
38
  from tqdm import tqdm
39
 
40
+ from src.db_utils import length_db1, set_dbid, set_userid, get_dbid, get_userid_direct, get_username_direct, \
41
+ set_userid_direct
42
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
43
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
44
  have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
 
4657
  return ret, extra
4658
 
4659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4660
  def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
4661
  dbs=None,
4662
  load_db_if_exists=None, db_type=None,
src/gradio_runner.py CHANGED
@@ -20,6 +20,7 @@ from iterators import TimeoutIterator
20
 
21
  from gradio_utils.css import get_css
22
  from gradio_utils.prompt_form import make_chatbots
 
23
 
24
  # This is a hack to prevent Gradio from phoning home when it gets imported
25
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@@ -459,7 +460,6 @@ def go_gradio(**kwargs):
459
  if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
460
  requests_state1.update(dict(host2=request.client.host))
461
  if not requests_state1.get('username', '') and hasattr(request, 'username'):
462
- from src.gpt_langchain import get_username_direct
463
  # use already-defined username instead of keep changing to new uuid
464
  # should be same as in requests_state1
465
  db_username = get_username_direct(db1s)
@@ -469,7 +469,6 @@ def go_gradio(**kwargs):
469
 
470
  def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
471
  requests_state1 = get_request_state(requests_state1, request, db1s)
472
- from src.gpt_langchain import set_userid
473
  set_userid(db1s, requests_state1, get_userid_auth)
474
  args_list = [db1s, requests_state1] + list(args)
475
  return tuple(args_list)
@@ -500,6 +499,8 @@ def go_gradio(**kwargs):
500
  inference_server=kwargs['inference_server'],
501
  prompt_type=kwargs['prompt_type'],
502
  prompt_dict=kwargs['prompt_dict'],
 
 
503
  )
504
  )
505
 
@@ -3746,6 +3747,8 @@ def go_gradio(**kwargs):
3746
  base_model=model_name, tokenizer_base_model=tokenizer_base_model,
3747
  lora_weights=lora_weights, inference_server=server_name,
3748
  prompt_type=prompt_type1, prompt_dict=prompt_dict1,
 
 
3749
  )
3750
 
3751
  max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)
 
20
 
21
  from gradio_utils.css import get_css
22
  from gradio_utils.prompt_form import make_chatbots
23
+ from src.db_utils import set_userid, get_username_direct
24
 
25
  # This is a hack to prevent Gradio from phoning home when it gets imported
26
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 
460
  if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
461
  requests_state1.update(dict(host2=request.client.host))
462
  if not requests_state1.get('username', '') and hasattr(request, 'username'):
 
463
  # use already-defined username instead of keep changing to new uuid
464
  # should be same as in requests_state1
465
  db_username = get_username_direct(db1s)
 
469
 
470
  def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
471
  requests_state1 = get_request_state(requests_state1, request, db1s)
 
472
  set_userid(db1s, requests_state1, get_userid_auth)
473
  args_list = [db1s, requests_state1] + list(args)
474
  return tuple(args_list)
 
499
  inference_server=kwargs['inference_server'],
500
  prompt_type=kwargs['prompt_type'],
501
  prompt_dict=kwargs['prompt_dict'],
502
+ visible_models=kwargs['visible_models'],
503
+ h2ogpt_key=kwargs['h2ogpt_key'],
504
  )
505
  )
506
 
 
3747
  base_model=model_name, tokenizer_base_model=tokenizer_base_model,
3748
  lora_weights=lora_weights, inference_server=server_name,
3749
  prompt_type=prompt_type1, prompt_dict=prompt_dict1,
3750
+ # FIXME: not typically required, unless want to expose adding h2ogpt endpoint in UI
3751
+ visible_models=None, h2ogpt_key=None,
3752
  )
3753
 
3754
  max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)