fork download
  1. import itertools
  2. import math
  3. import random
  4.  
  5.  
  6. def rank(p: list[int], elements: set[int]) -> int:
  7. """Rank a permutation of k out of n things.
  8.  
  9. Args:
  10. p: The permutation to rank
  11. elements: The elements to choose from
  12.  
  13. Returns:
  14. The rank of the permutation within all choices of
  15. k out of n things
  16. """
  17.  
  18. k = len(p)
  19.  
  20. if k == 1:
  21. return len(list(e for e in elements if e < p[0]))
  22.  
  23. head, tail = p[0], p[1:]
  24.  
  25. n = len(elements) - 1
  26. num_preceding = len([e for e in elements if e < head])
  27.  
  28. preceding = num_preceding * (
  29. math.factorial(n)
  30. / math.factorial(n - k + 1)
  31. )
  32.  
  33. return int(preceding) + rank(tail, elements - {head})
  34.  
  35.  
  36. def _update_p(p: list[int], i: int) -> None:
  37. """Update a candidate unranked permutation
  38.  
  39. Args:
  40. p: The permutation to update
  41. i: The index of the leftmost element to attempt to
  42. increment
  43.  
  44. Returns:
  45. None
  46. """
  47.  
  48. p[i] += 1
  49. prev = p[:i]
  50. while p[i] in prev:
  51. p[i] += 1
  52. for j in range(i + 1, len(p)):
  53. p[j] = 0
  54. prev = p[:j]
  55. while p[j] in prev:
  56. p[j] += 1
  57.  
  58.  
  59. def unrank(index, k, n):
  60. """Unrank a permutation of k out of n things.
  61.  
  62. Args:
  63. index: The rank of the permutation to unrank
  64. k: The cardinality of the permutation
  65. n: The size of the set the permutation is chosen
  66. from, which is enumerated consecutively from zero
  67.  
  68. Returns:
  69. The unranked permutation
  70. """
  71.  
  72. elements = sorted(range(n))
  73. p = elements[:k]
  74. for i in range(k):
  75. if p[i] == n - 1:
  76. continue
  77. while True:
  78. r = rank(p, set(elements))
  79. if r > index:
  80. break
  81. prev = p[:]
  82. _update_p(p, i)
  83. p = prev
  84. return p
  85.  
  86.  
  87. # Test rank and unrank by checking reversibility
  88. num_tests = 1
  89. max_n = 800
  90. max_k = 24
  91.  
  92.  
  93. for _ in range(num_tests):
  94. n = random.randint(5, max_n)
  95. k = random.randint(2, min(max_k, n))
  96.  
  97. p = random.sample(range(n), k)
  98. elements = set(range(n))
  99.  
  100. rank_p = rank(p, elements)
  101. unranked = unrank(rank_p, k, n)
  102.  
  103. if unranked != p:
  104. print((p, n, rank_p, unranked, rank(unranked, elements)))
  105. break
  106.  
  107. print('Done.')
Success #stdin #stdout 0.34s 14388KB
stdin
Standard input is empty
stdout
Done.