archit11 commited on
Commit
a3db774
1 Parent(s): 7f7274f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -30,25 +30,6 @@ MODEL_OPTIONS = [
30
  models = {}
31
  tokenizers = {}
32
 
33
- # Custom chat templates
34
- MISTRAL_TEMPLATE = """<s>[INST] {instruction} [/INST]
35
- {response}
36
- </s>
37
- <s>[INST] {instruction} [/INST]
38
- """
39
-
40
- LLAMA_TEMPLATE = """<s>[INST] <<SYS>>
41
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
42
-
43
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
44
- <</SYS>>
45
-
46
- {instruction} [/INST]
47
- {response}
48
- </s>
49
- <s>[INST] {instruction} [/INST]
50
- """
51
-
52
  for model_id in MODEL_OPTIONS:
53
  tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
54
  models[model_id] = AutoModelForCausalLM.from_pretrained(
@@ -58,11 +39,9 @@ for model_id in MODEL_OPTIONS:
58
  )
59
  models[model_id].eval()
60
 
61
- # Set custom chat templates
62
- if "Navarna" in model_id:
63
- tokenizers[model_id].chat_template = MISTRAL_TEMPLATE
64
- elif "OpenHathi" in model_id:
65
- tokenizers[model_id].chat_template = LLAMA_TEMPLATE
66
 
67
  # Initialize Flask app
68
  app = Flask(__name__)
@@ -74,6 +53,25 @@ def log_results():
74
  print("Logged:", json.dumps(data, indent=2))
75
  return jsonify({"status": "success"}), 200
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @spaces.GPU(duration=90)
78
  def generate(
79
  model_id: str,
@@ -86,29 +84,28 @@ def generate(
86
  model = models[model_id]
87
  tokenizer = tokenizers[model_id]
88
 
89
- conversation = []
90
- for user, assistant in chat_history:
91
- conversation.extend([
92
- {"role": "user", "content": user},
93
- {"role": "assistant", "content": assistant},
94
- ])
95
- conversation.append({"role": "user", "content": message})
96
 
97
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
98
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
99
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
100
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
101
  input_ids = input_ids.to(model.device)
 
102
 
103
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
104
  generate_kwargs = dict(
105
  input_ids=input_ids,
 
106
  streamer=streamer,
107
  max_new_tokens=max_new_tokens,
108
  do_sample=True,
109
  top_p=top_p,
110
  temperature=temperature,
111
  num_beams=1,
 
112
  )
113
  t = Thread(target=model.generate, kwargs=generate_kwargs)
114
  t.start()
@@ -215,5 +212,5 @@ if __name__ == "__main__":
215
  flask_thread = Thread(target=app.run, kwargs={"host": "0.0.0.0", "port": 5000})
216
  flask_thread.start()
217
 
218
- # Start Gradio app
219
- demo.queue(max_size=10).launch()
 
30
  models = {}
31
  tokenizers = {}
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  for model_id in MODEL_OPTIONS:
34
  tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
35
  models[model_id] = AutoModelForCausalLM.from_pretrained(
 
39
  )
40
  models[model_id].eval()
41
 
42
+ # Set pad_token_id to eos_token_id if it's not set
43
+ if tokenizers[model_id].pad_token_id is None:
44
+ tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
 
 
45
 
46
  # Initialize Flask app
47
  app = Flask(__name__)
 
53
  print("Logged:", json.dumps(data, indent=2))
54
  return jsonify({"status": "success"}), 200
55
 
56
+ def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
57
+ if "OpenHathi" in model_id:
58
+ # OpenHathi model doesn't use a specific chat template
59
+ full_prompt = message
60
+ for history_message in chat_history:
61
+ full_prompt = f"{history_message[0]}\n{history_message[1]}\n{full_prompt}"
62
+ return tokenizers[model_id](full_prompt, return_tensors="pt")
63
+ elif "Navarna" in model_id:
64
+ # Navarna model uses a chat template
65
+ conversation = []
66
+ for user, assistant in chat_history:
67
+ conversation.extend([
68
+ {"role": "user", "content": user},
69
+ {"role": "assistant", "content": assistant},
70
+ ])
71
+ conversation.append({"role": "user", "content": message})
72
+ prompt = tokenizers[model_id].apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
73
+ return tokenizers[model_id](prompt, return_tensors="pt")
74
+
75
  @spaces.GPU(duration=90)
76
  def generate(
77
  model_id: str,
 
84
  model = models[model_id]
85
  tokenizer = tokenizers[model_id]
86
 
87
+ inputs = prepare_input(model_id, message, chat_history)
88
+ input_ids = inputs.input_ids
89
+ attention_mask = inputs.attention_mask
 
 
 
 
90
 
 
91
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
92
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
93
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
94
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
95
  input_ids = input_ids.to(model.device)
96
+ attention_mask = attention_mask.to(model.device)
97
 
98
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
99
  generate_kwargs = dict(
100
  input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
  streamer=streamer,
103
  max_new_tokens=max_new_tokens,
104
  do_sample=True,
105
  top_p=top_p,
106
  temperature=temperature,
107
  num_beams=1,
108
+ pad_token_id=tokenizer.eos_token_id,
109
  )
110
  t = Thread(target=model.generate, kwargs=generate_kwargs)
111
  t.start()
 
212
  flask_thread = Thread(target=app.run, kwargs={"host": "0.0.0.0", "port": 5000})
213
  flask_thread.start()
214
 
215
+ # Start Gradio app with public link
216
+ demo.queue(max_size=10).launch(share=True)