2015-12-03 3 views
0

Следуя инструкциям here Я создал подкласс ndarray, который добавляет новые атрибуты в класс ndarray. Теперь я хочу определить оператор сравнения для нового класса, который помимо сравнения данных также сравнивает значения атрибутов. Так что я попытался это:Сравнение классов ndarray

def __eq__(self, other): 
    return (self._prop1 == other._prop1) and \ 
      (self._prop2 == other._prop2) and \ 
      (self.data == other.data) 

Это позволяет сравнивать как T1 == T2 и возвращает логическое значение. Однако, поскольку я хотел бы использовать эти массивы взаимозаменяемо с другими ndarrays, я бы хотел, чтобы сравнение возвращало булевский массив. Если я не определяю свою функцию __eq__, тогда сравнение возвращает логический массив, но тогда я не могу проверить атрибуты. Как я могу объединить эти два?

+0

Похож, что 'ndarray' возвращает скаляр True/False, если атрибуты (например, форма) не совпадают, а булевский массив - только если все они совпадают. С помощью одного или 'ifs' вы должны иметь возможность возвращать тесты атрибутов, если они терпят неудачу, а' else' возвращает тест 'data'. Позвольте себе несколько «возвратов». Так проще писать код. – hpaulj

+0

@hpaulj, разделяющее сравнение моих пользовательских атрибутов и базовых данных, является хорошим предложением. Однако проблема в том, что сравнение самих атрибутов данных возвращает логическое, а не массив. Как я могу вызвать элементарное сравнение numpy с моим оператором сравнения? – deepak

+1

Посмотрите, как маскированный массив обрабатывает это: 'np.ma.core.MaskedArray .__ eq__' – hpaulj

ответ

1

В соответствии с suggestion by hpaulj Я выяснил, как это сделать, посмотрев на np.ma.core.MaskedArray.__eq__. Вот минимальная реализация для справки. Основная идея - вызвать numpy __eq__() на вид self в виде базового класса DerivedArray.

class DerivedArray(np.ndarray): 
    def __new__(cls, input_array, prop1, prop2):  
     _baseclass = getattr(input_array, '_baseclass', type(input_array)) 
     obj = np.asarray(input_array).view(cls) 

     obj._prop1 = prop1 
     obj._prop2 = prop2 
     obj._baseclass = _baseclass 
     return obj 

    def __array_finalize__(self, obj): 
     if obj is None: 
      return 
     else: 
      if not isinstance(obj, np.ndarray): 
       _baseclass = type(obj) 
      else: 
       _baseclass = np.ndarray 

     self._prop1 = getattr(obj, '_prop1', None) 
     self._prop2 = getattr(obj, '_prop2', None) 
     self._baseclass= getattr(obj, '_baseclass', _baseclass) 

    def _get_data(self): 
     """Return the current data, as a view of the original 
     underlying data. 
     """ 
     return np.ndarray.view(self, self._baseclass) 

    _data = property(fget=_get_data) 
    data = property(fget=_get_data) 

    def __eq__(self, other): 
     attsame = (self._prop1 == other._prop1) and (self._prop2 == other._prop2) 
     if not attsame: return False 
     return self._data.__eq__(other) 
Смежные вопросы