blog/static/files/ib002/postcondition-ambiguity/test_sort.py

140 lines
3.4 KiB
Python
Raw Normal View History

from hypothesis import given, settings
from hypothesis.strategies import integers, lists
import pytest
import itertools
def compare_by_value(pair):
"""
Takes a pair from enumerate in form: `index, value` and returns the value
Args:
pair (Tuple[int, T]): Index and element from the iterator at the index.
Returns:
Element from the pair.
"""
index, value = pair
return value
def maximum(arr, n):
"""
Finds index of maximum element of first n elements in list.
Args:
arr (List): List of values.
n (int): Upper bound of index where to search for maximum (inclusive).
Returns:
Index of a maximum element in `arr[:n]`.
"""
first_n_elements = itertools.islice(enumerate(arr), n)
index, value = max(first_n_elements, key=compare_by_value)
return index
# Precondition: n = |A|
# Loop invariant:
# A[i + 1 : n] is sorted AND
# all elements of A[i + 1 : n] are bigger or equal to the other elements
# Postcondition: A is sorted
def select_sort(arr, n):
"""
Sorts list `arr` using select sort algorithm.
Args:
arr (List): List of values.
n (int): Size of the list.
Returns:
Sorted list `arr`.
"""
assert n == len(arr)
check_loop_invariant(arr, n, n)
for i in reversed(range(1, n)):
j = maximum(arr, i + 1)
arr[i], arr[j] = arr[j], arr[i]
check_loop_invariant(arr, n, i)
return arr
def broken_select_sort(arr, n):
assert n == len(arr)
if not arr:
return
max_value = max(arr)
check_loop_invariant(arr, n, n)
for i in reversed(range(n)):
arr[i] = max_value + i
check_loop_invariant(arr, n, i)
return arr
def check_loop_invariant(arr, n, i):
# A[i + 1 : n] is sorted
for x, y in zip(itertools.islice(arr, i + 1, n), itertools.islice(arr, i + 2, n)):
assert x <= y
# all elements of A[i + 1 : n] are bigger or equal to the other elements
if i + 1 >= n:
# in case there are no elements
return
# otherwise, since the "tail" is sorted, we can assume that it is enough to
# check the other elements to the smallest value of the tail
smallest = arr[i + 1]
for element in itertools.islice(arr, i + 1):
assert smallest >= element
def check_vague_postcondition(original_arr, arr):
if not arr:
return
# check ordering
for x, y in zip(arr, itertools.islice(arr, 1, len(arr))):
assert x <= y
def check_postcondition(original_arr, arr):
if not arr:
return
# check ordering
for x, y in zip(arr, itertools.islice(arr, 1, len(arr))):
assert x <= y
# get counts from original list
original_counts = {}
for value in original_arr:
original_counts[value] = 1 + original_counts.get(value, 0)
# get counts from resulting list
counts = {}
for value in arr:
counts[value] = 1 + counts.get(value, 0)
# if arr is permutation of original_arr then all counts must be the same
assert counts == original_counts
@given(lists(integers()))
@settings(max_examples=1000)
@pytest.mark.parametrize(
"postcondition", [check_vague_postcondition, check_postcondition]
)
@pytest.mark.parametrize("sorting_function", [select_sort, broken_select_sort])
def test_select_sort(sorting_function, postcondition, numbers):
result = sorting_function(numbers[:], len(numbers))
postcondition(numbers, result)