• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

DS-NeRF代码

武飞扬头像
威尔士矮脚狗
帮助1

打开debug:

一.加载数据

llff:
加载稀疏点云以及深度信息:

if args.colmap_depth:
    depth_gts = load_colmap_depth(args.datadir, factor=args.factor, bd_factor=.75)

load_colmap_depth():

这个函数主要读取了两个类型的变量:
记录了 images 和 points 的对应关系,即:
每张图像上的特征点与3D点云中的哪些点匹配,3D点云中的点和图像中的哪些点匹配。

images:
返回值有以下属性:
id:对应数据集中的每张图片
qvec:用于生成旋转矩阵R的向量
tvec:用于生成平移矩阵T的向量
(C2W实现了基变换的功能,包含旋转和平移RT两个矩阵,由它们拼接而成)
camera_id :拍摄的相机,这里可能不同的相机对应这不同的内参数
name:图片的名称
xys:特征点在2D图像上的位置
point3D_ids:每个特征点对应3D点云中的id
(关键点仅仅是图像在特征点匹配时生成的点,最后可能没有生成点云,所以存在大量匹配不上的点,在这里,如果关键点没有与其匹配的3D点云,3D_point_id就会被定义为-1,匹配到3D点云时会标记为3D点云的下标)

images[image_id] = Image(
    id=image_id, qvec=qvec, tvec=tvec,
    camera_id=camera_id, name=image_name,
    xys=xys, point3D_ids=point3D_ids)

points:
id:3D点的下标
xyz:在世界坐标系下的位置
rgb:颜色
error:每个3D点在图像上都有一个与之对应的2D点,3D点投影到图像上后可能会与对应的2D点有一些偏差,因为3D点是由多个2D图像经过匹配生成的,包含了多张图像的信息,投影到一张图像上会有偏移。
image_ids:一个点对应多张图像的编号
point2D_idx:每张图像上特征点对应的坐标(x,y)

points3D[point3D_id] = Point3D(
    id=point3D_id, xyz=xyz, rgb=rgb,
    error=error, image_ids=image_ids,
    point2D_idxs=point2D_idxs)

get_poses:获取每张图片对应的相机位姿(一个 4*4 的c2w矩阵)。
确定near平面与far平面,其中factor为下采样倍数

_, bds_raw, _ = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
bds_raw = np.moveaxis(bds_raw, -1, 0).astype(np.float32)

(这部分好像不对,我再研究一下)
nerf使用的是NDC坐标系,要将所有坐标点归一化,所以这里根据bds中提供的场景信息和缩放的比例确定缩放系数sc的大小,初始值为1(NDC)。

sc:设置缩放比例:表示场景的大小
默认为1,sc 根据下采样倍数和近平面的最小值有关。

sc = 1. if bd_factor is None else 1./(bds_raw.min() * bd_factor)

如果找到平面上有一点与3D空间中一个点云对应,将图像上的3D点云转到相机坐标系下并获取它的深度。

point3D = points[id_3D].xyz
depth = (poses[id_im-1,:3,2].T @ (point3D - poses[id_im-1,:3,3])) * sc
if depth < bds_raw[id_im-1,0] * sc or depth > bds_raw[id_im-1,1] * sc:
    continue

接着,保存误差,权重(论文中的公式),深度值,坐标值。
最后放入存放每张图片信息的字典中。

err = points[id_3D].error
weight = 2 * np.exp(-(err/Err_mean)**2)
depth_list.append(depth)
coord_list.append(point2D/factor)
weight_list.append(weight)
if len(depth_list) > 0:
    print(id_im, len(depth_list), np.min(depth_list), np.max(depth_list), np.mean(depth_list))
    data_list.append({"depth":np.array(depth_list), "coord":np.array(coord_list), "error":np.array(weight_list)})
else:
    print(id_im, len(depth_list))

接下来就是加载训练数据,和nerf中的一样。

images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                          recenter=True, bd_factor=.75,
                                                          spherify=args.spherify)

二.创建模型

创建nerf模型,一切都和nerf-torch代码一样

三.获取训练数据

和nerf一样,获取每张图片的射线,对应的方法是从相机原点向每个像素的中心发射一条射线,获取方向向量再乘上每张图片对应的C2W矩阵,最终生成世界坐标系下的射线,并将相机原点广播到和Rays一样的维度,并将它们拼接起来,最终为形状为(o,d)的张量。之后将每个像素的RGB信息并入这个张量,生成rays_rgb(o,d,rgb)。
只是在训练的时候加入了rays_depth表示用于深度监督的射线,这就利用到了之前提到的depth_gt,存储了每张图片中与3D点云相关的像素点。将每个深度点以及权重深度值拼接到一起。形成一个形状为 (N,(o,d,weight,depth_value),3) 的张量。

for i in i_train:
    rays_depth = np.stack(get_rays_by_coord_np(H, W, focal, poses[i,:3,:4], depth_gts[i]['coord']), axis=0) # 2 x N x 3
    # print(rays_depth.shape)
    rays_depth = np.transpose(rays_depth, [1,0,2])
    depth_value = np.repeat(depth_gts[i]['depth'][:,None,None], 3, axis=2) # N x 1 x 3
    weights = np.repeat(depth_gts[i]['error'][:,None,None], 3, axis=2) # N x 1 x 3
    rays_depth = np.concatenate([rays_depth, d公式epth_value, weights], axis=1) # N x 4 x 3
    rays_depth_list.append(rays_depth)

这样就获取了用于深度监督的每条射线的原点,方向,权重以及深度。将射线拼接起来再打乱。

rays_depth = np.concatenate(rays_depth_list, axis=0)
rays_depth = rays_depth.astype(np.float32)
np.random.shuffle(rays_depth)

四.训练

depth:(B,(o,d,w,depth),3) -> ((o,d) B,3 ),(w,B,3),(depth,B,3),将ray_batch拆分成三组,射线方向,权重,深度值。

if args.colmap_depth:
    # batch_depth = rays_depth[i_batch:i_batch N_rand]
    try:
        batch_depth = next(raysDepth_iter).to(device)
    except StopIteration:
        raysDepth_iter = iter(DataLoader(RayDataset(rays_depth), batch_size = N_depth, shuffle=True, num_workers=0))
        batch_depth = next(raysDepth_iter).to(device)
    batch_depth = torch.transpose(batch_depth, 0, 1)
    batch_rays_depth = batch_depth[:2] # 2 x B x 3
    target_depth = batch_depth[2,:,0] # B
    ray_weights = batch_depth[3,:,0]

再将rgb射线和深度射线拼接起来,输入到渲染的网络中。

if args.colmap_depth:
    N_batch = batch_rays.shape[1]
    batch_rays = torch.cat([batch_rays, batch_rays_depth], 1) # (2, 2 * N_rand, 3)
render_rays:渲染光线

生成一个(2batch,11(如果没有视角信息,这个维度为8)),包含了深度和RGB的所有射线。
定义了near(2
batch,1),far平面(2batch,1),拼接进张量中。
(4096,11)<—>(2
batch,o d view near far)

    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, focal, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1) # B x 8
    if depths is not None:
        rays = torch.cat([rays, depths.reshape(-1,1)], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

计算深度的方法和计算rgb的方法相同,根据射线上的权重在乘上在射线上的位置。

rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
depth_map = torch.sum(weights * z_vals, -1)

返回时,返回射线深度和colmap稀疏点云对应的深度。

depth, depth_col = depth[:N_batch], depth[N_batch:]

稀疏点云深度loss有两个选项,第一个是mse就是传统的均方误差;
第二个为相对性loss,计算得到的深度与提供的gt深度之差占真实深度的比例,再平方。
(现在论文中的损失被换为了KL散度)

elif args.relative_loss:
    depth_loss = torch.mean(((depth_col - target_depth) / target_depth)**2)
else:
    depth_loss = img2mse(depth_col, target_depth)

loss 的公式
rgbloss depthloss sigmaloss(体积密度损失)

loss = img_loss   args.depth_lambda * depth_loss   args.sigma_lambda * sigma_loss

DONE

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgejajc
系列文章
更多 icon
同类精品
更多 icon
继续加载