SunDou commited on
Commit
48128bb
·
verified ·
1 Parent(s): a51dbcc

Upload data2/step22/gemini_generation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data2/step22/gemini_generation.py +160 -0
data2/step22/gemini_generation.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from pathlib import Path
6
+ import jsonlines
7
+ from tqdm import tqdm
8
+
9
+ import vertexai
10
+ from vertexai.generative_models import GenerativeModel
11
+
12
+ # ======================
13
+ # 1. 初始化
14
+ # ======================
15
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/weifengsun/tangou1/step2/gemini.json"
16
+ vertexai.init(project="tangou")
17
+ model = GenerativeModel("gemini-2.5-flash-lite")
18
+
19
+ # ======================
20
+ # 2. Gemini 价格配置
21
+ # ======================
22
+ PRICE_INPUT_PER_1M = 0.1
23
+ PRICE_OUTPUT_PER_1M = 0.4
24
+
25
+ # ======================
26
+ # 3. 全局消耗 & 熔断器
27
+ # ======================
28
+ class APIMonitor:
29
+ def __init__(self, max_usd: float):
30
+ self.max_usd = max_usd
31
+ self.input_tokens = 0
32
+ self.output_tokens = 0
33
+ self.total_cost = 0.0
34
+ self.lock = threading.Lock()
35
+ self.stop_event = threading.Event()
36
+ self.start_time = time.time()
37
+
38
+ @staticmethod
39
+ def estimate_tokens(text: str) -> int:
40
+ return max(1, len(text) // 3)
41
+
42
+ def reserve_input(self, prompt: str):
43
+ est = self.estimate_tokens(prompt)
44
+ est_cost = est / 1_000_000 * PRICE_INPUT_PER_1M
45
+ with self.lock:
46
+ if self.total_cost + est_cost > self.max_usd:
47
+ self.stop_event.set()
48
+ raise RuntimeError("💥 API budget exceeded (input)")
49
+ self.input_tokens += est
50
+ self.total_cost += est_cost
51
+
52
+ def record_output(self, text: str):
53
+ est = self.estimate_tokens(text)
54
+ cost = est / 1_000_000 * PRICE_OUTPUT_PER_1M
55
+ with self.lock:
56
+ self.output_tokens += est
57
+ self.total_cost += cost
58
+ if self.total_cost > self.max_usd:
59
+ self.stop_event.set()
60
+ raise RuntimeError("💥 API budget exceeded (output)")
61
+
62
+ def snapshot(self):
63
+ with self.lock:
64
+ return {
65
+ "input_tokens": self.input_tokens,
66
+ "output_tokens": self.output_tokens,
67
+ "total_cost": round(self.total_cost, 6),
68
+ "elapsed": round(time.time() - self.start_time, 2),
69
+ }
70
+
71
+ monitor = APIMonitor(max_usd=100.0)
72
+
73
+ # ======================
74
+ # 4. 推理函数
75
+ # ======================
76
+ def infer_one(prompt: str, idx: int):
77
+ if monitor.stop_event.is_set():
78
+ return {"idx": idx, "status": "stopped", "output": ""}
79
+
80
+ try:
81
+ monitor.reserve_input(prompt)
82
+ resp = model.generate_content(prompt)
83
+ text = resp.text or ""
84
+ monitor.record_output(text)
85
+ return {"idx": idx, "status": "ok", "output": text}
86
+ except Exception as e:
87
+ return {"idx": idx, "status": "error", "error": str(e)}
88
+
89
+ # ======================
90
+ # 5. 读取输入 prompt
91
+ # ======================
92
+ prompt_template = Path("prompt.txt").read_text(encoding="utf-8")
93
+ input_file = "/home/weifengsun/tangou1/step2/step22/output/function_filtered_scores.jsonl"
94
+ inputs = []
95
+
96
+ amount = 500000
97
+ with jsonlines.open(input_file, "r") as reader:
98
+ for obj in reader:
99
+ if amount == 0:
100
+ break
101
+ amount -= 1
102
+ prompt = prompt_template.replace("<<<CODE>>>", obj["code_content"]).replace(
103
+ "<<<README>>>", obj["md_summary"]
104
+ )
105
+ inputs.append(prompt)
106
+
107
+ # ======================
108
+ # 6. 断点续跑
109
+ # ======================
110
+ output_file = "/home/weifengsun/tangou1/step2/step22/output/gemini_results.jsonl"
111
+ completed_idx = set()
112
+ if os.path.exists(output_file):
113
+ with jsonlines.open(output_file, "r") as reader:
114
+ for obj in reader:
115
+ completed_idx.add(obj["idx"])
116
+
117
+ # 只处理未完成的
118
+ tasks = [(idx, prompt) for idx, prompt in enumerate(inputs) if idx not in completed_idx]
119
+ total_tasks = len(inputs)
120
+ remaining_tasks = len(tasks)
121
+ print(f"Total: {total_tasks}, Completed: {len(completed_idx)}, Remaining: {remaining_tasks}")
122
+
123
+ # ======================
124
+ # 7. 并行执行 + 即时写入 + 进度条
125
+ # ======================
126
+ write_lock = threading.Lock()
127
+ MAX_WORKERS = 8
128
+
129
+ with jsonlines.open(output_file, mode="a", flush=True) as writer, ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
130
+ futures = {executor.submit(infer_one, prompt, idx): idx for idx, prompt in tasks}
131
+
132
+ pbar = tqdm(total=remaining_tasks, desc="Generating", unit="item")
133
+
134
+ for future in as_completed(futures):
135
+ result = future.result()
136
+
137
+ # 写入 JSONL
138
+ with write_lock:
139
+ writer.write(result)
140
+
141
+ # 更新进度条
142
+ pbar.update(1)
143
+
144
+ # 显示 ETA 与成本
145
+ snap = monitor.snapshot()
146
+ pbar.set_postfix({
147
+ "cost": f"${snap['total_cost']}",
148
+ "in_tok": snap["input_tokens"],
149
+ "out_tok": snap["output_tokens"],
150
+ "elapsed_s": snap["elapsed"]
151
+ })
152
+
153
+ # 超预算停止
154
+ if monitor.stop_event.is_set():
155
+ print("🛑 Budget limit reached. Stopping all requests.")
156
+ break
157
+
158
+ pbar.close()
159
+
160
+ print("✅ All done.")