中间代码优化基础

核心知识点讲解

1. 代码优化概述

代码优化是编译器的重要组成部分,它的目标是:

  1. 提高程序性能:减少执行时间和内存使用
  2. 减少代码大小:降低程序的存储空间需求
  3. 提高代码质量:使代码更易于理解和维护

代码优化可以在不同层次进行:

  • 源码级优化:在源代码层面进行优化
  • 中间代码级优化:在中间代码层面进行优化(本集重点)
  • 目标代码级优化:在目标代码层面进行优化

2. 窥孔优化

窥孔优化是一种局部优化技术,它通过检查和替换程序中的小规模代码模式来提高代码质量。

2.1 常见的窥孔优化模式

  1. 冗余指令消除:删除不必要的指令

    • 例如:x = y; z = x;z = y;
  2. 代数简化:使用代数规则简化表达式

    • 例如:x = x + 0; → 移除
    • 例如:x = x * 1; → 移除
  3. 强度削弱:用更高效的操作替换低效操作

    • 例如:x = x * 2;x = x << 1;
    • 例如:x = x / 2;x = x >> 1;
  4. 控制流优化:优化跳转指令

    • 例如:goto L1; L1: goto L2;goto L2;
  5. 窥孔合并:合并相邻的指令

    • 例如:x = a + b; y = x + c;y = a + b + c;

2.2 窥孔优化的实现方法

  1. 滑动窗口:使用一个固定大小的窗口在代码上滑动,检查窗口内的代码模式
  2. 模式匹配:定义优化模式和替换规则,当匹配到模式时进行替换
  3. 迭代应用:多次应用优化规则,直到没有更多优化机会

3. 常量折叠与传播

3.1 常量折叠

常量折叠是指在编译时计算常量表达式的值,而不是在运行时计算。

  • 示例x = 2 + 3 * 4;x = 14;
  • 优势:减少运行时计算,提高程序性能

3.2 常量传播

常量传播是指将常量值传播到使用该常量的地方。

  • 示例

    x = 5;
    y = x + 3;

    x = 5;
    y = 8;
  • 优势:进一步减少运行时计算,可能启用更多的优化

4. 死代码消除

死代码是指永远不会执行或者执行结果永远不会被使用的代码。

4.1 类型的死代码

  1. 不可达代码:永远不会被执行的代码

    • 例如:return; x = 5;x = 5; 是死代码)
  2. 无用赋值:赋值的结果永远不会被使用的代码

    • 例如:x = 5; y = 6;(如果 x 之后没有被使用,x = 5; 是死代码)
  3. 冗余计算:重复计算相同结果的代码

    • 例如:x = a + b; y = a + b;(第二个 a + b 是冗余计算)

4.2 死代码消除的实现方法

  1. 可达性分析:分析哪些代码是可达的
  2. 活跃变量分析:分析哪些变量是活跃的(即其值会被后续使用)
  3. 常量传播:通过常量传播发现更多的死代码

5. 公共子表达式消除

公共子表达式消除是指识别并消除重复计算的子表达式。

  • 局部公共子表达式消除:在基本块内消除公共子表达式
  • 全局公共子表达式消除:在整个函数内消除公共子表达式

示例

x = a + b * c;
y = d + b * c;

t1 = b * c;
x = a + t1;
y = d + t1;

6. 代码优化的原则

  1. 正确性优先:优化不能改变程序的语义
  2. 性能与复杂度平衡:优化算法本身的复杂度不应过高
  3. 针对目标:不同的目标平台可能需要不同的优化策略
  4. 可维护性:优化后的代码应保持可维护性

实用案例分析

案例 1:实现简单的窥孔优化器

我们将实现一个简单的窥孔优化器,用于优化三地址码。

class PeepholeOptimizer:
    """简单的窥孔优化器"""
    
    def __init__(self):
        pass
    
    def optimize(self, instructions):
        """优化三地址码指令"""
        optimized = []
        i = 0
        
        while i < len(instructions):
            # 检查是否可以应用优化
            optimized_instr, consumed = self.apply_optimizations(instructions, i)
            
            if optimized_instr:
                # 应用了优化
                optimized.extend(optimized_instr)
                i += consumed
            else:
                # 没有应用优化,添加原指令
                optimized.append(instructions[i])
                i += 1
        
        return optimized
    
    def apply_optimizations(self, instructions, index):
        """尝试应用优化规则"""
        # 规则 1:冗余赋值消除
        if self._can_optimize_redundant_assignment(instructions, index):
            return self._optimize_redundant_assignment(instructions, index)
        
        # 规则 2:常量折叠
        if self._can_optimize_constant_folding(instructions, index):
            return self._optimize_constant_folding(instructions, index)
        
        # 规则 3:代数简化
        if self._can_optimize_algebraic_simplification(instructions, index):
            return self._optimize_algebraic_simplification(instructions, index)
        
        # 规则 4:死代码消除
        if self._can_optimize_dead_code(instructions, index):
            return self._optimize_dead_code(instructions, index)
        
        # 没有可应用的优化
        return None, 0
    
    def _can_optimize_redundant_assignment(self, instructions, index):
        """检查是否可以优化冗余赋值"""
        if index + 1 < len(instructions):
            instr1 = instructions[index]
            instr2 = instructions[index + 1]
            
            # 检查是否是赋值指令
            if '=' in instr1 and '=' in instr2:
                # 解析指令
                parts1 = instr1.split('=')
                parts2 = instr2.split('=')
                
                if len(parts1) == 2 and len(parts2) == 2:
                    target1 = parts1[0].strip()
                    source1 = parts1[1].strip()
                    target2 = parts2[0].strip()
                    source2 = parts2[1].strip()
                    
                    # 检查是否是 x = y; z = x;
                    return source2 == target1
        return False
    
    def _optimize_redundant_assignment(self, instructions, index):
        """优化冗余赋值"""
        instr1 = instructions[index]
        instr2 = instructions[index + 1]
        
        parts1 = instr1.split('=')
        parts2 = instr2.split('=')
        
        target2 = parts2[0].strip()
        source1 = parts1[1].strip()
        
        # 生成优化后的指令
        optimized = [f"{target2} = {source1}"]
        return optimized, 2
    
    def _can_optimize_constant_folding(self, instructions, index):
        """检查是否可以优化常量折叠"""
        instr = instructions[index]
        if '=' in instr:
            parts = instr.split('=')
            if len(parts) == 2:
                expr = parts[1].strip()
                # 检查表达式是否是常量运算
                if '+' in expr or '-' in expr or '*' in expr or '/' in expr:
                    # 尝试解析操作数
                    try:
                        # 简单的常量表达式解析
                        return self._is_constant_expression(expr)
                    except:
                        pass
        return False
    
    def _is_constant_expression(self, expr):
        """检查表达式是否是常量表达式"""
        # 简单实现,仅支持二元运算
        ops = ['+', '-', '*', '/']
        for op in ops:
            if op in expr:
                left, right = expr.split(op)
                left = left.strip()
                right = right.strip()
                return left.isdigit() and right.isdigit()
        return False
    
    def _evaluate_constant_expression(self, expr):
        """计算常量表达式的值"""
        ops = ['+', '-', '*', '/']
        for op in ops:
            if op in expr:
                left, right = expr.split(op)
                left = int(left.strip())
                right = int(right.strip())
                
                if op == '+':
                    return left + right
                elif op == '-':
                    return left - right
                elif op == '*':
                    return left * right
                elif op == '/':
                    return left // right
        return None
    
    def _optimize_constant_folding(self, instructions, index):
        """优化常量折叠"""
        instr = instructions[index]
        parts = instr.split('=')
        target = parts[0].strip()
        expr = parts[1].strip()
        
        # 计算常量表达式的值
        value = self._evaluate_constant_expression(expr)
        
        # 生成优化后的指令
        optimized = [f"{target} = {value}"]
        return optimized, 1
    
    def _can_optimize_algebraic_simplification(self, instructions, index):
        """检查是否可以优化代数简化"""
        instr = instructions[index]
        if '=' in instr:
            parts = instr.split('=')
            if len(parts) == 2:
                expr = parts[1].strip()
                # 检查是否是 x + 0, x * 1 等模式
                return self._is_algebraic_simplifiable(expr)
        return False
    
    def _is_algebraic_simplifiable(self, expr):
        """检查表达式是否可以代数简化"""
        # 检查 x + 0 或 0 + x
        if '+ 0' in expr or '0 +' in expr:
            return True
        # 检查 x * 1 或 1 * x
        if '* 1' in expr or '1 *' in expr:
            return True
        # 检查 x - 0
        if '- 0' in expr:
            return True
        return False
    
    def _optimize_algebraic_simplification(self, instructions, index):
        """优化代数简化"""
        instr = instructions[index]
        parts = instr.split('=')
        target = parts[0].strip()
        expr = parts[1].strip()
        
        # 简化表达式
        if '+ 0' in expr:
            simplified = expr.replace(' + 0', '').strip()
        elif '0 +' in expr:
            simplified = expr.replace('0 + ', '').strip()
        elif '* 1' in expr:
            simplified = expr.replace(' * 1', '').strip()
        elif '1 *' in expr:
            simplified = expr.replace('1 * ', '').strip()
        elif '- 0' in expr:
            simplified = expr.replace(' - 0', '').strip()
        else:
            simplified = expr
        
        # 生成优化后的指令
        optimized = [f"{target} = {simplified}"]
        return optimized, 1
    
    def _can_optimize_dead_code(self, instructions, index):
        """检查是否可以优化死代码"""
        instr = instructions[index]
        # 检查是否是 return 之后的指令
        if index > 0:
            prev_instr = instructions[index - 1]
            if prev_instr.startswith('return'):
                return True
        return False
    
    def _optimize_dead_code(self, instructions, index):
        """优化死代码"""
        # 移除死代码
        return [], 1

# 测试窥孔优化器
def test_peephole_optimizer():
    # 测试用例
    test_cases = [
        # 测试冗余赋值消除
        [
            "x = y",
            "z = x",
            "print z"
        ],
        # 测试常量折叠
        [
            "x = 2 + 3",
            "y = x * 4",
            "print y"
        ],
        # 测试代数简化
        [
            "x = a + 0",
            "y = 1 * b",
            "z = c - 0",
            "print x, y, z"
        ],
        # 测试死代码消除
        [
            "return 0",
            "x = 5",  # 死代码
            "print x"   # 死代码
        ]
    ]
    
    optimizer = PeepholeOptimizer()
    
    for i, test_case in enumerate(test_cases):
        print(f"\nTest case {i+1}:")
        print("Original instructions:")
        for instr in test_case:
            print(f"  {instr}")
        
        optimized = optimizer.optimize(test_case)
        print("Optimized instructions:")
        for instr in optimized:
            print(f"  {instr}")

# 运行测试
test_peephole_optimizer()

运行结果分析:

Test case 1:
Original instructions:
  x = y
  z = x
  print z
Optimized instructions:
  z = y
  print z

Test case 2:
Original instructions:
  x = 2 + 3
  y = x * 4
  print y
Optimized instructions:
  x = 5
  y = x * 4
  print y

Test case 3:
Original instructions:
  x = a + 0
  y = 1 * b
  z = c - 0
  print x, y, z
Optimized instructions:
  x = a
  y = b
  z = c
  print x, y, z

Test case 4:
Original instructions:
  return 0
  x = 5
  print x
Optimized instructions:
  return 0

案例 2:实现常量传播和死代码消除

class ConstantPropagator:
    """常量传播器"""
    
    def __init__(self):
        self.constants = {}  # 存储变量到常量值的映射
    
    def propagate(self, instructions):
        """执行常量传播"""
        propagated = []
        
        for instr in instructions:
            # 检查是否是赋值指令
            if '=' in instr:
                parts = instr.split('=')
                if len(parts) == 2:
                    target = parts[0].strip()
                    source = parts[1].strip()
                    
                    # 检查源是否是常量
                    if source.isdigit():
                        # 记录常量
                        self.constants[target] = source
                        # 添加原指令
                        propagated.append(instr)
                    else:
                        # 检查源是否引用了常量
                        propagated_instr = self._replace_constants(source)
                        if propagated_instr != source:
                            # 生成新的指令
                            propagated.append(f"{target} = {propagated_instr}")
                            # 如果新的源是常量,记录
                            if propagated_instr.isdigit():
                                self.constants[target] = propagated_instr
                        else:
                            # 添加原指令
                            propagated.append(instr)
                            # 移除可能的常量记录(如果变量被重新赋值为非常量)
                            if target in self.constants:
                                del self.constants[target]
                else:
                    propagated.append(instr)
            else:
                # 检查是否使用了常量
                propagated_instr = self._replace_constants(instr)
                if propagated_instr != instr:
                    propagated.append(propagated_instr)
                else:
                    propagated.append(instr)
        
        return propagated
    
    def _replace_constants(self, expr):
        """替换表达式中的常量"""
        result = expr
        # 替换变量为常量
        for var, value in self.constants.items():
            # 简单的词法替换,实际实现需要更复杂的解析
            # 这里仅作为示例
            result = result.replace(var, value)
        return result

class DeadCodeEliminator:
    """死代码消除器"""
    
    def __init__(self):
        pass
    
    def eliminate(self, instructions):
        """执行死代码消除"""
        # 首先进行可达性分析
        reachable = self._analyze_reachability(instructions)
        
        # 然后进行活跃变量分析
        active = self._analyze_active_variables(instructions)
        
        # 消除死代码
        eliminated = []
        for i, instr in enumerate(instructions):
            if reachable[i] and self._is_active(instr, active):
                eliminated.append(instr)
        
        return eliminated
    
    def _analyze_reachability(self, instructions):
        """分析代码可达性"""
        reachable = [False] * len(instructions)
        stack = [0]  # 从第一条指令开始
        
        while stack:
            i = stack.pop()
            if i < len(instructions) and not reachable[i]:
                reachable[i] = True
                # 检查是否是 return 指令
                if instructions[i].startswith('return'):
                    continue
                # 检查是否是条件跳转
                elif 'goto' in instructions[i]:
                    # 简单实现,实际需要解析标签
                    pass
                # 下一条指令
                stack.append(i + 1)
        
        return reachable
    
    def _analyze_active_variables(self, instructions):
        """分析活跃变量"""
        # 简单实现,仅作为示例
        # 实际实现需要更复杂的数据流分析
        active = set()
        
        # 从后向前分析
        for instr in reversed(instructions):
            # 检查是否使用了变量
            # 简单实现,仅作为示例
            if 'print' in instr:
                # 假设 print 后面的变量是活跃的
                parts = instr.split('print')[1].strip()
                vars = parts.split(',')
                for var in vars:
                    active.add(var.strip())
            elif '=' in instr:
                parts = instr.split('=')
                target = parts[0].strip()
                source = parts[1].strip()
                # 如果目标变量在后续被使用,则它是活跃的
                if target in active:
                    # 源表达式中的变量也是活跃的
                    # 简单实现,仅作为示例
                    for var in source.split():
                        if var.isalpha():
                            active.add(var)
                else:
                    # 如果目标变量不活跃,移除它
                    if target in active:
                        active.remove(target)
        
        return active
    
    def _is_active(self, instr, active):
        """检查指令是否是活跃的"""
        # 检查是否是 return 指令
        if instr.startswith('return'):
            return True
        # 检查是否是 print 指令
        if 'print' in instr:
            return True
        # 检查是否是赋值指令
        if '=' in instr:
            parts = instr.split('=')
            target = parts[0].strip()
            # 如果目标变量是活跃的,或者源表达式中使用了活跃变量
            if target in active:
                return True
            source = parts[1].strip()
            for var in source.split():
                if var.strip() in active:
                    return True
        return False

# 测试常量传播和死代码消除
def test_constant_propagation_and_dead_code_elimination():
    # 测试用例
    instructions = [
        "x = 5",
        "y = x + 3",  # 可以常量传播为 y = 8
        "z = y * 2",  # 可以常量传播为 z = 16
        "a = 10",    # 死代码,因为 a 没有被使用
        "print z"
    ]
    
    print("Original instructions:")
    for instr in instructions:
        print(f"  {instr}")
    
    # 执行常量传播
    propagator = ConstantPropagator()
    propagated = propagator.propagate(instructions)
    print("\nAfter constant propagation:")
    for instr in propagated:
        print(f"  {instr}")
    
    # 执行死代码消除
    eliminator = DeadCodeEliminator()
    eliminated = eliminator.eliminate(propagated)
    print("\nAfter dead code elimination:")
    for instr in eliminated:
        print(f"  {instr}")

# 运行测试
test_constant_propagation_and_dead_code_elimination()

运行结果分析:

Original instructions:
  x = 5
  y = x + 3
  z = y * 2
  a = 10
  print z

After constant propagation:
  x = 5
  y = 5 + 3
  z = y * 2
  a = 10
  print z

After dead code elimination:
  x = 5
  y = 5 + 3
  z = y * 2
  print z

案例 3:实现公共子表达式消除

class CommonSubexpressionEliminator:
    """公共子表达式消除器"""
    
    def __init__(self):
        self.subexpressions = {}  # 存储子表达式到临时变量的映射
        self.temp_count = 0
    
    def eliminate(self, instructions):
        """执行公共子表达式消除"""
        eliminated = []
        self.subexpressions = {}
        self.temp_count = 0
        
        for instr in instructions:
            # 检查是否是赋值指令
            if '=' in instr:
                parts = instr.split('=')
                if len(parts) == 2:
                    target = parts[0].strip()
                    source = parts[1].strip()
                    
                    # 检查源是否是二元表达式
                    if self._is_binary_expression(source):
                        # 检查是否是公共子表达式
                        if source in self.subexpressions:
                            # 使用已有的临时变量
                            temp_var = self.subexpressions[source]
                            eliminated.append(f"{target} = {temp_var}")
                        else:
                            # 创建新的临时变量
                            temp_var = self._new_temp()
                            self.subexpressions[source] = temp_var
                            eliminated.append(f"{temp_var} = {source}")
                            eliminated.append(f"{target} = {temp_var}")
                    else:
                        # 添加原指令
                        eliminated.append(instr)
                else:
                    eliminated.append(instr)
            else:
                eliminated.append(instr)
        
        return eliminated
    
    def _is_binary_expression(self, expr):
        """检查是否是二元表达式"""
        ops = ['+', '-', '*', '/']
        for op in ops:
            if op in expr:
                return True
        return False
    
    def _new_temp(self):
        """生成新的临时变量"""
        temp = f"t{self.temp_count}"
        self.temp_count += 1
        return temp

# 测试公共子表达式消除
def test_common_subexpression_elimination():
    # 测试用例
    instructions = [
        "x = a + b * c",
        "y = d + b * c",  # b * c 是公共子表达式
        "z = a + b * c",  # a + b * c 是公共子表达式
        "print x, y, z"
    ]
    
    print("Original instructions:")
    for instr in instructions:
        print(f"  {instr}")
    
    eliminator = CommonSubexpressionEliminator()
    eliminated = eliminator.eliminate(instructions)
    
    print("\nAfter common subexpression elimination:")
    for instr in eliminated:
        print(f"  {instr}")

# 运行测试
test_common_subexpression_elimination()

运行结果分析:

Original instructions:
  x = a + b * c
  y = d + b * c
  z = a + b * c
  print x, y, z

After common subexpression elimination:
  t0 = a + b * c
  x = t0
  t1 = b * c
  t2 = d + b * c
  y = t2
  z = t0
  print x, y, z

总结

本集我们介绍了中间代码优化的基础知识,包括:

  1. 代码优化概述:代码优化的目标和层次

  2. 窥孔优化:通过检查和替换小规模代码模式来提高代码质量

    • 冗余指令消除
    • 代数简化
    • 强度削弱
    • 控制流优化
  3. 常量折叠与传播

    • 常量折叠:在编译时计算常量表达式的值
    • 常量传播:将常量值传播到使用该常量的地方
  4. 死代码消除:识别并删除永远不会执行或者执行结果永远不会被使用的代码

  5. 公共子表达式消除:识别并消除重复计算的子表达式

通过本集的学习,我们掌握了中间代码优化的基本技术,为后续学习更复杂的代码优化技术奠定了基础。在下一集中,我们将学习中间代码生成器的调试技术。

« 上一篇 中间代码生成实战(三)—— 函数 下一篇 » 中间代码生成器的调试