Volumetric memory network for interactive medical image segmentation代码分析

文章介绍:

文章使用了交互式分割的半监督学习方法,对视频目标进行分割。与处理三维数据的3DU-Net等3D网络不同,对视频目标进行分割时,会从每个视频中取出相邻或相近的固定个帧,训练时只使用这几个帧进行训练,测试时将全部帧加载进模型中进行分割,因此网络中test和train部分相差较大。文章中的交互式分割是在分割完成后,根据人为提示内容,对分割进行进一步修改。

代码介绍:

train_SAQ.py:

加载数据时,会根据不同用途加载不同量的数据。train时会加载sample_per_volume的帧,该变量在option中设置,而test时会加载一个volume的全部数据。这个设置也会在网络的初始化时进行,网络创建时设置内部phase为train或test。

net = STM(opt.keydim, opt.valdim, 'train',
              mode=opt.mode, iou_threshold=opt.iou_threshold)
train_loss = train(trainloader,
                           model=net,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch,
                           use_cuda=True,
                           iter_size=opt.iter_size,
                           mode=opt.mode,
                           threshold=opt.iou_threshold)

        if (epoch + 1) % opt.epoch_per_test == 0:
            net.module.phase = 'test'
            test_loss = test(testloader,
                             model=net.module,
                             criterion=criterion,
                             epoch=epoch,
                             use_cuda=True)

在train和test函数中,被选中的帧frame与它的mask分别传入model,计算损失函数

out, quality, ious = model(frame=frames, mask=masks, num_objects=objs, criterion=mask_iou_loss)

model_SAQ.py:

获得数据后,模型的训练过程分为以下几步:

memorize:

具体来说,就是使用上一个帧的分割结果以及内容,指导当前帧的分割,当没有上一个帧时,使用的就是当前帧和当前帧的mask,即当前为参考帧

                    if t - 1 == 0 or self.mode == 'mask':
                        tmp_mask = mask[idx, t - 1:t]
                    elif self.mode == 'recurrent':
                        tmp_mask = out
                    else:
                        pred_mask = out[0, 1:num_object + 1]
                        iou = mask_iou(pred_mask, mask[idx, t - 1, 1:num_object + 1])

                        if iou > self.iou_threshold:
                            tmp_mask = out
                        else:
                            tmp_mask = mask[idx, t - 1:t]

                    key, val, _ = self.memorize(frame=frame[idx, t - 1:t], masks=tmp_mask,
                                                num_objects=num_object)

在memorize函数中,分别获取帧,标签以及背景,送入内存编码器

frame_batch = []
        mask_batch = []
        bg_batch = []
        # print('
')
        # print(num_objects)
        try:
            for o in range(1, num_objects + 1):  # 1 - no
                frame_batch.append(frame)
                mask_batch.append(masks[:, o])

            for o in range(1, num_objects + 1):
                bg_batch.append(torch.clamp(1.0 - masks[:, o], min=0.0, max=1.0))

            # make Batch
            frame_batch = torch.cat(frame_batch, dim=0)
            mask_batch = torch.cat(mask_batch, dim=0)
            bg_batch = torch.cat(bg_batch, dim=0)
        except RuntimeError as re:
            print(re)
            print(num_objects)
            raise re

        r4, _, _, _ = self.Encoder_M(frame_batch, mask_batch, bg_batch)  # no, c, h, w

从编码器中获得特征,并由特征获得键与值

k4, v4 = self.KV_M_r4(memfeat)
        k4 = k4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.keydim)
        v4 = v4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.valdim)

        return k4, v4, r4

segment:

根据得到的key和value,使用Encoder_Q对当前帧进行分割

# segment
                    tmp_key = torch.cat(batch_keys, dim=1)
                    tmp_val = torch.cat(batch_vals, dim=1)
                    logits, ps, r4 = self.segment(frame=frame[idx, t:t + 1], keys=tmp_key, values=tmp_val,
                                                  num_objects=num_object, max_obj=max_obj)

                    r4s.append(r4)
                    out = torch.softmax(logits, dim=1)
                    tmp_out.append(out)

将所有的分割结果全部堆叠起来,获得batch_out,并根据batch_out获得当前分割的质量以及损失函数,将这三者返回后,一轮训练结束