gosha6037 commited on
Commit
a4f8b32
1 Parent(s): 62851f3

Added description for bloom

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -25,6 +25,7 @@ model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-
25
 
26
  def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
27
  new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
 
28
  print('Started predict_common_bloom')
29
  print(f'history: {history}')
30
  if history != []:
@@ -32,13 +33,15 @@ def predict_common_bloom(model, tokenizer, input_text, history, person_descripti
32
  else:
33
  bot_input_ids = new_user_input_ids
34
  print(f'bot_input_ids: {bot_input_ids}')
 
35
 
36
  history = model.generate(
37
- bot_input_ids,
38
  max_new_tokens=number_of_new_tokens,
39
  pad_token_id=tokenizer.eos_token_id
40
  ).tolist()
41
  print(f'history: {history}')
 
42
 
43
  decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
44
  all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
@@ -128,6 +131,7 @@ gr.Interface(
128
  'DialoGPT-medium',
129
  'DialoGPT-large',
130
  'bloom-petals',
 
131
  ]
132
  ),
133
  gr.Radio(
 
25
 
26
  def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
27
  new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
28
+ person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt')
29
  print('Started predict_common_bloom')
30
  print(f'history: {history}')
31
  if history != []:
 
33
  else:
34
  bot_input_ids = new_user_input_ids
35
  print(f'bot_input_ids: {bot_input_ids}')
36
+ input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
37
 
38
  history = model.generate(
39
+ input_with_desc_ids,
40
  max_new_tokens=number_of_new_tokens,
41
  pad_token_id=tokenizer.eos_token_id
42
  ).tolist()
43
  print(f'history: {history}')
44
+ history[0] = history[0][len(person_description_ids[0]):]
45
 
46
  decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
47
  all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
 
131
  'DialoGPT-medium',
132
  'DialoGPT-large',
133
  'bloom-petals',
134
+ 'bloom-petals-cluster',
135
  ]
136
  ),
137
  gr.Radio(