From 434f9c1b3102ef1bbce5f73f17babdbf7b55d974 Mon Sep 17 00:00:00 2001 From: Case Duckworth Date: Tue, 2 Apr 2024 12:46:16 -0500 Subject: Implement arity checks --- core.lua | 102 ++++++++++++++++++++++++++++++++++++--------------------------- util.lua | 28 +++++++++++++++--- 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/core.lua b/core.lua index e8ad42b..dec4a39 100644 --- a/core.lua +++ b/core.lua @@ -4,6 +4,7 @@ local m = {} local type = require "type" local isa, null = type.isa, type.null local math = math +local proc = require("util").proc local function fold (kons, knil, r) if r == null then @@ -17,58 +18,71 @@ end m.env = { -- all functions here take R, which is the list of arguments ------- numbers - ["number?"] = function (r) return isa(r[1], "number") end, + ["number?"] = proc(1, function (r) return isa(r[1], "number") end), ["="] = - function (r) - local function go (a, b) - if a ~= b then return false, 1 end - return b - end - return fold(go, r[1], r[2]) and true - end, + proc({0}, function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] ~= r[2][1] then return false end + r = r[2] + end + return true + end), ["<"] = - function (r) - local function go (a, b) - if a >= b then return false, 1 end - return b - end - return fold(go, r[1], r[2]) and true - end, + proc({0}, function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] >= r[2][1] then return false end + r = r[2] + end + return true + end), [">"] = - function (r) - local function go (a, b) - if a <= b then return false, 1 end - return b - end - return fold(go, r[1], r[2]) and true - end, - ["<="] = function (r) return not m.env[">"](r) end, - [">="] = function (r) return not m.env["<"](r) end, + proc({0}, function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] <= r[2][1] then return false end + r = r[2] + end + return true + end), + ["<="] = proc({0}, function (r) return not m.env[">"](r) end), + [">="] = proc({0}, function (r) return not m.env["<"](r) end), ------- math ["+"] = - function (r) - return fold(function (a, b) return a + b end, 0, r) - end, + proc({0}, function (r) + return fold(function (a, b) + return a + b + end, 0, r) + end), ["-"] = - function (r) - if r == null then return -1 end - if r[2] == null then return (- r[1]) end - return fold(function (a, b) return a-b end, r[1], r[2]) - end, + proc({0}, function (r) + if r == null then return -1 end + if r[2] == null then return (- r[1]) end + return fold(function (a, b) + return a - b + end, r[1], r[2]) + end), ["*"] = - function (r) - local function go (a, b) - if a == 0 or b == 0 then return 0, 1 end - return a * b - end - return fold(go, 1, r) - end, + proc({0}, function (r) + local function go (a, b) + if a == 0 or b == 0 then + return 0, 1 + end + return a * b + end + return fold(go, 1, r) + end), ["/"] = - function (r) - if r == null then error("Wrong arity") end - if r[2] == null then return (1 / r[1]) end - return fold(function (a, b) return a/b end, r[1], r[2]) - end, + proc({1}, function (r) + if r[2] == null then return (1 / r[1]) end + return fold(function (a, b) + return a / b + end, r[1], r[2]) + end), } -------- diff --git a/util.lua b/util.lua index b5a57b1..d151858 100644 --- a/util.lua +++ b/util.lua @@ -7,10 +7,30 @@ function m.pop (tbl) return table.remove(tbl, 1) end -function m.arity (r, min, max) - --[[ Return whether R is within MIN and MAX (inclusive). ]] - local len = #r - return len >= min and len <= max +function m.proc (arity, fn) + --[[ Wrap RN in a check that for its ARITY. + ARITY can be a number, the minimum number of arguments, + or a table {MIN, MAX}. If MIN is nil or absent, it's 0; + if MAX is nil or absent, it's infinity. MIN and MAX are + both inclusive. + ]] + local rmin, rmax, rstr + if type(arity) ~= "table" then + rmin, rmax = arity, arity + rstr = rmin + else + rmin, rmax = arity[1] or 0, arity[2] or 1/0 -- infinity + rstr = rmin .. ".." .. rmax + end + return function (r) + local rlen = r and #r or 0 + if rlen < rmin or rlen > rmax then + error(string.format("Wrong arity: %s, need %s", + rlen, + rstr)) + end + return fn(r) + end end --- -- cgit 1.4.1-21-gabe81