From 2a5c0883ca907e97110ea0050080f74ccbb143e2 Mon Sep 17 00:00:00 2001 From: Case Duckworth Date: Tue, 2 Apr 2024 21:04:32 -0500 Subject: Change arity assertion code --- core.lua | 117 ++++++++++++++++++++++++++++++++------------------------------- eval.lua | 88 ++++++++++++++++++++++++++++------------------- util.lua | 31 +++++------------ 3 files changed, 121 insertions(+), 115 deletions(-) diff --git a/core.lua b/core.lua index dec4a39..7891f3a 100644 --- a/core.lua +++ b/core.lua @@ -4,7 +4,8 @@ local m = {} local type = require "type" local isa, null = type.isa, type.null local math = math -local proc = require("util").proc +local util = require "util" +local assert_arity = util.assert_arity local function fold (kons, knil, r) if r == null then @@ -18,71 +19,73 @@ end m.env = { -- all functions here take R, which is the list of arguments ------- numbers - ["number?"] = proc(1, function (r) return isa(r[1], "number") end), + ["number?"] = + function (r) + assert_arity(r, 1, 1) + return isa(r[1], "number") + end, ["="] = - proc({0}, function (r) - if r[1] == nil then return true end - if r[2] == nil then return true end - while r[2] ~= null do - if r[1] ~= r[2][1] then return false end - r = r[2] - end - return true - end), + function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] ~= r[2][1] then return false end + r = r[2] + end + return true + end, ["<"] = - proc({0}, function (r) - if r[1] == nil then return true end - if r[2] == nil then return true end - while r[2] ~= null do - if r[1] >= r[2][1] then return false end - r = r[2] - end - return true - end), + function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] >= r[2][1] then return false end + r = r[2] + end + return true + end, [">"] = - proc({0}, function (r) - if r[1] == nil then return true end - if r[2] == nil then return true end - while r[2] ~= null do - if r[1] <= r[2][1] then return false end - r = r[2] - end - return true - end), - ["<="] = proc({0}, function (r) return not m.env[">"](r) end), - [">="] = proc({0}, function (r) return not m.env["<"](r) end), + function (r) + if r[1] == nil then return true end + if r[2] == nil then return true end + while r[2] ~= null do + if r[1] <= r[2][1] then return false end + r = r[2] + end + return true + end, + ["<="] = function (r) return not m.env[">"](r) end, + [">="] = function (r) return not m.env["<"](r) end, ------- math ["+"] = - proc({0}, function (r) - return fold(function (a, b) - return a + b - end, 0, r) - end), + function (r) + return fold(function (a, b) return a + b end, 0, r) + end, ["-"] = - proc({0}, function (r) - if r == null then return -1 end - if r[2] == null then return (- r[1]) end - return fold(function (a, b) - return a - b - end, r[1], r[2]) - end), + function (r) + if r == null then return -1 end + if r[2] == null then return (- r[1]) end + return fold(function (a, b) + return a - b + end, r[1], r[2]) + end, ["*"] = - proc({0}, function (r) - local function go (a, b) - if a == 0 or b == 0 then - return 0, 1 - end - return a * b + function (r) + local function go (a, b) + if a == 0 or b == 0 then + return 0, 1 end - return fold(go, 1, r) - end), + return a * b + end + return fold(go, 1, r) + end, ["/"] = - proc({1}, function (r) - if r[2] == null then return (1 / r[1]) end - return fold(function (a, b) - return a / b - end, r[1], r[2]) - end), + function (r) + assert_arity(r, 1) + if r[2] == null then return (1 / r[1]) end + return fold(function (a, b) return a / b end, + r[1], r[2]) + end, } -------- diff --git a/eval.lua b/eval.lua index 53292d0..60369a9 100644 --- a/eval.lua +++ b/eval.lua @@ -20,7 +20,7 @@ function m.environ (inner, outer) return setmetatable(inner, mt) end -local function call_proc (proc, r) +local function procedure_call (proc, r) local function doargs (p, r, e) if p == type.null and r == type.null then return e end if type.isa(p, "symbol") then @@ -50,57 +50,75 @@ function m.procedure (params, body, env) } local mt = { __type = "procedure", - __call = call_proc, + __call = procedure_call, } return setmetatable(t, mt) end +local function handle_quasiquote (r, e) + assert_arity(r, 1, 1) + local x = r[1] + if not type.islist(x) or x == type.null then + return x + end + local QQ, fin = {}, nil + local car, cdr = x[1], x[2] + while cdr do + if type.islist(car) then + if car[1] == "unquote" then + table.insert(QQ, m.eval(car[2][1], e)) + elseif car[1] == "unquote-splicing" then + local usl = m.eval(car[2][1], e) + if not type.islist(usl) then + fin = usl + break + end + while usl[2] do + table.insert(QQ, usl[1]) + usl = usl[2] + end + end + else + table.insert(QQ, car) + end + car, cdr = cdr[1], cdr[2] + end + return type.list(QQ, fin) +end + m.specials = { -- each of these takes R (a list of args) and E (an environment) - quote = function (r, e) return r[1] end, - quasiquote = + quote = function (r, e) - local x = r[1] - if not type.islist(x) or x == type.null then - return x - end - local QQ, fin = {}, nil - local car, cdr = x[1], x[2] - while cdr do - if type.islist(car) then - if car[1] == "unquote" then - table.insert(QQ, - m.eval(car[2][1], e)) - elseif car[1] == "unquote-splicing" then - local usl = m.eval(car[2][1], e) - if not type.islist(usl) then - fin = usl - break - end - while usl[2] do - table.insert(QQ, usl[1]) - usl = usl[2] - end - end - else - table.insert(QQ, car) - end - car, cdr = cdr[1], cdr[2] - end - return type.list(QQ, fin) + assert_arity(r, 1, 1) + return r[1] end, + quasiquote = handle_quasiquote, -- if not inside quasiquote, unquote and unquote-splicing are errors unquote = function () error("Unexpected unquote") end, ["unquote-splicing"] = function () error("Unexpected unquote-splicing") end, -- define variables - define = function (r, e) rawset(e, r[1], m.eval(r[2][1], e)) end, - ["set!"] = function (r, e) e[r[1]] = m.eval(r[2][1], e) end, + define = + function (r, e) + assert_arity(r, 2, 2) + rawset(e, r[1], m.eval(r[2][1], e)) + end, + ["set!"] = + function (r, e) + assert_arity(r, 2, 2) + e[r[1]] = m.eval(r[2][1], e) + end, -- y'know, ... lambda - lambda = function (r, e) return m.procedure(r[1], r[2], e) end, + lambda = + function (r, e) + assert_arity(r, 2) + return m.procedure(r[1], r[2], e) + end, -- control flow ["if"] = function (r, e) + assert_arity(r, 3, 3) local test, conseq, alt = r[1], r[2][1], r[2][2][1] if m.eval(test) diff --git a/util.lua b/util.lua index d151858..8fedbf7 100644 --- a/util.lua +++ b/util.lua @@ -7,29 +7,14 @@ function m.pop (tbl) return table.remove(tbl, 1) end -function m.proc (arity, fn) - --[[ Wrap RN in a check that for its ARITY. - ARITY can be a number, the minimum number of arguments, - or a table {MIN, MAX}. If MIN is nil or absent, it's 0; - if MAX is nil or absent, it's infinity. MIN and MAX are - both inclusive. - ]] - local rmin, rmax, rstr - if type(arity) ~= "table" then - rmin, rmax = arity, arity - rstr = rmin - else - rmin, rmax = arity[1] or 0, arity[2] or 1/0 -- infinity - rstr = rmin .. ".." .. rmax - end - return function (r) - local rlen = r and #r or 0 - if rlen < rmin or rlen > rmax then - error(string.format("Wrong arity: %s, need %s", - rlen, - rstr)) - end - return fn(r) +function m.assert_arity (r, min, max) + local rmin = min or 0 + local rmax = max or 1/0 -- infinity + local rlen = #r + if rlen < rmin or rlen > rmax then + error(string.format("Wrong arity: %s; expecting %s", + rlen, + rmin == rmax and rmin or (rmin..".."..rmax))) end end -- cgit 1.4.1-21-gabe81