ranamhamoud commited on
Commit
9905ae2
โ€ข
1 Parent(s): bda6d90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -56,7 +56,14 @@ def generate(
56
  top_k: int = 50,
57
  repetition_penalty: float = 1.2,
58
  ) -> Iterator[str]:
59
- if model == "A":
 
 
 
 
 
 
 
60
  model = modelA
61
  tokenizer = tokenizerA
62
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
@@ -66,13 +73,7 @@ def generate(
66
  model = modelB
67
  tokenizer = tokenizerB
68
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
69
- conversation = []
70
- if system_prompt:
71
- conversation.append({"role": "system", "content": system_prompt})
72
- for user, assistant in chat_history:
73
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
74
- conversation.append({"role": "user", "content": message})
75
-
76
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
56
  top_k: int = 50,
57
  repetition_penalty: float = 1.2,
58
  ) -> Iterator[str]:
59
+
60
+ conversation = []
61
+ if system_prompt:
62
+ conversation.append({"role": "system", "content": system_prompt})
63
+ for user, assistant in chat_history:
64
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
65
+ conversation.append({"role": "user", "content": message})
66
+ if model == "A":
67
  model = modelA
68
  tokenizer = tokenizerA
69
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
 
73
  model = modelB
74
  tokenizer = tokenizerB
75
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
76
+
 
 
 
 
 
 
77
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
78
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
79
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")