Python 单元测试中针对浮点集合的 assertAlmostequals

Python 的单元测试框架中的 Assert几乎相等(x,y)方法测试 xy是否大致相等,假设它们是浮点数。

assertAlmostEqual()的问题在于它只能在浮动上工作。我正在寻找一个像 assertAlmostEqual()这样的方法,它可以处理浮点数列表、浮点数集合、浮点数字典、浮点数元组、浮点数元组列表、浮点数集合等等。

例如,让 x = 0.1234567890y = 0.1234567891xy几乎相等,因为除了最后一个数字之外,它们在每个数字上都是一致的。因此,self.assertAlmostEqual(x, y)True,因为 assertAlmostEqual()适用于浮点数。

我正在寻找一个更通用的 assertAlmostEquals(),它也评估以下对 True的调用:

  • self.assertAlmostEqual_generic([x, x, x], [y, y, y]).
  • self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}).
  • self.assertAlmostEqual_generic([(x,x)], [(y,y)]).

是否有这样的方法,还是我必须自己实现它?

澄清:

  • assertAlmostEquals()有一个名为 places的可选参数,通过计算整数为十进制 places的差值来比较这些数字。默认情况下 places=7,因此 self.assertAlmostEqual(0.5, 0.4)为 False,而 self.assertAlmostEqual(0.12345678, 0.12345679)为 True。我的推测性 assertAlmostEqual_generic()应该具有相同的功能。

  • 如果两个列表的数目几乎相等,且顺序完全相同,则认为它们几乎相等。

  • 类似地,如果可以将两个集合转换为几乎相等的列表(通过为每个集合分配顺序) ,则认为它们几乎相等。

  • 类似地,如果每个字典的键集几乎等于另一个字典的键集,并且对于每个这样的几乎相等的键对,有一个相应的几乎相等的值,则两个字典被认为是几乎相等的。

  • 一般来说: 如果两个集合相等,我认为它们几乎相等,除了一些对应的浮点数几乎相等。换句话说,我希望真正地比较对象,但是在比较沿途的浮点数时具有较低的(定制的)精度。

94197 次浏览

There is no such method, you'd have to do it yourself.

For lists and tuples the definition is obvious, but note that the other cases you mention aren't obvious, so it's no wonder such a function isn't provided. For instance, is {1.00001: 1.00002} almost equal to {1.00002: 1.00001}? Handling such cases requires making a choice about whether closeness depends on keys or values or both. For sets you are unlikely to find a meaningful definition, since sets are unordered, so there is no notion of "corresponding" elements.

You may have to implement it yourself, while its true that list and sets can be iterated the same way, dictionaries are a different story, you iterate their keys not values, and the third example seems a bit ambiguous to me, do you mean to compare each value within the set, or each value from each set.

heres a simple code snippet.

def almost_equal(value_1, value_2, accuracy = 10**-8):
return abs(value_1 - value_2) < accuracy


x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))

if you don't mind using NumPy (which comes with your Python(x,y)), you may want to look at the np.testing module which defines, among others, a assert_almost_equal function.

The signature is np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError:
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)

Here's how I've implemented a generic is_almost_equal(first, second) function:

First, duplicate the objects you need to compare (first and second), but don't make an exact copy: cut the insignificant decimal digits of any float you encounter inside the object.

Now that you have copies of first and second for which the insignificant decimal digits are gone, just compare first and second using the == operator.

Let's assume we have a cut_insignificant_digits_recursively(obj, places) function which duplicates obj but leaves only the places most significant decimal digits of each float in the original obj. Here's a working implementation of is_almost_equals(first, second, places):

from insignificant_digit_cutter import cut_insignificant_digits_recursively


def is_almost_equal(first, second, places):
'''returns True if first and second equal.
returns true if first and second aren't equal but have exactly the same
structure and values except for a bunch of floats which are just almost
equal (floats are almost equal if they're equal when we consider only the
[places] most significant digits of each).'''
if first == second: return True
cut_first = cut_insignificant_digits_recursively(first, places)
cut_second = cut_insignificant_digits_recursively(second, places)
return cut_first == cut_second

And here's a working implementation of cut_insignificant_digits_recursively(obj, places):

def cut_insignificant_digits(number, places):
'''cut the least significant decimal digits of a number,
leave only [places] decimal digits'''
if  type(number) != float: return number
number_as_str = str(number)
end_of_number = number_as_str.find('.')+places+1
if end_of_number > len(number_as_str): return number
return float(number_as_str[:end_of_number])


def cut_insignificant_digits_lazy(iterable, places):
for obj in iterable:
yield cut_insignificant_digits_recursively(obj, places)


def cut_insignificant_digits_recursively(obj, places):
'''return a copy of obj except that every float loses its least significant
decimal digits remaining only [places] decimal digits'''
t = type(obj)
if t == float: return cut_insignificant_digits(obj, places)
if t in (list, tuple, set):
return t(cut_insignificant_digits_lazy(obj, places))
if t == dict:
return {cut_insignificant_digits_recursively(key, places):
cut_insignificant_digits_recursively(val, places)
for key,val in obj.items()}
return obj

The code and its unit tests are available here: https://github.com/snakile/approximate_comparator. I welcome any improvement and bug fix.

If you don't mind using the numpy package then numpy.testing has the assert_array_almost_equal method.

This works for array_like objects, so it is fine for arrays, lists and tuples of floats, but does it not work for sets and dictionaries.

The documentation is here.

As of python 3.5 you may compare using

math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)

As described in pep-0485. The implementation should be equivalent to

abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )

An alternative approach is to convert your data into a comparable form by e.g turning each float into a string with fixed precision.

def comparable(data):
"""Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
if isinstance(data, (int, str)):
return data
if isinstance(data, float):
return '{:.4f}'.format(data)
if isinstance(data, list):
return [comparable(el) for el in data]
if isinstance(data, tuple):
return tuple([comparable(el) for el in data])
if isinstance(data, dict):
return {k: comparable(v) for k, v in data.items()}

Then you can:

self.assertEquals(comparable(value1), comparable(value2))

None of these answers work for me. The following code should work for python collections, classes, dataclasses, and namedtuples. I might have forgotten something, but so far this works for me.

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any




def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
"""
Compares two objects by recursively walking them trough. Equality is as usual except for floats.
Floats are compared according to the two measures defined below.


:param o1: The first object.
:param o2: The second object.
:param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
`abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
:param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
:return: Whether the two objects are almost equal.
"""
if type(o1) != type(o2):
return False


composite_type_passed = False


if hasattr(o1, '__slots__'):
if len(o1.__slots__) != len(o2.__slots__):
return False
if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
max_abs_ratio_diff, max_abs_diff)
for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
return False
else:
composite_type_passed = True


if hasattr(o1, '__dict__'):
if len(o1.__dict__) != len(o2.__dict__):
return False
if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for ((k1, v1), (k2, v2))
in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
if not k1.startswith('__')):  # avoid infinite loops
return False
else:
composite_type_passed = True


if isinstance(o1, dict):
if len(o1) != len(o2):
return False
if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
return False


elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
if len(o1) != len(o2):
return False
if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for v1, v2 in zip(o1, o2)):
return False


elif isinstance(o1, float):
if o1 == o2:
return True
else:
if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
if o2 != 0:
if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
return False
else:  # if both == 0, we already returned True
if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
return False
if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
return False
return True


else:
if not composite_type_passed:
return o1 == o2


return True




class EqualityTest(unittest.TestCase):


def test_floats(self) -> None:
o1 = ('hi', 3, 3.4)
o2 = ('hi', 3, 3.400001)
self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))


def test_ratio_only(self):
o1 = ['hey', 10000, 123.12]
o2 = ['hey', 10000, 123.80]
self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))


def test_diff_only(self):
o1 = ['hey', 10000, 1234567890.12]
o2 = ['hey', 10000, 1234567890.80]
self.assertTrue(are_almost_equal(o1, o2, -1, 1))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))


def test_both_ignored(self):
o1 = ['hey', 10000, 1234567890.12]
o2 = ['hey', 10000, 0.80]
o3 = ['hi', 10000, 0.80]
self.assertTrue(are_almost_equal(o1, o2, -1, -1))
self.assertFalse(are_almost_equal(o1, o3, -1, -1))


def test_different_lengths(self):
o1 = ['hey', 1234567890.12, 10000]
o2 = ['hey', 1234567890.80]
self.assertFalse(are_almost_equal(o1, o2, 1, 1))


def test_classes(self):
class A:
d = 12.3


def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c


o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))


o2.hello = 'hello'
self.assertFalse(are_almost_equal(o1, o2, -1, -1))


def test_namedtuples(self):
B = namedtuple('B', ['x', 'y'])
o1 = B(3.3, 4.4)
o2 = B(3.4, 4.5)
self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))


def test_classes_with_slots(self):
class C(object):
__slots__ = ['a', 'b']


def __init__(self, a, b):
self.a = a
self.b = b


o1 = C(3.3, 4.4)
o2 = C(3.4, 4.5)
self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))


def test_dataclasses(self):
@dataclass
class D:
s: str
i: int
f: float


@dataclass
class E:
f2: float
f4: str
d: D


o1 = E(12.3, 'hi', D('hello', 34, 20.01))
o2 = E(12.1, 'hi', D('hello', 34, 20.0))
self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))


o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
self.assertFalse(are_almost_equal(o2, o3, -1, -1))


def test_ordereddict(self):
o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))

I would still use self.assertEqual() for it stays the most informative when shit hits the fan. You can do that by rounding, eg.

self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

where round_tuple is

def round_tuple(t: tuple, ndigits: int) -> tuple:
return tuple(round(e, ndigits=ndigits) for e in t)


def round_list(l: list, ndigits: int) -> list:
return [round(e, ndigits=ndigits) for e in l]

According to the python docs (see https://stackoverflow.com/a/41407651/1031191) you can get away with rounding issues like 13.94999999, because 13.94999999 == 13.95 is True.

Use Pandas

Another way is to convert each of the two dicts etc into pandas dataframes and then use pd.testing.assert_frame_equal() to compare the two. I have used this successfully to compare lists of dicts.

Previous answers often don't work on structures involving dictionaries, but this one should. I haven't exhaustively tested this on highly nested structures, but imagine pandas would handle them correctly.

Example 1: compare two dicts

To illustrate this I will use your example data of a dict, since the other methods don't work with dicts. Your dict was:

x, y = 0.1234567890, 0.1234567891
{1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}

Then we can do:

pd.testing.assert_frame_equal(
pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index')   ,
pd.DataFrame.from_dict({1: y, 2: y, 3: y}, orient='index')   )

This doesn't raise an error, meaning that they are equal to a certain degree of precision.

However if we were to do

pd.testing.assert_frame_equal(
pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index')   ,
pd.DataFrame.from_dict({1: y, 2: y, 3: y + 1}, orient='index')   ) #add 1 to last value

then we are rewarded with the following informative message:

AssertionError: DataFrame.iloc[:, 0] (column name="0") are different


DataFrame.iloc[:, 0] (column name="0") values are different (33.33333 %)
[index]: [1, 2, 3]
[left]:  [0.123456789, 0.123456789, 0.123456789]
[right]: [0.1234567891, 0.1234567891, 1.1234567891]

For further details see pd.testing.assert_frame_equal documentation , particularly parameters check_exact, rtol, atol for info about how to specify required degree of precision either relative or actual.

Example 2: Nested dict of dicts

a = {i*10 : {1:1.1,2:2.1} for i in range(4)}
b = {i*10 : {1:1.1000001,2:2.100001} for i in range(4)}
# a = {0: {1: 1.1, 2: 2.1}, 10: {1: 1.1, 2: 2.1}, 20: {1: 1.1, 2: 2.1}, 30: {1: 1.1, 2: 2.1}}
# b = {0: {1: 1.1000001, 2: 2.100001}, 10: {1: 1.1000001, 2: 2.100001}, 20: {1: 1.1000001, 2: 2.100001}, 30: {1: 1.1000001, 2: 2.100001}}

and then do

pd.testing.assert_frame_equal(   pd.DataFrame(a), pd.DataFrame(b) )

- it doesn't raise an error: all values fairly similar. However, if we change a value e.g.

b[30][2] += 1
#  b = {0: {1: 1.1000001, 2: 2.1000001}, 10: {1: 1.1000001, 2: 2.1000001}, 20: {1: 1.1000001, 2: 2.1000001}, 30: {1: 1.1000001, 2: 3.1000001}}

and then run the same test, we get the following clear error message:

AssertionError: DataFrame.iloc[:, 3] (column name="30") are different


DataFrame.iloc[:, 3] (column name="30") values are different (50.0 %)
[index]: [1, 2]
[left]:  [1.1, 2.1]
[right]: [1.1000001, 3.1000001]

You can also recursively call the already present unittest.assertAlmostEquals() and keep track of what element you are comparing, by adding a method to your unittest.

E.g. for lists of lists and list of tuples of floats:

def assertListAlmostEqual(self, first, second, delta=None, context=None):
"""Asserts lists of lists or tuples to check if they compare and
shows which element is wrong when comparing two lists
"""
self.assertEqual(len(first), len(second), msg="List have different length")
context = [first, second] if context is None else context
for i in range(0, len(first)):
if isinstance(first[0], tuple):
context.append(i)
self.assertListAlmostEqual(first[i], second[i], delta, context=context)
if isinstance(first[0], list):
context.append(i)
self.assertListAlmostEqual(first[i], second[i], delta, context=context)
elif isinstance(first[0], float):
msg = "Difference in \n{} and \n{}\nFaulty element index={}".format(context[0], context[1], context[2:]+[i]) \
if context is not None else None
self.assertAlmostEqual(first[i], second[i], delta, msg=msg)

Outputs something like:

line 23, in assertListAlmostEqual
self.assertAlmostEqual(first[i], second[i], delta, msg=msg)
AssertionError: 5.0 != 6.0 within 7 places (1.0 difference) : Difference in
[(0.0, 5.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989090909092)] and
[(0.0, 6.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989)]
Faulty element index=[0, 1]

Looking at this myself, I used the addTypeEqualityFunc method of the UnitTest library in combination with math.isclose.

Sample setup:

import math
from unittest import TestCase


class SomeFixtures(TestCase):
@classmethod
def float_comparer(cls, a, b, msg=None):
if len(a) != len(b):
raise cls.failureException(msg)
if not all(map(lambda args: math.isclose(*args), zip(a, b))):
raise cls.failureException(msg)


def some_test(self):
self.addTypeEqualityFunc(list, self.float_comparer)
self.assertEqual([1.0, 2.0, 3.0], [1.0, 2.0, 3.0])