2016-07-02 2 views
1

Недавно я попытался получить матрицу замешательства для одной из моих обучаемых моделей, чтобы увидеть, насколько она точна. Я загрузил этот скрипт и накормил свою модель. К моему удивлению, точность, рассчитанная по сценарию, сильно отличается от точности, Caffe отчетов.Является ли сообщение Caffe достоверным?

Я использовал this script вычислить матрицу путаницы, это, однако, сообщает точность, а также, проблемы является точность, выдаваемой этим сценарием является способом отличается, что один сообщило Caffe!
Например, Caffe сообщает, что точность позволяет указывать для CIFAR10, как 92,34%, тогда как, когда модель загружается в скрипт для вычисления матрицы путаницы и ее точности, это приводит, например, к чему-то вроде 86,5%.!

Какой один из этих точностей являются правильным, и могут быть представлены в работах или по сравнению с результатами других работ, таких как те, here?

Я также увидел что-то странное снова, я обучил две идентичные модели, только с одной разницей, что один использовал xavier, а другой использовал msra для инициализации.
Первый сообщает о точности 94,25, а остальные отчеты 94,26 в кофе. когда эти модели загружаются в скрипт, который я связывал выше, для вычисления матрицы смешения. их точность составила 89,2% и 87,4% соответственно!
Это нормально? в чем причина этого? MSRA?

Я действительно не знаю, достоверны ли данные о кофе, правда, достойны и надежны. Я был бы признателен, если бы кто-нибудь мог пролить свет на этот вопрос.

P.N: Точность в сценарии рассчитывается как (complete script):

for i, image, label in reader: 
     image_caffe = image.reshape(1, *image.shape) 
     out = net.forward_all(data=np.asarray([ image_caffe ])) 
     plabel = int(out['prob'][0].argmax(axis=0)) 

     count += 1 
     iscorrect = label == plabel 
     correct += (1 if iscorrect else 0) 
     matrix[(label, plabel)] += 1 
     labels_set.update([label, plabel]) 

     if not iscorrect: 
      print("\rError: i=%s, expected %i but predicted %i" \ 
        % (i, label, plabel)) 

     sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) 
     sys.stdout.flush() 

    print(", %i/%i corrects" % (correct, count)) 

Какой IMHO это нормально и правильно. количество правильных прогнозов, деленное на общее количество экземпляров в наборе данных.

ответ

1

Я нашел причину. Причина несоответствия между точностью, полученной Caffe, и точностью, сгенерированной рассматриваемым скриптом, была вызвана исключительно вычитанием, которое было сделано в caffe, а не в скрипте. This - это модифицированная версия скрипта, которая учитывает это и, надеюсь, все в порядке.

# Author: Axel Angel, copyright 2015, license GPLv3. 
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction 
# Seyyed Hossein Hasan Pour 
# [email protected] 
# 7/3/2016 

import sys 
import caffe 
import numpy as np 
import lmdb 
import argparse 
from collections import defaultdict 

def flat_shape(x): 
    "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" 
    return x.reshape(filter(lambda s: s > 1, x.shape)) 

def lmdb_reader(fpath): 
    import lmdb 
    lmdb_env = lmdb.open(fpath) 
    lmdb_txn = lmdb_env.begin() 
    lmdb_cursor = lmdb_txn.cursor() 

    for key, value in lmdb_cursor: 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.ParseFromString(value) 
     label = int(datum.label) 
     image = caffe.io.datum_to_array(datum).astype(np.uint8) 
     yield (key, flat_shape(image), label) 

def leveldb_reader(fpath): 
    import leveldb 
    db = leveldb.LevelDB(fpath) 

    for key, value in db.RangeIter(): 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.ParseFromString(value) 
     label = int(datum.label) 
     image = caffe.io.datum_to_array(datum).astype(np.uint8) 
     yield (key, flat_shape(image), label) 

def npz_reader(fpath): 
    npz = np.load(fpath) 

    xs = npz['arr_0'] 
    ls = npz['arr_1'] 

    for i, (x, l) in enumerate(np.array([ xs, ls ]).T): 
     yield (i, x, l) 

if __name__ == "__main__": 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--proto', type=str, required=True) 
    parser.add_argument('--model', type=str, required=True) 
    parser.add_argument('--mean', type=str, required=True) 
    group = parser.add_mutually_exclusive_group(required=True) 
    group.add_argument('--lmdb', type=str, default=None) 
    group.add_argument('--leveldb', type=str, default=None) 
    group.add_argument('--npz', type=str, default=None) 
    args = parser.parse_args() 

# Extract mean from the mean image file 
    mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() 
    f = open(args.mean, 'rb') 
    mean_blobproto_new.ParseFromString(f.read()) 
    mean_image = caffe.io.blobproto_to_array(mean_blobproto_new) 
    f.close() 

    count = 0 
    correct = 0 
    matrix = defaultdict(int) # (real,pred) -> int 
    labels_set = set() 

    # CNN reconstruction and loading the trained weights 
    net = caffe.Net(args.proto, args.model, caffe.TEST) 
    caffe.set_mode_cpu() 
    print "args", vars(args) 
    if args.lmdb != None: 
     reader = lmdb_reader(args.lmdb) 
    if args.leveldb != None: 
     reader = leveldb_reader(args.leveldb) 
    if args.npz != None: 
     reader = npz_reader(args.npz) 

    for i, image, label in reader: 
     image_caffe = image.reshape(1, *image.shape) 
     out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image) 
     plabel = int(out['prob'][0].argmax(axis=0)) 

     count += 1 
     iscorrect = label == plabel 
     correct += (1 if iscorrect else 0) 
     matrix[(label, plabel)] += 1 
     labels_set.update([label, plabel]) 

     if not iscorrect: 
      print("\rError: i=%s, expected %i but predicted %i" \ 
        % (i, label, plabel)) 

     sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) 
     sys.stdout.flush() 

    print(", %i/%i corrects" % (correct, count)) 

    print "" 
    print "Confusion matrix:" 
    print "(r , p) | count" 
    for l in labels_set: 
     for pl in labels_set: 
      print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)]) 
Смежные вопросы