该论文旨在利用无监督自适应技术来解决轨迹预测模型的跨域问题,提出了一个被称作Straj的模型。该模型整体架构由三个部分组成,具体细节见下文。
1.模型预训练
生成能够泛化到跨地理场景的高质量初始伪标签
- 输入:
- 已标注的源域数据
- 未标注的目标域数据
- 感兴趣智能体过去 $L_p$ 步的轨迹坐标 $P_a$
- 其他智能体 $P_o$
- 地图信息 $M$
- 现有轨迹预测器 $\mathcal{P}$
- $ϕ_a$ 智能体特征编码器
- $ϕ_m$ 地图特征编码器
- $ϕ_f$ 融合特征编码器
- 解码器 $φ$
- 过程:
- 特征编码:预测器 $\mathcal{P}$ 首先通过特征编码器将输入数据转换为原始特征
- 智能体特征 $\mathcal{A} = ϕ_a(P_a, P_o)$
- 地图特征 $\mathcal{M} = ϕ_m(P_a, M)$
- 融合特征 $\mathcal{F} = ϕ_f(\mathcal{A}, \mathcal{M})$
- 智能体和地图互补增强
- 根据智能体或车道段与目标智能体的距离计算遮蔽概率
- 对 r%的智能体和车道段进行遮蔽,生成弱增强和强增强后的特征
- 通过权衡策略生成被遮蔽的融合特征 $\mathcal{F_{(W,S)}} = ϕ_f \mathcal{(A_W,M_S)} ,$ 和 $\mathcal{F_{(S,W)}} = ϕ_f \mathcal{(A_S,M_W)}$
- 通过均方误差损失进行特征重建
- 计算被遮蔽的智能体特征 ($\mathcal{A_W}$, $\mathcal{A_S}$)、地图特征 ($\mathcal{M_W}$, $\mathcal{M_S}$) 和融合特征 ($\mathcal{F_{(W,S)}}$, $\mathcal{F_{(S,W)}}$) 与其对应的原始未遮蔽特征 ($\mathcal{A}$, $\mathcal{M}$, $\mathcal{F}$) 之间的均方误差。
- 使模型重建出所有被遮蔽的智能体特征、地图特征和融合特征,使其与原始特征保持一致
- 对于属于源域的智能体除了 $\mathcal{L}_{MSE}$ 之外,还会计算其预测轨迹与真实轨迹之间的预测误差 $\mathcal{L}_{pred}$
- 特征编码:预测器 $\mathcal{P}$ 首先通过特征编码器将输入数据转换为原始特征
- 输出:
- 经过初步训练的预测器 $\mathcal{P}$,对跨地理场景中的异构元素分布更具鲁棒性和泛化能力
- 基于这个预训练的 $\mathcal{P}$,生成目标域的初始伪标签$H_0$
2.伪标签更新
确保伪标签的一致性和稳定性,促进预测器的稳定训练
- 输入:
- 经过预训练的预测器$\mathcal{P}$。
- 第i个epoch下的目标域智能体 $a$ 的伪轨迹$F_i$.
- 历史epoch的伪轨迹集合 $H_{i−1}$(初始为 $H_0$)
- $F_i$ 和历史伪轨迹的置信度分数 $s(F_i)$ 和 $s(ĥ)$
- 预设的一致性阈值 $T_U$ 和置信度阈值 $T_C$
- 用于损失权重计算的缩放参数 $ρt$。
- 过程:
- 计算当前伪轨迹 $F_i$ 与 $H_{i−1}$ 中每个历史伪轨迹的余弦相似度U
- 从 $H_{i−1}$ 中选择与 $F_i$ 余弦相似度最高的历史伪轨迹
- 伪标签更新策略:
- 如果 $F_i$ 与 $ĥ$ 的一致性 $U(F̂i, ĥ)$ 低于 $T_U$,或 $F̂i$ 和 $ĥ$ 中任何一个的置信度低于 $Tc$,则不选择任何伪轨迹进行监督,并将 $F̂i$ 加入 $H_{i−1}$ 形成 $Hi$ 。
- 否则将 $ĥ$ 和 $F̂i$ 都加入 $Hi$,并选择置信度更高的作为更新后的伪轨迹 $b_i$
- 加权损失
$$\mathcal{L}^{tar}_{pred} = exp (U(\hat{F}_i, \hat{h})/ρt )\mathcal{L}_{pred} (\hat{F}_{i+1}, b_i)$$
使用基于一致性水平的指数函数计算目标域的预测损失权重,其目的是:- 关注高一致性的样本,优先让模型从那些更一致的样本中学习,从而提高伪标签的质量和训练的稳定性
- 防止误差累积: 通过一致性加权,降低模型在不一致样本上的学习权重,从而减少伪标签错误传播的风险。
- 输出:
- 每个目标域智能体更新后的伪轨迹 $b_i$,用于在下一个epoch 监督其预测
- 加权的目标域预测损失 $L^{tar}_{pred}$,用于最终的总损失计算.
- 更新后的历史伪轨迹集合 $H_i$.
3.轨迹诱导对比学习
通过增强相似轨迹的跨地理表示紧凑性和增强不同轨迹的表示可分离性,缓解跨地理区域智能体的表示偏差.
- 输入:
- 已更新得伪轨迹 $\hat{F}_a$ (target domain智能体) 或真值轨迹 $F_a$ (source domain智能体,用于减少误差累积)
- 过程:
- 根据轨迹一致性 $U(\hat{F}_a, \hat{F}_j)$ 和置信度,选择与当前智能体 $a$ 相似的轨迹作为正样本 $p$,选择不相似的轨迹作为负样本集 $Sn$
- 如果智能体 $a$ 属于源域,则用其真值轨迹 $Fa$ 替换预测轨迹 $\hat{F}_a$ 来进行对比学习,以减少误差累积.
- 对比损失计算:借鉴 InfoNCE loss,通过余弦相似度 $sim(x, y)$ 计算轨迹诱导对比损失 $L(d1,d2)$
$$\mathcal{L}(d_1,d_2) = −log \frac{exp (sim(F_a,F_p)/ρ_c)}{\sum_{j∈S_n∪p} exp (sim(F_a,F_j)/ρ_c)}$$
智能体 a 被称为锚点智能体,它来自域 d1,智能体 j 是从域 d2中选取的候选智能体,用于构成 a的正样本或负样本。d1 和 d2 可以是源域 Ds 或目标域 Dt,当 d1 = d2 时,a 和 j 来自同一个域,对应于域内对比损失,当 d1 ≠ d2时,a 和 j 来自不同的域,对应于域间对比损失。
分子计算锚点 $F_a$ 与正样本 $F_p$ 之间相似度的指数加权值。分母计算锚点 $F_a$与所有样本之间相似度的指数加权和。当正样本的相似度相对于所有样本的相似度越大时,损失越小。 - 总对比损失 $L_{con}$
- 域内对比损失 $\mathcal{L}(D_s,D_s)$ 和 $\mathcal{L}(D_t,D_t)$
- 域间对比损失 $\mathcal{L}(D_s,D_t)$ 和 $\mathcal{L}(D_t,D_s)$
- 输出:
- 总对比损失 $Lcon$,强制模型学习更具判别性的跨地理表示