|
|
|
|
|
import os |
|
|
from queue import Queue |
|
|
from threading import Thread |
|
|
from typing import Any, Dict, List, Literal, Union |
|
|
|
|
|
import json |
|
|
import requests |
|
|
import torch.distributed as dist |
|
|
from accelerate.utils import gather_object |
|
|
from modelscope.hub.api import ModelScopeConfig |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .env import is_master |
|
|
from .logger import get_logger |
|
|
from .utils import check_json_format |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def download_ms_file(url: str, local_path: str, cookies=None) -> None: |
|
|
if cookies is None: |
|
|
cookies = ModelScopeConfig.get_cookies() |
|
|
resp = requests.get(url, cookies=cookies, stream=True) |
|
|
with open(local_path, 'wb') as f: |
|
|
for data in tqdm(resp.iter_lines()): |
|
|
f.write(data) |
|
|
|
|
|
|
|
|
def read_from_jsonl(fpath: str, encoding: str = 'utf-8') -> List[Any]: |
|
|
res: List[Any] = [] |
|
|
with open(fpath, 'r', encoding=encoding) as f: |
|
|
for line in f: |
|
|
res.append(json.loads(line)) |
|
|
return res |
|
|
|
|
|
|
|
|
def write_to_jsonl(fpath: str, obj_list: List[Any], encoding: str = 'utf-8') -> None: |
|
|
res: List[str] = [] |
|
|
for obj in obj_list: |
|
|
res.append(json.dumps(obj, ensure_ascii=False)) |
|
|
with open(fpath, 'w', encoding=encoding) as f: |
|
|
text = '\n'.join(res) |
|
|
f.write(f'{text}\n') |
|
|
|
|
|
|
|
|
class JsonlWriter: |
|
|
|
|
|
def __init__(self, fpath: str, *, encoding: str = 'utf-8', strict: bool = True, enable_async: bool = False): |
|
|
self.fpath = os.path.abspath(os.path.expanduser(fpath)) if is_master() else None |
|
|
self.encoding = encoding |
|
|
self.strict = strict |
|
|
self.enable_async = enable_async |
|
|
self._queue = Queue() |
|
|
self._thread = None |
|
|
|
|
|
def _append_worker(self): |
|
|
while True: |
|
|
item = self._queue.get() |
|
|
self._append(**item) |
|
|
|
|
|
def _append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False): |
|
|
if isinstance(obj, (list, tuple)) and all(isinstance(item, dict) for item in obj): |
|
|
obj_list = obj |
|
|
else: |
|
|
obj_list = [obj] |
|
|
if gather_obj and dist.is_initialized(): |
|
|
obj_list = gather_object(obj_list) |
|
|
if not is_master(): |
|
|
return |
|
|
obj_list = check_json_format(obj_list) |
|
|
for i, _obj in enumerate(obj_list): |
|
|
obj_list[i] = json.dumps(_obj, ensure_ascii=False) + '\n' |
|
|
self._write_buffer(''.join(obj_list)) |
|
|
|
|
|
def append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False): |
|
|
if self.enable_async: |
|
|
if self._thread is None: |
|
|
self._thread = Thread(target=self._append_worker, daemon=True) |
|
|
self._thread.start() |
|
|
self._queue.put({'obj': obj, 'gather_obj': gather_obj}) |
|
|
else: |
|
|
self._append(obj, gather_obj=gather_obj) |
|
|
|
|
|
def _write_buffer(self, text: str): |
|
|
if not text: |
|
|
return |
|
|
assert is_master(), f'is_master(): {is_master()}' |
|
|
try: |
|
|
os.makedirs(os.path.dirname(self.fpath), exist_ok=True) |
|
|
with open(self.fpath, 'a', encoding=self.encoding) as f: |
|
|
f.write(text) |
|
|
except Exception: |
|
|
if self.strict: |
|
|
raise |
|
|
logger.error(f'Cannot write content to jsonl file. text: {text}') |
|
|
|
|
|
|
|
|
def append_to_jsonl(fpath: str, obj: Union[Dict, List[Dict]], *, encoding: str = 'utf-8', strict: bool = True) -> None: |
|
|
jsonl_writer = JsonlWriter(fpath, encoding=encoding, strict=strict) |
|
|
jsonl_writer.append(obj) |
|
|
|
|
|
|
|
|
def get_file_mm_type(file_name: str) -> Literal['image', 'video', 'audio']: |
|
|
video_extensions = {'.mp4', '.mkv', '.mov', '.avi', '.wmv', '.flv', '.webm'} |
|
|
audio_extensions = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a'} |
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} |
|
|
|
|
|
_, ext = os.path.splitext(file_name) |
|
|
|
|
|
if ext.lower() in video_extensions: |
|
|
return 'video' |
|
|
elif ext.lower() in audio_extensions: |
|
|
return 'audio' |
|
|
elif ext.lower() in image_extensions: |
|
|
return 'image' |
|
|
else: |
|
|
raise ValueError(f'file_name: {file_name}, ext: {ext}') |
|
|
|