Вот моя модель:

filters = 256
kernel_size = 3
strides = 1
factor = 4  # the factor of upscaling

inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))
conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)

res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv1, res])

for i in range(15):  # 16-1
    res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
    act = ReLU()(res1)
    res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
    res_rec = Add()([res_rec, res2])

conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
a = Add()([conv1, conv2])
up = UpSampling2D(size=4)(a)
outputLayer = Conv2D(filters=3,
                     kernel_size=1,
                     strides=1,
                     padding='same')(up)

model = Model(inputs=inputLayer, outputs=outputLayer)

model.summary() показывает:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 350, 350, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 350, 350, 256 7168        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 350, 350, 256 590080      conv2d_1[0][0]                   
__________________________________________________________________________________________________
re_lu_1 (ReLU)                  (None, 350, 350, 256 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 350, 350, 256 590080      re_lu_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 350, 350, 256 590080      add_1[0][0]                      
__________________________________________________________________________________________________
re_lu_2 (ReLU)                  (None, 350, 350, 256 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 350, 350, 256 590080      re_lu_2[0][0]                    
__________________________________________________________________________________________________
add_2 (Add)                     (None, 350, 350, 256 0           add_1[0][0]                      
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 350, 350, 256 590080      add_2[0][0]                      
__________________________________________________________________________________________________
re_lu_3 (ReLU)                  (None, 350, 350, 256 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 350, 350, 256 590080      re_lu_3[0][0]                    
__________________________________________________________________________________________________
add_3 (Add)                     (None, 350, 350, 256 0           add_2[0][0]                      
                                                                 conv2d_7[0][0]                   

 ...... this goes on for a long time .....



 __________________________________________
add_15 (Add)                    (None, 350, 350, 256 0           add_14[0][0]                     
                                                                 conv2d_31[0][0]                  
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 350, 350, 256 590080      add_15[0][0]                     
__________________________________________________________________________________________________
re_lu_16 (ReLU)                 (None, 350, 350, 256 0           conv2d_32[0][0]                  
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 350, 350, 256 590080      re_lu_16[0][0]                   
__________________________________________________________________________________________________
add_16 (Add)                    (None, 350, 350, 256 0           add_15[0][0]                     
                                                                 conv2d_33[0][0]                  
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 350, 350, 256 590080      add_16[0][0]                     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================
Total params: 19,480,579
Trainable params: 19,480,579
Non-trainable params: 0
__________________________________________________________________________________________________
None

Важная часть находится в самом конце, рядом с выходом:

__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================

Теперь взгляните на ошибку, которую я получаю при запуске сети:

Traceback (most recent call last):
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>
    setUpImages()
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages
    setUpData(trainData, testData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData
    setUpModel(X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel
    train(model, X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train
    batch_size=32)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit
    batch_size=batch_size)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data
    exception_prefix='target')
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
    str(data_shape))
ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)

Почему моя последняя свертка ожидает иметь тензор (1400, 1400, 1), но получает тензор (1400, 1400, 3), в то время как в сводке говорится, что UpSampling2D должен возвращать тензор (1400, 1400, 2)?

Чтобы немного прояснить контекст: предполагается, что это сеть, которая принимает изображение размером 350x350x3 и выводит изображение размером 1400x1400x3.

2
payne 22 Сен 2018 в 05:19

1 ответ

Лучший ответ

Таким образом, очевидно, что сообщение об ошибке не относилось конкретно к объекту conv2d_35, а скорее к последнему объекту сети, который был связан с моей функцией потерь.

Поскольку я выбрал sparse_categorical_crossentropy в качестве функции потерь, он ожидал одномерного вектора.

Установка потерь как mean_squared_error исправила это.

0
payne 22 Сен 2018 в 18:58