2119 字
11 分钟
连续思维机(CTM)本地部署与使用教程:Sakana AI 新型 AI 模型架构实践

Continuous Thought Machines 是什么#

Continuous Thought Machines(CTM,连续思维机)是日本 AI 初创公司 Sakana AI 提出的一种新型神经网络架构,由 “Transformer 八子” 之一的 Llion Jones 联合创立团队开发。论文发表于 2025 年 5 月(arXiv:2505.05522),并被 NeurIPS 2025 接收为 Spotlight 论文。

CTM 的核心思想是重新引入神经时序与同步机制——这些在标准架构(如 Transformer)中被抽象掉的关键生物学特征。它让神经网络不再是”一步到位”的静态计算,而是沿着一条内部的”时间轴”逐步推理,像人脑一样”思考”。

核心创新:

特性说明
神经元级时序处理每个神经元拥有私有权重参数,处理自身接收到的信号历史,而非静态激活函数
神经同步作为隐层表示利用神经元同步矩阵直接编码信息、调节注意力并生成输出
内部 “Tick” 维度独立于输入数据的内部时间轴,支持迭代推理和自适应计算
自适应计算简单任务自动提前停止,复杂任务持续推理,按需分配算力
开源协议Apache-2.0,商用友好

与传统架构的区别#

对比维度Transformer / CNNCTM
计算方式固定层数,单次前向传播沿内部时间轴逐步展开
神经元静态激活函数(ReLU、GELU)每个神经元拥有私有权重 + 历史记录
推理方式一次前馈逐步推理,类似人类”思考”过程
计算量任务固定自适应——简单早停,困难继续
位置编码需要显式位置编码时序自然涌现
可解释性有限高——可实时观察注意力随时间变化
校准性通常需要温度缩放天然良好的置信度校准

CTM 不以刷榜为目标,而是代表一种哲学上的转变——从追求基准分数的极致优化,转向打造更接近生物大脑、具备持续思考能力的 AI 系统。

系统要求#

项目最低要求推荐配置
Python>= 3.103.12
PyTorch>= 2.02.5+
CUDA>= 11.812.4+
GPU 显存~8 GB(CIFAR)24 GB+(ImageNet 训练)
磁盘空间~10 GB50 GB+(含 ImageNet 数据集)

安装#

环境准备#

建议使用 Conda 创建独立环境:

Terminal window
# 创建并激活虚拟环境
conda create --name=ctm python=3.12
conda activate ctm
# 或者使用 venv
python -m venv ctm_env
source ctm_env/bin/activate

克隆代码库#

Terminal window
git clone https://github.com/SakanaAI/continuous-thought-machines.git
cd continuous-thought-machines

安装依赖#

Terminal window
# 先安装 PyTorch(根据 CUDA 版本选择)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
# 再安装其余依赖
pip install -r requirements.txt

如果遇到 CUDA 版本不匹配,可先卸载 PyTorch 再按需安装:pip uninstall torch torchvision,然后从 pytorch.org 获取正确版本。

验证安装#

Terminal window
# 检查 Python 版本
python --version # 应 >= 3.10
# 检查 CUDA 是否可用
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')"

项目结构#

了解代码库的组织方式,有助于后续操作:

continuous-thought-machines/
├── tasks/ # 各任务的训练与评估脚本
│ ├── image_classification/ # 图像分类(CIFAR / ImageNet)
│ │ ├── train.py # 训练入口
│ │ ├── plotting.py # 可视化
│ │ └── analysis/ # 分析评估
│ ├── mazes/ # 2D 迷宫求解
│ │ ├── train.py
│ │ └── analysis/
│ ├── sort/ # 排序任务
│ ├── parity/ # 奇偶校验
│ ├── qamnist/ # MNIST 问答
│ └── rl/ # 强化学习(CartPole, Acrobot 等)
├── models/ # CTM 模型架构实现
│ ├── ctm.py # CTM 主模型
│ ├── modules.py # 神经元级模型及突触 UNet
│ ├── utils.py # 同步机制工具函数
│ ├── resnet.py # ResNet 特征提取器封装
│ ├── ff.py / lstm.py # 基线模型
├── utils/ # 通用工具
│ ├── housekeeping.py
│ ├── losses.py
│ └── schedulers.py
├── data/ # 数据集存放目录(自动下载)
├── checkpoints/ # 模型检查点保存目录
└── requirements.txt # 依赖列表

运行训练任务#

所有任务都通过 python -m 方式从项目根目录运行。以下是各任务的详细说明。

图像分类(CIFAR-10 / CIFAR-100)#

适合入门验证 CTM 效果,数据量小,训练快:

Terminal window
# CIFAR-10 训练(小型实验)
python -m tasks.image_classification.train \
--dataset cifar10 \
--batch_size 128 \
--epochs 50
# CIFAR-100 训练(消融实验)
python -m tasks.image_classification.train \
--dataset cifar100 \
--batch_size 128 \
--epochs 50

图像分类(ImageNet-1K)#

需要较大的 GPU 显存和较长的训练时间:

Terminal window
python -m tasks.image_classification.train \
--dataset imagenet \
--batch_size 256 \
--temporal_depth 8

关键训练参数:

参数说明建议值
--temporal_depth时序处理层数8(默认)
--batch_size批次大小根据显存调整
--dataset数据集cifar10 / cifar100 / imagenet
--epochs训练轮数50–200(依任务而定)

2D 迷宫求解#

CTM 的标志性任务——模型学会逐”tick”追踪路径,即使从未训练过位置编码:

Terminal window
python -m tasks.mazes.train

迷宫相关参数可在 data/custom_datasets.py 中找到,包括 maze_route_length 等配置。

排序任务#

Terminal window
python -m tasks.sort.train

奇偶校验(Parity)#

Terminal window
python -m tasks.parity.train

强化学习(RL)#

Terminal window
# CartPole
python -m tasks.rl.train
# 或使用预配置脚本
bash tasks/rl/scripts/cartpole/run.sh

各 RL 环境的脚本位于 tasks/rl/scripts/ 下,包括 4rooms/acrobot/cartpole/

使用预训练模型#

Sakana AI 提供了在 ImageNet-1K 上预训练的模型权重,托管在 Hugging Face:

下载预训练权重#

# 方式一:Hugging Face 自动下载
from huggingface_hub import snapshot_download
model_path = snapshot_download("SakanaAI/ctm-imagenet")
print(f"模型下载至: {model_path}")
Terminal window
# 方式二:命令行下载
pip install huggingface_hub
huggingface-cli download SakanaAI/ctm-imagenet --local-dir ./pretrained/ctm-imagenet

加载与评估#

将预训练权重放入 checkpoints/ 目录后,使用分析脚本进行评估:

Terminal window
# ImageNet 分析评估
python -m tasks.image_classification.analysis.run_imagenet_analysis

可视化分析#

CTM 的可解释性是其最大亮点之一。分析脚本会生成以下可视化内容:

  • 神经元激活热力图(按图像类别分组)
  • 注意力在图像上的移动轨迹(GIF 动画)
  • UMAP 降维后的神经元活动模式
  • 低频行波现象的可视化
# 在 Python 中手动可视化
from tasks.image_classification.plotting import plot_activation_map
plot_activation_map(
checkpoint_path='checkpoints/imagenet/model.pth'
)

论文中展示了 CTM 识别大猩猩时注意力从”眼睛→鼻子→嘴巴”逐步移动的过程,类似于人类的观察方式。

调优建议#

参数说明建议
temporal_depth时序处理层数ImageNet 建议 8;小数据集可降至 4 以加速
synch_decay同步衰减系数0.9–0.99,数值越大同步信号保留越久
phase_lr相位学习率建议 1e-4
batch_size批次大小越大越稳定,但受显存限制
内部 Tick 数每次推理的循环步数默认 10,简单任务可提前停止

常见问题#

Q:CUDA out of memory

按优先级尝试以下方案:

  1. 降低 batch_size(如 256 → 128 → 64)
  2. 降低 temporal_depth(如 8 → 4)
  3. 使用更小的数据集先做验证(CIFAR-10 替代 ImageNet)
  4. 启用梯度累积(需手动修改训练脚本)

Q:WSL2 中检测不到 GPU

确保 Windows 已安装最新的 NVIDIA 驱动,并在 WSL2 内运行 nvidia-smi 验证:

Terminal window
nvidia-smi # 应显示 GPU 信息

如果未识别,请参考 NVIDIA WSL2 用户指南

Q:国内 Hugging Face 下载缓慢

设置 HF 镜像:

Terminal window
export HF_ENDPOINT=https://hf-mirror.com

Q:训练时损失不下降

  • 检查学习率是否合适,建议从 1e-4 开始
  • 确认数据集正确加载(可打印几张样本验证)
  • 尝试减少 temporal_depth,过深的时序层在小数据集上可能过拟合

Q:CTM 能用在语言模型上吗?

论文主要验证了图像、迷宫、排序、强化学习等任务。语言建模是论文中明确指出的未来方向之一,但目前官方仓库尚未提供 NLP 任务的训练脚本。

局限与展望#

当前局限#

  1. 计算成本高:顺序推理导致训练时间较长,神经元级模型带来额外参数开销
  2. 准确率差距:ImageNet 上 72.47% 的 Top-1 准确率尚未达到 ViT / ConvNeXt 等 SOTA 水平
  3. 工具生态不成熟:现有分析和调试工具主要针对静态模型
  4. “过度思考”风险:在部分任务上可能出现循环绕圈或误差累积

未来方向#

  • 更大规模、更高维度的同步表示
  • 语言建模与视频理解等序列任务
  • 借鉴更多生物机制(赫布学习、STDP 等)
  • 更高效的训练策略

参考资源#

连续思维机(CTM)本地部署与使用教程:Sakana AI 新型 AI 模型架构实践
https://blog.syomega.top/posts/ctm-local-setup-guide/
作者
酱w
发布于
2026-05-19
许可协议
CC BY-NC-SA 4.0