Module:Sandbox/あめろ

Template-info.svg 模块文档  [创建] [刷新]
  1. local fmt = string.format
  2. local type = type
  3. local ipairs = ipairs
  4. local get_mt = getmetatable
  5. local set_mt = setmetatable
  6. -- Schema
  7. local schema = {}
  8. local function is_raw_table(val)
  9. return type(val) == 'table' and not get_mt(val)
  10. end
  11. --- 过滤数组
  12. local function ifilter(t, filter)
  13. local new = {}
  14. local filtered_num = 0
  15. for i, v in ipairs(t) do
  16. if filter(v) then
  17. new[i - filtered_num] = v
  18. else
  19. filtered_num = filtered_num + 1
  20. end
  21. end
  22. return new, filtered_num
  23. end
  24. ---@param ty string 类型
  25. ---@param a string? 冠词
  26. ---@return function
  27. local function type_checker(ty, a)
  28. local fmt_str = "%s (type: %s) isn't "
  29. if a then
  30. fmt_str = fmt_str..a..' '
  31. end
  32. fmt_str = fmt_str..ty
  33. return function(self, testee)
  34. if type(testee) == ty then
  35. return true
  36. end
  37. return false, fmt(fmt_str, testee, type(testee))
  38. end
  39. end
  40. --- `val`是`x`或`{x}`时返回`{x}`,返回的`{x}`与前一个`{x}`是同一个对象
  41. local function ensure_wrapped(val)
  42. return is_raw_table(val) and val or {val}
  43. end
  44. local mts = set_mt({}, {__mode = 'k'}) ---@type {[metatable]: true}
  45. ---@param name string
  46. ---@param super_mt metatable?
  47. ---@param without_override boolean?
  48. ---@return metatable
  49. local function reg_mt(name, super_mt, without_override)
  50. local index = super_mt and super_mt.__index or {}
  51. if not without_override then
  52. index = set_mt({}, {__index = index})
  53. end
  54. local mt = {
  55. __name = name,
  56. __index = index,
  57. }
  58. mts[mt] = true
  59. return mt
  60. end
  61. ---@param v any
  62. ---@return string | nil
  63. local function get_scm_type(v)
  64. local mt = get_mt(v)
  65. if not mts[mt] then return nil end
  66. return mt.__name
  67. end
  68. ---@param constraints table
  69. ---@return table | nil
  70. local function get_validators_from_constraints(constraints)
  71. local t = ifilter(
  72. ensure_wrapped(constraints.validators or constraints.validator),
  73. function(v)
  74. return assert(type(v) == 'function', 'validator需要是函数或元素为函数的表')
  75. end
  76. )
  77. return t[1] and t or nil
  78. end
  79. ---@param validators function[]?
  80. ---@param val any
  81. ---@return boolean, string?
  82. local function validate_all(validators, val)
  83. if not validators then return true end
  84. for _, validator in ipairs(validators) do
  85. local valid, msg = validator(val)
  86. if not valid then
  87. return false, msg
  88. end
  89. end
  90. return true
  91. end
  92. local Any_mt = reg_mt('Any', nil)
  93. schema.Any = set_mt({
  94. test = function()
  95. return true
  96. end,
  97. }, Any_mt)
  98. function Any_mt:__call(constraints)
  99. return set_mt({
  100. super = self,
  101. validators = get_validators_from_constraints(constraints),
  102. }, get_mt(self))
  103. end
  104. function Any_mt.__index:test(testee)
  105. if self.super then
  106. local valid, msg = self.super:test(testee)
  107. if not valid then
  108. return false, msg
  109. end
  110. end
  111. return validate_all(self.validators, testee)
  112. end
  113. function Any_mt.__index:assert(testee)
  114. local valid, msg = self:test(testee)
  115. if valid then
  116. return true
  117. end
  118. error(msg, 2)
  119. end
  120. local Nil_mt = reg_mt('Nil', Any_mt, true)
  121. schema.Nil = set_mt({
  122. test = type_checker('nil'),
  123. }, Nil_mt)
  124. local Boolean_mt = reg_mt('Boolean', Any_mt, true)
  125. schema.Boolean = set_mt({
  126. test = type_checker('boolean', 'a'),
  127. }, Boolean_mt)
  128. local Number_mt = reg_mt('Number', Any_mt)
  129. schema.Number = set_mt({
  130. test = type_checker('number', 'a'),
  131. }, Number_mt)
  132. function Number_mt:__call(constraints)
  133. return set_mt({
  134. super = self,
  135. int = constraints.int,
  136. lt = constraints.lt,
  137. gt = constraints.gt,
  138. le = constraints.le or constraints.max,
  139. ge = constraints.ge or constraints.min,
  140. ne = constraints.ne,
  141. validators = get_validators_from_constraints(constraints),
  142. }, Number_mt)
  143. end
  144. function Number_mt.__index:test(testee)
  145. if self.super then
  146. local valid, msg = self.super:test(testee)
  147. if not valid then
  148. return false, msg
  149. end
  150. end
  151. if self.int and math.fmod(testee, 1) ~= 0 then
  152. return false, fmt("%s isn't an integer", testee)
  153. end
  154. if self.lt and testee >= self.lt then
  155. return false, fmt("%s isn't < %s", testee, self.lt)
  156. end
  157. if self.gt and testee <= self.gt then
  158. return false, fmt("%s isn't > %s", testee, self.gt)
  159. end
  160. if self.le and testee > self.le then
  161. return false, fmt("%s isn't <= %s", testee, self.le)
  162. end
  163. if self.ge and testee < self.ge then
  164. return false, fmt("%s isn't >= %s", testee, self.ge)
  165. end
  166. if self.ne and testee == self.ne then
  167. return false, fmt('testee equals %s', self.ne)
  168. end
  169. return validate_all(self.validators, testee)
  170. end
  171. local String_mt = reg_mt('String', Any_mt)
  172. schema.String = set_mt({
  173. test = type_checker('string', 'a')
  174. }, String_mt)
  175. function String_mt:__call(constraints)
  176. return set_mt({
  177. super = self,
  178. max_len = constraints.max_len,
  179. min_len = constraints.min_len,
  180. pattern = constraints.pattern,
  181. validators = get_validators_from_constraints(constraints),
  182. }, String_mt)
  183. end
  184. function String_mt.__index:test(testee)
  185. if self.super then
  186. local valid, msg = self.super:test(testee)
  187. if not valid then
  188. return false, msg
  189. end
  190. end
  191. if self.max_len and #testee > self.max_len then
  192. return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len)
  193. end
  194. if self.min_len and #testee < self.min_len then
  195. return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len)
  196. end
  197. if self.pattern and not testee:match(self.pattern) then
  198. return false, fmt("%q doesn't match the pattern %q", testee, self.pattern)
  199. end
  200. return validate_all(self.validators, testee)
  201. end
  202. local Function_mt = reg_mt('Function', Any_mt)
  203. schema.Function = set_mt({
  204. test = type_checker('function', 'a'),
  205. }, Function_mt)
  206. Function_mt.__call = Any_mt.__call
  207. function Function_mt.__index:test(testee)
  208. if self.super then
  209. local valid, msg = self.super:test(testee)
  210. if not valid then
  211. return false, msg
  212. end
  213. end
  214. return validate_all(self.validators, testee)
  215. end
  216. local Table_mt = reg_mt('Table', Any_mt)
  217. schema.Table = set_mt({
  218. test = type_checker('table', 'a')
  219. }, Table_mt)
  220. function Table_mt:__call(constraints)
  221. local specific = {}
  222. local generic = {}
  223. for k, v in pairs(constraints) do
  224. local scm_type = get_scm_type(k)
  225. if scm_type then
  226. if scm_type == 'Literal' then
  227. specific[k.val] = v
  228. else
  229. generic[k] = v
  230. end
  231. elseif k ~= 'validators' and k ~= 'validator' then
  232. specific[k] = v
  233. end
  234. end
  235. return set_mt({
  236. super = self,
  237. specific = specific,
  238. generic = generic,
  239. validators = get_validators_from_constraints(constraints)
  240. }, Table_mt)
  241. end
  242. function Table_mt.__index:test(testee)
  243. if self.super then
  244. local valid, msg = self.super:test(testee)
  245. if not valid then
  246. return false, msg
  247. end
  248. end
  249. for key_scm, val_scm in pairs(self.generic) do
  250. for testee_key, testee_val in pairs(testee) do
  251. if key_scm:test(testee_key) then
  252. local valid, msg = val_scm:test(testee_val)
  253. if not valid then
  254. return false, msg
  255. end
  256. end
  257. end
  258. end
  259. for key, val_scm in pairs(self.specific) do
  260. local valid, msg = val_scm:test(testee[key])
  261. if not valid then
  262. return false, testee[key] == nil and fmt('`%s` misses field `%s`', testee, key) or msg
  263. end
  264. end
  265. return validate_all(self.validators, testee)
  266. end
  267. local Union_mt = reg_mt('Union', Any_mt)
  268. function schema.Union(...)
  269. local union = {}
  270. for i = 1, select('#', ...) do
  271. local sub_scm = select(i, ...)
  272. if sub_scm == nil then
  273. union[schema.Nil] = true
  274. elseif get_scm_type(sub_scm) == 'Union' then
  275. for scm_in_union in next, sub_scm do
  276. union[scm_in_union] = true
  277. end
  278. else
  279. union[sub_scm] = true
  280. end
  281. end
  282. return set_mt(union, Union_mt)
  283. end
  284. function Union_mt.__index:test(testee)
  285. for allowed_val in next, self do
  286. if get_scm_type(allowed_val) then
  287. if allowed_val:test(testee) then
  288. return true
  289. end
  290. elseif testee == allowed_val then
  291. return true
  292. end
  293. end
  294. return false, fmt('testee `%s` fails to match each value in the union: %s', testee, self)
  295. end
  296. for mt in next, mts do
  297. mt.__bor = schema.Union
  298. mt.__div = schema.Union
  299. end
  300. schema.Truthy = schema.Any{validator=function(v) return v end}
  301. schema.Falsy = schema.Any{validator=function(v) return not v end}
  302. return schema
返回顶部
页面反馈