| import json
|
| import httpx
|
| from datetime import datetime, timedelta, timezone
|
| from typing import Optional, Dict, Any
|
| from urllib.parse import urlparse
|
| from config import Config
|
| from cache_manager import cache
|
|
|
| async def get_cid(force: bool = False) -> str:
|
| if not force:
|
| cached = cache.get_cid()
|
| if cached:
|
| return cached
|
|
|
| try:
|
| url = Config.get_cid_url()
|
|
|
| async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
|
| response = await client.get(url)
|
| response.raise_for_status()
|
| data = response.json()
|
|
|
| if 'cid' not in data:
|
| raise ValueError("CID not found in response")
|
|
|
| cid = data['cid']
|
| cache.set_cid(cid)
|
| return cid
|
|
|
| except Exception as e:
|
| if cache.cid:
|
| return cache.cid
|
| raise e
|
|
|
| async def get_auth(force: bool = False, retry_count: int = 0) -> Dict[str, Any]:
|
| if not force:
|
| cached = cache.get_auth()
|
| if cached:
|
| return cached
|
|
|
| try:
|
| cid = await get_cid(force=(retry_count > 0))
|
|
|
| login_url = Config.get_login_url(cid)
|
|
|
| async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
|
| response = await client.get(login_url)
|
| response.raise_for_status()
|
| data = response.json()
|
|
|
| if data.get('code') != 'OK':
|
| error_msg = data.get('message', 'Unknown error')
|
|
|
| if 'cid' in error_msg.lower() and retry_count < 2:
|
| return await get_auth(force=True, retry_count=retry_count + 1)
|
|
|
| raise ValueError(f"Login failed: {error_msg}")
|
|
|
| product_config = json.loads(data.get('product_config', '{}'))
|
|
|
| auth = {
|
| 'access_token': data['access_token'],
|
| 'vms_host': product_config['vms_host'].rstrip('/'),
|
| 'vms_uid': product_config['vms_uid']
|
| }
|
|
|
| if not all(auth.values()):
|
| raise ValueError("Incomplete auth data")
|
|
|
| cache.set_auth(auth)
|
| return auth
|
|
|
| except Exception as e:
|
| if cache.auth and retry_count == 0:
|
| return cache.auth
|
| raise e
|
|
|
| async def get_channels(auth: Dict[str, Any], force: bool = False) -> list:
|
| if not force:
|
| cached = cache.get_channels()
|
| if cached:
|
| return cached
|
|
|
| try:
|
| url = Config.get_list_url(auth['vms_uid'], with_epg=False)
|
|
|
| headers = {
|
| 'Referer': Config.REQUIRED_REFERER,
|
| 'User-Agent': 'Mozilla/5.0'
|
| }
|
|
|
| async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
|
| response = await client.get(url, headers=headers)
|
| response.raise_for_status()
|
| data = response.json()
|
|
|
| channels = [
|
| ch for ch in data.get('result', [])
|
| if ch.get('id') and ch.get('no') and ch.get('name') and ch.get('playpath')
|
| ]
|
|
|
| if not channels:
|
| raise ValueError("No channels found")
|
|
|
| cache.set_channels(channels)
|
| return channels
|
|
|
| except httpx.HTTPStatusError as e:
|
| if e.response.status_code in [401, 403]:
|
| new_auth = await get_auth(force=True)
|
| return await get_channels(new_auth, force=True)
|
| raise e
|
|
|
| except Exception as e:
|
| if cache.channels:
|
| return cache.channels
|
| raise e
|
|
|
| async def fetch_epg(vid: str, date: str, auth: dict, retry_count: int = 0) -> list:
|
| """获取EPG数据,优先从缓存读取"""
|
|
|
| cached = cache.get_epg(vid, date)
|
| if cached is not None:
|
| return cached
|
|
|
|
|
| try:
|
| url = Config.get_epg_url(auth['vms_uid'], vid)
|
|
|
| headers = {
|
| 'Referer': Config.REQUIRED_REFERER,
|
| 'User-Agent': 'Mozilla/5.0'
|
| }
|
|
|
| async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
|
| response = await client.get(url, headers=headers)
|
|
|
| if response.status_code in [401, 403] and retry_count < 2:
|
| new_auth = await get_auth(force=True)
|
| return await fetch_epg(vid, date, new_auth, retry_count + 1)
|
|
|
| response.raise_for_status()
|
| data = response.json()
|
|
|
| if not data.get('result') or not data['result'][0].get('record_epg'):
|
|
|
| cache.set_epg(vid, date, [])
|
| return []
|
|
|
| full_epg = json.loads(data['result'][0]['record_epg'])
|
|
|
|
|
| processed_epg = []
|
| for i, program in enumerate(full_epg):
|
| if not program.get('time'):
|
| continue
|
|
|
| if 'time_end' not in program or not program['time_end']:
|
| if i + 1 < len(full_epg) and full_epg[i + 1].get('time'):
|
| program['time_end'] = full_epg[i + 1]['time']
|
| else:
|
| continue
|
|
|
| processed_epg.append(program)
|
|
|
|
|
| daily_epg = {}
|
| for program in processed_epg:
|
| dt = datetime.fromtimestamp(program['time'])
|
| date_str = get_jst_date(dt)
|
|
|
| if date_str not in daily_epg:
|
| daily_epg[date_str] = []
|
| daily_epg[date_str].append(program)
|
|
|
|
|
| for d, programs in daily_epg.items():
|
| sorted_programs = sorted(programs, key=lambda x: x['time'])
|
| cache.set_epg(vid, d, sorted_programs)
|
|
|
|
|
| result = daily_epg.get(date, [])
|
| if result:
|
| return sorted(result, key=lambda x: x['time'])
|
| else:
|
|
|
| if date not in daily_epg:
|
| cache.set_epg(vid, date, [])
|
| return []
|
|
|
| except Exception as e:
|
| raise e
|
|
|
|
|
| async def get_all_epg(auth: Dict[str, Any], force: bool = False) -> Dict[str, list]:
|
| """获取所有频道的EPG数据,优先使用缓存"""
|
|
|
| if not force:
|
| cached = cache.get_epg('_all_', 'full')
|
| if cached:
|
| return cached
|
|
|
|
|
| try:
|
| url = Config.get_list_url(auth['vms_uid'], with_epg=True)
|
|
|
| headers = {
|
| 'Referer': Config.REQUIRED_REFERER,
|
| 'User-Agent': 'Mozilla/5.0'
|
| }
|
|
|
| async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
|
| response = await client.get(url, headers=headers)
|
| response.raise_for_status()
|
| data = response.json()
|
|
|
| result = {}
|
|
|
| for channel in data.get('result', []):
|
| channel_id = channel.get('id')
|
| record_epg = channel.get('record_epg')
|
|
|
| if not channel_id:
|
| continue
|
|
|
| if not record_epg:
|
| result[channel_id] = []
|
| continue
|
|
|
| try:
|
| epg_list = json.loads(record_epg)
|
|
|
| processed_programs = []
|
| for i, program in enumerate(epg_list):
|
| if not program.get('time'):
|
| continue
|
|
|
| if 'time_end' not in program or not program['time_end']:
|
| if i + 1 < len(epg_list) and epg_list[i + 1].get('time'):
|
| program['time_end'] = epg_list[i + 1]['time']
|
| else:
|
| continue
|
|
|
| processed_programs.append(program)
|
|
|
|
|
| daily_epg = {}
|
| for program in processed_programs:
|
| dt = datetime.fromtimestamp(program['time'])
|
| date_str = get_jst_date(dt)
|
|
|
| if date_str not in daily_epg:
|
| daily_epg[date_str] = []
|
| daily_epg[date_str].append(program)
|
|
|
|
|
| for date, programs in daily_epg.items():
|
| sorted_programs = sorted(programs, key=lambda x: x['time'])
|
| cache.set_epg(channel_id, date, sorted_programs)
|
|
|
| result[channel_id] = processed_programs
|
|
|
| except json.JSONDecodeError:
|
| result[channel_id] = []
|
| continue
|
|
|
|
|
| cache.set_epg('_all_', 'full', result)
|
|
|
| return result
|
|
|
| except Exception as e:
|
|
|
| cached = cache.get_epg('_all_', 'full')
|
| if cached:
|
| return cached
|
| return {}
|
|
|
|
|
| def get_jst_date(dt: Optional[datetime] = None) -> str:
|
| if dt is None:
|
| dt = datetime.now()
|
|
|
| jst = timezone(timedelta(hours=9))
|
| jst_time = dt.astimezone(jst)
|
| return jst_time.strftime('%Y-%m-%d')
|
|
|
| def rewrite_m3u8(content: str, current_path: str, worker_base: str) -> str:
|
| lines = content.split('\n')
|
| output = []
|
|
|
| if '?' in current_path:
|
| base_path_part, query_part = current_path.rsplit('?', 1)
|
| base_dir = base_path_part[:base_path_part.rfind('/') + 1]
|
| else:
|
| base_dir = current_path[:current_path.rfind('/') + 1]
|
| query_part = ''
|
|
|
| for line in lines:
|
| trimmed = line.strip()
|
|
|
| if trimmed.startswith('#') or not trimmed:
|
| output.append(line)
|
| continue
|
|
|
| if trimmed.startswith('http://') or trimmed.startswith('https://'):
|
| parsed = urlparse(trimmed)
|
| target_path = parsed.path
|
| if parsed.query:
|
| target_path += f"?{parsed.query}"
|
|
|
| elif trimmed.startswith('/'):
|
| target_path = trimmed
|
|
|
| else:
|
| target_path = base_dir + trimmed
|
|
|
| if '?' not in target_path and query_part:
|
| target_path += f"?{query_part}"
|
|
|
| output.append(worker_base + target_path)
|
|
|
| return '\n'.join(output)
|
|
|
| def extract_playlist_url(content: str, base_url: str) -> Optional[str]:
|
| for line in content.split('\n'):
|
| trimmed = line.strip()
|
|
|
| if not trimmed or trimmed.startswith('#'):
|
| continue
|
|
|
| if trimmed.startswith('http'):
|
| return trimmed
|
|
|
| if trimmed.endswith('.m3u8') or trimmed.endswith('.M3U8'):
|
| parsed = urlparse(base_url)
|
| if trimmed.startswith('/'):
|
| return f"{parsed.scheme}://{parsed.netloc}{trimmed}"
|
| else:
|
| base_path = base_url[:base_url.rfind('/') + 1]
|
| return base_path + trimmed
|
|
|
| return None |