import gradio as gr import numpy as np import matplotlib.pyplot as plt from GAN.diffusion import build_model, GaussianDiffusion, DiffusionModel import tensorflow as tf from tensorflow.python.types.core import TensorLike import imageio import tempfile import os import io from PIL import Image EPS = 1e-18 class TSFeatureScaler: """Global time series scaler that scales all features to [0,1] then normalizes to [-1,1]""" def __init__(self) -> None: self.min_val = None self.max_val = None def fit(self, X: TensorLike) -> "TSFeatureScaler": """ Fit scaler to data Args: X: Input tensor of shape [N, T, D] (N: samples, T: timesteps, D: features) """ # 计算整个数据集的全局最大最小值 self.min_val = np.min(X) self.max_val = np.max(X) return self def transform(self, X: TensorLike) -> TensorLike: """ Transform data in two steps: 1. Scale to [0,1] using min-max scaling 2. Normalize to [-1,1] """ if self.min_val is None or self.max_val is None: raise ValueError("Scaler must be fitted before transform") # 1. 缩放到[0,1] X_scaled = (X - self.min_val) / (self.max_val - self.min_val + EPS) # 2. 归一化到[-1,1] X_normalized = 2.0 * X_scaled - 1.0 return X_normalized def fit_transform(self, X: TensorLike) -> TensorLike: """Fit to data, then transform it""" return self.fit(X).transform(X) def create_animation(frames, fps=1): """将帧列表转换为GIF动画数据""" import tempfile import os temp_dir = tempfile.gettempdir() temp_path = os.path.join(temp_dir, f"temp_{id(frames)}.gif") # 将fps转换为duration (毫秒) #duration = int(1000 / fps) # 1000ms = 1s duration = min(1, 6.55) # 保存为GIF文件,设置循环播放 imageio.mimsave(temp_path, frames, format='GIF', duration=duration, loop=0) # loop=0 表示无限循环 return temp_path # def create_animation(frames, duration=0.1): # """创建GIF动画""" # duration = min(duration, 6.55) # # 数据标准化到[0,1]范围 # frames = np.array(frames) # frames = (frames - frames.min()) / (frames.max() - frames.min()) # # 转换为RGB图像 # frames_rgb = [] # for frame in frames: # # 创建彩色图像 # plt.figure(figsize=(6, 4)) # # 对每个特征分别绘制 # for i in range(frame.shape[-1]): # 遍历最后一个维度(特征维度) # plt.plot(frame[:, i], label=f'Feature {i+1}') # plt.grid(True) # plt.ylim(-0.1, 1.1) # plt.legend() # # 保存到内存 # buf = io.BytesIO() # plt.savefig(buf, format='png') # plt.close() # # 读取图像 # buf.seek(0) # img = Image.open(buf) # frames_rgb.append(np.array(img)) # buf.close() # # 创建临时文件 # with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as temp_file: # temp_path = temp_file.name # # 保存GIF # imageio.mimsave(temp_path, frames_rgb, format='GIF', duration=duration, loop=0) # return temp_path def generate_timeseries(input_file, num_samples=16): try: # 加载数据 real_data = np.load(input_file.name) scaler = TSFeatureScaler() real_data = scaler.fit_transform(real_data) print(f"Loaded data shape: {real_data.shape}") # 确保数据形状正确 expected_shape = (None, 96, 3) if len(real_data.shape) != 3 or real_data.shape[1:] != expected_shape[1:]: return None, None # 创建模型和必要的组件 network = build_model( time_len=96, fea_num=3, d_model=16, n_heads=2, encoder_type='dual' ) ema_network = build_model( time_len=96, fea_num=3, d_model=16, n_heads=2, encoder_type='dual' ) ema_network.set_weights(network.get_weights()) noise_util = GaussianDiffusion(timesteps=10) print("Creating model...") model = DiffusionModel( network=network, ema_network=ema_network, timesteps=10, gdf_util=noise_util, data=real_data[:num_samples] ) # 加载预训练权重 checkpoint_path = "checkpoint/cp.ckpt" print(f"Loading weights from {checkpoint_path}") model.load_weights(checkpoint_path) # 生成加噪过程的动画 print("Generating noising animation...") noise_frames = model.plot_noise_process_app(num_samples) noise_gif = create_animation(noise_frames) # 生成去噪过程的动画 print("Generating denoising animation...") denoise_frames = model.plot_denoise_process_app(num_samples)[1:] denoise_gif = create_animation(denoise_frames) return noise_gif, denoise_gif except Exception as e: import traceback error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, None def update_example_gifs(num_samples): """根据选择的样本数更新示例GIF""" return f"noising_example_{num_samples}.gif", f"denoising_example_{num_samples}.gif" # 创建Gradio界面 with gr.Blocks(title="Wearable Sensors Time-Series Generation") as demo: with gr.Column(elem_id="container"): # Logo gr.Image("logo.webp", elem_id="logo", show_label=False, container=False) # 标题和副标题 gr.Markdown( """ # Wearable Sensors Time-Series Generation