任务介绍
本工作将利用起不同session,不同task,不同subject的神经信号数据,运用Mae + transformer的结构训练出能够解码神经信号的decoder。
摘要与行业现状
脑机接口记录下的神经集群活动蕴含丰富的信息,现有的模型通常是针对单个实验设计的,数据量被设计为单个session内可收集的量,由于数据量的不足,dnn的训练存在过拟合的现象,并且无法处理神经元之间的时空关联性。本任务(ndt2)尝试使用大规模无监督训练出能够提取神经元之间的时空关联的模型,并证明该模型能够利用起跨任务跨会话(session)的神经活动信息。训练得到的模型能够快速微调适应上下文信息。
关于皮质信号:
背后存在着稳定的规律,通过对神经集群活动信号进行主成分分析能够提取出这种稳定的信号。现有的模型能够提取出单任务中背后的规律,本任务(ndt)尝试从跨task跨session中提取出一些规律,从而得到更general的模型。
领域现状:
神经数据集之间基本没有相似性,但训练模型的方法都极为相似,一般都采用MAE作为预训练目标,用transformer(或cnn)作为backbone。脑机接口的数据具有不稳定性,不稳定性来自于信号采集设备的高度敏感性。目前领域内主流的聚合不同数据的方法为stitching。ndt2期望研究随着数据量的增加是否有利于模型的表现,尽管数据之间的相关性会随着数据量的增加而降低。
stitching方法:
大致思路是通过对readin和readout的学习提取神经元信号之间的稳定的子空间,背后主要的算法是em算法,期望最大化算法。但由于实验发现stitching方法并不理想,此处不展开介绍。
模型介绍:
MAE:masked autoencoder,mae简单来说就是给图像加上很高比例的mask,然后设计编解码器来重建图像。大致流程如下图,先把加了mask的图片通过encoder进行编码,得到一堆tokens(蓝色块),然后再把这些tokens加上mask送进decoder重建得到图片。其中,给token加mask的位置与原图像mask的位置相同。加mask的比例非常高(75%左右),而且是完全随机采样的mask,防止模型能够轻易从有规律的图像采样中学到信息。
结合NDT2的设计来具体讲MAE的细节。
先看encoder,这里的图像就是把k条神经元的活动序列堆叠起来变成类似二维图像的形式,encoder的输入是unmasked的patches,也就是图像加了mask后的剩余的可见部分,然后和任务信息和图像位置信息等一起输入encoder进行训练。
decoder处理encoder编码得到的tokens,也就是encoder的输出,同时的decoder也要处理masked的token,也就是mask的位置信息,所有的mask token公用一个记录mask位置的向量。
作者设计了两个解码器,一个解码出神经活动信号,另一个直接解码出生物的行为,可以直接用于实时预测。
右边的图B讲的是数据的相关性与数据量的关系。随着不同task,不同subject,不同session数据集的加入,训练数据量在上升,但数据的相关性也在减弱。
然后编码器和解码器都是采用的vit(vision transformer)作为主结构,transformer能够处理序列信息,vit的做法是把二维图像切割开,分成若干个小块,然后排成一排,为每个小块加入位置信息,如此就把图像转化为了序列信息。然后把每个小块经过线性映射或者flatten操作后输入transformer encoder中,内部结构就是经典的多头注意力机制。
然后是ndt2的一些设计思想。ndt1是直接把单条神经元的活动序列输入transformer进行学习,在不同任务中同一个个神经元的活动模式可能并不相同,所以每次只输入一个神经元进行学习可能会影响效果,但一次把所有的数据都序列化输入模型在计算上是行不通的,所以ndt2的作者就折中,把k条记录叠成二维图像,用vit的思想去处理数据。
接下来讲实验。a图是在sorted data上进行的预训练的效果,b图是unsorted数据集上的结果,sort与否的区别就是,sort是按照对应刺激反射的波形分类后的。先看右边,这里的single-session应该代表的是单个神经元在不同时间段内的预测表现,这里的横杠不是连字符的意思,我的理解是single和multi指的是神经元的数量,session指的是时间,subject是任务类型,task是具体的任务,single-session指的就是单条神经元在不同时间做同一个动作,也就是scratch这个动作时的表现。r平方是统计学习里的决定系数,越大越好,nll是负的对数似然,越小越好。所以说越靠近左上角的点表现越好。可以看到ndt在两种类型的数据集中都是在multi-session上表现最好。也就是
主结构为mae+transformer。
难题1:data tokenization
本文贡献如下:
- 通过三种方法处理异构数据:1.时空attention 2.学习embedding 3.非对称的编解码器设计
- 快速微调
实验介绍
本文主要做了以下三个方面的实验
- 多任务预训练效果展示。
在本实验中使用了不同的数据集的数据进行预训练,测试所用的任务为速度预测,指标为r方和nll。该试验将ndt1+stitching,ndt1不加stiching和ndt2做对比。无论在sort或者unsortdata上做实验,ndt2的效果都显著优于ndt1。对于unsorteddata,ndt1+stitching确实有显著的效果提升。
- 数据量与模型性能的关系
- 解决下游任务
名词解释
Session(会话)
session 代表了一次完整的试验过程。它可以包括了多个不同的任务(tasks),并且在整个过程中,实验者会尽量保持相同的环境和条件,以保证数据的一致性和可靠性。例如,在脑电图(EEG)实验中,一次 session 可能包括了一系列的刺激任务,如观看视频、听音乐等,而在每个任务中,都会记录下大脑的电信号。这样的 session 通常会持续几分钟到几十分钟不等,视实验设计和目的而定。
Task(任务)
task 通常指的是被试者需要执行的一项特定的活动或任务。这些任务可以是视觉、听觉、运动或认知任务,旨在激发特定的神经信号活动。例如,在脑电图(EEG)或功能性磁共振成像(fMRI)等实验中,研究者会设计一系列的任务,例如观看特定类型的图像、听特定类型的音频、进行特定类型的运动或执行特定类型的认知任务。这些任务的目的是研究特定区域或神经网络在执行不同任务时的活动模式,以了解大脑的功能和结构。
Subject(受试者)
subject 是指参与实验的被试者或参与者。在神经信号试验中,通常会有多个 subject 参与到同一个实验中。每个 subject 都会完成一系列的 task,通常会有多个 session。通过对多个 subject 的数据进行分析,研究者可以得到更加全面和准确的结论,从而了解大脑的功能和结构。
Readin(输入信号)
Readin 是指从外部环境中接收到的信号,也就是输入到神经系统中的信号。这些信号可能是来自于外部感觉器官(例如视觉、听觉、触觉等),或者是来自于其他神经系统的输出。在脑电图(EEG)或功能性磁共振成像(fMRI)等神经信号实验中,readin 通常指的是被试者在实验中接收到的刺激(例如视觉刺激、听觉刺激等)所产生的神经信号。
Readout(输出信号)
Readout 是指从神经系统中输出的信号,也就是从神经系统中读取到的信号。这些信号可能是来自于大脑中的神经元活动、脑区的活动模式等。在脑电图(EEG)或功能性磁共振成像(fMRI)等神经信号实验中,readout 通常指的是被试者在实验中产生的神经信号,例如大脑的电信号或血氧水平等。
数据集nlb
其中迷宫数据集包含了猴子的神经活动,光标,猴子的目光所及和猴子手的位置。
cae
大致思路:
将可见部分送进encoder,得到可见部分的表征zv,regressor预测masked patch的表征,相当于基于 visible patches 去预测 masked patches,而非像 MAE 那样直接将 visible patches 的特征也输入到 Decoder。regressor的本质是一个cross-attention层,输入是encoder的输出,即可见部分的表征估计,和masked patches(在attention中作为query),由此可见,regressor的作用是从可见的视觉内容中通过cross-attention去捕捉所需要的语义,而后regressor把masked的表征送给decoder进行学习,此时的decoder学不到未加mask的原图像的表征。但是regressor是学到了原始图像的表征的,也就是说现在仍然没有做到只由encoder学习表征。为了使只有encoder学习原图想的表征,此时的做法是将regressor得到的masked token的表征与encoder编码masked patches后得到的表征对齐,需要进行一个操作,就是把原来加上mask的部分复原后输入decoder(encoder只输入原来被masked的部分的原图)。这么做相当于对regressor做了约束,即regressor学的masked的表征是由encoder产生的,相当于还是在对encoder进行学习,此时regressor偷学到的特征与encoder输出的结果进行对齐,相当于还是encoder在学习。如此便做到了全程只有encoder在学习。
stitching
这个x相当于是一个有着n个群神经元的神经集群的隐动力的线性表示,大a矩阵表示神经之间的时空关联性。aij如果非0则表示的是第j个神经元在t时刻对t+1时刻的第i个神经元有着统计学意义上的影响。不一定代表着真实的神经连接。大a矩阵用李雅普诺夫方程标出。大b矩阵代表着神经元对刺激的反应。
一塔t表示神经元之间的噪声,为一个0均值协方差矩阵为Q的多维高斯分布。每一个时刻的噪声q_t都可以用李雅普诺夫方程,用大a和大q矩阵表示出来。
再来看y。y代表着对整个共有n个神经元神经元集群中的k个神经元的观测值,这个y的表达式相当于用总体的隐动力x加上一些噪声和偏置得出来的。这里的矩阵c被称为测量矩阵,maesurement matraix,大小为N_k x N。
这个大c矩阵可以看做已知。
之后再用em算法学习出四个参数abqr便完成了对神经集群活动的建模。xt即为最终的表示结果。em算法大致分为两个步骤,e步骤求期望,m步骤最大化所求参数
作者源代码中对stitching的评价如下:
- 数据集们没有很好的聚合在一起,stitching某种程度上解决了这个问题,但是计算开销非常大
- stitching出于某种原因崩溃了。
个人思考
无论是mae或cae,都是把原始神经信号处理成为图像格式在进入transformer做序列分析,图像变为序列在解码序列信息,*相比直接从序列中学习神经活动的表征效果会变的有多差劲?是否存可以在序列优化上下手?
ndt2给出的无法序列化学习的原因是
数据集nlb
先简单介绍下一个神经数据集。Neural Latents Benchmark
模型mae
首先介绍下mae,masked autoencoder,等会要介绍的方法里用到了mae的思想。mae简单来说就是给图像加上很高比例的mask,然后设计编解码器来重建图像。大致流程直接看这个图,先把加了mask的图片通过encoder进行编码,得到一堆tokens,然后再把这些tokens加上mask送进decoder重建得到图片。其他的一些特点还有mask的比例非常高,而且是完全随机采样的mask,防止模型能够轻易从有规律的图像采样中学到信息。
然后结合ndt2的设计来具体讲mae的细节。
先看encoder,这里的图像就是把k条神经元的活动序列堆叠起来变成类似二维图像的形式,encoder的输入是unmasked的patches,patches就是图像加了mask后的剩余的可见部分,然后和任务信息和图像位置信息等一起输入encoder进行训练。
decoder处理encoder编码得到的tokens,也就是encoder的输出,同时的decoder也要处理masked的token,也就是mask的位置信息,所有的mask token公用一个向量。
作者设计了两个解码器,一个解码出神经活动信号,另一个直接解码出生物的行为,可以直接用于实时预测。
右边的图B讲的是数据的相关性与数据量的关系。随着不同时间,不同任务的数据加如训练数据集,训练数据量在上升,但数据的相关性也在减弱。
然后编码器和解码器都是采用的vit,vision transformer作为主结构,transformer是处理序列信息的结构嘛,vit就把二维图像切割开,分成若干个小块,然后排序,为每个小块加入位置信息,如此就把图像转化为了序列信息。然后把每个小块经过线性映射或者flatten操作后输入transformer encoder中,内部结构就是经典的多头注意力机制。如右图所示。模型大致设计思路就讲完了。
然后是ndt2的一些设计思想。ndt1是直接把单条神经元的活动序列输入transformer进行学习,在不同任务中同一个个神经元的活动模式可能并不相同,所以每次只输入一个神经元进行学习可能会影响效果,但一次把所有的数据都序列化输入模型在计算上是行不通的,所以ndt2的作者就折中,把k条记录叠成二维图像,用vit的思想去处理数据。
右图是ndt2的数据组成,一半是人类数据,一半是猴子做不同任务时的数据。而且还用到了未公开的临床试验数据。感觉这个任务数据集是个问题。现有的能用的数据集并不多。而且似乎只能找到人和猴子这两个物种的数据集。
接下来讲实验。a图是在sorted data上进行的预训练的效果,b图是unsorted数据集上的结果,sort与否的区别就是,sort是按照对应刺激反射的波形分类后的。先看右边,这里的single-session应该代表的是单个神经元在不同时间段内的预测表现,这里的横杠不是连字符的意思,我的理解是single和multi指的是神经元的数量,session指的是时间,subject是任务类型,task是具体的任务,single-session指的就是单条神经元在不同时间做同一个动作,也就是scratch这个动作时的表现。r平方是统计学习里的决定系数,越大越好,nll是负的对数似然,越小越好。所以说越靠近左上角的点表现越好。可以看到ndt在两种类型的数据集中都是在multi-session上表现最好。也就是
改进cae
对于ndt2的主要结构,作者用的是mae,mae的主要思想刚讲过,就是encoder只学习可见部分,decoder输入为mask的信息和encoder的输出。有一个比较明显的问题就是encoder学习表征不充分,因为在训练的过程中decoder学习了encoder编码的特征,在预训练的过程中相当于也会对decoder进行优化,也就是说原始数据集的一部分表征是在decoder中被学习的。当适应下游任务或进行微调时,decoder就被扔了,此时的encoder存在着表征学习不充分的可能性。
cae的做法是在encoder和decoder之间加了一层lantern regressor,是一个在潜在特征空间中捕捉语义的模块。在潜在表征空间中基于 visible patches 去预测 masked patches,而非像 MAE 那样直接将 visible patches 的特征也输入到 Decoder。
再来讲这个regressor。regressor 有两部分输入:一部分是 masked tokens,对应 masked patches 的表征估计,在 attention 中作为 query,它是1个可学习的向量,对于所有图片的所有 masked patches 都一样,可看作是一种特征统计分布;另一部分是 un-masked patches 经过 Encoder 编码后的表征,它们与 maksed tokens 拼接(concat)在一起,作为 cross attention 中的 key & value 角色。
数据处理stitching
总结
整体看下来还有非常多的问题需要解决。