Upload make_conditions.py
Browse files- make_conditions.py +138 -0
make_conditions.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import jsonlines
|
3 |
+
import pandas as pd
|
4 |
+
import time
|
5 |
+
from vllm import LLM, SamplingParams
|
6 |
+
from huggingface_hub import HfApi, Repository
|
7 |
+
import torch
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
import random
|
11 |
+
|
12 |
+
|
13 |
+
def generate_responses(llm, batch_texts, sampling_params):
|
14 |
+
print("Generating responses for the current batch...")
|
15 |
+
appended_prompts = [
|
16 |
+
f"""<<SYS>> You are a highly intelligent, empathic, helpful, respectful, and honest assistant with high emotional intelligence. Always answer as helpfully and honest as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> TEXT TO ANALYZE: {text} INSTRUCTION: Write with 1 or 2 phrases to which category/genre the previous text belongs. Then list the central themes of the previous text as a list of keywords. Finally, write a summary of the previous text in one to a maximum of three sentences. Use the format: "CATEGORY/GENRE: ... KEYWORDS: ... SUMMARY: ... ": \nRESPONSE:"""
|
17 |
+
for text in batch_texts ]
|
18 |
+
|
19 |
+
|
20 |
+
outputs = llm.generate(appended_prompts, sampling_params)
|
21 |
+
|
22 |
+
responses1 = [[output.outputs[k].text.strip() for k in range(len(output.outputs))] for output in outputs]
|
23 |
+
|
24 |
+
appended_prompts = [
|
25 |
+
f"""<<SYS>> You are a highly intelligent, empathic, helpful, respectful, and honest assistant with high emotional intelligence. Always answer as helpfully and honest as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> TEXT TO ANALYZE: {text} INSTRUCTION: First, write in one sentence to what extent and in which degree the following text contains violence, physical violence, and psychological violence. Secondly, write in one sentence to what extent and in which degree the following text contains sexual content. Thirdly, explain to what extent this is still appropriate for children and non-adult teenagers. Finally, suggest an age rating from the following list [ "Suitable for kids & people of all ages", "Suitable for kids & people of age 6 or higher", "Suitable for teenagers & people of age 12 or higher", "Suitable for teenagers & people of age 16 or higher", "Suitable for adults of age 18 or higher"]: \nRESPONSE:"""
|
26 |
+
for text in batch_texts ]
|
27 |
+
|
28 |
+
|
29 |
+
outputs = llm.generate(appended_prompts, sampling_params)
|
30 |
+
|
31 |
+
responses2 = [[output.outputs[k].text.strip() for k in range(len(output.outputs))] for output in outputs]
|
32 |
+
|
33 |
+
responses= []
|
34 |
+
try:
|
35 |
+
for i in range(len(responses1)):
|
36 |
+
responses.append([responses1[i],responses2[i]])
|
37 |
+
|
38 |
+
except:
|
39 |
+
pass
|
40 |
+
|
41 |
+
return responses
|
42 |
+
|
43 |
+
def process_file(llm, filepath, sampling_params):
|
44 |
+
print(f"Processing file: {filepath}")
|
45 |
+
BATCH_SIZE = 128
|
46 |
+
BATCH_INCREMENT = 32
|
47 |
+
prev_eps = 0
|
48 |
+
batch_texts = []
|
49 |
+
df = pd.DataFrame()
|
50 |
+
batch_counter = 0 # Counter to keep track of batches processed
|
51 |
+
|
52 |
+
if filepath.endswith('.parquet'):
|
53 |
+
print("Reading from a parquet file...")
|
54 |
+
df = pd.read_parquet(filepath)
|
55 |
+
batch_texts = df['TEXT'].tolist()
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
total_prompts = len(batch_texts)
|
60 |
+
print(f"Total prompts found: {total_prompts}")
|
61 |
+
|
62 |
+
i = 0
|
63 |
+
new_filepath = filepath.replace('.parquet', '_processed.jsonl')
|
64 |
+
print(f"Data will be saved to: {new_filepath}")
|
65 |
+
|
66 |
+
with jsonlines.open(new_filepath, 'w') as writer:
|
67 |
+
with ThreadPoolExecutor() as executor:
|
68 |
+
while i < total_prompts:
|
69 |
+
batch = batch_texts[i:i+BATCH_SIZE]
|
70 |
+
|
71 |
+
start_time = time.time()
|
72 |
+
batch_responses = generate_responses(llm, batch, sampling_params)
|
73 |
+
end_time = time.time()
|
74 |
+
|
75 |
+
duration = end_time - start_time
|
76 |
+
eps = len(batch) / duration
|
77 |
+
|
78 |
+
# Adjust batch size based on examples per second
|
79 |
+
if eps > prev_eps and BATCH_SIZE + BATCH_INCREMENT <= total_prompts - i:
|
80 |
+
BATCH_SIZE += BATCH_INCREMENT
|
81 |
+
print(f"Increasing batch size to: {BATCH_SIZE}")
|
82 |
+
elif eps < prev_eps and BATCH_SIZE - BATCH_INCREMENT > 0:
|
83 |
+
BATCH_SIZE -= BATCH_INCREMENT
|
84 |
+
print(f"Decreasing batch size to: {BATCH_SIZE}")
|
85 |
+
|
86 |
+
prev_eps = eps
|
87 |
+
|
88 |
+
# Print progress and write to file after every batch.
|
89 |
+
print(f"Processed: {min(i + BATCH_SIZE, total_prompts)}/{total_prompts}, Batch Size: {BATCH_SIZE}, EPS: {eps:.2f}")
|
90 |
+
print("Writing to the new jsonl file...")
|
91 |
+
for idx, text in enumerate(batch):
|
92 |
+
writer.write({'TEXT': text, 'CONDITIONING': batch_responses[idx][0][0]+ "\n"+batch_responses[idx][1][0]})
|
93 |
+
|
94 |
+
# Delete the processed rows from the original parquet file
|
95 |
+
if not df.empty:
|
96 |
+
df = df.iloc[i + BATCH_SIZE:]
|
97 |
+
executor.submit(df.to_parquet, filepath)
|
98 |
+
|
99 |
+
i += BATCH_SIZE
|
100 |
+
batch_counter += 1
|
101 |
+
|
102 |
+
# Push to hub every 10 batches
|
103 |
+
if batch_counter % 10 == 0:
|
104 |
+
# Initialize the HuggingFace API
|
105 |
+
api = HfApi()
|
106 |
+
|
107 |
+
# Upload the processed file to the repository
|
108 |
+
try:
|
109 |
+
api.upload_file(
|
110 |
+
path_or_fileobj=new_filepath,
|
111 |
+
path_in_repo=new_filepath,
|
112 |
+
repo_id="AlignmentLab-AI/caption_creation_0.8",
|
113 |
+
repo_type="dataset",
|
114 |
+
)
|
115 |
+
print(f"Uploaded {new_filepath} to AlignmentLab-AI/caption_creation_0.8 repository.")
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Error uploading file: {e}")
|
118 |
+
|
119 |
+
# Delete the original parquet file if it is empty
|
120 |
+
if df.empty:
|
121 |
+
os.remove(filepath)
|
122 |
+
print(f"Deleted the original file: {filepath}")
|
123 |
+
|
124 |
+
def main():
|
125 |
+
folder_name = 'generate_conditions'
|
126 |
+
sampling_params = SamplingParams(temperature=0.4, top_p=0.95, max_tokens=200)
|
127 |
+
|
128 |
+
print("Initializing the LLM model...")
|
129 |
+
llm = LLM("Open-Orca/Mistral-7B-OpenOrca")
|
130 |
+
|
131 |
+
print("Iterating through the files in the folder...")
|
132 |
+
for filename in os.listdir(folder_name):
|
133 |
+
if filename.endswith(".parquet"):
|
134 |
+
process_file(llm, os.path.join(folder_name, filename), sampling_params)
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
main()
|
138 |
+
`
|