jerinaj commited on
Commit
47d6d47
·
1 Parent(s): 587fdf0
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -21,7 +21,9 @@ hf_token = os.environ.get("HF_TOKEN")
21
  if hf_token:
22
  huggingface_hub.login(token=hf_token)
23
 
24
- # Export model to OpenVINO format on first run if not already done
 
 
25
  if not os.path.isdir(OV_MODEL_DIR):
26
  print(f"OpenVINO model not found at '{OV_MODEL_DIR}', exporting now...")
27
  subprocess.run(
@@ -29,6 +31,7 @@ if not os.path.isdir(OV_MODEL_DIR):
29
  "optimum-cli", "export", "openvino",
30
  "--model", model_name,
31
  "--task", "text-generation-with-past",
 
32
  OV_MODEL_DIR + "/",
33
  ],
34
  check=True,
@@ -169,6 +172,11 @@ async def generate(request: Request):
169
 
170
  prompt = build_prompt(messages, tools)
171
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
172
  outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True)
173
 
174
  prompt_tokens = inputs["input_ids"].shape[-1]
 
21
  if hf_token:
22
  huggingface_hub.login(token=hf_token)
23
 
24
+ # Export model to OpenVINO format on first run if not already done.
25
+ # --disable-stateful avoids the static sliding-window shape (512) that gets
26
+ # baked in during tracing, which causes shape mismatches for long prompts.
27
  if not os.path.isdir(OV_MODEL_DIR):
28
  print(f"OpenVINO model not found at '{OV_MODEL_DIR}', exporting now...")
29
  subprocess.run(
 
31
  "optimum-cli", "export", "openvino",
32
  "--model", model_name,
33
  "--task", "text-generation-with-past",
34
+ "--disable-stateful",
35
  OV_MODEL_DIR + "/",
36
  ],
37
  check=True,
 
172
 
173
  prompt = build_prompt(messages, tools)
174
  inputs = tokenizer(prompt, return_tensors="pt")
175
+
176
+ # Truncate from the left if prompt exceeds model's context window (8192 tokens).
177
+ MAX_INPUT_TOKENS = 8192 - max_new_tokens
178
+ if inputs["input_ids"].shape[-1] > MAX_INPUT_TOKENS:
179
+ inputs = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in inputs.items()}
180
  outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True)
181
 
182
  prompt_tokens = inputs["input_ids"].shape[-1]