| | import os
|
| | import openai
|
| | import threading
|
| | from concurrent.futures import ThreadPoolExecutor
|
| | from openai import APIError
|
| |
|
| | API_KEY = os.getenv("DEEPSEEK_API_KEY", "your_api_key")
|
| |
|
| | class ThreadSafeWriter:
|
| | """线程安全写入器"""
|
| | def __init__(self, output_path: str):
|
| | self.file = open(output_path, 'a+', encoding='utf-8')
|
| | self.lock = threading.Lock()
|
| | self.counter = 0
|
| |
|
| | def write_line(self, content: str):
|
| | with self.lock:
|
| | self.file.write(content + '\n')
|
| | self.file.flush()
|
| | self.counter += 1
|
| |
|
| | def get_progress(self):
|
| | with self.lock:
|
| | return self.counter
|
| |
|
| | def close(self):
|
| | self.file.close()
|
| |
|
| | class DeepSeekBatchProcessor:
|
| | def __init__(self, max_workers: int = 100):
|
| | self.client = openai.OpenAI(
|
| | api_key=API_KEY,
|
| | base_url="https://api.deepseek.com/v1"
|
| | )
|
| | self.max_workers = max_workers
|
| | self.error_flag = threading.Event()
|
| | self.rate_limiter = threading.Semaphore(20)
|
| |
|
| | def process_batch(self, batch, writer: ThreadSafeWriter):
|
| | """批量处理,每个任务单独线程"""
|
| | futures = []
|
| | with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
| | for line_num, line in batch:
|
| | if self.error_flag.is_set():
|
| | break
|
| | futures.append(
|
| | executor.submit(
|
| | self._process_single_line,
|
| | line_num,
|
| | line,
|
| | writer
|
| | )
|
| | )
|
| | for future in futures:
|
| | future.result()
|
| |
|
| | def _process_single_line(self, line_num: int, line: str, writer: ThreadSafeWriter):
|
| | if self.error_flag.is_set():
|
| | return
|
| |
|
| |
|
| | separator = None
|
| | if ':' in line:
|
| | separator = ':'
|
| | elif ':' in line:
|
| | separator = ':'
|
| |
|
| | if not separator:
|
| | print(f"\n行 {line_num} 格式错误")
|
| | writer.write_line(f"格式错误:{line}")
|
| | return
|
| |
|
| | keywords_part, original_text = line.split(separator, 1)
|
| |
|
| | keywords = [kw.strip() for kw in keywords_part.split(",") if kw.strip()]
|
| | if not keywords:
|
| | keywords = ["无关键词"]
|
| |
|
| |
|
| | prompt = "请根据以下关键词写一首诗:" + ",".join(keywords)
|
| | messages = [{"role": "user", "content": prompt}]
|
| |
|
| | retries = 0
|
| | while retries < 3 and not self.error_flag.is_set():
|
| | try:
|
| | with self.rate_limiter:
|
| | response = self.client.chat.completions.create(
|
| | model="deepseek-reasoner",
|
| | messages=messages,
|
| | temperature=0.1
|
| | )
|
| |
|
| | reasoning_content = response.choices[0].message.reasoning_content.replace('\n', '').replace('\r', '')
|
| | poem_original = response.choices[0].message.content.replace('\n', '/').replace('\r', '')
|
| |
|
| | final_line = f"{','.join(keywords)}<think>{reasoning_content}</think>:{poem_original}"
|
| | writer.write_line(final_line)
|
| | progress = writer.get_progress()
|
| | print(f"\r已处理 {progress} 条", end='')
|
| | break
|
| |
|
| | except APIError as e:
|
| | if e.status_code == 402:
|
| | print(f"\n行 {line_num} 处理失败:API余额不足")
|
| | self.error_flag.set()
|
| | return
|
| | elif e.status_code == 429:
|
| | print(f"\n行 {line_num} 速率受限,重试中...")
|
| | retries += 1
|
| | if retries >= 3:
|
| | print(f"\n行 {line_num} 重试次数耗尽")
|
| | else:
|
| | print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
|
| | return
|
| |
|
| | except Exception as e:
|
| | print(f"\n行 {line_num} 处理异常:{str(e)}")
|
| | retries += 1
|
| | if retries >= 3:
|
| | print(f"\n行 {line_num} 重试次数耗尽")
|
| |
|
| | if retries >= 3 and not self.error_flag.is_set():
|
| | writer.write_line(f"处理失败:{line}")
|
| |
|
| | def process_first_1000_lines(input_path: str, output_path: str, max_workers: int = 100):
|
| | """仅读取前1000行数据,并使用多线程处理"""
|
| | processor = DeepSeekBatchProcessor(max_workers)
|
| | writer = ThreadSafeWriter(output_path)
|
| | batch = []
|
| | try:
|
| | with open(input_path, 'r', encoding='utf-8') as f:
|
| | for line_num, line in enumerate(f, 1):
|
| | if not line.strip():
|
| | continue
|
| | batch.append( (line_num, line.strip()) )
|
| | if line_num >= 1000:
|
| | break
|
| |
|
| | total = len(batch)
|
| | print(f"总数据量:{total} 条")
|
| | processor.process_batch(batch, writer)
|
| | print("\n处理完成!")
|
| | finally:
|
| | writer.close()
|
| |
|
| | if __name__ == '__main__':
|
| | input_file = "data/DSdata.txt"
|
| | output_file = "data/CoTdata.txt"
|
| | process_first_1000_lines(input_file, output_file, max_workers=100)
|
| |
|