manandey commited on
Commit
9c9d8ef
1 Parent(s): e0a6493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -16,7 +16,7 @@ def generate(html, entity, website_desc, datasource, year, month, title, prompt)
16
  entity_text = entity_text + " |" + ent + "|"
17
  entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
18
  else:
19
- entity_text = ""
20
  website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
21
  datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
22
  year_text = "Year: " + year + " | " if year != "" else ""
@@ -26,11 +26,12 @@ def generate(html, entity, website_desc, datasource, year, month, title, prompt)
26
  final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
27
 
28
  model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
29
- tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer")
 
30
 
31
  inputs = tokenizer(final_prompt, return_tensors="pt")
32
 
33
- outputs = model.generate(**inputs, max_new_tokens=128)
34
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)
35
 
36
 
 
16
  entity_text = entity_text + " |" + ent + "|"
17
  entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
18
  else:
19
+ entity_text = "||| "
20
  website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
21
  datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
22
  year_text = "Year: " + year + " | " if year != "" else ""
 
26
  final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
27
 
28
  model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
29
+ tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True)
30
+ bad_words_ids = tokenizer(["<ENTITY_CHAIN>", " </ENTITY_CHAIN> "]).input_ids
31
 
32
  inputs = tokenizer(final_prompt, return_tensors="pt")
33
 
34
+ outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids)
35
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)
36
 
37