sjrhuschlee commited on
Commit
914343a
1 Parent(s): 1961e9a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -4
README.md CHANGED
@@ -61,14 +61,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
  question = f'{tokenizer.cls_token}Where do I live?' # '<cls>Where do I live?'
62
  context = 'My name is Sarah and I live in London'
63
  encoding = tokenizer(question, context, return_tensors="pt")
64
- start_scores, end_scores, _ = model(
65
  encoding["input_ids"],
66
- attention_mask=encoding["attention_mask"],
67
- return_dict=False
68
  )
69
 
70
  all_tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].tolist())
71
- answer_tokens = all_tokens[torch.argmax(start_scores):torch.argmax(end_scores) + 1]
72
  answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
73
  # 'London'
74
  ```
 
61
  question = f'{tokenizer.cls_token}Where do I live?' # '<cls>Where do I live?'
62
  context = 'My name is Sarah and I live in London'
63
  encoding = tokenizer(question, context, return_tensors="pt")
64
+ output = model(
65
  encoding["input_ids"],
66
+ attention_mask=encoding["attention_mask"]
 
67
  )
68
 
69
  all_tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].tolist())
70
+ answer_tokens = all_tokens[torch.argmax(output["start_logits"]):torch.argmax(output["end_logits"]) + 1]
71
  answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
72
  # 'London'
73
  ```