neuralleap commited on
Commit
ba89683
1 Parent(s): 35b74b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -63,31 +63,31 @@ terminators = [
63
  ]
64
 
65
 
66
- def generate_response(input_ids, generate_kwargs):
67
  try:
68
- # Generate the output using the model
69
  output = model.generate(**generate_kwargs)
70
- return output
71
  except Exception as e:
72
  print(f"Error during generation: {e}")
 
73
 
74
 
75
  @spaces.GPU(duration=120)
76
  def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=512):
77
  # Prepare conversation context
78
- conversation = [{"role": "user", "content": message}] + [{"role": "assistant", "content": reply} for reply in history]
79
- input_ids = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
80
 
81
  generate_kwargs = {
82
- "input_ids": input_ids,
83
- "max_length": input_ids.shape[1] + max_new_tokens,
84
  "temperature": temperature,
85
  "num_return_sequences": 1
86
  }
87
 
88
  # Thread for generating model response
89
  output_queue = []
90
- response_thread = Thread(target=generate_response, args=(input_ids, generate_kwargs, output_queue))
91
  response_thread.start()
92
  response_thread.join() # Wait for the thread to complete
93
 
 
63
  ]
64
 
65
 
66
+ def generate_response(input_ids, generate_kwargs, output_queue):
67
  try:
 
68
  output = model.generate(**generate_kwargs)
69
+ output_queue.append(output)
70
  except Exception as e:
71
  print(f"Error during generation: {e}")
72
+ output_queue.append(None)
73
 
74
 
75
  @spaces.GPU(duration=120)
76
  def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=512):
77
  # Prepare conversation context
78
+ conversation = [message] + [msg for pair in history for msg in pair]
79
+ inputs = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
80
 
81
  generate_kwargs = {
82
+ "input_ids": inputs,
83
+ "max_length": inputs.shape[1] + max_new_tokens,
84
  "temperature": temperature,
85
  "num_return_sequences": 1
86
  }
87
 
88
  # Thread for generating model response
89
  output_queue = []
90
+ response_thread = Thread(target=generate_response, args=(inputs, generate_kwargs, output_queue))
91
  response_thread.start()
92
  response_thread.join() # Wait for the thread to complete
93