dejanseo commited on
Commit
eb404b7
1 Parent(s): f8a17aa

Upload 2 files

Browse files
Files changed (2) hide show
  1. synth.py +111 -0
  2. train.py +90 -0
synth.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch
3
+ from transformers import pipeline
4
+
5
+ # Initialize the chatbot with half-precision
6
+ chatbot = pipeline(
7
+ "text-generation",
8
+ model="mistralai/Mistral-7B-Instruct-v0.3",
9
+ torch_dtype=torch.float16,
10
+ device=0 # Assuming you are using a GPU
11
+ )
12
+
13
+ # Sentiments and their labels
14
+ sentiments = ["Positive", "Neutral", "Negative"]
15
+
16
+ # List of content formats to cycle through
17
+ formats = [
18
+ "Feature Stories", "Instructional Manuals", "FAQs", "Policy Documents", "Live Stream Descriptions",
19
+ "Editorial Content", "Research Papers", "User Manuals", "Commentaries", "Opinion Pieces",
20
+ "Newsletters", "Online Courses", "Photo Essays", "Annual Reports", "User-Generated Content",
21
+ "Testimonials", "DIY Content", "How-To Videos", "Campaign Reports", "Legal Briefs",
22
+ "Blog Posts", "Case Studies", "Tutorials", "Interviews", "Press Releases",
23
+ "eBooks", "Infographics", "Webinars", "Podcast Descriptions", "Video Scripts",
24
+ "Advertisements", "Forum Discussions", "Whitepapers", "Surveys", "Product Reviews",
25
+ "Event Summaries", "Opinion Editorials", "Letters to the Editor", "Round-Up Posts",
26
+ "Buying Guides", "Checklists", "Cheat Sheets", "Recipes", "Travel Guides",
27
+ "Profiles", "Lists", "Q&A Sessions", "Debates", "Polls"
28
+ ]
29
+
30
+ # List of topics to cycle through
31
+ topics = [
32
+ "Family", "Travel", "Politics", "Science", "Health", "Technology", "Sports",
33
+ "Education", "Environment", "Economics", "Culture", "History", "Music",
34
+ "Literature", "Food", "Art", "Fashion", "Entertainment", "Business",
35
+ "Relationships", "Fitness", "Automotive", "Finance", "Real Estate", "Law",
36
+ "Psychology", "Philosophy", "Religion", "Gardening", "DIY", "Hobbies",
37
+ "Pets", "Career", "Marketing", "Customer Service", "Networking", "Innovation",
38
+ "Artificial Intelligence", "Sustainability", "Social Issues", "Digital Media",
39
+ "Programming", "Cybersecurity", "Astronomy", "Geography", "Travel Tips",
40
+ "Cooking", "Parenting", "Productivity", "Mindfulness", "Mental Health",
41
+ "Self-Improvement", "Leadership", "Teamwork", "Volunteering", "Nonprofits",
42
+ "Gaming", "E-commerce", "Photography", "Videography", "Film", "Television",
43
+ "Streaming Services", "Podcasts", "Public Speaking", "Event Planning",
44
+ "Interior Design", "Architecture", "Urban Development", "Agriculture",
45
+ "Climate Change", "Renewable Energy", "Space Exploration", "Biotechnology",
46
+ "Cryptocurrency", "Blockchain", "Robotics", "Automated Systems", "Genetics",
47
+ "Medicine", "Pharmacy", "Veterinary Science", "Marine Biology", "Ecology",
48
+ "Conservation", "Wildlife", "Botany", "Zoology", "Geology", "Meteorology",
49
+ "Aviation", "Maritime", "Logistics", "Supply Chain", "Human Resources",
50
+ "Diversity and Inclusion", "Ethics", "Corporate Governance", "Public Relations",
51
+ "Journalism", "Advertising", "Sales", "Customer Experience", "Retail",
52
+ "Hospitality", "Tourism", "Luxury Goods", "Consumer Electronics", "Fashion Design",
53
+ "Textiles", "Jewelry", "Cosmetics", "Skincare", "Perfume", "Toys", "Gadgets",
54
+ "Home Appliances", "Furniture", "Home Improvement", "Landscaping", "Real Estate Investment"
55
+ ]
56
+
57
+ # CSV file setup with utf-8 encoding and quoting minimal
58
+ csv_file = "sentences.csv"
59
+ with open(csv_file, mode='w', newline='', encoding='utf-8') as file:
60
+ writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
61
+ writer.writerow(["text", "label"])
62
+
63
+ # Function to ensure correct quoting
64
+ def ensure_correct_quoting(text):
65
+ # Check if the text is already properly quoted
66
+ if text.startswith('"') and text.endswith('"'):
67
+ return text
68
+ else:
69
+ return f'"{text}"' # Add quotes if not already present
70
+
71
+ # Collect and save responses until reaching 100,000 rows
72
+ row_count = 0
73
+ format_index = 0
74
+ topic_index = 0
75
+
76
+ while row_count < 100000:
77
+ for idx, sentiment in enumerate(sentiments):
78
+ format_type = formats[format_index % len(formats)]
79
+ format_index += 1
80
+ topic = topics[topic_index % len(topics)]
81
+ topic_index += 1
82
+
83
+ # Add the current sentiment prompt with the format and topic
84
+ prompt = f"Write a single sentence of web content in Croatian. Content type: {format_type}. Topic: {topic}. Sentiment: {sentiment}."
85
+
86
+ response = chatbot(prompt, max_new_tokens=100) # Adjusted max_new_tokens for longer responses
87
+
88
+ # Debug print to check response format
89
+ print(f"Full model response: {response}")
90
+
91
+ # Extract the generated text from the response structure
92
+ generated_text = response[0]['generated_text']
93
+
94
+ # Remove any part of the prompt from the generated text if it exists
95
+ clean_text = generated_text.replace(prompt, "").strip().split('\n')[0]
96
+
97
+ # Ensure the text starts and ends with quotes only if it doesn't already
98
+ correctly_quoted_text = ensure_correct_quoting(clean_text)
99
+
100
+ # Append the clean response text to the CSV
101
+ with open(csv_file, mode='a', newline='', encoding='utf-8') as file:
102
+ writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
103
+ writer.writerow([correctly_quoted_text, idx])
104
+
105
+ row_count += 1
106
+ print(f"Response for sentiment '{sentiment}' saved to {csv_file}. Total rows: {row_count}")
107
+
108
+ if row_count >= 100000:
109
+ break
110
+
111
+ print("All responses saved. Total rows:", row_count)
train.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.model_selection import train_test_split
3
+ from transformers import ElectraTokenizer, ElectraForSequenceClassification, Trainer, TrainingArguments
4
+ import torch
5
+ from datasets import Dataset
6
+ import wandb
7
+ from sklearn.metrics import precision_recall_fscore_support, accuracy_score
8
+
9
+ # Load dataset
10
+ data = pd.read_csv('sentences.csv')
11
+
12
+ # Split dataset into train and eval sets
13
+ train_df, eval_df = train_test_split(data, test_size=0.2, random_state=42)
14
+
15
+ # Convert to Hugging Face Dataset
16
+ train_dataset = Dataset.from_pandas(train_df)
17
+ eval_dataset = Dataset.from_pandas(eval_df)
18
+
19
+ # Initialize the tokenizer and model
20
+ model_name = 'classla/bcms-bertic'
21
+ tokenizer = ElectraTokenizer.from_pretrained(model_name)
22
+ model = ElectraForSequenceClassification.from_pretrained(model_name, num_labels=3)
23
+
24
+ # Tokenize the datasets
25
+ def tokenize_function(examples):
26
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)
27
+
28
+ train_dataset = train_dataset.map(tokenize_function, batched=True)
29
+ eval_dataset = eval_dataset.map(tokenize_function, batched=True)
30
+
31
+ # Set format for PyTorch
32
+ train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
33
+ eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
34
+
35
+ # Define the compute_metrics function
36
+ def compute_metrics(p):
37
+ preds = p.predictions.argmax(-1)
38
+ precision, recall, f1, _ = precision_recall_fscore_support(p.label_ids, preds, average='weighted')
39
+ acc = accuracy_score(p.label_ids, preds)
40
+ return {
41
+ 'accuracy': acc,
42
+ 'precision': precision,
43
+ 'recall': recall,
44
+ 'f1': f1
45
+ }
46
+
47
+ # Define the training arguments
48
+ training_args = TrainingArguments(
49
+ output_dir='./results',
50
+ evaluation_strategy='epoch',
51
+ save_strategy='epoch',
52
+ learning_rate=1e-5,
53
+ per_device_train_batch_size=128,
54
+ per_device_eval_batch_size=128,
55
+ num_train_epochs=20,
56
+ weight_decay=0.01,
57
+ warmup_steps=500,
58
+ logging_dir='./logs',
59
+ logging_steps=10,
60
+ save_total_limit=20,
61
+ load_best_model_at_end=True,
62
+ metric_for_best_model='accuracy',
63
+ report_to='wandb',
64
+ run_name='sentiment-classification',
65
+ )
66
+
67
+ # Initialize WandB
68
+ wandb.init(project="sentiment-classification", entity="dejan")
69
+
70
+ # Define Trainer
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=training_args,
74
+ train_dataset=train_dataset,
75
+ eval_dataset=eval_dataset,
76
+ compute_metrics=compute_metrics
77
+ )
78
+
79
+ # Train the model
80
+ trainer.train()
81
+
82
+ # Evaluate the model
83
+ trainer.evaluate()
84
+
85
+ # Finish the WandB run
86
+ wandb.finish()
87
+
88
+ # Save the model
89
+ model.save_pretrained('./sentiment-model')
90
+ tokenizer.save_pretrained('./sentiment-model')