Problem
mask rcnn transfer learning중에 loss dict를 구하는 상황이었는데, model(image,target)넣어도 다음처럼 predict 결과만 숭숭 나왔다.
{'boxes': tensor([[ 23.0990, 377.5876, 305.9265, 562.2125],
[ 11.2159, 242.4047, 369.8665, 428.1471],
[155.2247, 96.2208, 580.0000, 248.9312],
[ 5.0806, 275.1395, 101.7696, 426.0195],
[ 20.9149, 266.0938, 331.9341, 614.7516],
[ 11.2956, 134.7541, 386.5382, 542.4605],
[ 57.4241, 228.0799, 580.0000, 568.5257],
[492.2102, 366.9597, 499.2005, 382.1749],
[282.4087, 288.1212, 302.7863, 431.6761],
[ 5.5812, 0.0000, 75.4029, 258.5290],
[289.0428, 344.2142, 316.8184, 432.1634],
[509.7052, 182.2039, 579.2445, 325.4382],
[498.1463, 293.8111, 521.6642, 371.7359],
[ 7.9055, 414.9242, 46.2961, 423.6565],
[494.6889, 365.7891, 555.0450, 536.9102],
[ 46.7819, 309.8593, 274.7211, 524.3120],
[ 52.5683, 52.7769, 407.5523, 271.2253],
[228.2233, 96.3352, 346.0507, 455.8103],
[ 94.7609, 113.7210, 282.1691, 583.6635],
[207.3311, 91.4893, 325.0174, 227.4353],
[ 4.4843, 416.8586, 42.4637, 426.7223],
[ 16.9198, 171.7563, 210.8070, 454.9992],
[449.9134, 361.9514, 580.0000, 584.3832],
[216.7644, 590.3699, 543.6064, 620.0000],
[228.1008, 376.4395, 246.4781, 389.9498],
[ 40.7714, 352.4985, 324.2688, 474.4106],
[217.5527, 92.1453, 580.0000, 382.5197],
[ 19.8337, 146.3252, 148.0808, 620.0000],
[224.0854, 240.1711, 285.3246, 540.4830],
[489.8106, 363.9995, 496.3597, 384.0640],
[493.2275, 421.7141, 573.0145, 574.1880],
[ 2.3869, 307.9571, 42.8903, 355.5377],
[ 1.6328, 125.9515, 80.3102, 267.5670],
[127.8913, 309.1691, 247.5878, 577.4413],
[307.7375, 370.2474, 580.0000, 548.8074],
[495.7788, 370.1597, 504.5872, 381.0194],
[ 64.2821, 393.1695, 229.6148, 438.3313],
[ 59.6702, 0.0000, 580.0000, 325.5826],
[397.7323, 399.7869, 575.2143, 471.3532],
[507.9253, 372.6577, 554.8649, 381.1511],
[489.5278, 379.2029, 508.8386, 388.5149],
[203.5288, 113.5201, 296.9296, 505.5798],
[ 9.0284, 310.6707, 437.4026, 487.6050],
[188.9025, 124.0381, 346.7412, 328.8367],
[244.5682, 208.2422, 561.3989, 410.0283],
[334.6353, 405.7608, 452.6267, 476.4648],
[256.0843, 263.3940, 277.3746, 428.9407],
[479.5076, 431.9332, 575.5471, 504.5226],
[470.5065, 380.4351, 576.2295, 517.2133],
[110.8656, 476.6021, 580.0000, 620.0000],
[490.3259, 371.3167, 496.1457, 393.5844],
[403.1678, 254.0272, 411.1543, 316.9261],
[255.5108, 146.3589, 322.1245, 384.9491],
[ 6.3780, 0.0000, 317.1211, 422.5948],
[497.1475, 372.6072, 523.1251, 380.2904],
[448.6642, 143.2047, 580.0000, 365.4225],
[ 35.0093, 392.8786, 217.1972, 620.0000],
[496.7905, 346.5901, 522.3054, 377.3755],
[ 4.0886, 148.9180, 65.3176, 620.0000],
[ 46.0636, 389.9944, 162.5583, 423.9978],
[509.6417, 377.1959, 562.0063, 387.4168],
[247.4742, 246.4299, 270.3476, 413.3393],
[511.0940, 443.1800, 580.0000, 472.2426],
[ 3.4275, 211.6860, 71.5894, 447.3386],
[493.6158, 373.0334, 501.5971, 386.6621],
[491.8876, 380.9993, 505.5616, 390.8474],
[234.0815, 389.1766, 260.1580, 423.7631],
[219.6255, 193.5732, 314.6744, 338.8667],
[277.3090, 341.4319, 306.5358, 445.0870],
[207.4244, 214.3731, 243.2476, 325.8946],
[373.3044, 229.2868, 413.4512, 337.2397],
[386.5466, 376.0895, 575.7539, 441.1465],
[266.3353, 269.5884, 580.0000, 488.0046],
[453.1906, 379.2785, 461.6151, 390.1317],
[530.6581, 234.1370, 580.0000, 268.2140],
[277.4349, 158.0855, 341.2065, 392.5967],
[456.1600, 368.0881, 463.8938, 379.4635],
[202.2168, 182.6642, 287.5092, 305.6806],
[204.0509, 24.7877, 314.8493, 358.7611],
[347.6951, 431.7468, 421.4155, 481.5295],
[ 49.5914, 142.2849, 401.4806, 346.4820],
[271.4739, 53.1259, 328.7435, 86.3012],
[505.5680, 335.6811, 579.0605, 377.9692],
[471.0669, 116.7345, 528.4269, 289.6050],
[396.8392, 255.4695, 406.2174, 316.7011],
[ 91.1898, 179.1382, 193.6283, 327.8245],
[461.2997, 96.6441, 560.4382, 489.6231],
[175.1966, 9.9618, 578.2646, 143.8588],
[372.1141, 610.3049, 522.4876, 620.0000],
[331.8203, 247.1843, 441.0966, 355.5180],
[345.6376, 412.5702, 413.1440, 461.4788],
[508.3494, 105.3887, 557.5039, 383.9501],
[495.4020, 377.5519, 521.1812, 387.7235],
[281.8093, 346.1822, 352.3070, 489.6473],
[276.7371, 247.1393, 321.6436, 510.5059],
[392.1492, 302.0256, 572.5671, 620.0000],
[497.5034, 374.3366, 526.1511, 383.4103],
[542.1813, 380.0844, 580.0000, 391.6952],
[496.7894, 335.7095, 512.9001, 373.5984],
[490.6134, 371.9872, 512.5731, 380.0279]], device='cuda:1'), 'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1], device='cuda:1'), 'scores': tensor([0.9183, 0.9150, 0.8968, 0.8913, 0.8883, 0.8803, 0.8702, 0.8692, 0.8649,
0.8634, 0.8566, 0.8565, 0.8558, 0.8552, 0.8541, 0.8529, 0.8521, 0.8502,
0.8457, 0.8447, 0.8443, 0.8442, 0.8439, 0.8406, 0.8364, 0.8328, 0.8325,
0.8310, 0.8301, 0.8275, 0.8269, 0.8267, 0.8266, 0.8261, 0.8259, 0.8253,
0.8253, 0.8248, 0.8242, 0.8235, 0.8207, 0.8193, 0.8185, 0.8171, 0.8166,
0.8153, 0.8148, 0.8147, 0.8143, 0.8133, 0.8121, 0.8110, 0.8108, 0.8107,
0.8106, 0.8082, 0.8080, 0.8075, 0.8074, 0.8066, 0.8056, 0.8043, 0.8042,
0.8040, 0.8036, 0.8034, 0.8034, 0.8027, 0.8018, 0.8018, 0.8017, 0.8015,
0.8012, 0.8001, 0.7997, 0.7991, 0.7977, 0.7973, 0.7958, 0.7958, 0.7957,
0.7954, 0.7953, 0.7937, 0.7936, 0.7919, 0.7911, 0.7903, 0.7900, 0.7899,
0.7897, 0.7883, 0.7873, 0.7869, 0.7868, 0.7863, 0.7859, 0.7855, 0.7840,
0.7833], device='cuda:1'), 'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]], device='cuda:1')}]
Solution
__init__에 model.train()을 선언해두었지만 효과가 없었고, every step마다 self.model.train()을 해주니 loss가 정상적으로 출력되었다.
def validation_step(self, batch, batch_idx):
self.model.train()
images,targets = batch
x = list(image for image in images)
y = []
for idx in range(len(x)):
dict = {}
for key in self.keys:
dict[key] = targets[key][idx]
y.append(dict)
#y = [{k: v for k, v in t.items()} for t in targets]
loss_dict = self.model(x,y)
print(loss_dict)
losses = sum(loss for loss in loss_dict.values())
self.log("validation_loss", loss_dict, on_step=True, on_epoch=True, sync_dist=True)
return losses
'''
{'loss_classifier': tensor(0.8349, device='cuda:1'), 'loss_box_reg': tensor(0.1718, device='cuda:1'), 'loss_mask': tensor(2.8173, device='cuda:1'), 'loss_objectness': tensor(0.0024, device='cuda:1'), 'loss_rpn_box_reg': tensor(0.0054, device='cuda:1')}
'''