about summary refs log tree commit diff stats
path: root/utf8.lua
blob: 2bfcf8ad0ecf968d4b2f0338a7ad9fd8bfbed6f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
--- lam.utf8 --- I did not write this code.  I did lightly edit it to work with
-- this project.  This module was written by GitHub user meepen and is under the
-- CC0 1.0 license.  The source can be found here:
-- https://github.com/meepen/Lua-5.1-UTF-8/
local utf8 = {}
local bit = bit
local error = error
local ipairs = ipairs
local string = string
local table = table
local unpack = table.unpack or unpack

-- Pattern that can be used with the string library to match a single UTF-8
-- byte-sequence.  This expects the string to contain valid UTF-8 data.
utf8.charpattern = "[%z\x01-\x7F\xC2-\xF4][\x80-\xBF]*"

-- Transforms indexes of a string to be positive.  Negative indices will wrap
-- around like the string library's functions.
local function strRelToAbs(str, ...)
	local args = { ... }
	for k, v in ipairs(args) do
		v = v > 0 and v or #str + v + 1
		if v < 1 or v > #str then
			error("bad index to string (out of range)", 3)
		end
		args[ k ] = v
	end
	return unpack(args)
end

-- Decodes a single UTF-8 byte-sequence from a string, ensuring it is valid.
-- Returns the index of the first and last character of the sequence
local function decode(str, startPos)
	startPos = strRelToAbs(str, startPos or 1)
	local b1 = str:byte(startPos, startPos)
	-- Single-byte sequence
	if b1 < 0x80 then
		return startPos, startPos
	end
	-- Validate first byte of multi-byte sequence
	if b1 > 0xF4 or b1 < 0xC2 then
		return nil
	end
	-- Get 'supposed' amount of continuation bytes from primary byte
	local contByteCount =	b1 >= 0xF0 and 3 or
		b1 >= 0xE0 and 2 or
		b1 >= 0xC0 and 1
	local endPos = startPos + contByteCount
	-- Validate our continuation bytes
	for _, bX in ipairs { str:byte(startPos + 1, endPos) } do
		if bit.band(bX, 0xC0) ~= 0x80 then
			return nil
		end
	end
	return startPos, endPos
end

-- Takes zero or more integers and returns a string containing the UTF-8
-- representation of each
function utf8.char(...)
	local buf = {}
	for k, v in ipairs { ... } do
		if v < 0 or v > 0x10FFFF then
			error("bad argument #" .. k ..
			       " to char (out of range)", 2)
		end
		local b1, b2, b3, b4 = nil, nil, nil, nil
		if v < 0x80 then -- Single-byte sequence
			table.insert(buf, string.char(v))
		elseif v < 0x800 then -- Two-byte sequence
			b1 = bit.bor(0xC0, bit.band(bit.rshift(v, 6), 0x1F))
			b2 = bit.bor(0x80, bit.band(v, 0x3F))
			table.insert(buf, string.char(b1, b2))
		elseif v < 0x10000 then -- Three-byte sequence
			b1 = bit.bor(0xE0, bit.band(bit.rshift(v, 12), 0x0F))
			b2 = bit.bor(0x80, bit.band(bit.rshift(v, 6), 0x3F))
			b3 = bit.bor(0x80, bit.band(v, 0x3F))
			table.insert(buf, string.char(b1, b2, b3))
		else -- Four-byte sequence
			b1 = bit.bor(0xF0, bit.band(bit.rshift(v, 18), 0x07))
			b2 = bit.bor(0x80, bit.band(bit.rshift(v, 12), 0x3F))
			b3 = bit.bor(0x80, bit.band(bit.rshift(v, 6), 0x3F))
			b4 = bit.bor(0x80, bit.band(v, 0x3F))
			table.insert(buf, string.char(b1, b2, b3, b4))
		end
	end
	return table.concat(buf, "")
end

-- Iterates over a UTF-8 string similarly to pairs.
-- k = index of sequence, v = string value of sequence
function utf8.codes(str)
	local i = 1
	return function()
		-- Have we hit the end of the iteration set?
		if i > #str then
			return nil
		end
		local startPos, endPos = decode(str, i)
		if not startPos then
			error("invalid UTF-8 code", 2)
		end
		i = endPos + 1
		return startPos, str:sub(startPos, endPos)
	end
end

-- Returns an integer-representation of the UTF-8 sequence(s) in a string
-- startPos defaults to 1, endPos defaults to startPos
function utf8.codepoint(str, startPos, endPos)
	startPos, endPos = strRelToAbs(str,
				       startPos or 1,
				       endPos or startPos or 1)
	local ret = {}
	repeat
		local seqStartPos, seqEndPos = decode(str, startPos)
		if not seqStartPos then
			error("invalid UTF-8 code", 2)
		end
		-- Increment current string index
		startPos = seqEndPos + 1
		-- Amount of bytes making up our sequence
		local len = seqEndPos - seqStartPos + 1
		if len == 1 then -- Single-byte codepoint
			table.insert(ret, str:byte(seqStartPos))
		else -- Multi-byte codepoint
			local b1 = str:byte(seqStartPos)
			local cp = 0
			for i = seqStartPos + 1, seqEndPos do
				local bX = str:byte(i)
				cp = bit.bor(bit.lshift(cp, 6),
					     bit.band(bX, 0x3F))
				b1 = bit.lshift(b1, 1)
			end
			cp = bit.bor(cp, bit.lshift(bit.band(b1, 0x7F),
						    (len - 1) * 5))
			table.insert(ret, cp)
		end
	until seqEndPos >= endPos
	return unpack(ret)
end

-- Returns the length of a UTF-8 string. false, index is returned if an invalid
-- sequence is hit startPos defaults to 1, endPos defaults to -1
function utf8.len(str, startPos, endPos)
	startPos, endPos = strRelToAbs(str, startPos or 1, endPos or -1)
	local len = 0
	repeat
		local seqStartPos, seqEndPos = decode(str, startPos)
		-- Hit an invalid sequence?
		if not seqStartPos then
			return false, startPos
		end
		-- Increment current string pointer
		startPos = seqEndPos + 1
		-- Increment length
		len = len + 1
	until seqEndPos >= endPos
	return len
end

-- Returns the byte-index of the n'th UTF-8-character after the given byte-index
-- (nil if none).  startPos defaults to 1 when n is positive and -1 when n is
-- negative.  If 0 is zero, this function instead returns the byte-index of the
-- UTF-8-character startPos lies within.
function utf8.offset(str, n, startPos)
	startPos = strRelToAbs(str, startPos or (n >= 0 and 1) or #str)
	-- Find the beginning of the sequence over startPos
	if n == 0 then
		for i = startPos, 1, -1 do
			local seqStartPos, seqEndPos = decode(str, i)
			if seqStartPos then
				return seqStartPos
			end
		end
		return nil
	end
	if not decode(str, startPos) then
		error("initial position is not beginning of a valid sequence",
		      2)
	end
	local itStart, itEnd, itStep = nil, nil, nil
	if n > 0 then -- Find the beginning of the n'th sequence forwards
		itStart = startPos
		itEnd = #str
		itStep = 1
	else -- Find the beginning of the n'th sequence backwards
		n = -n
		itStart = startPos
		itEnd = 1
		itStep = -1
	end
	for i = itStart, itEnd, itStep do
		local seqStartPos, seqEndPos = decode(str, i)
		if seqStartPos then
			n = n - 1
			if n == 0 then
				return seqStartPos
			end
		end
	end
	return nil
end

-- Forces a string to contain only valid UTF-8 data.
-- Invalid sequences are replaced with U+FFFD.
function utf8.force(str)
	local buf = {}
	local curPos, endPos = 1, #str
	repeat
		local seqStartPos, seqEndPos = decode(str, curPos)
		if not seqStartPos then
			table.insert(buf, char(0xFFFD))
			curPos = curPos + 1
		else
			table.insert(buf, str:sub(seqStartPos, seqEndPos))
			curPos = seqEndPos + 1
		end
	until curPos > endPos
	return table.concat(buf, "")
end

---
return utf8