Unverified Commit 97e8c6c7 authored by Yiyao Wei's avatar Yiyao Wei
Browse files

black formater

parent 0ef2d81a
import torch
CLASS_NAME = ['cl0', 'cl1', 'cl2','cl3','cl4']
CLASS_NAME = ["cl0", "cl1", "cl2", "cl3", "cl4"]
def skew_symmetric(axag_unit):
"""
......@@ -17,14 +18,31 @@ def skew_symmetric(axag_unit):
sh = axag_unit.shape
axag_unit_exp = torch.unsqueeze(torch.unsqueeze(axag_unit, 2), 3)
row1 = torch.cat([torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda(),
-axag_unit_exp[:, 2, :, :], axag_unit_exp[:, 1, :, :]], dim=2)
row2 = torch.cat([axag_unit_exp[:, 2, :, :], torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda(),
-axag_unit_exp[:, 0, :, :]], dim=2)
row1 = torch.cat(
[
torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda(),
-axag_unit_exp[:, 2, :, :],
axag_unit_exp[:, 1, :, :],
],
dim=2,
)
row2 = torch.cat(
[
axag_unit_exp[:, 2, :, :],
torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda(),
-axag_unit_exp[:, 0, :, :],
],
dim=2,
)
row3 = torch.cat(
[-axag_unit_exp[:, 1, :, :], axag_unit_exp[:, 0, :, :], torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda()],
dim=2)
[
-axag_unit_exp[:, 1, :, :],
axag_unit_exp[:, 0, :, :],
torch.zeros((sh[0], 1, 1), dtype=torch.float64).cuda(),
],
dim=2,
)
axag_unit_ss = torch.cat([row1, row2, row3], dim=1)
return axag_unit_ss
......@@ -50,18 +68,38 @@ def exponential_map(axag, EPS=1e-2):
theta_pow_6 = theta_sq * theta_sq * theta_sq
theta_pow_8 = theta_sq * theta_sq * theta_sq * theta_sq
term_1 = torch.where(is_angle_small,
1 - (theta_sq / 6.0) + (theta_pow_4 / 120) - (theta_pow_6 / 5040) + (theta_pow_8 / 362880),
torch.sin(theta) / theta)
term_2 = torch.where(is_angle_small,
0.5 - (theta_sq / 24.0) + (theta_pow_4 / 720) - (theta_pow_6 / 40320) + (
theta_pow_8 / 3628800),
(1 - torch.cos(theta)) / theta_sq)
term_1 = torch.where(
is_angle_small,
1
- (theta_sq / 6.0)
+ (theta_pow_4 / 120)
- (theta_pow_6 / 5040)
+ (theta_pow_8 / 362880),
torch.sin(theta) / theta,
)
term_2 = torch.where(
is_angle_small,
0.5
- (theta_sq / 24.0)
+ (theta_pow_4 / 720)
- (theta_pow_6 / 40320)
+ (theta_pow_8 / 3628800),
(1 - torch.cos(theta)) / theta_sq,
)
term_1_expand = torch.unsqueeze(torch.unsqueeze(term_1, 1), 2)
term_2_expand = torch.unsqueeze(torch.unsqueeze(term_2, 1), 2)
batch_identity = torch.eye(3, dtype=torch.float64).unsqueeze(0).repeat(axag.shape[0], 1, 1).cuda()
axag_exp = batch_identity + torch.mul(term_1_expand, ss) + torch.mul(term_2_expand, torch.matmul(ss, ss))
batch_identity = (
torch.eye(3, dtype=torch.float64)
.unsqueeze(0)
.repeat(axag.shape[0], 1, 1)
.cuda()
)
axag_exp = (
batch_identity
+ torch.mul(term_1_expand, ss)
+ torch.mul(term_2_expand, torch.matmul(ss, ss))
)
return axag_exp
......@@ -109,10 +147,12 @@ def logarithm(R, b_deal_with_sym=False, EPS=1e-2):
theta_pow_6 = theta_pow_2 * theta_pow_4
# ss = (R - tf.matrix_transpose(R))
ss = (R - R.transpose(1, 2))
mul_expand = torch.where(is_angle_small,
0.5 + (theta_pow_2 / 12) + (7 * theta_pow_4 / 720) + (31 * theta_pow_6 / 30240),
theta / (2 * torch.sin(theta)))
ss = R - R.transpose(1, 2)
mul_expand = torch.where(
is_angle_small,
0.5 + (theta_pow_2 / 12) + (7 * theta_pow_4 / 720) + (31 * theta_pow_6 / 30240),
theta / (2 * torch.sin(theta)),
)
if b_deal_with_sym:
log_R = torch.unsqueeze(torch.unsqueeze(mul_expand, 2), 3) * ss
else:
......@@ -122,12 +162,12 @@ def logarithm(R, b_deal_with_sym=False, EPS=1e-2):
def get_rotation_error(pred, label):
'''
"""
Return (mean) rotation error in form of angular distance in SO(3)
:param pred: B,3 tensor
:param label: B,3 tensor
:return: 1D scalar
'''
"""
pred_expMap = exponential_map(pred)
label_expMap = exponential_map(label)
......@@ -145,33 +185,37 @@ def get_translation_error(pred, label):
def get_loss(end_points):
translate_pred = end_points['translate_pred']
translate_label = end_points['translate_label']
axag_pred = end_points['axag_pred']
axag_label = end_points['axag_label']
point_class = end_points['point_clouds'][:, 0, 3:].double()
trans_loss, trans_perLoss = get_translation_error(translate_pred.double(), translate_label.double())
axag_loss, axag_perLoss = get_rotation_error(axag_pred.double(), axag_label.double())
translate_pred = end_points["translate_pred"]
translate_label = end_points["translate_label"]
axag_pred = end_points["axag_pred"]
axag_label = end_points["axag_label"]
point_class = end_points["point_clouds"][:, 0, 3:].double()
trans_loss, trans_perLoss = get_translation_error(
translate_pred.double(), translate_label.double()
)
axag_loss, axag_perLoss = get_rotation_error(
axag_pred.double(), axag_label.double()
)
total_loss = 10 * trans_loss + axag_loss
total_perloss = 10 * trans_perLoss + axag_perLoss
trans_perLoss = torch.unsqueeze(trans_perLoss, dim=0).t() * point_class
trans_clsLoss = torch.sum(trans_perLoss, dim=0)/torch.sum(point_class, dim=0)
trans_clsLoss = torch.sum(trans_perLoss, dim=0) / torch.sum(point_class, dim=0)
axag_perLoss = torch.unsqueeze(axag_perLoss, dim=0).t() * point_class
axag_clsLoss = torch.sum(axag_perLoss, dim=0)/torch.sum(point_class, dim=0)
axag_clsLoss = torch.sum(axag_perLoss, dim=0) / torch.sum(point_class, dim=0)
total_perloss = torch.unsqueeze(total_perloss, dim=0).t() * point_class
total_clsLoss = torch.sum(total_perloss, dim=0)/torch.sum(point_class, dim=0)
total_clsLoss = torch.sum(total_perloss, dim=0) / torch.sum(point_class, dim=0)
end_points["trans_loss"] = trans_loss
end_points["axag_loss"] = axag_loss
end_points["total_loss"] = total_loss
end_points['trans_loss'] = trans_loss
end_points['axag_loss'] = axag_loss
end_points['total_loss'] = total_loss
return total_loss, end_points
if __name__ == '__main__':
if __name__ == "__main__":
label = torch.tensor([[0.6977, 0.8248, 0.9367]], dtype=torch.float64).cuda()
pred = torch.tensor([[-2.100418, -2.167796, 0.2733]], dtype=torch.float64).cuda()
# print(torch.matmul(pred, label))
......
......@@ -3,7 +3,6 @@ import torch.utils.data
import torch.nn.functional as F
class CloudPose_trans(nn.Module):
def __init__(self, channel=3, num_class=5):
super(CloudPose_trans, self).__init__()
......@@ -101,28 +100,31 @@ class CloudPose_all(nn.Module):
self.rot = CloudPose_rot(self.channel, self.num_class)
def forward(self, input):
point_clouds = input['point_clouds']
point_clouds_tp = point_clouds.transpose(1, 2) # b 8 256
base_xyz = torch.mean(point_clouds_tp[:, :self.channel, :], dim=2)
point_clouds_res = point_clouds_tp[:, :self.channel, :] - base_xyz.unsqueeze(-1) # b 3 1
point_clouds_res_with_cls = torch.cat((point_clouds_res, point_clouds_tp[:, self.channel:, :]),
dim=1) # channel 在前 cls在后
point_clouds = input["point_clouds"]
point_clouds_tp = point_clouds.transpose(1, 2) # b 8 256
base_xyz = torch.mean(point_clouds_tp[:, : self.channel, :], dim=2)
point_clouds_res = point_clouds_tp[:, : self.channel, :] - base_xyz.unsqueeze(
-1
) # b 3 1
point_clouds_res_with_cls = torch.cat(
(point_clouds_res, point_clouds_tp[:, self.channel :, :]), dim=1
) # channel 在前 cls在后
t, ind_t = self.trans(point_clouds_res_with_cls)
r, ind_r = self.rot(point_clouds_res_with_cls) # better than point_clouds_tp
# r, ind_r = self.rot(point_clouds_tp)
end_points = {}
end_points['translate_pred'] = t + base_xyz
end_points['axag_pred'] = r
end_points["translate_pred"] = t + base_xyz
end_points["axag_pred"] = r
return end_points
if __name__ == '__main__':
if __name__ == "__main__":
sim_data = torch.rand(32, 2500, 3 + 5)
input = {}
input['point_clouds'] = sim_data
input["point_clouds"] = sim_data
feat = CloudPose_all(3, 5)
end_points = feat(input)
print( end_points['translate_pred'],end_points['axag_pred'])
print(end_points["translate_pred"], end_points["axag_pred"])
''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch '''
""" Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch """
from typing import List, Tuple
import torch
import torch.nn as nn
from typing import List, Tuple
class SharedMLP(nn.Sequential):
def __init__(
self,
args: List[int],
*,
bn: bool = False,
activation=nn.ReLU(inplace=True),
preact: bool = False,
first: bool = False,
name: str = ""
self,
args: List[int],
*,
bn: bool = False,
activation=nn.ReLU(inplace=True),
preact: bool = False,
first: bool = False,
name: str = ""
):
super().__init__()
for i in range(len(args) - 1):
self.add_module(
name + 'layer{}'.format(i),
name + "layer{}".format(i),
Conv2d(
args[i],
args[i + 1],
bn=(not first or not preact or (i != 0)) and bn,
activation=activation
if (not first or not preact or (i != 0)) else None,
preact=preact
)
if (not first or not preact or (i != 0))
else None,
preact=preact,
),
)
class _BNBase(nn.Sequential):
def __init__(self, in_size, batch_norm=None, name=""):
super().__init__()
self.add_module(name + "bn", batch_norm(in_size))
......@@ -44,40 +42,36 @@ class _BNBase(nn.Sequential):
class BatchNorm1d(_BNBase):
def __init__(self, in_size: int, *, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
class BatchNorm2d(_BNBase):
def __init__(self, in_size: int, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
class BatchNorm3d(_BNBase):
def __init__(self, in_size: int, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)
class _ConvBase(nn.Sequential):
def __init__(
self,
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=None,
batch_norm=None,
bias=True,
preact=False,
name=""
self,
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=None,
batch_norm=None,
bias=True,
preact=False,
name="",
):
super().__init__()
......@@ -88,7 +82,7 @@ class _ConvBase(nn.Sequential):
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
bias=bias,
)
init(conv_unit.weight)
if bias:
......@@ -102,37 +96,36 @@ class _ConvBase(nn.Sequential):
if preact:
if bn:
self.add_module(name + 'bn', bn_unit)
self.add_module(name + "bn", bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + "activation", activation)
self.add_module(name + 'conv', conv_unit)
self.add_module(name + "conv", conv_unit)
if not preact:
if bn:
self.add_module(name + 'bn', bn_unit)
self.add_module(name + "bn", bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + "activation", activation)
class Conv1d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
self,
in_size: int,
out_size: int,
*,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
......@@ -147,26 +140,25 @@ class Conv1d(_ConvBase):
batch_norm=BatchNorm1d,
bias=bias,
preact=preact,
name=name
name=name,
)
class Conv2d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
......@@ -181,26 +173,25 @@ class Conv2d(_ConvBase):
batch_norm=BatchNorm2d,
bias=bias,
preact=preact,
name=name
name=name,
)
class Conv3d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int, int] = (1, 1, 1),
stride: Tuple[int, int, int] = (1, 1, 1),
padding: Tuple[int, int, int] = (0, 0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int, int] = (1, 1, 1),
stride: Tuple[int, int, int] = (1, 1, 1),
padding: Tuple[int, int, int] = (0, 0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
......@@ -215,22 +206,21 @@ class Conv3d(_ConvBase):
batch_norm=BatchNorm3d,
bias=bias,
preact=preact,
name=name
name=name,
)
class FC(nn.Sequential):
def __init__(
self,
in_size: int,
out_size: int,
*,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=None,
preact: bool = False,
name: str = ""
self,
in_size: int,
out_size: int,
*,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=None,
preact: bool = False,
name: str = ""
):
super().__init__()
......@@ -242,19 +232,19 @@ class FC(nn.Sequential):
if preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(in_size))
self.add_module(name + "bn", BatchNorm1d(in_size))
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + "activation", activation)
self.add_module(name + 'fc', fc)
self.add_module(name + "fc", fc)
if not preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(out_size))
self.add_module(name + "bn", BatchNorm1d(out_size))
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + "activation", activation)
def set_bn_momentum_default(bn_momentum):
......@@ -266,16 +256,10 @@ def set_bn_momentum_default(bn_momentum):
class BNMomentumScheduler(object):
def __init__(
self, model, bn_lambda, last_epoch=-1,
setter=set_bn_momentum_default
):
def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(
type(model).__name__
)
"Class '{}' is not a PyTorch nn Module".format(type(model).__name__)
)
self.model = model
......
......@@ -20,31 +20,79 @@ from tensorboardX import SummaryWriter
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'mydataset'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
sys.path.append(os.path.join(ROOT_DIR, "utils"))
sys.path.append(os.path.join(ROOT_DIR, "mydataset"))
sys.path.append(os.path.join(ROOT_DIR, "models"))
from pt_utils import BNMomentumScheduler
from tf_visualizer import Visualizer as TfVisualizer
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mydataset', help='Dataset name. sunrgbd or scannet. [default: sunrgbd]')
parser.add_argument('--checkpoint_path', default=None, help='Model checkpoint path [default: None]')
parser.add_argument('--log_dir', default='log', help='Dump dir to save model checkpoint [default: log]')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024] [default: 256]')
parser.add_argument('--num_class', type=int, default=5, help='class Number [default: 5]')
parser.add_argument('--max_epoch', type=int, default=300, help='Epoch to run [default: 90]')
parser.add_argument('--optimizer', default='adam', help='adam or gd [default: adam]')
parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 8]')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Initial learning rate [default: 0.001]')
parser.add_argument('--weight_decay', type=float, default=0, help='Optimization L2 weight decay [default: 0]')
parser.add_argument('--bn_decay_step', type=int, default=40, help='Period of BN decay (in epochs) [default: 20]')
parser.add_argument('--bn_decay_rate', type=float, default=0.5, help='Decay rate for BN decay [default: 0.5]')
parser.add_argument('--lr_decay_steps', default='35,55,70',
help='When to decay the learning rate (in epochs) [default: 80,120,160]')
parser.add_argument('--lr_decay_rates', default='0.1,0.1,0.1', help='Decay rates for lr decay [default: 0.1,0.1,0.1]')
parser.add_argument('--overwrite', action='store_true', help='Overwrite existing log and dump folders.')
parser.add_argument(
"--dataset",
default="mydataset",
help="Dataset name. sunrgbd or scannet. [default: sunrgbd]",
)
parser.add_argument(
"--checkpoint_path", default=None, help="Model checkpoint path [default: None]"
)
parser.add_argument(
"--log_dir", default="log", help="Dump dir to save model checkpoint [default: log]"
)
parser.add_argument(
"--num_point",
type=int,