본문 바로가기

인공지능/Python

model 학습 시 loss대신 predict 결과가 나올 경우

Problem

mask rcnn transfer learning중에 loss dict를 구하는 상황이었는데, model(image,target)넣어도 다음처럼 predict 결과만 숭숭 나왔다.

{'boxes': tensor([[ 23.0990, 377.5876, 305.9265, 562.2125],
        [ 11.2159, 242.4047, 369.8665, 428.1471],
        [155.2247,  96.2208, 580.0000, 248.9312],
        [  5.0806, 275.1395, 101.7696, 426.0195],
        [ 20.9149, 266.0938, 331.9341, 614.7516],
        [ 11.2956, 134.7541, 386.5382, 542.4605],
        [ 57.4241, 228.0799, 580.0000, 568.5257],
        [492.2102, 366.9597, 499.2005, 382.1749],
        [282.4087, 288.1212, 302.7863, 431.6761],
        [  5.5812,   0.0000,  75.4029, 258.5290],
        [289.0428, 344.2142, 316.8184, 432.1634],
        [509.7052, 182.2039, 579.2445, 325.4382],
        [498.1463, 293.8111, 521.6642, 371.7359],
        [  7.9055, 414.9242,  46.2961, 423.6565],
        [494.6889, 365.7891, 555.0450, 536.9102],
        [ 46.7819, 309.8593, 274.7211, 524.3120],
        [ 52.5683,  52.7769, 407.5523, 271.2253],
        [228.2233,  96.3352, 346.0507, 455.8103],
        [ 94.7609, 113.7210, 282.1691, 583.6635],
        [207.3311,  91.4893, 325.0174, 227.4353],
        [  4.4843, 416.8586,  42.4637, 426.7223],
        [ 16.9198, 171.7563, 210.8070, 454.9992],
        [449.9134, 361.9514, 580.0000, 584.3832],
        [216.7644, 590.3699, 543.6064, 620.0000],
        [228.1008, 376.4395, 246.4781, 389.9498],
        [ 40.7714, 352.4985, 324.2688, 474.4106],
        [217.5527,  92.1453, 580.0000, 382.5197],
        [ 19.8337, 146.3252, 148.0808, 620.0000],
        [224.0854, 240.1711, 285.3246, 540.4830],
        [489.8106, 363.9995, 496.3597, 384.0640],
        [493.2275, 421.7141, 573.0145, 574.1880],
        [  2.3869, 307.9571,  42.8903, 355.5377],
        [  1.6328, 125.9515,  80.3102, 267.5670],
        [127.8913, 309.1691, 247.5878, 577.4413],
        [307.7375, 370.2474, 580.0000, 548.8074],
        [495.7788, 370.1597, 504.5872, 381.0194],
        [ 64.2821, 393.1695, 229.6148, 438.3313],
        [ 59.6702,   0.0000, 580.0000, 325.5826],
        [397.7323, 399.7869, 575.2143, 471.3532],
        [507.9253, 372.6577, 554.8649, 381.1511],
        [489.5278, 379.2029, 508.8386, 388.5149],
        [203.5288, 113.5201, 296.9296, 505.5798],
        [  9.0284, 310.6707, 437.4026, 487.6050],
        [188.9025, 124.0381, 346.7412, 328.8367],
        [244.5682, 208.2422, 561.3989, 410.0283],
        [334.6353, 405.7608, 452.6267, 476.4648],
        [256.0843, 263.3940, 277.3746, 428.9407],
        [479.5076, 431.9332, 575.5471, 504.5226],
        [470.5065, 380.4351, 576.2295, 517.2133],
        [110.8656, 476.6021, 580.0000, 620.0000],
        [490.3259, 371.3167, 496.1457, 393.5844],
        [403.1678, 254.0272, 411.1543, 316.9261],
        [255.5108, 146.3589, 322.1245, 384.9491],
        [  6.3780,   0.0000, 317.1211, 422.5948],
        [497.1475, 372.6072, 523.1251, 380.2904],
        [448.6642, 143.2047, 580.0000, 365.4225],
        [ 35.0093, 392.8786, 217.1972, 620.0000],
        [496.7905, 346.5901, 522.3054, 377.3755],
        [  4.0886, 148.9180,  65.3176, 620.0000],
        [ 46.0636, 389.9944, 162.5583, 423.9978],
        [509.6417, 377.1959, 562.0063, 387.4168],
        [247.4742, 246.4299, 270.3476, 413.3393],
        [511.0940, 443.1800, 580.0000, 472.2426],
        [  3.4275, 211.6860,  71.5894, 447.3386],
        [493.6158, 373.0334, 501.5971, 386.6621],
        [491.8876, 380.9993, 505.5616, 390.8474],
        [234.0815, 389.1766, 260.1580, 423.7631],
        [219.6255, 193.5732, 314.6744, 338.8667],
        [277.3090, 341.4319, 306.5358, 445.0870],
        [207.4244, 214.3731, 243.2476, 325.8946],
        [373.3044, 229.2868, 413.4512, 337.2397],
        [386.5466, 376.0895, 575.7539, 441.1465],
        [266.3353, 269.5884, 580.0000, 488.0046],
        [453.1906, 379.2785, 461.6151, 390.1317],
        [530.6581, 234.1370, 580.0000, 268.2140],
        [277.4349, 158.0855, 341.2065, 392.5967],
        [456.1600, 368.0881, 463.8938, 379.4635],
        [202.2168, 182.6642, 287.5092, 305.6806],
        [204.0509,  24.7877, 314.8493, 358.7611],
        [347.6951, 431.7468, 421.4155, 481.5295],
        [ 49.5914, 142.2849, 401.4806, 346.4820],
        [271.4739,  53.1259, 328.7435,  86.3012],
        [505.5680, 335.6811, 579.0605, 377.9692],
        [471.0669, 116.7345, 528.4269, 289.6050],
        [396.8392, 255.4695, 406.2174, 316.7011],
        [ 91.1898, 179.1382, 193.6283, 327.8245],
        [461.2997,  96.6441, 560.4382, 489.6231],
        [175.1966,   9.9618, 578.2646, 143.8588],
        [372.1141, 610.3049, 522.4876, 620.0000],
        [331.8203, 247.1843, 441.0966, 355.5180],
        [345.6376, 412.5702, 413.1440, 461.4788],
        [508.3494, 105.3887, 557.5039, 383.9501],
        [495.4020, 377.5519, 521.1812, 387.7235],
        [281.8093, 346.1822, 352.3070, 489.6473],
        [276.7371, 247.1393, 321.6436, 510.5059],
        [392.1492, 302.0256, 572.5671, 620.0000],
        [497.5034, 374.3366, 526.1511, 383.4103],
        [542.1813, 380.0844, 580.0000, 391.6952],
        [496.7894, 335.7095, 512.9001, 373.5984],
        [490.6134, 371.9872, 512.5731, 380.0279]], device='cuda:1'), 'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1], device='cuda:1'), 'scores': tensor([0.9183, 0.9150, 0.8968, 0.8913, 0.8883, 0.8803, 0.8702, 0.8692, 0.8649,
        0.8634, 0.8566, 0.8565, 0.8558, 0.8552, 0.8541, 0.8529, 0.8521, 0.8502,
        0.8457, 0.8447, 0.8443, 0.8442, 0.8439, 0.8406, 0.8364, 0.8328, 0.8325,
        0.8310, 0.8301, 0.8275, 0.8269, 0.8267, 0.8266, 0.8261, 0.8259, 0.8253,
        0.8253, 0.8248, 0.8242, 0.8235, 0.8207, 0.8193, 0.8185, 0.8171, 0.8166,
        0.8153, 0.8148, 0.8147, 0.8143, 0.8133, 0.8121, 0.8110, 0.8108, 0.8107,
        0.8106, 0.8082, 0.8080, 0.8075, 0.8074, 0.8066, 0.8056, 0.8043, 0.8042,
        0.8040, 0.8036, 0.8034, 0.8034, 0.8027, 0.8018, 0.8018, 0.8017, 0.8015,
        0.8012, 0.8001, 0.7997, 0.7991, 0.7977, 0.7973, 0.7958, 0.7958, 0.7957,
        0.7954, 0.7953, 0.7937, 0.7936, 0.7919, 0.7911, 0.7903, 0.7900, 0.7899,
        0.7897, 0.7883, 0.7873, 0.7869, 0.7868, 0.7863, 0.7859, 0.7855, 0.7840,
        0.7833], device='cuda:1'), 'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:1')}]

Solution

__init__에 model.train()을 선언해두었지만 효과가 없었고, every step마다 self.model.train()을 해주니 loss가 정상적으로 출력되었다.

def validation_step(self, batch, batch_idx):
    self.model.train()
    images,targets = batch
    x = list(image for image in images)
    y = []
    for idx in range(len(x)):
        dict = {}
        for key in self.keys:
            dict[key] = targets[key][idx]
        y.append(dict)
    #y = [{k: v for k, v in t.items()} for t in targets]
    loss_dict = self.model(x,y)
    print(loss_dict)
    losses = sum(loss for loss in loss_dict.values())
    self.log("validation_loss", loss_dict, on_step=True, on_epoch=True, sync_dist=True)
    return losses

'''
{'loss_classifier': tensor(0.8349, device='cuda:1'), 'loss_box_reg': tensor(0.1718, device='cuda:1'), 'loss_mask': tensor(2.8173, device='cuda:1'), 'loss_objectness': tensor(0.0024, device='cuda:1'), 'loss_rpn_box_reg': tensor(0.0054, device='cuda:1')}
'''