Continuous Thought Machines 是什么
Continuous Thought Machines(CTM,连续思维机)是日本 AI 初创公司 Sakana AI 提出的一种新型神经网络架构,由 “Transformer 八子” 之一的 Llion Jones 联合创立团队开发。论文发表于 2025 年 5 月(arXiv:2505.05522),并被 NeurIPS 2025 接收为 Spotlight 论文。
CTM 的核心思想是重新引入神经时序与同步机制——这些在标准架构(如 Transformer)中被抽象掉的关键生物学特征。它让神经网络不再是”一步到位”的静态计算,而是沿着一条内部的”时间轴”逐步推理,像人脑一样”思考”。
核心创新:
| 特性 | 说明 |
|---|---|
| 神经元级时序处理 | 每个神经元拥有私有权重参数,处理自身接收到的信号历史,而非静态激活函数 |
| 神经同步作为隐层表示 | 利用神经元同步矩阵直接编码信息、调节注意力并生成输出 |
| 内部 “Tick” 维度 | 独立于输入数据的内部时间轴,支持迭代推理和自适应计算 |
| 自适应计算 | 简单任务自动提前停止,复杂任务持续推理,按需分配算力 |
| 开源协议 | Apache-2.0,商用友好 |
与传统架构的区别
| 对比维度 | Transformer / CNN | CTM |
|---|---|---|
| 计算方式 | 固定层数,单次前向传播 | 沿内部时间轴逐步展开 |
| 神经元 | 静态激活函数(ReLU、GELU) | 每个神经元拥有私有权重 + 历史记录 |
| 推理方式 | 一次前馈 | 逐步推理,类似人类”思考”过程 |
| 计算量 | 任务固定 | 自适应——简单早停,困难继续 |
| 位置编码 | 需要显式位置编码 | 时序自然涌现 |
| 可解释性 | 有限 | 高——可实时观察注意力随时间变化 |
| 校准性 | 通常需要温度缩放 | 天然良好的置信度校准 |
CTM 不以刷榜为目标,而是代表一种哲学上的转变——从追求基准分数的极致优化,转向打造更接近生物大脑、具备持续思考能力的 AI 系统。
系统要求
| 项目 | 最低要求 | 推荐配置 |
|---|---|---|
| Python | >= 3.10 | 3.12 |
| PyTorch | >= 2.0 | 2.5+ |
| CUDA | >= 11.8 | 12.4+ |
| GPU 显存 | ~8 GB(CIFAR) | 24 GB+(ImageNet 训练) |
| 磁盘空间 | ~10 GB | 50 GB+(含 ImageNet 数据集) |
安装
环境准备
建议使用 Conda 创建独立环境:
# 创建并激活虚拟环境conda create --name=ctm python=3.12conda activate ctm
# 或者使用 venvpython -m venv ctm_envsource ctm_env/bin/activate克隆代码库
git clone https://github.com/SakanaAI/continuous-thought-machines.gitcd continuous-thought-machines安装依赖
# 先安装 PyTorch(根据 CUDA 版本选择)pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
# 再安装其余依赖pip install -r requirements.txt如果遇到 CUDA 版本不匹配,可先卸载 PyTorch 再按需安装:
pip uninstall torch torchvision,然后从 pytorch.org 获取正确版本。
验证安装
# 检查 Python 版本python --version # 应 >= 3.10
# 检查 CUDA 是否可用python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')"项目结构
了解代码库的组织方式,有助于后续操作:
continuous-thought-machines/├── tasks/ # 各任务的训练与评估脚本│ ├── image_classification/ # 图像分类(CIFAR / ImageNet)│ │ ├── train.py # 训练入口│ │ ├── plotting.py # 可视化│ │ └── analysis/ # 分析评估│ ├── mazes/ # 2D 迷宫求解│ │ ├── train.py│ │ └── analysis/│ ├── sort/ # 排序任务│ ├── parity/ # 奇偶校验│ ├── qamnist/ # MNIST 问答│ └── rl/ # 强化学习(CartPole, Acrobot 等)├── models/ # CTM 模型架构实现│ ├── ctm.py # CTM 主模型│ ├── modules.py # 神经元级模型及突触 UNet│ ├── utils.py # 同步机制工具函数│ ├── resnet.py # ResNet 特征提取器封装│ ├── ff.py / lstm.py # 基线模型├── utils/ # 通用工具│ ├── housekeeping.py│ ├── losses.py│ └── schedulers.py├── data/ # 数据集存放目录(自动下载)├── checkpoints/ # 模型检查点保存目录└── requirements.txt # 依赖列表运行训练任务
所有任务都通过 python -m 方式从项目根目录运行。以下是各任务的详细说明。
图像分类(CIFAR-10 / CIFAR-100)
适合入门验证 CTM 效果,数据量小,训练快:
# CIFAR-10 训练(小型实验)python -m tasks.image_classification.train \ --dataset cifar10 \ --batch_size 128 \ --epochs 50
# CIFAR-100 训练(消融实验)python -m tasks.image_classification.train \ --dataset cifar100 \ --batch_size 128 \ --epochs 50图像分类(ImageNet-1K)
需要较大的 GPU 显存和较长的训练时间:
python -m tasks.image_classification.train \ --dataset imagenet \ --batch_size 256 \ --temporal_depth 8关键训练参数:
| 参数 | 说明 | 建议值 |
|---|---|---|
--temporal_depth | 时序处理层数 | 8(默认) |
--batch_size | 批次大小 | 根据显存调整 |
--dataset | 数据集 | cifar10 / cifar100 / imagenet |
--epochs | 训练轮数 | 50–200(依任务而定) |
2D 迷宫求解
CTM 的标志性任务——模型学会逐”tick”追踪路径,即使从未训练过位置编码:
python -m tasks.mazes.train迷宫相关参数可在 data/custom_datasets.py 中找到,包括 maze_route_length 等配置。
排序任务
python -m tasks.sort.train奇偶校验(Parity)
python -m tasks.parity.train强化学习(RL)
# CartPolepython -m tasks.rl.train
# 或使用预配置脚本bash tasks/rl/scripts/cartpole/run.sh各 RL 环境的脚本位于
tasks/rl/scripts/下,包括4rooms/、acrobot/、cartpole/。
使用预训练模型
Sakana AI 提供了在 ImageNet-1K 上预训练的模型权重,托管在 Hugging Face:
下载预训练权重
# 方式一:Hugging Face 自动下载from huggingface_hub import snapshot_download
model_path = snapshot_download("SakanaAI/ctm-imagenet")print(f"模型下载至: {model_path}")# 方式二:命令行下载pip install huggingface_hubhuggingface-cli download SakanaAI/ctm-imagenet --local-dir ./pretrained/ctm-imagenet加载与评估
将预训练权重放入 checkpoints/ 目录后,使用分析脚本进行评估:
# ImageNet 分析评估python -m tasks.image_classification.analysis.run_imagenet_analysis可视化分析
CTM 的可解释性是其最大亮点之一。分析脚本会生成以下可视化内容:
- 神经元激活热力图(按图像类别分组)
- 注意力在图像上的移动轨迹(GIF 动画)
- UMAP 降维后的神经元活动模式
- 低频行波现象的可视化
# 在 Python 中手动可视化from tasks.image_classification.plotting import plot_activation_map
plot_activation_map( checkpoint_path='checkpoints/imagenet/model.pth')论文中展示了 CTM 识别大猩猩时注意力从”眼睛→鼻子→嘴巴”逐步移动的过程,类似于人类的观察方式。
调优建议
| 参数 | 说明 | 建议 |
|---|---|---|
temporal_depth | 时序处理层数 | ImageNet 建议 8;小数据集可降至 4 以加速 |
synch_decay | 同步衰减系数 | 0.9–0.99,数值越大同步信号保留越久 |
phase_lr | 相位学习率 | 建议 1e-4 |
batch_size | 批次大小 | 越大越稳定,但受显存限制 |
| 内部 Tick 数 | 每次推理的循环步数 | 默认 10,简单任务可提前停止 |
常见问题
Q:CUDA out of memory
按优先级尝试以下方案:
- 降低
batch_size(如 256 → 128 → 64) - 降低
temporal_depth(如 8 → 4) - 使用更小的数据集先做验证(CIFAR-10 替代 ImageNet)
- 启用梯度累积(需手动修改训练脚本)
Q:WSL2 中检测不到 GPU
确保 Windows 已安装最新的 NVIDIA 驱动,并在 WSL2 内运行 nvidia-smi 验证:
nvidia-smi # 应显示 GPU 信息如果未识别,请参考 NVIDIA WSL2 用户指南。
Q:国内 Hugging Face 下载缓慢
设置 HF 镜像:
export HF_ENDPOINT=https://hf-mirror.comQ:训练时损失不下降
- 检查学习率是否合适,建议从 1e-4 开始
- 确认数据集正确加载(可打印几张样本验证)
- 尝试减少
temporal_depth,过深的时序层在小数据集上可能过拟合
Q:CTM 能用在语言模型上吗?
论文主要验证了图像、迷宫、排序、强化学习等任务。语言建模是论文中明确指出的未来方向之一,但目前官方仓库尚未提供 NLP 任务的训练脚本。
局限与展望
当前局限
- 计算成本高:顺序推理导致训练时间较长,神经元级模型带来额外参数开销
- 准确率差距:ImageNet 上 72.47% 的 Top-1 准确率尚未达到 ViT / ConvNeXt 等 SOTA 水平
- 工具生态不成熟:现有分析和调试工具主要针对静态模型
- “过度思考”风险:在部分任务上可能出现循环绕圈或误差累积
未来方向
- 更大规模、更高维度的同步表示
- 语言建模与视频理解等序列任务
- 借鉴更多生物机制(赫布学习、STDP 等)
- 更高效的训练策略