文章介绍:
文章使用了交互式分割的半监督学习方法,对视频目标进行分割。与处理三维数据的3DU-Net等3D网络不同,对视频目标进行分割时,会从每个视频中取出相邻或相近的固定个帧,训练时只使用这几个帧进行训练,测试时将全部帧加载进模型中进行分割,因此网络中test和train部分相差较大。文章中的交互式分割是在分割完成后,根据人为提示内容,对分割进行进一步修改。
代码介绍:
train_SAQ.py:
加载数据时,会根据不同用途加载不同量的数据。train时会加载sample_per_volume的帧,该变量在option中设置,而test时会加载一个volume的全部数据。这个设置也会在网络的初始化时进行,网络创建时设置内部phase为train或test。
net = STM(opt.keydim, opt.valdim, 'train', mode=opt.mode, iou_threshold=opt.iou_threshold)
train_loss = train(trainloader, model=net, criterion=criterion, optimizer=optimizer, epoch=epoch, use_cuda=True, iter_size=opt.iter_size, mode=opt.mode, threshold=opt.iou_threshold) if (epoch + 1) % opt.epoch_per_test == 0: net.module.phase = 'test' test_loss = test(testloader, model=net.module, criterion=criterion, epoch=epoch, use_cuda=True)
在train和test函数中,被选中的帧frame与它的mask分别传入model,计算损失函数
out, quality, ious = model(frame=frames, mask=masks, num_objects=objs, criterion=mask_iou_loss)
model_SAQ.py:
获得数据后,模型的训练过程分为以下几步:
memorize:
具体来说,就是使用上一个帧的分割结果以及内容,指导当前帧的分割,当没有上一个帧时,使用的就是当前帧和当前帧的mask,即当前为参考帧
if t - 1 == 0 or self.mode == 'mask': tmp_mask = mask[idx, t - 1:t] elif self.mode == 'recurrent': tmp_mask = out else: pred_mask = out[0, 1:num_object + 1] iou = mask_iou(pred_mask, mask[idx, t - 1, 1:num_object + 1]) if iou > self.iou_threshold: tmp_mask = out else: tmp_mask = mask[idx, t - 1:t] key, val, _ = self.memorize(frame=frame[idx, t - 1:t], masks=tmp_mask, num_objects=num_object)
在memorize函数中,分别获取帧,标签以及背景,送入内存编码器
frame_batch = [] mask_batch = [] bg_batch = [] # print(' ') # print(num_objects) try: for o in range(1, num_objects + 1): # 1 - no frame_batch.append(frame) mask_batch.append(masks[:, o]) for o in range(1, num_objects + 1): bg_batch.append(torch.clamp(1.0 - masks[:, o], min=0.0, max=1.0)) # make Batch frame_batch = torch.cat(frame_batch, dim=0) mask_batch = torch.cat(mask_batch, dim=0) bg_batch = torch.cat(bg_batch, dim=0) except RuntimeError as re: print(re) print(num_objects) raise re r4, _, _, _ = self.Encoder_M(frame_batch, mask_batch, bg_batch) # no, c, h, w
从编码器中获得特征,并由特征获得键与值
k4, v4 = self.KV_M_r4(memfeat) k4 = k4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.keydim) v4 = v4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.valdim) return k4, v4, r4
segment:
根据得到的key和value,使用Encoder_Q对当前帧进行分割
# segment tmp_key = torch.cat(batch_keys, dim=1) tmp_val = torch.cat(batch_vals, dim=1) logits, ps, r4 = self.segment(frame=frame[idx, t:t + 1], keys=tmp_key, values=tmp_val, num_objects=num_object, max_obj=max_obj) r4s.append(r4) out = torch.softmax(logits, dim=1) tmp_out.append(out)
将所有的分割结果全部堆叠起来,获得batch_out,并根据batch_out获得当前分割的质量以及损失函数,将这三者返回后,一轮训练结束