File size: 3,919 Bytes
897c136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from swift.tuners import Swift #chinese toolkit for finetunin and inference


from swift.llm import (
    get_model_tokenizer, get_template, inference, ModelType,
    get_default_template_type, inference_stream
)
from swift.utils import seed_everything
import torch
from tqdm import tqdm
import time

model_type = ModelType.phi3_vision_128k_instruct # model type
template_type = get_default_template_type(model_type)
print(f'template_type: {template_type}')

model_path = "./phi3-1476" # by default it is the lora path, not sure if it works the same way with merged checkpoint
model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, model_kwargs={'device_map': 'auto'})
model.generation_config.max_new_tokens = 1256 #generation params. As for me - defaults with do_sample=False works better than anything.
model.generation_config.do_sample = False
#model.generation_config.top_p = 0.7
#model.generation_config.temperature = 0.3
model = Swift.from_pretrained(model, model_path, "lora", inference_mode=True)
template = get_template(template_type, tokenizer)
#seed_everything(6321)

text = 'Make a caption that describe this image'
image_dir = './images/'  # path to images
txt_dir = './tags/'  # path to txt files with tags (from danbooru or from WD_Tagger)
maintxt_dir = './maintxt/'  # path for result txt caprtions in natureal language

# image parsing
image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

total_files = len(image_files)
start_time = time.time()

progress_bar = tqdm(total=total_files, unit='file', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]')
total_elapsed_time = 0
processed_files = 0

# Main captioning cycle
for image_file in image_files:
    image_path = os.path.join(image_dir, image_file)
    if os.path.exists(image_path):
        txt_file = os.path.splitext(image_file)[0] + '.txt'
        txt_path = os.path.join(txt_dir, txt_file)
        
        if os.path.exists(txt_path):
            with open(txt_path, 'r', encoding='utf-8') as f:
                tags = f.read().strip()
            
            text = f'<img>{image_path}</img> Make a caption that describe this image. Here is the tags describing image: {tags}\n Find the relevant character\'s names in the tags and use it.'
            print(text)
            step_start_time = time.time()
            response, history = inference(model, template, text, do_sample=True, temperature=0, repetition_penalty=1.05)
            step_end_time = time.time()
            step_time = step_end_time - step_start_time
            total_elapsed_time += step_time
            remaining_time = (total_elapsed_time / (processed_files + 1)) * (total_files - processed_files)
            
            remaining_hours = int(remaining_time // 3600)
            remaining_minutes = int((remaining_time % 3600) // 60)
            remaining_seconds = int(remaining_time % 60)
            
            progress_bar.set_postfix(remaining=f'\n', refresh=False)
            print(f"\n\n\nFile {image_file}\nConsumed time: {step_time:.2f} s\n{response}")
            
            # Создаем имя файла для сохранения ответа
            output_file = os.path.splitext(image_file)[0] + '.txt'
            output_path = os.path.join(maintxt_dir, output_file)
            
            # Записываем ответ в файл
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(response)
            
            print(f"Caption saved in file: {output_file} \n")
            processed_files += 1
            progress_bar.update(1)
        else:
            print(f"File {txt_file} doesn't exist.")
    else:
        print(f"Image {image_file} not found.")
progress_bar.close()