USAMTS Checker App
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
6.2 KiB

10 months ago
  1. #!/usr/bin/env python3
  2. import io
  3. import base64
  4. from matplotlib import pyplot as plt
  5. import numpy
  6. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  7. from matplotlib.figure import Figure
  8. class Grid:
  9. def __init__(self, M, N):
  10. self.N = N
  11. self.M = M
  12. self.grid = [[0 for _ in range(M+N-1)] for _ in range(M+N-1)]
  13. def check_bounds(self, x, y):
  14. return 1 <= x <= self.M+self.N-1 and\
  15. 1 <= y <= self.M+self.N-1 and\
  16. self.M+1 <= x+y <= 2*self.M + self.N - 1
  17. def get(self, x, y):
  18. if not self.check_bounds(x, y):
  19. return 0
  20. return self.grid[x-1][y-1]
  21. def select(self, x, y):
  22. assert self.check_bounds(x, y)
  23. self.grid[x-1][y-1] = 1
  24. def grid_parity(self):
  25. for x in range(1, self.M+self.N):
  26. nums = [self.get(x, y) for y in range(1, self.N + self.M)]
  27. if sum(nums) % 2 != 1:
  28. print(f"Found a contradiction at x={x}! {nums}")
  29. return False
  30. for y in range(1, self.M+self.N):
  31. nums = [self.get(x, y) for x in range(1, self.N + self.M)]
  32. if sum(nums) % 2 != 1:
  33. print(f"Found a contradiction at y={y}! {nums}")
  34. return False
  35. for s in range(self.M+1, 2*self.M+self.N):
  36. nums = [self.get(x, s-x) for x in range(1, s)]
  37. if sum(nums) % 2 != 1:
  38. print(f"Found a contradiction at x+y={s}! {nums}")
  39. return False
  40. return True
  41. def reflect(self):
  42. res = [[0 for _ in range(len(self.grid[i]))]
  43. for i in range(len(self.grid))]
  44. for x in range(self.M+self.N-1):
  45. # reflect across x + y = m + n - 2
  46. for y in range(self.M+self.N-1):
  47. # sm = 2 * (self.M+self.N-2) - (x+y)
  48. # diff = x - y
  49. res[self.M+self.N-2-x][self.M+self.N-2-y] = self.grid[x][y]
  50. self.grid = res
  51. tmp = self.M
  52. self.M = self.N
  53. self.N = tmp
  54. def plot(self) -> tuple[str, bool]:
  55. fig = Figure()
  56. ax = fig.add_subplot(1, 1, 1)
  57. ax.set_xticks(numpy.arange(1, self.M+self.N, 1))
  58. ax.set_yticks(numpy.arange(1, self.M+self.N, 1))
  59. ax.set_aspect("equal")
  60. ax.set_xbound(0, self.M+self.N)
  61. ax.set_ybound(0, self.M+self.N)
  62. ax.autoscale(enable=False)
  63. x = []
  64. y = []
  65. for i in range(1, self.M+self.N):
  66. for j in range(1, self.M+self.N):
  67. if self.get(i, j):
  68. x.append(i)
  69. y.append(j)
  70. ax.scatter(x, y, color='b')
  71. # plt.title(f"Construction for M={self.M}, N={self.N}")
  72. ax.grid()
  73. ax.plot([1, self.M], [self.M, 1], color='r')
  74. ax.plot([self.M, self.N+self.M-1],
  75. [self.N+self.M-1, self.M], color='r')
  76. ax.plot([1, 1, self.M], [self.M, self.M +
  77. self.N-1, self.M+self.N-1], color='r')
  78. ax.plot([self.M, self.M+self.N-1, self.M+self.N-1],
  79. [1, 1, self.M], color='r')
  80. pngImage = io.BytesIO()
  81. FigureCanvas(fig).print_png(pngImage)
  82. pngImageB64String = "data:image/png;base64,"
  83. pngImageB64String += base64.b64encode(
  84. pngImage.getvalue()).decode('utf8')
  85. return pngImageB64String, self.grid_parity()
  86. def one(M, N) -> Grid:
  87. assert (N-M) % 4 == 0
  88. assert M == 1 or N == 1
  89. reflect = M == 1
  90. if reflect:
  91. tmp = M
  92. M = N
  93. N = tmp
  94. g = Grid(M, N)
  95. for k in range(1, M+1):
  96. g.select(k, M+1-k)
  97. for k in range(1, M//2+1):
  98. g.select((M+1)//2, (M+1)//2+k)
  99. g.select(M, (M+1)//2+k)
  100. if reflect:
  101. g.reflect()
  102. return g
  103. def cong_mod4(M, N) -> Grid:
  104. assert (N-M) % 4 == 0 and N != 1 and M != 1
  105. reflect = M > N
  106. if reflect:
  107. tmp = M
  108. M = N
  109. N = tmp
  110. g = Grid(M, N)
  111. if M % 2 == 1:
  112. g.select(M, M)
  113. # Main Axis points
  114. for k in range(1, N//2+1):
  115. g.select(M, M+2*k-1)
  116. for k in range(1, (M-1)//2+1):
  117. g.select(M, M-2*k)
  118. for k in range(1, (N-1)//2+1):
  119. g.select(M+2*k, M)
  120. for k in range(1, M//2+1):
  121. g.select(M-2*k+1, M)
  122. # Points on the diagonal
  123. for k in range(1, M//2 + 1):
  124. g.select(M+2*k-1, M-2*k+1) # going down
  125. for k in range(1, (M-1)//2+1):
  126. g.select(M-2*k, M+2*k) # going up
  127. for k in range(1, (N-M)//2+1):
  128. g.select(
  129. M+2*(M//2) - 1 + 2*k,
  130. M-2*(M//2) + 2,
  131. )
  132. g.select(
  133. M - 2*((M-1)//2) + 1,
  134. M + 2*((M-1)//2) + 2*k,
  135. )
  136. if reflect:
  137. g.reflect()
  138. return g
  139. def mod_0_1(M, N):
  140. assert (M+N) % 4 == 1 and M % 4 in [0, 1]
  141. reflect = (M % 4) == 1
  142. if reflect:
  143. tmp = M
  144. M = N
  145. N = tmp
  146. g = Grid(M, N)
  147. g.select(M, M)
  148. # Main Axis Points
  149. for k in range(1, N//2+1):
  150. g.select(M, M+2*k-1)
  151. for k in range(1, (M-1)//2+1):
  152. g.select(M, M-2*k)
  153. for k in range(1, (N-1)//2+1):
  154. g.select(M+2*k, M)
  155. for k in range(1, M//2+1):
  156. g.select(M-2*k+1, M)
  157. # Step 3
  158. for k in range(0, M//2):
  159. g.select(M-2*k, 1+2*k)
  160. # Tail
  161. for k in range(1, (N-1)//2+1):
  162. g.select(2, M+2*k)
  163. g.select(M+2*k-1, 3)
  164. if reflect:
  165. g.reflect()
  166. return g
  167. def mod_2_3(M, N):
  168. assert (M+N) % 4 == 1 and M % 4 in [2, 3]
  169. reflect = (M % 4) == 2
  170. if reflect:
  171. tmp = M
  172. M = N
  173. N = tmp
  174. g = Grid(M, N)
  175. for k in range(1, M+1):
  176. g.select(k, M+1-k)
  177. for k in range((M+1)//2+1, M+1):
  178. g.select((M+1)//2, k)
  179. g.select(M+1, k)
  180. g.select((M+1)//2, M+1)
  181. for k in range(1, (N-2)//2 + 1):
  182. g.select(M+1, M+2*k)
  183. g.select(M+2*k, M)
  184. g.select(1, M+1+2*k)
  185. g.select(M+1+2*k, 1)
  186. if reflect:
  187. g.reflect()
  188. return g
  189. def construct(M, N) -> 'Grid | None':
  190. if (M-N) % 4 == 0:
  191. if M != 1 and N != 1:
  192. return cong_mod4(M, N)
  193. else:
  194. return one(M, N)
  195. elif (M+N) % 4 == 1:
  196. if (M % 4) in [0, 1]:
  197. return mod_0_1(M, N)
  198. else:
  199. return mod_2_3(M, N)
  200. else:
  201. assert ((N+M-1)*(N-M)) % 4 == 2
  202. return None