- local fmt = string.format
- local type = type
- local ipairs = ipairs
- local get_mt = getmetatable
- local set_mt = setmetatable
-
- -- Schema
-
- local schema = {}
-
- local function is_raw_table(val)
- return type(val) == 'table' and not get_mt(val)
- end
-
- --- 过滤数组
- local function ifilter(t, filter)
- local new = {}
- local filtered_num = 0
- for i, v in ipairs(t) do
- if filter(v) then
- new[i - filtered_num] = v
- else
- filtered_num = filtered_num + 1
- end
- end
- return new, filtered_num
- end
-
- ---@param ty string 类型
- ---@param a string? 冠词
- ---@return function
- local function type_checker(ty, a)
- local fmt_str = "%s (type: %s) isn't "
- if a then
- fmt_str = fmt_str..a..' '
- end
- fmt_str = fmt_str..ty
- return function(self, testee)
- if type(testee) == ty then
- return true
- end
- return false, fmt(fmt_str, testee, type(testee))
- end
- end
-
-
- --- `val`是`x`或`{x}`时返回`{x}`,返回的`{x}`与前一个`{x}`是同一个对象
- local function ensure_wrapped(val)
- return is_raw_table(val) and val or {val}
- end
-
-
- local mts = set_mt({}, {__mode = 'k'}) ---@type {[metatable]: true}
-
-
- ---@param name string
- ---@param super_mt metatable?
- ---@param without_override boolean?
- ---@return metatable
- local function reg_mt(name, super_mt, without_override)
- local index = super_mt and super_mt.__index or {}
- if not without_override then
- index = set_mt({}, {__index = index})
- end
-
- local mt = {
- __name = name,
- __index = index,
- }
- mts[mt] = true
- return mt
- end
-
-
- ---@param v any
- ---@return string | nil
- local function get_scm_type(v)
- local mt = get_mt(v)
- if not mts[mt] then return nil end
- return mt.__name
- end
-
-
- ---@param constraints table
- ---@return table | nil
- local function get_validators_from_constraints(constraints)
- local t = ifilter(
- ensure_wrapped(constraints.validators or constraints.validator),
- function(v)
- return assert(type(v) == 'function', 'validator需要是函数或元素为函数的表')
- end
- )
- return t[1] and t or nil
- end
-
-
- ---@param validators function[]?
- ---@param val any
- ---@return boolean, string?
- local function validate_all(validators, val)
- if not validators then return true end
- for _, validator in ipairs(validators) do
- local valid, msg = validator(val)
- if not valid then
- return false, msg
- end
- end
- return true
- end
-
-
- local Any_mt = reg_mt('Any', nil)
- schema.Any = set_mt({
- test = function()
- return true
- end,
- }, Any_mt)
-
- function Any_mt:__call(constraints)
- return set_mt({
- super = self,
- validators = get_validators_from_constraints(constraints),
- }, get_mt(self))
- end
-
- function Any_mt.__index:test(testee)
- if self.super then
- local valid, msg = self.super:test(testee)
- if not valid then
- return false, msg
- end
- end
- return validate_all(self.validators, testee)
- end
-
- function Any_mt.__index:assert(testee)
- local valid, msg = self:test(testee)
- if valid then
- return true
- end
- error(msg, 2)
- end
-
-
- local Nil_mt = reg_mt('Nil', Any_mt, true)
- schema.Nil = set_mt({
- test = type_checker('nil'),
- }, Nil_mt)
-
-
- local Boolean_mt = reg_mt('Boolean', Any_mt, true)
- schema.Boolean = set_mt({
- test = type_checker('boolean', 'a'),
- }, Boolean_mt)
-
-
- local Number_mt = reg_mt('Number', Any_mt)
- schema.Number = set_mt({
- test = type_checker('number', 'a'),
- }, Number_mt)
-
- function Number_mt:__call(constraints)
- return set_mt({
- super = self,
- int = constraints.int,
- lt = constraints.lt,
- gt = constraints.gt,
- le = constraints.le or constraints.max,
- ge = constraints.ge or constraints.min,
- ne = constraints.ne,
- validators = get_validators_from_constraints(constraints),
- }, Number_mt)
- end
-
- function Number_mt.__index:test(testee)
- if self.super then
- local valid, msg = self.super:test(testee)
- if not valid then
- return false, msg
- end
- end
- if self.int and math.fmod(testee, 1) ~= 0 then
- return false, fmt("%s isn't an integer", testee)
- end
- if self.lt and testee >= self.lt then
- return false, fmt("%s isn't < %s", testee, self.lt)
- end
- if self.gt and testee <= self.gt then
- return false, fmt("%s isn't > %s", testee, self.gt)
- end
- if self.le and testee > self.le then
- return false, fmt("%s isn't <= %s", testee, self.le)
- end
- if self.ge and testee < self.ge then
- return false, fmt("%s isn't >= %s", testee, self.ge)
- end
- if self.ne and testee == self.ne then
- return false, fmt('testee equals %s', self.ne)
- end
- return validate_all(self.validators, testee)
- end
-
-
- local String_mt = reg_mt('String', Any_mt)
- schema.String = set_mt({
- test = type_checker('string', 'a')
- }, String_mt)
-
- function String_mt:__call(constraints)
- return set_mt({
- super = self,
- max_len = constraints.max_len,
- min_len = constraints.min_len,
- pattern = constraints.pattern,
- validators = get_validators_from_constraints(constraints),
- }, String_mt)
- end
-
- function String_mt.__index:test(testee)
- if self.super then
- local valid, msg = self.super:test(testee)
- if not valid then
- return false, msg
- end
- end
- if self.max_len and #testee > self.max_len then
- return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len)
- end
- if self.min_len and #testee < self.min_len then
- return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len)
- end
- if self.pattern and not testee:match(self.pattern) then
- return false, fmt("%q doesn't match the pattern %q", testee, self.pattern)
- end
- return validate_all(self.validators, testee)
- end
-
-
- local Function_mt = reg_mt('Function', Any_mt)
- schema.Function = set_mt({
- test = type_checker('function', 'a'),
- }, Function_mt)
-
- Function_mt.__call = Any_mt.__call
-
- function Function_mt.__index:test(testee)
- if self.super then
- local valid, msg = self.super:test(testee)
- if not valid then
- return false, msg
- end
- end
- return validate_all(self.validators, testee)
- end
-
-
- local Table_mt = reg_mt('Table', Any_mt)
- schema.Table = set_mt({
- test = type_checker('table', 'a')
- }, Table_mt)
-
- function Table_mt:__call(constraints)
- local specific = {}
- local generic = {}
- for k, v in pairs(constraints) do
- local scm_type = get_scm_type(k)
- if scm_type then
- if scm_type == 'Literal' then
- specific[k.val] = v
- else
- generic[k] = v
- end
- elseif k ~= 'validators' and k ~= 'validator' then
- specific[k] = v
- end
- end
- return set_mt({
- super = self,
- specific = specific,
- generic = generic,
- validators = get_validators_from_constraints(constraints)
- }, Table_mt)
- end
-
- function Table_mt.__index:test(testee)
- if self.super then
- local valid, msg = self.super:test(testee)
- if not valid then
- return false, msg
- end
- end
- for key_scm, val_scm in pairs(self.generic) do
- for testee_key, testee_val in pairs(testee) do
- if key_scm:test(testee_key) then
- local valid, msg = val_scm:test(testee_val)
- if not valid then
- return false, msg
- end
- end
- end
- end
- for key, val_scm in pairs(self.specific) do
- local valid, msg = val_scm:test(testee[key])
- if not valid then
- return false, testee[key] == nil and fmt('`%s` misses field `%s`', testee, key) or msg
- end
- end
- return validate_all(self.validators, testee)
- end
-
-
- local Union_mt = reg_mt('Union', Any_mt)
-
- function schema.Union(...)
- local union = {}
- for i = 1, select('#', ...) do
- local sub_scm = select(i, ...)
- if sub_scm == nil then
- union[schema.Nil] = true
- elseif get_scm_type(sub_scm) == 'Union' then
- for scm_in_union in next, sub_scm do
- union[scm_in_union] = true
- end
- else
- union[sub_scm] = true
- end
- end
- return set_mt(union, Union_mt)
- end
-
- function Union_mt.__index:test(testee)
- for allowed_val in next, self do
- if get_scm_type(allowed_val) then
- if allowed_val:test(testee) then
- return true
- end
- elseif testee == allowed_val then
- return true
- end
- end
- return false, fmt('testee `%s` fails to match each value in the union: %s', testee, self)
- end
-
- for mt in next, mts do
- mt.__bor = schema.Union
- mt.__div = schema.Union
- end
-
- schema.Truthy = schema.Any{validator=function(v) return v end}
- schema.Falsy = schema.Any{validator=function(v) return not v end}
-
-
- return schema