How scalar evolution is used to do optimization in LuaJIT
I fixed an error related to invalide using Scalar Evolution result in the LuaJIT several months ago https://github.com/LuaJIT/LuaJIT/pull/1115. It is interesting to see how Scalar Evolution result is used to do optimization.
what is Scalar Evolution?
While Scalar Evolution analysis (SE) can become complex in static
compiler when we try to capture more cases (example). But in LuaJIT,
it just records the index related information in a for
loop:
typedef struct ScEvEntry { MRef pc; /* Bytecode PC of FORI. */ IRRef1 idx; /* Index reference. */ IRRef1 start; /* Constant start reference. */ IRRef1 stop; /* Constant stop reference. */ IRRef1 step; /* Constant step reference. */ IRType1 t; /* Scalar type. */ uint8_t dir; /* Direction. 1: +, 0: -. */ } ScEvEntry;
For the code like
for i = 1, n do sum = sum + arr[i] end
Bascially, it records i
's information: start 1
, end n
, step 1
.
Two kinds of loop with different performance
SE is used for array boundary check elimination (ABC elimination). ABC has non negligible performance overhead in safe language, because it is usually required in a loop, which can be a hotspot.
The idea of ABC elimination is like if we can prove the index of accessing array in the loop is always within in array's boundary, it is safe to remove the ABC.
For the following code
local function faster(arr, n) local sum = 0 for i = 1, n do sum = sum + arr[i] end return sum end local myarr = {} for i = 1, 100 do myarr[#myarr + 1] = i end faster(myarr, 100)
It's trace is like:
---- TRACE 2 start test.lua:3 0006 TGETV 7 0 6 0007 ADDVV 2 2 7 0008 FORL 3 => 0006 ---- TRACE 2 IR 0001 int SLOAD #6 RI 0002 > int LE 0001 +2147483646 0003 > int SLOAD #5 TI 0004 > tab SLOAD #2 T 0005 int FLOAD 0004 tab.asize 0006 > p32 ABC 0005 0001 0007 p64 FLOAD 0004 tab.array 0008 p64 AREF 0007 0003 0009 > int ALOAD 0008 0010 > int SLOAD #4 T 0011 >+ int ADDOV 0010 0009 0012 + int ADD 0003 +1 0013 > int LE 0012 0001 0014 ------ LOOP ------------ 0015 p64 AREF 0007 0012 0016 > int ALOAD 0015 0017 >+ int ADDOV 0016 0011 0018 + int ADD 0012 +1 0019 > int LE 0018 0001 0020 int PHI 0012 0018 0021 int PHI 0011 0017
As we can see there is no ABC in the LOOP from 0014 to 0021.
If we change the lua code a little bit to:
local function slower(arr, n) local sum, i = 0, 1 while i <= n do sum = sum + arr[i] i = i + 1 end return sum end local myarr = {} for i = 1, 100 do myarr[#myarr + 1] = i end slower(myarr, 100)
It's trace will be:
---- TRACE 3 start test.lua:11 0006 TGETV 4 0 3 0007 ADDVV 2 2 4 0008 ADDVN 3 3 0 ; 1 0009 JMP 4 => 0003 0003 ISGT 3 1 0004 JMP 4 => 0010 0005 LOOP 4 => 0010 ---- TRACE 3 IR 0001 > tab SLOAD #2 T 0002 > int SLOAD #5 T 0003 int FLOAD 0001 tab.asize 0004 > int ABC 0003 0002 0005 p64 FLOAD 0001 tab.array 0006 p64 AREF 0005 0002 0007 > int ALOAD 0006 0008 > int SLOAD #4 T 0009 >+ int ADDOV 0008 0007 0010 >+ int ADDOV 0002 +1 0011 > int SLOAD #3 T 0012 > int GE 0011 0010 0013 ------ LOOP ------------ 0014 > int ABC 0003 0010 0015 p64 AREF 0005 0010 0016 > int ALOAD 0015 0017 >+ int ADDOV 0016 0009 0018 >+ int ADDOV 0010 +1 0019 > int LE 0018 0011 0020 int PHI 0010 0018 0021 int PHI 0009 0017
The ABC (0014) is not eliminated in the LOOP.
Here is the comparison:
// faster // slower ---- TRACE 2 IR ---- TRACE 3 IR 0001 int SLOAD #6 RI 0001 > tab SLOAD #2 T 0002 > int LE 0001 +2147483646 0002 > int SLOAD #5 T 0003 > int SLOAD #5 TI 0003 int FLOAD 0001 tab.asize 0004 > tab SLOAD #2 T 0004 > int ABC 0003 0002 0005 int FLOAD 0004 tab.asize 0005 p64 FLOAD 0001 tab.array 0006 > p32 ABC 0005 0001 0006 p64 AREF 0005 0002 0007 p64 FLOAD 0004 tab.array 0007 > int ALOAD 0006 0008 p64 AREF 0007 0003 0008 > int SLOAD #4 T 0009 > int ALOAD 0008 0009 >+ int ADDOV 0008 0007 0010 > int SLOAD #4 T 0010 >+ int ADDOV 0002 +1 0011 >+ int ADDOV 0010 0009 0011 > int SLOAD #3 T 0012 + int ADD 0003 +1 0012 > int GE 0011 0010 0013 > int LE 0012 0001 0013 ------ LOOP ------------ 0014 ------ LOOP ------------ 0014 > int ABC 0003 0010 0015 p64 AREF 0007 0012 0015 p64 AREF 0005 0010 0016 > int ALOAD 0015 0016 > int ALOAD 0015 0017 >+ int ADDOV 0016 0011 0017 >+ int ADDOV 0016 0009 0018 + int ADD 0012 +1 0018 >+ int ADDOV 0010 +1 0019 > int LE 0018 0001 0019 > int LE 0018 0011 0020 int PHI 0012 0018 0020 int PHI 0010 0018 0021 int PHI 0011 0017 0021 int PHI 0009 0017
In the faster case, in the Loop body, there is no ABC (array
boundary checking). It is because ABC in the loop preamble is 0006
ABC 0005 0001 And 0001 is loaded from interpreter's Lua stack,
which is a constant value. This value is scalar evolution analysis'
result: J->scev.stop
. And it is invariant. This ABC generation is
done in rec_idx_abc()
. And J->scev
is setup during
lj_record_setup()
's rec_for_loop()
call.
So during Loop code generation, the 0006 ABC 0005 0001
is not
regenerated since its operands does not change and ABC itself is a
normal op without side effect (IRM_N
).
// static void loop_unroll(LoopState *lps) /* Substitute instruction operands. */ ir = IR(ins); op1 = ir->op1; if (!irref_isk(op1)) op1 = subst[op1]; op2 = ir->op2; if (!irref_isk(op2)) op2 = subst[op2]; if (irm_kind(lj_ir_mode[ir->o]) == IRM_N && op1 == ir->op1 && op2 == ir->op2) { /* Regular invariant ins? */ subst[ins] = (IRRef1)ins; /* Shortcut. */ }
And in comparision, the slower loop contains ABC. This ABC is
generated in the loop_unroll()
. Because its second op 0002 is not
an invariant. And it's value is updated by 0010 ADDOV 0002 +1
.
The reason of it is not an invariant is during rec_idx_abc()
, the
optimizer does not find ABC's operand matching to the scalar
evolution analysis result. Because the scalar evolution analysis
does not works with the LOOP
Byte Code (BC), it only works with FORL
BC.
Strength reduction
When there is a CONV IR that convert integer with sign extention, Scalar Evolution result can also be used to simplify (eliminate) the CONV.
say CONV 0002 SEXT
does sign extention for 0002
. And from Scalar
Evolution result, we know 0002
's minimal value is > 0. Then this
CONV is not needed. From SE result, we can infer that minimal value
is start if the direction is increasing, if direction is decreasing,
the minimal is stop.
This is implemented in FOLD
rule simplify_conv_sext
.