2015-06-25 2 views
1

Я пытаюсь обучить нейронную сеть, чтобы узнать функцию y = x1 + x2 + x3. Цель состоит в том, чтобы поиграть с Caffe, чтобы лучше узнать и понять его. Необходимые данные синтетически генерируются в питоне и записываются в память в виде файла базы данных lmdb.Caffe: Чрезвычайно высокие потери при изучении простых линейных функций

Код для генерации данных:

import numpy as np 
import lmdb 
import caffe 

Ntrain = 100 
Ntest = 20 
K = 3 
H = 1 
W = 1 

Xtrain = np.random.randint(0,1000, size = (Ntrain,K,H,W)) 
Xtest = np.random.randint(0,1000, size = (Ntest,K,H,W)) 

ytrain = Xtrain[:,0,0,0] + Xtrain[:,1,0,0] + Xtrain[:,2,0,0] 
ytest = Xtest[:,0,0,0] + Xtest[:,1,0,0] + Xtest[:,2,0,0] 

env = lmdb.open('expt/expt_train') 

for i in range(Ntrain): 
    datum = caffe.proto.caffe_pb2.Datum() 
    datum.channels = Xtrain.shape[1] 
    datum.height = Xtrain.shape[2] 
    datum.width = Xtrain.shape[3] 
    datum.data = Xtrain[i].tobytes() 
    datum.label = int(ytrain[i]) 
    str_id = '{:08}'.format(i) 

    with env.begin(write=True) as txn: 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 


env = lmdb.open('expt/expt_test') 

for i in range(Ntest): 
    datum = caffe.proto.caffe_pb2.Datum() 
    datum.channels = Xtest.shape[1] 
    datum.height = Xtest.shape[2] 
    datum.width = Xtest.shape[3] 
    datum.data = Xtest[i].tobytes() 
    datum.label = int(ytest[i]) 
    str_id = '{:08}'.format(i) 

    with env.begin(write=True) as txn: 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

Solver.prototext файл:

net: "expt/expt.prototxt" 

display: 1 
max_iter: 200 
test_iter: 20 
test_interval: 100 

base_lr: 0.000001 
momentum: 0.9 
# weight_decay: 0.0005 

lr_policy: "inv" 
# gamma: 0.5 
# stepsize: 10 
# power: 0.75 

snapshot_prefix: "expt/expt" 
snapshot_diff: true 

solver_mode: CPU 
solver_type: SGD 

debug_info: true 

Caffe модель:

name: "expt" 


layer { 
    name: "Expt_Data_Train" 
    type: "Data" 
    top: "data" 
    top: "label"  

    include { 
     phase: TRAIN 
    } 

    data_param { 
     source: "expt/expt_train" 
     backend: LMDB 
     batch_size: 1 
    } 
} 


layer { 
    name: "Expt_Data_Validate" 
    type: "Data" 
    top: "data" 
    top: "label"  

    include { 
     phase: TEST 
    } 

    data_param { 
     source: "expt/expt_test" 
     backend: LMDB 
     batch_size: 1 
    } 
} 


layer { 
    name: "IP" 
    type: "InnerProduct" 
    bottom: "data" 
    top: "ip" 

    inner_product_param { 
     num_output: 1 

     weight_filler { 
      type: 'constant' 
     } 

     bias_filler { 
      type: 'constant' 
     } 
    } 
} 


layer { 
    name: "Loss" 
    type: "EuclideanLoss" 
    bottom: "ip" 
    bottom: "label" 
    top: "loss" 
} 

Потери на тестовых данных, которые я получаю 233,655. Это шокирует, поскольку потеря на три порядка больше, чем числа в наборах учебных и тестовых данных. Кроме того, функция, которую нужно изучить, представляет собой простую линейную функцию. Я не могу понять, что не так в коде. Любые предложения/вклады очень ценятся.

ответ

1

генерируется Потеря много в этом случае, поскольку Caffe принимает только данные (т.е. datum.data) в формате uint8 и этикеток (datum.label) в формате int32. Тем не менее, для ярлыков формат numpy.int64 также работает. Я думаю, что datum.data принимается только в формате uint8, потому что Caffe был разработан в первую очередь для задач Computer Vision, где входы представляют собой изображения, которые имеют значения RGB в диапазоне [0,255]. uint8 может записывать это, используя наименьший объем памяти. Я сделал следующие изменения в код генерации данных:

Xtrain = np.uint8(np.random.randint(0,256, size = (Ntrain,K,H,W))) 
Xtest = np.uint8(np.random.randint(0,256, size = (Ntest,K,H,W))) 

ytrain = int(Xtrain[:,0,0,0]) + int(Xtrain[:,1,0,0]) + int(Xtrain[:,2,0,0]) 
ytest = int(Xtest[:,0,0,0]) + int(Xtest[:,1,0,0]) + int(Xtest[:,2,0,0]) 

После игры вокруг с чистыми параметрами (скорость обучения, количество итераций и т.д.) Я получаю сообщение об ошибке порядка 10^(- 6), который я считаю довольно хорошим!

Смежные вопросы