debisoft commited on
Commit
ec9bb5c
1 Parent(s): 3efb331

n_shot_learning

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -11,9 +11,14 @@ _ = load_dotenv(find_dotenv())
11
  databricks_token = os.getenv('DATABRICKS_TOKEN')
12
  model_uri = "https://dbc-eb788f31-6c73.cloud.databricks.com/serving-endpoints/Mpt-7b-tester/invocations"
13
 
14
- def extract_json(gen_text):
15
  start_index = gen_text.index("### Response:\n{") + 14
16
- end_index = gen_text.index("}\n\n### End") + 1
 
 
 
 
 
17
  return gen_text[start_index:end_index]
18
 
19
  def score_model(model_uri, databricks_token, prompt):
@@ -43,7 +48,7 @@ def greet(input):
43
  response = get_completion(input)
44
  gen_text = response["predictions"][0]["generated_text"]
45
 
46
- return extract_json(gen_text)
47
  #return json.dumps(response)
48
 
49
  iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Prompt", lines=3)], outputs="text")
 
11
  databricks_token = os.getenv('DATABRICKS_TOKEN')
12
  model_uri = "https://dbc-eb788f31-6c73.cloud.databricks.com/serving-endpoints/Mpt-7b-tester/invocations"
13
 
14
+ def extract_json(gen_text, n_shot_learning=0):
15
  start_index = gen_text.index("### Response:\n{") + 14
16
+ if(n_shot_learning > 0) :
17
+ for i in range(0, n_shot_learning):
18
+ gen_text = gen_text[start_index:]
19
+ start_index = gen_text.index("### Response:\n{") + 14
20
+
21
+ end_index = gen_text.index("}\n\n### ") + 1
22
  return gen_text[start_index:end_index]
23
 
24
  def score_model(model_uri, databricks_token, prompt):
 
48
  response = get_completion(input)
49
  gen_text = response["predictions"][0]["generated_text"]
50
 
51
+ return extract_json(gen_text, 3)
52
  #return json.dumps(response)
53
 
54
  iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Prompt", lines=3)], outputs="text")