Flexy0 commited on
Commit
5736201
β€’
1 Parent(s): 9f08cf9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
2
+ # Load tokenizer and model
3
+ tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
4
+ model = AutoModelForSeq2SeqLM.from_pretrained("Kaludi/chatgpt-gpt4-prompts-bart-large-cnn-samsum", from_tf=True)
5
+
6
+ # Assuming you have your own dataset for fine-tuning
7
+ # Replace this with loading your dataset as needed
8
+ # For example, you can use the datasets library for loading datasets
9
+ # See previous responses for an example of how to use datasets
10
+
11
+ # Define data collator for sequence-to-sequence modeling
12
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
13
+
14
+ # Define training arguments
15
+ training_args = Seq2SeqTrainingArguments(
16
+ output_dir="./gpt4-text-gen",
17
+ overwrite_output_dir=True,
18
+ per_device_train_batch_size=4,
19
+ save_steps=10_000,
20
+ save_total_limit=2,
21
+ )
22
+
23
+ # Create Seq2SeqTrainer
24
+ trainer = Seq2SeqTrainer(
25
+ model=model,
26
+ args=training_args,
27
+ data_collator=data_collator,
28
+ train_dataset=your_training_dataset, # Replace with your training dataset
29
+ )
30
+
31
+ # Train the model
32
+ trainer.train()
33
+
34
+ # Save the fine-tuned model and tokenizer
35
+ model.save_pretrained("./gpt4-text-gen")
36
+ tokenizer.save_pretrained("./gpt4-text-gen")
37
+
38
+ # Generate text using the fine-tuned model
39
+ input_text = "Once upon a time"
40
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
41
+ output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
42
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
43
+ print("Generated Text: ", generated_text)