zpbrent commited on
Commit
deef385
1 Parent(s): 5f1e856

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -14,10 +14,10 @@ This model is a non-finetuned RAG-Sequence model and was created as follows:
14
  ```python
15
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer
16
 
17
- model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")
18
 
19
- question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
20
- generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
21
 
22
  tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
23
  model.config.use_dummy_dataset = True
@@ -40,9 +40,9 @@ The model can be fine-tuned as follows:
40
  ```python
41
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
42
 
43
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
44
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
45
- model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
46
 
47
  input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt")
48
 
 
14
  ```python
15
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer
16
 
17
+ model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("repo_name")
18
 
19
+ question_encoder_tokenizer = AutoTokenizer.from_pretrained("repo_name")
20
+ generator_tokenizer = AutoTokenizer.from_pretrained("repo_name")
21
 
22
  tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
23
  model.config.use_dummy_dataset = True
 
40
  ```python
41
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
42
 
43
+ tokenizer = RagTokenizer.from_pretrained("repo_name")
44
+ retriever = RagRetriever.from_pretrained("repo_name")
45
+ model = RagTokenForGeneration.from_pretrained("repo_name", retriever=retriever)
46
 
47
  input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt")
48