def clean_output(decoded_list): """Remove duplicates and trim whitespace""" return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()])) def preprocess_context(context): return f"generate question: {context.strip()}" def get_shap_values(tokenizer, model, prompt): # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True) input_ids = inputs["input_ids"] # Define wrapper prediction function def f(x): x = torch.tensor(x).long().to(model.device) # 🔧 convert to LongTensor with torch.no_grad(): out = model.generate( input_ids=x, max_length=64, do_sample=False, num_beams=2 ) return np.ones((x.shape[0], 1)) # dummy prediction # SHAP explainer explainer = shap.Explainer(f, input_ids.numpy()) shap_values = explainer(input_ids.numpy()) tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) return shap_values.values[0], tokens