ajeetkumar01 commited on
Commit
ab65d01
1 Parent(s): 1854a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -6,22 +6,20 @@ model_name = "mistralai/Mistral-7B-Instruct-v0.2"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def generate_response(messages):
10
  """
11
  Generate response based on the given user messages.
12
  Parameters:
13
- - messages (list): A list of dictionaries containing user messages with roles.
14
  Returns:
15
  - response (str): The generated response.
16
  """
17
- # Apply chat template and encode messages
18
- encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
19
- # Move inputs to device
20
- model_inputs = encodeds
21
  # Generate response
22
- generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
23
  # Decode the generated response
24
- response = tokenizer.batch_decode(generated_ids)[0]
25
  return response
26
 
27
  # Define Gradio interface components
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ def generate_response(input_text):
10
  """
11
  Generate response based on the given user messages.
12
  Parameters:
13
+ - input_text (str): A single string containing all user messages.
14
  Returns:
15
  - response (str): The generated response.
16
  """
17
+ # Tokenize the input text
18
+ inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
 
 
19
  # Generate response
20
+ generated_ids = model.generate(inputs, max_length=1024, do_sample=True)
21
  # Decode the generated response
22
+ response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
23
  return response
24
 
25
  # Define Gradio interface components