sayakpaul HF Staff commited on
Commit
555f59f
·
verified ·
1 Parent(s): 9ec9de7

Create prompt_expander.py

Browse files
Files changed (1) hide show
  1. prompt_expander.py +67 -0
prompt_expander.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+ from diffusers.modular_pipelines import (
4
+ PipelineState,
5
+ ModularPipelineBlocks,
6
+ InputParam,
7
+ OutputParam,
8
+ )
9
+ import google.generativeai as genai
10
+ import os
11
+
12
+ SYSTEM_PROMPT = (
13
+ "You are an expert image generation assistant. "
14
+ "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
15
+ "Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
16
+ "Avoid vague terms — be specific about style, perspective, and mood. "
17
+ "Try to keep the output under 512 tokens."
18
+ )
19
+
20
+ class GeminiPromptExpander(ModularPipelineBlocks):
21
+ def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT):
22
+ super().__init__()
23
+ api_key = os.getenv("GOOGLE_API_KEY")
24
+ if api_key is None:
25
+ raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
26
+ genai.configure(api_key=api_key)
27
+ self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt)
28
+
29
+ @property
30
+ def expected_components(self):
31
+ return []
32
+
33
+ @property
34
+ def inputs(self) -> List[InputParam]:
35
+ return [
36
+ InputParam(
37
+ "prompt",
38
+ type_hint=str,
39
+ required=True,
40
+ description="Prompt to use",
41
+ )
42
+ ]
43
+
44
+ @property
45
+ def intermediate_outputs(self) -> List[OutputParam]:
46
+ return [
47
+ OutputParam(
48
+ "prompt",
49
+ type_hint=str,
50
+ description="Expanded prompt by the LLM",
51
+ ),
52
+ OutputParam(
53
+ "old_prompt",
54
+ type_hint=str,
55
+ description="Old prompt provided by the user",
56
+ )
57
+ ]
58
+
59
+
60
+ def __call__(self, components, state: PipelineState) -> PipelineState:
61
+ block_state = self.get_block_state(state)
62
+
63
+ block_state.old_prompt = block_state.prompt
64
+ block_state.prompt = self.model.generate_content(block_state.old_prompt).text
65
+ self.set_block_state(state, block_state)
66
+
67
+ return components, state