Unverified Commit 8eff7416 authored by wawatt's avatar wawatt Committed by GitHub
Browse files

Update losses.py

parent dac8093c
......@@ -167,14 +167,7 @@ def get_loss(end_points):
end_points['trans_loss'] = trans_loss
end_points['axag_loss'] = axag_loss
end_points['total_loss'] = total_loss
TPer_loss = 'TPer_loss_'
APer_loss = 'APer_loss_'
TotPer_loss = 'TotPer_loss_'
for i in range(point_class.size(-1)):
end_points[TPer_loss+CLASS_NAME[i]] = trans_clsLoss[i]
end_points[APer_loss+CLASS_NAME[i]] = axag_clsLoss[i]
end_points[TotPer_loss+CLASS_NAME[i]] = total_clsLoss[i]
return total_loss, end_points
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment