NewContest / model3pLOCAL.py
ffzeroHua's picture
Upload 7 files
201fbbb verified
import json
import gzip
import torch
import pathlib
import requests
import traceback
import numpy as np
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.distributions import Normal, Categorical
from typing import *
from functools import partial
from itertools import permutations
try:
from libriichi3p.mjai import Bot
from libriichi3p.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
except:
import importlib.util
import sys
import os
# ⚠️ 这里必须填入你在 Colab 中的绝对路径!
# 假设你的文件在云盘的 MahjongTest 文件夹下,名字叫 libriichi3p.so
# 如果你的文件叫别的名字,或者在别的文件夹,请务必修改这行路径
SO_FILE_PATH = "/content/drive/MyDrive/MahjongTest/libriichi3p.so"
# 1. 检查文件到底存不存在
if not os.path.exists(SO_FILE_PATH):
print(f"❌ 致命错误:在路径 {SO_FILE_PATH} 下根本找不到文件!请检查路径拼写。")
else:
print(f"✅ 找到文件: {SO_FILE_PATH},正在尝试强行加载...")
try:
# 2. 根据绝对路径创建模块加载规范 (spec)
# 第一个参数是你想给它起的名字(供 Python 内部识别),第二个参数是文件路径
spec = importlib.util.spec_from_file_location("libriichi3p", SO_FILE_PATH)
# 3. 实例化模块
libriichi3p_module = importlib.util.module_from_spec(spec)
# 4. 注册到系统的模块字典里 (非常重要!这样后续其他文件 import libriichi3p 就能直接用)
sys.modules["libriichi3p"] = libriichi3p_module
# 5. 执行底层代码加载
spec.loader.exec_module(libriichi3p_module)
print("🎉 强行导入成功!现在可以在代码里正常使用了。")
except Exception as e:
print(f"❌ 导入失败,暴露出真实报错: {e}")
# ========== Online Server =========== #
OT_REQUEST_TIMEOUT = 2
ot_settings = {
"server": "http://example.com",
"online": False,
"api_key": "example_api_key",
}
is_online = False
def online_settings_init():
global ot_settings
# Check if the file exists
if (pathlib.Path(__file__).parent / 'ot_settings.json').exists():
with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f:
ot_settings = json.load(f)
online_settings_init()
# ==================================== #
class ChannelAttention(nn.Module):
def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
super().__init__()
self.shared_mlp = nn.Sequential(
nn.Linear(channels, channels // ratio, bias=bias),
actv_builder(),
nn.Linear(channels // ratio, channels, bias=bias),
)
if bias:
for mod in self.modules():
if isinstance(mod, nn.Linear):
nn.init.constant_(mod.bias, 0)
def forward(self, x: Tensor):
avg_out = self.shared_mlp(x.mean(-1))
max_out = self.shared_mlp(x.amax(-1))
weight = (avg_out + max_out).sigmoid()
x = weight.unsqueeze(-1) * x
return x
class ResBlock(nn.Module):
def __init__(
self,
channels,
*,
norm_builder = nn.Identity,
actv_builder = nn.ReLU,
pre_actv = False,
):
super().__init__()
self.pre_actv = pre_actv
if pre_actv:
self.res_unit = nn.Sequential(
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
)
else:
self.res_unit = nn.Sequential(
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
)
self.actv = actv_builder()
self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
def forward(self, x):
out = self.res_unit(x)
out = self.ca(out)
out = out + x
if not self.pre_actv:
out = self.actv(out)
return out
class ResNet(nn.Module):
def __init__(
self,
in_channels,
conv_channels,
num_blocks,
*,
norm_builder = nn.Identity,
actv_builder = nn.ReLU,
pre_actv = False,
):
super().__init__()
blocks = []
for _ in range(num_blocks):
blocks.append(ResBlock(
conv_channels,
norm_builder = norm_builder,
actv_builder = actv_builder,
pre_actv = pre_actv,
))
layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
if pre_actv:
layers += [*blocks, norm_builder(), actv_builder()]
else:
layers += [norm_builder(), actv_builder(), *blocks]
layers += [
nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
actv_builder(),
nn.Flatten(),
nn.Linear(32 * 34, 1024),
]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class Brain(nn.Module):
def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
super().__init__()
self.is_oracle = is_oracle
self.version = version
in_channels = obs_shape(version)[0]
if is_oracle:
in_channels += oracle_obs_shape(version)[0]
norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
actv_builder = partial(nn.Mish, inplace=True)
pre_actv = True
match version:
case 1:
actv_builder = partial(nn.ReLU, inplace=True)
pre_actv = False
self.latent_net = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
)
self.mu_head = nn.Linear(512, 512)
self.logsig_head = nn.Linear(512, 512)
case 2:
pass
case 3 | 4:
norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
case _:
raise ValueError(f'Unexpected version {self.version}')
self.encoder = ResNet(
in_channels = in_channels,
conv_channels = conv_channels,
num_blocks = num_blocks,
norm_builder = norm_builder,
actv_builder = actv_builder,
pre_actv = pre_actv,
)
self.actv = actv_builder()
# always use EMA or CMA when True
self._freeze_bn = False
def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
if self.is_oracle:
assert invisible_obs is not None
obs = torch.cat((obs, invisible_obs), dim=1)
phi = self.encoder(obs)
phi = F.dropout(phi, p=0.1, training=self.training)
match self.version:
case 1:
latent_out = self.latent_net(phi)
mu = self.mu_head(latent_out)
logsig = self.logsig_head(latent_out)
return mu, logsig
case 2 | 3 | 4:
return self.actv(phi)
case _:
raise ValueError(f'Unexpected version {self.version}')
def train(self, mode=True):
super().train(mode)
if self._freeze_bn:
for mod in self.modules():
if isinstance(mod, nn.BatchNorm1d):
mod.eval()
# I don't think this benefits
# module.requires_grad_(False)
return self
def reset_running_stats(self):
for mod in self.modules():
if isinstance(mod, nn.BatchNorm1d):
mod.reset_running_stats()
def freeze_bn(self, value: bool):
self._freeze_bn = value
return self.train(self.training)
class AuxNet(nn.Module):
def __init__(self, dims=None):
super().__init__()
self.dims = dims
self.net = nn.Linear(1024, sum(dims), bias=False)
def forward(self, x):
return self.net(x).split(self.dims, dim=-1)
class DQN(nn.Module):
def __init__(self, *, version=1):
super().__init__()
self.version = version
match version:
case 1:
self.v_head = nn.Linear(512, 1)
self.a_head = nn.Linear(512, ACTION_SPACE)
case 2 | 3:
hidden_size = 512 if version == 2 else 256
self.v_head = nn.Sequential(
nn.Linear(1024, hidden_size),
nn.Mish(inplace=True),
nn.Linear(hidden_size, 1),
)
self.a_head = nn.Sequential(
nn.Linear(1024, hidden_size),
nn.Mish(inplace=True),
nn.Linear(hidden_size, ACTION_SPACE),
)
case 4:
self.net = nn.Linear(1024, 1 + ACTION_SPACE)
nn.init.constant_(self.net.bias, 0)
def forward(self, phi, mask):
if self.version == 4:
v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
else:
v = self.v_head(phi)
a = self.a_head(phi)
a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
mask_sum = mask.sum(-1, keepdim=True)
a_mean = a_sum / mask_sum
q = (v + a - a_mean).masked_fill(~mask, -1e9)
return q
class MortalEngine:
def __init__(
self,
brain,
dqn,
is_oracle,
version,
device = None,
stochastic_latent = False,
enable_amp = False,
enable_quick_eval = True,
enable_rule_based_agari_guard = False,
name = 'NoName',
boltzmann_epsilon = 0,
boltzmann_temp = 1,
top_p = 1,
):
self.engine_type = 'mortal'
self.device = device or torch.device('cpu')
assert isinstance(self.device, torch.device)
self.brain = brain.to(self.device).eval()
self.dqn = dqn.to(self.device).eval()
self.is_oracle = is_oracle
self.version = version
self.stochastic_latent = stochastic_latent
self.enable_amp = enable_amp
self.enable_quick_eval = enable_quick_eval
self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
self.name = name
self.boltzmann_epsilon = boltzmann_epsilon
self.boltzmann_temp = boltzmann_temp
self.top_p = top_p
def react_batch(self, obs, masks, invisible_obs):
# ========== Online Server =========== #
global ot_settings, is_online
# print('Reacting Batch')
if ot_settings['online']:
try:
list_obs = [o.tolist() for o in obs]
list_masks = [m.tolist() for m in masks]
post_data = {
'obs': list_obs,
'masks': list_masks,
}
data = json.dumps(post_data, separators=(',', ':'))
compressed_data = gzip.compress(data.encode('utf-8'))
headers = {
'Authorization': ot_settings['api_key'],
'Content-Encoding': 'gzip',
}
r = requests.post(
f'{ot_settings["server"]}/react_batch_3p',
headers=headers,
data=compressed_data,
timeout=OT_REQUEST_TIMEOUT
)
assert r.status_code == 200
is_online = True
r_json = r.json()
return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy']
except:
is_online = False
pass
# ==================================== #
try:
with (
torch.autocast(self.device.type, enabled=self.enable_amp),
torch.inference_mode(),
):
return self._react_batch(obs, masks, invisible_obs)
except Exception as ex:
raise Exception(f'{ex}\n{traceback.format_exc()}')
def _react_batch(self, obs, masks, invisible_obs):
obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
invisible_obs = None
if self.is_oracle:
invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
batch_size = obs.shape[0]
match self.version:
case 1:
mu, logsig = self.brain(obs, invisible_obs)
if self.stochastic_latent:
latent = Normal(mu, logsig.exp() + 1e-6).sample()
else:
latent = mu
q_out = self.dqn(latent, masks)
case 2 | 3 | 4:
phi = self.brain(obs)
q_out = self.dqn(phi, masks)
if self.boltzmann_epsilon > 0:
is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
sampled = sample_top_p(logits, self.top_p)
actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
else:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
actions = q_out.argmax(-1)
return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
def sample_top_p(logits, p):
if p >= 1:
return Categorical(logits=logits).sample()
if p <= 0:
return logits.argmax(-1)
probs = logits.softmax(-1)
probs_sort, probs_idx = probs.sort(-1, descending=True)
probs_sum = probs_sort.cumsum(-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.
sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
return sampled
def load_model(seat: int, model: str) -> Bot:
# check if GPU is available
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# latest binary model
if model == None:
model = 'Elite4zWeightedBest5.pth'
model = str(model).split('?')[0]
control_state_file = model
print(control_state_file, 'loaded')
# Get the path of control_state_file = current directory / control_state_file
control_state_file = pathlib.Path(__file__).parent / control_state_file
state = torch.load(control_state_file, map_location=device)
mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval()
dqn = DQN(version=state['config']['control']['version']).eval()
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
engine = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = state['config']['control']['version'],
device = device,
enable_amp = False,
enable_quick_eval = False,
enable_rule_based_agari_guard = True,
name = 'mortal',
top_p = 1,
)
bot = Bot(engine, seat)
return bot