置顶公告:【置顶】关于临时开启评论区所有功能的公告(2022.10.22) | 【置顶】关于本站Widget恢复使用的公告
  • 你好~!欢迎来到萌娘百科镜像站!如需查看或编辑,请联系本站管理员注册账号。
  • 本镜像站和其他萌娘百科的镜像站无关,请注意分别。

Module:Sandbox/あめろ

猛汉♂百科,万男皆可猛的百科全书!转载请标注来源页面的网页链接,并声明引自猛汉百科。内容不可商用。
跳到导航 跳到搜索
Template-info.svg 模块文档  [创建] [刷新]
local fmt = string.format
local type = type
local ipairs = ipairs
local get_mt = getmetatable
local set_mt = setmetatable

-- Schema

local schema = {}

local function is_raw_table(val)
	return type(val) == 'table' and not get_mt(val)
end

--- 过滤数组
local function ifilter(t, filter)
	local new = {}
	local filtered_num = 0
	for i, v in ipairs(t) do
		if filter(v) then
			new[i - filtered_num] = v
		else
			filtered_num = filtered_num + 1
		end
	end
	return new, filtered_num
end

---@param ty string 类型
---@param a string? 冠词
---@return function
local function type_checker(ty, a)
	local fmt_str = "%s (type: %s) isn't "
	if a then
		fmt_str = fmt_str..a..' '
	end
	fmt_str = fmt_str..ty
	return function(self, testee)
		if type(testee) == ty then
			return true
		end
		return false, fmt(fmt_str, testee, type(testee))
	end
end


--- `val`是`x`或`{x}`时返回`{x}`,返回的`{x}`与前一个`{x}`是同一个对象
local function ensure_wrapped(val)
	return is_raw_table(val) and val or {val}
end


local mts = set_mt({}, {__mode = 'k'})  ---@type {[metatable]: true}


---@param name string
---@param super_mt metatable?
---@param without_override boolean?
---@return metatable
local function reg_mt(name, super_mt, without_override)
	local index = super_mt and super_mt.__index or {}
	if not without_override then
		index = set_mt({}, {__index = index})
	end

	local mt = {
		__name = name,
		__index = index,
	}
	mts[mt] = true
	return mt
end


---@param v any
---@return string | nil
local function get_scm_type(v)
	local mt =  get_mt(v)
	if not mts[mt] then return nil end
	return mt.__name
end


---@param constraints table
---@return table | nil
local function get_validators_from_constraints(constraints)
	local t = ifilter(
		ensure_wrapped(constraints.validators or constraints.validator),
		function(v)
			return assert(type(v) == 'function', 'validator需要是函数或元素为函数的表')
		end
	)
	return t[1] and t or nil
end


---@param validators function[]?
---@param val any
---@return boolean, string?
local function validate_all(validators, val)
	if not validators then return true end
	for _, validator in ipairs(validators) do
		local valid, msg = validator(val)
		if not valid then
			return false, msg
		end
	end
	return true
end


local Any_mt = reg_mt('Any', nil)
schema.Any = set_mt({
	test = function()
		return true
	end,
}, Any_mt)

function Any_mt:__call(constraints)
	return set_mt({
		super = self,
		validators = get_validators_from_constraints(constraints),
	}, get_mt(self))
end

function Any_mt.__index:test(testee)
	if self.super then
		local valid, msg = self.super:test(testee)
		if not valid then
			return false, msg
		end
	end
	return validate_all(self.validators, testee)
end

function Any_mt.__index:assert(testee)
	local valid, msg = self:test(testee)
	if valid then
		return true
	end
	error(msg, 2)
end


local Nil_mt = reg_mt('Nil', Any_mt, true)
schema.Nil = set_mt({
	test = type_checker('nil'),
}, Nil_mt)


local Boolean_mt = reg_mt('Boolean', Any_mt, true)
schema.Boolean = set_mt({
	test = type_checker('boolean', 'a'),
}, Boolean_mt)


local Number_mt = reg_mt('Number', Any_mt)
schema.Number = set_mt({
	test = type_checker('number', 'a'),
}, Number_mt)

function Number_mt:__call(constraints)
	return set_mt({
		super = self,
		int = constraints.int,
		lt = constraints.lt,
		gt = constraints.gt,
		le = constraints.le or constraints.max,
		ge = constraints.ge or constraints.min,
		ne = constraints.ne,
		validators = get_validators_from_constraints(constraints),
	}, Number_mt)
end

function Number_mt.__index:test(testee)
	if self.super then
		local valid, msg = self.super:test(testee)
		if not valid then
			return false, msg
		end
	end
	if self.int and math.fmod(testee, 1) ~= 0 then
		return false, fmt("%s isn't an integer", testee)
	end
	if self.lt and testee >= self.lt then
		return false, fmt("%s isn't < %s", testee, self.lt)
	end
	if self.gt and testee <= self.gt then
		return false, fmt("%s isn't > %s", testee, self.gt)
	end
	if self.le and testee > self.le then
		return false, fmt("%s isn't <= %s", testee, self.le)
	end
	if self.ge and testee < self.ge then
		return false, fmt("%s isn't >= %s", testee, self.ge)
	end
	if self.ne and testee == self.ne then
		return false, fmt('testee equals %s', self.ne)
	end
	return validate_all(self.validators, testee)
end


local String_mt = reg_mt('String', Any_mt)
schema.String = set_mt({
	test = type_checker('string', 'a')
}, String_mt)

function String_mt:__call(constraints)
	return set_mt({
		super = self,
		max_len = constraints.max_len,
		min_len = constraints.min_len,
		pattern = constraints.pattern,
		validators = get_validators_from_constraints(constraints),
	}, String_mt)
end

function String_mt.__index:test(testee)
	if self.super then
		local valid, msg = self.super:test(testee)
		if not valid then
			return false, msg
		end
	end
	if self.max_len and #testee > self.max_len then
		return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len)
	end
	if self.min_len and #testee < self.min_len then
		return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len)
	end
	if self.pattern and not testee:match(self.pattern) then
		return false, fmt("%q doesn't match the pattern %q", testee, self.pattern)
	end
	return validate_all(self.validators, testee)
end


local Function_mt = reg_mt('Function', Any_mt)
schema.Function = set_mt({
	test = type_checker('function', 'a'),
}, Function_mt)

Function_mt.__call = Any_mt.__call

function Function_mt.__index:test(testee)
	if self.super then
		local valid, msg = self.super:test(testee)
		if not valid then
			return false, msg
		end
	end
	return validate_all(self.validators, testee)
end


local Table_mt = reg_mt('Table', Any_mt)
schema.Table = set_mt({
	test = type_checker('table', 'a')
}, Table_mt)

function Table_mt:__call(constraints)
	local specific = {}
	local generic = {}
	for k, v in pairs(constraints) do
		local scm_type = get_scm_type(k)
		if scm_type then
			if scm_type == 'Literal' then
				specific[k.val] = v
			else
				generic[k] = v
			end
		elseif k ~= 'validators' and k ~= 'validator' then
			specific[k] = v
		end
	end
	return set_mt({
		super = self,
		specific = specific,
		generic = generic,
		validators = get_validators_from_constraints(constraints)
	}, Table_mt)
end

function Table_mt.__index:test(testee)
	if self.super then
		local valid, msg = self.super:test(testee)
		if not valid then
			return false, msg
		end
	end
	for key_scm, val_scm in pairs(self.generic) do
		for testee_key, testee_val in pairs(testee) do
			if key_scm:test(testee_key) then
				local valid, msg = val_scm:test(testee_val)
				if not valid then
					return false, msg
				end
			end
		end
	end
	for key, val_scm in pairs(self.specific) do
		local valid, msg = val_scm:test(testee[key])
		if not valid then
			return false, testee[key] == nil and fmt('`%s` misses field `%s`', testee, key) or msg
		end
	end
	return validate_all(self.validators, testee)
end


local Union_mt = reg_mt('Union', Any_mt)

function schema.Union(...)
	local union = {}
	for i = 1, select('#', ...) do
		local sub_scm = select(i, ...)
		if sub_scm == nil then
			union[schema.Nil] = true
		elseif get_scm_type(sub_scm) == 'Union' then
			for scm_in_union in next, sub_scm do
				union[scm_in_union] = true
			end
		else
			union[sub_scm] = true
		end
	end
	return set_mt(union, Union_mt)
end

function Union_mt.__index:test(testee)
	for allowed_val in next, self do
		if get_scm_type(allowed_val) then
			if allowed_val:test(testee) then
				return true
			end
		elseif testee == allowed_val then
			return true
		end
	end
	return false, fmt('testee `%s` fails to match each value in the union: %s', testee, self)
end

for mt in next, mts do
	mt.__bor = schema.Union
	mt.__div = schema.Union
end

schema.Truthy = schema.Any{validator=function(v) return v end}
schema.Falsy = schema.Any{validator=function(v) return not v end}


return schema