第195集:优化实战(一)—— 局部优化

1. 什么是局部优化?

局部优化(Local Optimization)是一种编译器优化技术,它在单个基本块(Basic Block)内部进行优化,不考虑基本块之间的控制流。局部优化是最基础的优化技术,也是其他高级优化技术的基础。

局部优化的基本思想

  1. 分析基本块:分析基本块内的指令及其依赖关系
  2. 构建中间表示:如DAG(有向无环图)
  3. 应用优化变换:如公共子表达式消除、常量传播等
  4. 重新生成代码:基于优化后的中间表示生成优化的代码

局部优化的优势

  • 实现简单:只需要分析单个基本块,不需要考虑控制流
  • 开销小:优化过程的时间和空间开销较小
  • 效果明显:对于包含大量冗余计算的代码,局部优化可以显著提高性能
  • 为其他优化做准备:局部优化可以为更高级的优化技术创造条件

2. 基本块

2.1 基本块的定义

基本块是指程序中一段连续的代码,其中:

  • 只有一个入口点:即第一个指令
  • 只有一个出口点:即最后一个指令
  • 执行时从入口点进入,从出口点退出,不会在中间跳转或停止

2.2 基本块的识别

def identify_basic_blocks(function):
    """识别函数中的基本块"""
    instructions = function.instructions
    blocks = []
    
    # 1. 识别基本块的边界
    leaders = set()
    leaders.add(0)  # 第一个指令总是一个基本块的开始
    
    for i, instr in enumerate(instructions):
        # 跳转指令的目标是一个基本块的开始
        if is_branch_instruction(instr):
            target = get_branch_target(instr)
            if target is not None:
                leaders.add(target)
        
        # 跳转指令的下一条指令是一个基本块的开始
        if is_jump_instruction(instr):
            if i + 1 < len(instructions):
                leaders.add(i + 1)
    
    # 2. 构建基本块
    sorted_leaders = sorted(leaders)
    for i in range(len(sorted_leaders)):
        start = sorted_leaders[i]
        end = sorted_leaders[i + 1] if i + 1 < len(sorted_leaders) else len(instructions)
        block = BasicBlock(start, end, instructions[start:end])
        blocks.append(block)
    
    return blocks

2.3 基本块的表示

class BasicBlock:
    """表示一个基本块"""
    def __init__(self, start, end, instructions):
        self.start = start  # 基本块的开始指令索引
        self.end = end  # 基本块的结束指令索引
        self.instructions = instructions  # 基本块内的指令
        self.successors = []  # 后继基本块
        self.predecessors = []  # 前驱基本块
    
    def __str__(self):
        return f"BasicBlock({self.start}-{self.end})"

3. DAG构造

3.1 DAG的基本概念

DAG(Directed Acyclic Graph,有向无环图)是局部优化的重要工具,它可以有效地表示基本块内的计算关系:

  • 节点:表示操作数或计算结果
  • :表示操作数与计算之间的依赖关系
  • 标签:节点可以有标签,表示变量名或常量值

3.2 DAG的构造算法

def build_dag(block):
    """为基本块构建DAG"""
    dag = {}
    node_map = {}  # 变量名到节点的映射
    
    for instr in block.instructions:
        if is_assignment_instruction(instr):
            # 处理赋值指令
            dest = instr.destination
            op = instr.operator
            args = instr.arguments
            
            # 为参数查找或创建节点
            arg_nodes = []
            for arg in args:
                if arg in node_map:
                    arg_nodes.append(node_map[arg])
                elif is_constant(arg):
                    # 创建常量节点
                    const_node = create_constant_node(arg)
                    arg_nodes.append(const_node)
                    node_map[arg] = const_node
                else:
                    # 创建变量节点
                    var_node = create_variable_node(arg)
                    arg_nodes.append(var_node)
                    node_map[arg] = var_node
            
            # 创建操作节点
            op_node = create_operation_node(op, arg_nodes)
            
            # 更新变量到节点的映射
            node_map[dest] = op_node
        elif is_load_instruction(instr):
            # 处理加载指令
            dest = instr.destination
            addr = instr.address
            
            # 创建加载节点
            load_node = create_load_node(addr)
            node_map[dest] = load_node
        elif is_store_instruction(instr):
            # 处理存储指令
            addr = instr.address
            value = instr.value
            
            # 为值查找或创建节点
            if value in node_map:
                value_node = node_map[value]
            elif is_constant(value):
                value_node = create_constant_node(value)
            else:
                value_node = create_variable_node(value)
            
            # 创建存储节点
            create_store_node(addr, value_node)
    
    return dag

3.3 DAG的优化

基于DAG,可以应用多种优化技术:

  • 公共子表达式消除:如果两个节点表示相同的计算,则只需要计算一次
  • 常量折叠:如果所有操作数都是常量,则可以在编译时计算结果
  • 死代码消除:如果一个节点的结果没有被使用,则可以删除该节点
  • 代数简化:应用代数规则简化表达式,如 x + 0 → x

4. 公共子表达式消除

4.1 什么是公共子表达式?

公共子表达式(Common Subexpression)是指在一个基本块中多次出现的相同表达式。例如:

x = a + b;
y = c * d;
z = a + b; // 这里的 a + b 是公共子表达式

4.2 公共子表达式消除的实现

def eliminate_common_subexpressions(block):
    """消除基本块中的公共子表达式"""
    # 1. 构建DAG
    dag, node_map = build_dag(block)
    
    # 2. 从DAG重新生成代码
    new_instructions = []
    temp_vars = {}  # 表达式到临时变量的映射
    var_map = {}  # 原始变量到新变量的映射
    
    # 遍历原始指令,保持副作用
    for instr in block.instructions:
        if is_assignment_instruction(instr):
            dest = instr.destination
            op = instr.operator
            args = instr.arguments
            
            # 构建表达式的规范化表示
            expr = (op, tuple(args))
            
            if expr in temp_vars:
                # 公共子表达式,复用临时变量
                temp_var = temp_vars[expr]
                new_instructions.append(create_assignment(dest, temp_var))
                var_map[dest] = temp_var
            else:
                # 新表达式,生成新的临时变量
                temp_var = create_temp_variable()
                temp_vars[expr] = temp_var
                new_instructions.append(create_assignment(temp_var, op, args))
                new_instructions.append(create_assignment(dest, temp_var))
                var_map[dest] = temp_var
        else:
            # 其他指令,保持不变
            new_instructions.append(instr)
    
    # 3. 更新基本块
    block.instructions = new_instructions
    
    return block

4.3 示例

// 优化前
x = a + b;
y = c * d;
z = a + b;
w = x + y;

// 优化后
_t1 = a + b;
x = _t1;
y = c * d;
z = _t1;
w = _t1 + y;

5. 常量折叠与传播

5.1 常量折叠

常量折叠(Constant Folding)是指在编译时计算常量表达式的值,例如:

x = 2 + 3; // 可以折叠为 x = 5

5.2 常量传播

常量传播(Constant Propagation)是指将常量值传播到使用该常量的地方,例如:

x = 5;
y = x + 3; // 可以传播为 y = 5 + 3,然后折叠为 y = 8

5.3 实现

def constant_folding_and_propagation(block):
    """常量折叠与传播"""
    # 1. 收集常量定义
    constants = {}
    for instr in block.instructions:
        if is_assignment_instruction(instr) and all(is_constant(arg) for arg in instr.arguments):
            # 常量折叠
            value = evaluate_constant_expression(instr.operator, instr.arguments)
            constants[instr.destination] = value
        elif is_assignment_instruction(instr):
            # 检查是否可以进行常量传播
            new_args = []
            all_constants = True
            for arg in instr.arguments:
                if arg in constants:
                    new_args.append(constants[arg])
                elif is_constant(arg):
                    new_args.append(arg)
                else:
                    new_args.append(arg)
                    all_constants = False
            
            # 更新指令
            instr.arguments = new_args
            
            # 如果所有参数都是常量,进行常量折叠
            if all_constants:
                value = evaluate_constant_expression(instr.operator, new_args)
                constants[instr.destination] = value
                # 替换为直接赋值
                instr.operator = '='
                instr.arguments = [value]
        elif is_load_instruction(instr):
            # 加载指令通常不能进行常量折叠
            pass
        elif is_store_instruction(instr):
            # 检查是否可以进行常量传播
            if instr.value in constants:
                instr.value = constants[instr.value]
    
    # 2. 常量传播到其他指令
    for instr in block.instructions:
        if is_assignment_instruction(instr) and not all(is_constant(arg) for arg in instr.arguments):
            new_args = []
            for arg in instr.arguments:
                if arg in constants:
                    new_args.append(constants[arg])
                else:
                    new_args.append(arg)
            instr.arguments = new_args
        elif is_load_instruction(instr):
            pass
        elif is_store_instruction(instr):
            if instr.value in constants:
                instr.value = constants[instr.value]
    
    return block

5.4 示例

// 优化前
x = 2 + 3;
y = x * 4;
z = y - 5;

// 优化后
x = 5;
y = 20;
z = 15;

6. 死代码消除

6.1 什么是死代码?

死代码(Dead Code)是指那些计算结果永远不会被使用的代码。死代码不仅浪费执行时间,还会增加程序的大小。

6.2 死代码消除的实现

def eliminate_dead_code(block):
    """消除基本块中的死代码"""
    # 1. 分析变量的使用情况
    used_vars = set()
    definitions = {}
    
    # 从后向前分析,找出被使用的变量
    for instr in reversed(block.instructions):
        if is_assignment_instruction(instr):
            dest = instr.destination
            if dest in used_vars:
                # 变量被使用,标记其定义的参数也可能被使用
                for arg in instr.arguments:
                    if not is_constant(arg):
                        used_vars.add(arg)
            else:
                # 变量未被使用,标记为死代码
                instr.dead = True
        elif is_load_instruction(instr):
            dest = instr.destination
            if dest in used_vars:
                # 加载的变量被使用,标记地址可能被使用
                addr = instr.address
                if not is_constant(addr):
                    used_vars.add(addr)
            else:
                # 加载的变量未被使用,标记为死代码
                instr.dead = True
        elif is_store_instruction(instr):
            # 存储指令有副作用,不能删除
            # 标记值和地址可能被使用
            value = instr.value
            if not is_constant(value):
                used_vars.add(value)
            addr = instr.address
            if not is_constant(addr):
                used_vars.add(addr)
        elif is_branch_instruction(instr):
            # 分支指令的条件可能被使用
            cond = instr.condition
            if not is_constant(cond):
                used_vars.add(cond)
    
    # 2. 删除死代码
    new_instructions = []
    for instr in block.instructions:
        if not hasattr(instr, 'dead') or not instr.dead:
            new_instructions.append(instr)
    
    block.instructions = new_instructions
    
    return block

6.3 示例

// 优化前
x = a + b;
y = c * d;
// y 的值从未被使用
z = x + 5;

// 优化后
x = a + b;
z = x + 5;

7. 代数简化

7.1 代数简化的规则

代数简化(Algebraic Simplification)是指应用代数规则简化表达式,例如:

  • x + 0 → x
  • x * 1 → x
  • x * 0 → 0
  • x - 0 → x
  • x / 1 → x
  • x && true → x
  • x || false → x

7.2 代数简化的实现

def algebraic_simplification(block):
    """代数简化"""
    for instr in block.instructions:
        if is_assignment_instruction(instr):
            op = instr.operator
            args = instr.arguments
            
            # 应用代数简化规则
            simplified = False
            
            if op == '+':
                if len(args) == 2:
                    a, b = args
                    if is_constant(b) and b == 0:
                        # x + 0 → x
                        instr.operator = '='
                        instr.arguments = [a]
                        simplified = True
                    elif is_constant(a) and a == 0:
                        # 0 + x → x
                        instr.operator = '='
                        instr.arguments = [b]
                        simplified = True
            elif op == '*':
                if len(args) == 2:
                    a, b = args
                    if is_constant(b) and b == 1:
                        # x * 1 → x
                        instr.operator = '='
                        instr.arguments = [a]
                        simplified = True
                    elif is_constant(a) and a == 1:
                        # 1 * x → x
                        instr.operator = '='
                        instr.arguments = [b]
                        simplified = True
                    elif (is_constant(a) and a == 0) or (is_constant(b) and b == 0):
                        # x * 0 → 0 或 0 * x → 0
                        instr.operator = '='
                        instr.arguments = [0]
                        simplified = True
            elif op == '-':
                if len(args) == 2:
                    a, b = args
                    if is_constant(b) and b == 0:
                        # x - 0 → x
                        instr.operator = '='
                        instr.arguments = [a]
                        simplified = True
            elif op == '/':
                if len(args) == 2:
                    a, b = args
                    if is_constant(b) and b == 1:
                        # x / 1 → x
                        instr.operator = '='
                        instr.arguments = [a]
                        simplified = True
            # 可以添加更多的代数简化规则
            
            if simplified:
                # 尝试进行常量折叠
                if all(is_constant(arg) for arg in instr.arguments):
                    value = evaluate_constant_expression(instr.operator, instr.arguments)
                    instr.operator = '='
                    instr.arguments = [value]
    
    return block

7.3 示例

// 优化前
x = a + 0;
y = b * 1;
z = c * 0;

// 优化后
x = a;
y = b;
z = 0;

8. 局部优化的实现步骤

8.1 整体流程

def optimize_basic_block(block):
    """优化基本块"""
    # 1. 代数简化
    block = algebraic_simplification(block)
    
    # 2. 常量折叠与传播
    block = constant_folding_and_propagation(block)
    
    # 3. 公共子表达式消除
    block = eliminate_common_subexpressions(block)
    
    # 4. 死代码消除
    block = eliminate_dead_code(block)
    
    # 5. 再次应用代数简化
    block = algebraic_simplification(block)
    
    return block

def optimize_function(function):
    """优化函数"""
    # 1. 识别基本块
    blocks = identify_basic_blocks(function)
    
    # 2. 优化每个基本块
    optimized_blocks = []
    for block in blocks:
        optimized_block = optimize_basic_block(block)
        optimized_blocks.append(optimized_block)
    
    # 3. 重新生成函数代码
    new_instructions = []
    for block in optimized_blocks:
        new_instructions.extend(block.instructions)
    
    function.instructions = new_instructions
    
    return function

8.2 示例

// 优化前
x = a + b;
y = c * d;
z = a + b;
w = x + y;
if (w > 0) {
    t = 5 * 0;
    u = t + 3;
}

// 优化后
_t1 = a + b;
x = _t1;
y = c * d;
z = _t1;
w = _t1 + y;
if (w > 0) {
    t = 0;
    u = 3;
}

9. 局部优化的挑战

9.1 表达式识别

  • 表达式的规范化:不同形式的表达式可能表示相同的计算,需要规范化
  • 复杂表达式:对于复杂表达式,识别公共子表达式变得困难
  • 副作用:带有副作用的表达式不能随意消除或重排序

9.2 依赖分析

  • 数据依赖:需要正确分析指令间的数据依赖关系
  • 控制依赖:局部优化不考虑控制依赖,可能会影响其他优化
  • 内存依赖:对于内存操作,依赖分析更加复杂

9.3 代码生成

  • 寄存器分配:优化后的代码可能需要重新进行寄存器分配
  • 指令选择:需要选择合适的指令来实现优化后的表达式
  • 代码布局:优化后的代码布局可能影响缓存性能

10. 局部优化的工具与实践

10.1 编译器选项

# GCC 中的局部优化选项
gcc -O1 source.c -o program  # 基本优化,包括局部优化
gcc -O2 source.c -o program  # 更高级的优化,包括局部优化
gcc -O3 source.c -o program  # 最高级的优化,包括局部优化

# 查看优化信息
gcc -O2 -fopt-info-vec source.c

10.2 优化器的调试

# 生成优化前后的中间表示
gcc -O0 -fdump-tree-original source.c
gcc -O2 -fdump-tree-optimized source.c

# 比较优化前后的代码
diff -u source.c.001t.original source.c.004t.optimized

10.3 实践建议

  • 从小规模开始:先在小规模的代码上应用局部优化,熟悉优化效果
  • 分析瓶颈:使用性能分析工具找出程序的性能瓶颈,针对性地应用优化
  • 验证结果:确保优化后的代码与原始代码的行为一致
  • 权衡利弊:优化可能会增加编译时间,需要在编译时间和运行时间之间取得平衡

11. 局部优化的效果评估

11.1 基准测试

// benchmark.c
int main() {
    int a = 1, b = 2, c = 3, d = 4;
    int x, y, z, w;
    
    // 包含大量冗余计算的代码
    for (int i = 0; i < 1000000000; i++) {
        x = a + b;
        y = c * d;
        z = a + b;  // 公共子表达式
        w = x + y;
    }
    
    printf("x = %d, y = %d, z = %d, w = %d\n", x, y, z, w);
    return 0;
}

11.2 编译与运行

# 无优化
gcc -O0 benchmark.c -o no_opt

# 有优化
gcc -O2 benchmark.c -o with_opt

# 比较性能
time ./no_opt
time ./with_opt

11.3 结果分析

  • 无优化:代码中包含大量冗余计算,执行速度慢
  • 有优化:编译器应用了公共子表达式消除等局部优化,执行速度快

12. 总结

局部优化是一种基础但有效的编译器优化技术,它通过分析和变换单个基本块内的代码,消除冗余计算,提高程序的性能。局部优化包括多种技术:

  • 公共子表达式消除:消除重复计算的表达式
  • 常量折叠与传播:在编译时计算常量表达式的值
  • 死代码消除:删除计算结果未被使用的代码
  • 代数简化:应用代数规则简化表达式

尽管局部优化只考虑单个基本块内的代码,不考虑控制流,但它是其他高级优化技术的基础,也是编译器优化中最常用的技术之一。通过合理应用局部优化,可以显著提高程序的性能,同时保持代码的正确性。

在实际应用中,编译器会根据代码的特点自动选择合适的优化技术,以达到最佳的优化效果。作为程序员,了解局部优化的原理和技术,可以帮助我们编写更易于优化的代码,从而获得更好的性能。

13. 练习

  1. 手动优化:手动优化以下代码,应用局部优化技术

    int foo(int a, int b, int c) {
        int x = a + b;
        int y = c * 2;
        int z = a + b;  // 公共子表达式
        int w = x + y;
        int t = 5 * 0;  // 常量折叠
        return w + t;
    }
  2. DAG构造:为以下代码构建DAG

    x = a + b;
    y = x * c;
    z = a + b;  // 公共子表达式
    w = y + z;
  3. 实现简单的局部优化器:实现一个简单的局部优化器,支持公共子表达式消除和常量折叠

  4. 优化效果分析:编译以下代码,比较优化前后的性能差异

    void compute(int *a, int *b, int *c, int n) {
        for (int i = 0; i < n; i++) {
            c[i] = a[i] * 2 + b[i] + a[i] * 2;  // 包含公共子表达式
        }
    }
  5. 代码优化建议:分析以下代码,提出局部优化的建议

    int calculate(int x, int y, int z) {
        int result = 0;
        result += x * y + z;
        result += x * y + z;  // 重复计算
        result += x * 0;  // 可以简化
        return result;
    }
« 上一篇 并行化 下一篇 » 优化实战(二)—— 循环优化