目的:

开集目标检测器 检测任意对象 支持类别名称或指代表达进行人机交互

方法:

image-20240730133027550

特征提取和融合:

对于图像,用Swin Transformer等骨干网络提取多尺度图像特征,使用Bert等网络提取文本特征

在提取了图像和文本特征后,输入到特征增强其进行跨模态的特征融合

特征融合部分,对图像特征和文本特征采用了自注意力,并且设计了文本到图像和图像到文本的交叉注意力机制

image-20240730133917368

基于语言的查询选择:

利用输入文本来直到对象检测,选择与输入文本更相关的特征作为解码器的查询

# image_features: (bs, num_img_tokens, ndim)
# text_features: (bs, num_text_tokens, ndim)
# num_query: int

def language_guided_query_selection(image_features, text_features, num_query):
    """
    基于语言的查询选择算法

    参数:
        image_features: 图像特征, 形状为 (bs, num_img_tokens, ndim)
        text_features: 文本特征, 形状为 (bs, num_text_tokens, ndim)
        num_query: 查询数量

    返回:
        topk_proposals_idx: 前K个查询的索引, 形状为 (bs, num_query)
    """
    # 计算图像特征和文本特征之间的相关性得分
    logits = torch.einsum("bic,btc->bit", image_features, text_features)  # 形状为 (bs, num_img_tokens, num_text_tokens)

    # 对每个图像特征取最大值以获取相关性得分
    logits_per_image_feature = logits.max(dim=-1)[0]  # 形状为 (bs, num_img_tokens)

    # 选择得分最高的特征索引
    topk_proposals_idx = torch.topk(logits_per_image_feature, num_query, dim=1)[1]  # 形状为 (bs, num_query)

    return topk_proposals_idx

选择K个得分最高的特征索引来初始化解码器查询,同样的,位置部分是动态锚框,内容查询在训练过程中是可学习的

跨模态解码器:

跨模态查询会通过一个自注意力层,一个与图像的交叉注意力层,一个与文本的交叉注意力层,最后通过FFN,与DINO相比,增加了一个文本的交叉注意力层

子句级文本特征:

文本提示可以用句子或者词语表示,句子回提取短语并丢弃其他词语,可以消除词语之间的影响,会丢失细粒度信息,后者可以一次前向传递多个类别名称,但是回引入不必要的依赖,本文采用了注意力掩码,消除了不同类别之间的影响,同时保留每个词的特征

image-20240730212140238

损失函数设计:

与DINO类似,设计预测和真实值之间的二分匹配

总结:

1.添加文本特征,并且在提取特征时将文本特征和图像特征融合

2.用相关度高的融合特征来指导查询,并且后续与文本图像执行交叉注意力,充分将文本信息融入进去

3.设计注意力掩码,保留每个词特征的同时消除不同类别之间的影响