Ashmi Banerjee commited on
Commit
316b9d4
1 Parent(s): c526665

bug fixes potentially

Browse files
Files changed (2) hide show
  1. models/gemini.py +11 -1
  2. setup/vertex_ai_setup.py +13 -9
models/gemini.py CHANGED
@@ -1,15 +1,25 @@
 
1
  from typing import Optional
2
  import sys
3
 
 
 
 
4
  sys.path.append("../")
5
  from setup.vertex_ai_setup import initialize_vertexai_params
6
  from vertexai.generative_models import GenerativeModel
7
 
 
 
 
8
 
9
- def get_gemini_response(prompt_text, model, parameters: Optional = None) -> str:
10
  initialize_vertexai_params()
 
11
  if model is None or parameters is None:
12
  model = "gemini-1.0-pro"
13
  model = GenerativeModel(model)
 
14
  model_response = model.generate_content(prompt_text)
 
15
  return model_response.text
 
1
+ import os
2
  from typing import Optional
3
  import sys
4
 
5
+ import vertexai
6
+ from dotenv import load_dotenv
7
+
8
  sys.path.append("../")
9
  from setup.vertex_ai_setup import initialize_vertexai_params
10
  from vertexai.generative_models import GenerativeModel
11
 
12
+ load_dotenv()
13
+ VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]
14
+
15
 
16
+ def get_gemini_response(prompt_text, model, parameters: Optional = None, location: Optional[str] = "us-central1") -> str:
17
  initialize_vertexai_params()
18
+
19
  if model is None or parameters is None:
20
  model = "gemini-1.0-pro"
21
  model = GenerativeModel(model)
22
+
23
  model_response = model.generate_content(prompt_text)
24
+
25
  return model_response.text
setup/vertex_ai_setup.py CHANGED
@@ -7,23 +7,27 @@ import json
7
  import base64
8
 
9
  load_dotenv()
10
-
11
- # TODO: fix it in spaces
12
-
13
  VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]
14
 
15
 
16
  def decode_service_key():
17
  encoded_key = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
18
  original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
19
- return original_service_key
 
 
20
 
21
 
22
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
23
- GOOGLE_APPLICATION_CREDENTIALS = decode_service_key()
24
- service_account.Credentials.from_service_account_info(GOOGLE_APPLICATION_CREDENTIALS, scopes=["https://www"
25
- ".googleapis.com/auth/cloud-platform"])
 
 
 
26
 
27
- print("service account worked")
 
 
 
28
  vertexai.init(project=VERTEXAI_PROJECT, location=location)
29
- print("init worked")
 
7
  import base64
8
 
9
  load_dotenv()
 
 
 
10
  VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]
11
 
12
 
13
  def decode_service_key():
14
  encoded_key = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
15
  original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
16
+ if original_service_key:
17
+ return original_service_key
18
+ return None
19
 
20
 
21
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
22
+ credentials = decode_service_key()
23
+ creds_file_name = os.getcwd() + "/.config/application_default_credentials.json"
24
+ with open(creds_file_name, 'w', encoding='utf-8') as file:
25
+ json.dump(credentials, file, ensure_ascii=False, indent=4)
26
+
27
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file_name
28
 
29
+ service_account.Credentials.from_service_account_file(
30
+ filename=os.environ["GOOGLE_APPLICATION_CREDENTIALS"],
31
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
32
+ )
33
  vertexai.init(project=VERTEXAI_PROJECT, location=location)