Spaces:
Running
Running
| import os | |
| import argparse | |
| from pathlib import Path | |
| from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID | |
| def optimize_prompt(user_prompt, captions_dir=None, captions_list=None): | |
| """Optimize a user prompt to follow the same format as training captions. | |
| Args: | |
| user_prompt (str): The simple user prompt to optimize | |
| captions_dir (str, optional): Directory containing caption .txt files | |
| captions_list (list, optional): List of captions to use instead of loading from files | |
| """ | |
| all_captions = [] | |
| if captions_list: | |
| all_captions = captions_list | |
| elif captions_dir: | |
| # Collect all captions from text files in the directory | |
| captions_path = Path(captions_dir) | |
| for file_path in captions_path.glob("*.txt"): | |
| captions = extract_captions(file_path) | |
| all_captions.extend(captions) | |
| if not all_captions: | |
| raise ValueError("Please provide either caption files or a list of captions!") | |
| # Concatenate all captions with newlines | |
| captions_text = "\n".join(all_captions) | |
| client = get_together_client() | |
| messages = [ | |
| {"role": "system", "content": get_system_prompt()}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"These are all of the captions used to train the LoRA:\n\n" | |
| f"{captions_text}\n\n" | |
| f"Now optimize this prompt to follow the caption format used in training: " | |
| f"{user_prompt}" | |
| ) | |
| } | |
| ] | |
| response = client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=messages | |
| ) | |
| optimized_prompt = response.choices[0].message.content.strip() | |
| return optimized_prompt | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.') | |
| parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize') | |
| parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files') | |
| args = parser.parse_args() | |
| if not os.path.isdir(args.captions): | |
| print(f"Error: Captions directory '{args.captions}' does not exist.") | |
| return | |
| try: | |
| optimized_prompt = optimize_prompt(args.prompt, args.captions) | |
| print("\nOptimized Prompt:") | |
| print(optimized_prompt) | |
| except Exception as e: | |
| print(f"Error optimizing prompt: {e}") | |
| if __name__ == "__main__": | |
| main() | |