hunthinn commited on
Commit
a7538a7
1 Parent(s): fc4c647
Files changed (1) hide show
  1. modules/inference.py +11 -7
modules/inference.py CHANGED
@@ -1,11 +1,15 @@
1
- from transformers import T5Tokenizer, T5ForConditionalGeneration
2
 
3
- tokenizer = T5Tokenizer.from_pretrained("t5-small")
4
- model = T5ForConditionalGeneration.from_pretrained("t5-small")
 
 
5
 
6
 
7
- def infer_t5(input):
8
- input_ids = tokenizer(input, return_tensors="pt").input_ids
9
- outputs = model.generate(input_ids)
10
 
11
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
 
3
+ # Load the fine-tuned model and tokenizer
4
+ model_path = "gpt2"
5
+ tokenizer = GPT2Tokenizer.from_pretrained(model_path)
6
+ model = GPT2LMHeadModel.from_pretrained(model_path)
7
 
8
 
 
 
 
9
 
10
+ def infer_t5(input):
11
+ input_text = "Q: " + input + " A:"
12
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
13
+ output = model.generate(input_ids, max_length=50, temperature=0.7, num_return_sequences=1)
14
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
15
+ return response