awinml commited on
Commit
244a3e0
1 Parent(s): c658480

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +4 -5
  2. utils.py +15 -1
app.py CHANGED
@@ -14,7 +14,8 @@ from utils import (
14
  format_query,
15
  get_flan_alpaca_xl_model,
16
  generate_alpaca_ner_prompt,
17
- generate_entities_flan_alpaca,
 
18
  format_entities_flan_alpaca,
19
  generate_flant5_prompt_instruct_chunk_context,
20
  generate_flant5_prompt_instruct_chunk_context_single,
@@ -56,9 +57,7 @@ col1, col2 = st.columns([3, 3], gap="medium")
56
  with st.sidebar:
57
  ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"])
58
 
59
- if ner_choice == "Alpaca":
60
- ner_model, ner_tokenizer = get_flan_alpaca_xl_model()
61
- else:
62
  ner_model = get_spacy_model()
63
 
64
  with col1:
@@ -70,7 +69,7 @@ with col1:
70
 
71
  if ner_choice == "Alpaca":
72
  ner_prompt = generate_alpaca_ner_prompt(query_text)
73
- entity_text = generate_entities_flan_alpaca(ner_model, ner_tokenizer, ner_prompt)
74
  company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
75
  else:
76
  company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
 
14
  format_query,
15
  get_flan_alpaca_xl_model,
16
  generate_alpaca_ner_prompt,
17
+ generate_entities_flan_alpaca_checkpoint,
18
+ generate_entities_flan_alpaca_inference_api,
19
  format_entities_flan_alpaca,
20
  generate_flant5_prompt_instruct_chunk_context,
21
  generate_flant5_prompt_instruct_chunk_context_single,
 
57
  with st.sidebar:
58
  ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"])
59
 
60
+ if ner_choice == "Spacy":
 
 
61
  ner_model = get_spacy_model()
62
 
63
  with col1:
 
69
 
70
  if ner_choice == "Alpaca":
71
  ner_prompt = generate_alpaca_ner_prompt(query_text)
72
+ entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
73
  company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
74
  else:
75
  company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
utils.py CHANGED
@@ -1,4 +1,6 @@
1
  import re
 
 
2
 
3
  import openai
4
  import pandas as pd
@@ -513,8 +515,20 @@ Company - Cisco, Quarter - none, Year - none
513
  ### Response:"""
514
  return prompt
515
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
- def generate_entities_flan_alpaca(model, tokenizer, prompt):
518
  model_inputs = tokenizer(prompt, return_tensors="pt")
519
  input_ids = inputs["input_ids"]
520
  generation_output = model.generate(
 
1
  import re
2
+ import json
3
+ import requests
4
 
5
  import openai
6
  import pandas as pd
 
515
  ### Response:"""
516
  return prompt
517
 
518
+ def generate_entities_flan_alpaca_inference_api(prompt):
519
+ API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
520
+ payload = {
521
+ "inputs": prompt,
522
+ "parameters": {"do_sample": True, "temperature":0.1, "max_length":80},
523
+ "options": {"use_cache": True, "wait_for_model": True}
524
+ }
525
+ data = json.dumps(payload)
526
+ response = requests.request("POST", API_URL, data=data)
527
+ output = json.loads(response.content.decode("utf-8"))[0]["generated_text"]
528
+ return output
529
+
530
 
531
+ def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
532
  model_inputs = tokenizer(prompt, return_tensors="pt")
533
  input_ids = inputs["input_ids"]
534
  generation_output = model.generate(