Files
RUL-Mamba/assistant.py
T
eason 79db6e5c96 feat: 添加RUL-Mamba模型及相关组件
新增锂电池剩余使用寿命预测模型RUL-Mamba,包含以下主要组件:
1. 添加Mamba模块作为核心时序建模组件
2. 实现特征注意力网络(FAN)和门控残差网络(GRN)
3. 新增数据预处理和归一化层
4. 添加模型训练和评估脚本
5. 补充README文档说明使用方法
6. 添加可视化辅助工具Helper_Plot.py
7. 实现多种时间序列处理层(Embedding、AutoCorrelation等)
8. 添加配置文件requirements.txt
9. 补充测试数据集TJU battery data
2026-01-09 08:53:50 +08:00

47 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
import torch
import shutil
import numpy as np
import time
import copy
import os
import argparse
import subprocess
import torch.backends.cudnn as cudnn
import glob
import logging
import sys
import yaml
import json
# 寻找现存最大的显卡编号
def get_gpus_memory_info():
"""Get the maximum free usage memory of gpu"""
rst = subprocess.run('nvidia-smi -q -d Memory', stdout=subprocess.PIPE, shell=True).stdout.decode('utf-8')
rst = rst.strip().split('\n')
memory_available = [int(line.split(':')[1].split(' ')[1]) for line in rst if 'Free' in line][::2]
id = int(np.argmax(memory_available))
return id, memory_available
# 设置random, numpy torchcpu和cuda)的seed
def set_seed(seed):
"""
set seed of numpy and torch
:param seed:
:return:
"""
if seed is None:
seed = np.random.randint(1e6)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现。
torch.manual_seed(seed) # 为CPU设置随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
return seed