File size: 4,609 Bytes
d721e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d952c8a
d721e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d952c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d721e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d952c8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import jsonlines
import pandas as pd
import time
from vllm import LLM, SamplingParams
from huggingface_hub import HfApi, Repository
import torch
from concurrent.futures import ThreadPoolExecutor
 
def generate_responses(llm, batch_texts, sampling_params):
    print("Generating responses for the current batch...")
    appended_prompts = [
        f"you are a captioner, you only generate 3 single sentence long captions as though the text were an image, and return the captions in an enumerated list with each being one sentence long and in quotes, and each a description of a hypothetical image inspired by [{prompt}]"
        for prompt in batch_texts
    ]
 
    outputs = llm.generate(appended_prompts, sampling_params)
 
    responses = [[output.outputs[k].text.strip() for k in range(len(output.outputs))] for output in outputs]
    return responses
 
def process_file(llm, filepath, sampling_params):
    print(f"Processing file: {filepath}")
    BATCH_SIZE = 128
    BATCH_INCREMENT = 32
    prev_eps = 0
    batch_texts = []
    df = pd.DataFrame()
    batch_counter = 0  # Counter to keep track of batches processed
 
    if filepath.endswith('.parquet'):
        print("Reading from a parquet file...")
        df = pd.read_parquet(filepath)
        batch_texts = df['TEXT'].tolist()
 
    total_prompts = len(batch_texts)
    print(f"Total prompts found: {total_prompts}")
 
    i = 0
    new_filepath = filepath.replace('.parquet', '_processed.jsonl')
    print(f"Data will be saved to: {new_filepath}")
 
    with jsonlines.open(new_filepath, 'w') as writer:
        with ThreadPoolExecutor() as executor:
            while i < total_prompts:
                batch = batch_texts[i:i+BATCH_SIZE]
 
                start_time = time.time()
                batch_responses = generate_responses(llm, batch, sampling_params)
                end_time = time.time()
 
                duration = end_time - start_time
                eps = len(batch) / duration
 
                # Adjust batch size based on examples per second
                if eps > prev_eps and BATCH_SIZE + BATCH_INCREMENT <= total_prompts - i:
                    BATCH_SIZE += BATCH_INCREMENT
                    print(f"Increasing batch size to: {BATCH_SIZE}")
                elif eps < prev_eps and BATCH_SIZE - BATCH_INCREMENT > 0:
                    BATCH_SIZE -= BATCH_INCREMENT
                    print(f"Decreasing batch size to: {BATCH_SIZE}")
 
                prev_eps = eps
 
                # Print progress and write to file after every batch.
                print(f"Processed: {min(i + BATCH_SIZE, total_prompts)}/{total_prompts}, Batch Size: {BATCH_SIZE}, EPS: {eps:.2f}")
                print("Writing to the new jsonl file...")
                for idx, text in enumerate(batch):
                    writer.write({'TEXT': text, 'RESPONSE': batch_responses[idx][0]})
 
                # Delete the processed rows from the original parquet file
                if not df.empty:
                    df = df.iloc[i + BATCH_SIZE:]
                    executor.submit(df.to_parquet, filepath)
 
                i += BATCH_SIZE
                batch_counter += 1
 
                # Push to hub every 10 batches
                if batch_counter % 10 == 0:
                    # Initialize the HuggingFace API
                    api = HfApi()
 
                    # Upload the processed file to the repository
                    try:
                        api.upload_file(
                            path_or_fileobj=new_filepath,
                            path_in_repo=new_filepath,
                            repo_id="AlignmentLab-AI/caption_creation_0.8",
                            repo_type="dataset",
                        )
                        print(f"Uploaded {new_filepath} to AlignmentLab-AI/caption_creation_0.8 repository.")
                    except Exception as e:
                        print(f"Error uploading file: {e}")
 
    # Delete the original parquet file if it is empty
    if df.empty:
        os.remove(filepath)
        print(f"Deleted the original file: {filepath}")
 
def main():
    folder_name = 'captionate'
    sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=100)
 
    print("Initializing the LLM model...")
    llm = LLM("Open-Orca/Mistral-7B-OpenOrca")
 
    print("Iterating through the files in the folder...")
    for filename in os.listdir(folder_name):
        if filename.endswith(".parquet"):
            process_file(llm, os.path.join(folder_name, filename), sampling_params)
 
if __name__ == "__main__":
    main()
 `