artificialguybr commited on
Commit
4c1f576
1 Parent(s): 8ab1b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -53,15 +53,21 @@ def _launch_demo(args, model, tokenizer, config):
53
  def predict(_query, _chatbot, _task_history):
54
  print(f"User: {_parse_text(_query)}")
55
  _chatbot.append((_parse_text(_query), ""))
56
- input_ids = input_ids.to('cuda')
57
- attention_mask = torch.ones(input_ids.shape).to('cuda')
58
- pad_token_id = tokenizer.eos_token_id
59
  # Tokenize the input
60
  input_ids = tokenizer.encode(_query, return_tensors='pt')
61
  print("Input IDs:", input_ids)
 
 
 
 
 
 
 
62
  # Generate a response using the model
63
  generated_ids = model.generate(input_ids, max_length=300)
64
  print("Generated IDs:", generated_ids)
 
65
  # Decode the generated IDs to text
66
  full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
67
 
@@ -73,6 +79,7 @@ def _launch_demo(args, model, tokenizer, config):
73
  _task_history.append((_query, full_response))
74
  print(f"OpenHermes: {_parse_text(full_response)}")
75
 
 
76
  def regenerate(_chatbot, _task_history):
77
  if not _task_history:
78
  yield _chatbot
 
53
  def predict(_query, _chatbot, _task_history):
54
  print(f"User: {_parse_text(_query)}")
55
  _chatbot.append((_parse_text(_query), ""))
56
+
 
 
57
  # Tokenize the input
58
  input_ids = tokenizer.encode(_query, return_tensors='pt')
59
  print("Input IDs:", input_ids)
60
+
61
+ # Move input_ids to CUDA if available
62
+ input_ids = input_ids.to('cuda')
63
+
64
+ # Generate attention_mask
65
+ attention_mask = torch.ones(input_ids.shape).to('cuda')
66
+
67
  # Generate a response using the model
68
  generated_ids = model.generate(input_ids, max_length=300)
69
  print("Generated IDs:", generated_ids)
70
+
71
  # Decode the generated IDs to text
72
  full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
73
 
 
79
  _task_history.append((_query, full_response))
80
  print(f"OpenHermes: {_parse_text(full_response)}")
81
 
82
+
83
  def regenerate(_chatbot, _task_history):
84
  if not _task_history:
85
  yield _chatbot