第177集:公共子表达式消除

核心知识点讲解

什么是公共子表达式?

公共子表达式(Common Subexpression)是指在程序中多次出现的、计算结果相同的表达式。例如:

x = a + b * c;
y = d + a + b * c;

这里的 a + b * c 就是一个公共子表达式,它在两个不同的赋值语句中出现。如果我们能够识别并消除这种重复计算,就可以提高程序的执行效率。

公共子表达式消除的类型

  1. 局部公共子表达式消除(Local CSE):在单个基本块内识别和消除公共子表达式
  2. 全局公共子表达式消除(Global CSE):跨多个基本块识别和消除公共子表达式

局部公共子表达式消除

局部CSE的实现相对简单,主要通过以下步骤:

  1. 构建基本块的DAG(有向无环图):将基本块中的表达式表示为DAG节点
  2. 识别公共子表达式:在DAG中,相同的子表达式会被表示为同一个节点
  3. 重构代码:从DAG生成优化后的代码,确保每个公共子表达式只计算一次

全局公共子表达式消除

全局CSE需要考虑整个程序的控制流,使用数据流分析技术来识别在不同基本块中出现的公共子表达式。主要涉及:

  1. 可用表达式分析:确定在程序的每个点上哪些表达式的值是可用的(即已经计算过且值未改变)
  2. 表达式哈希:为每个表达式计算哈希值,以便快速识别相同的表达式
  3. 插入计算点:在适当的位置插入公共子表达式的计算,并在其他位置引用该计算结果

可用表达式分析

可用表达式分析是一种数据流分析,用于确定在程序的每个点上哪些表达式是可用的。它基于以下两个集合:

  1. **IN[B]**:进入基本块B时可用的表达式集合
  2. **OUT[B]**:离开基本块B时可用的表达式集合

分析规则如下:

  • 边界条件:对于程序的第一个基本块,IN[B]为空集
  • 传递函数:OUT[B] = gen[B] ∪ (IN[B] - kill[B]),其中:
    • gen[B]:基本块B中生成的新可用表达式
    • kill[B]:基本块B中杀死的可用表达式(即修改了表达式中变量的值)
  • 交汇运算:对于有多个前驱的基本块B,IN[B]是所有前驱基本块的OUT集合的交集

实用案例分析

局部公共子表达式消除示例

原始代码

x = (a + b) * c;
y = (a + b) * d;
z = x + y;

优化过程

  1. 构建DAG

    • 节点1:a + b
    • 节点2:节点1 * c (x)
    • 节点3:节点1 * d (y)
    • 节点4:节点2 + 节点3 (z)
  2. 生成优化后代码

t1 = a + b;
x = t1 * c;
y = t1 * d;
z = x + y;

全局公共子表达式消除示例

原始代码

if (condition) {
    x = a + b * c;
    // 其他操作,不修改a、b、c
} else {
    y = d + e;
    // 其他操作,不修改a、b、c
}
z = a + b * c + f;

优化过程

  1. 可用表达式分析

    • 在if分支中计算了a + b * c
    • 在else分支中没有修改a、b、c
    • 因此,在if和else分支之后,a + b * c都是可用的
  2. 生成优化后代码

t1 = a + b * c;
if (condition) {
    x = t1;
    // 其他操作
} else {
    y = d + e;
    // 其他操作
}
z = t1 + f;

代码实现

局部公共子表达式消除实现

class LocalCSE:
    def __init__(self):
        self.temp_count = 0
    
    def new_temp(self):
        """生成新的临时变量"""
        self.temp_count += 1
        return f"t{self.temp_count}"
    
    def eliminate(self, basic_block):
        """对基本块执行局部公共子表达式消除"""
        # 表达式到临时变量的映射
        expr_map = {}
        optimized_code = []
        
        for instr in basic_block:
            if '=' in instr:
                left, right = instr.split('=', 1)
                left = left.strip()
                right = right.strip()
                
                # 检查是否为表达式
                if '+' in right or '-' in right or '*' in right or '/' in right:
                    # 检查是否是公共子表达式
                    if right in expr_map:
                        # 使用已有的临时变量
                        optimized_code.append(f"{left} = {expr_map[right]}")
                    else:
                        # 生成新的临时变量
                        temp = self.new_temp()
                        expr_map[right] = temp
                        optimized_code.append(f"{temp} = {right}")
                        optimized_code.append(f"{left} = {temp}")
                else:
                    # 简单赋值,直接添加
                    optimized_code.append(instr)
            else:
                # 非赋值语句,直接添加
                optimized_code.append(instr)
        
        return optimized_code

# 测试示例
basic_block = [
    "x = a + b * c",
    "y = d + a + b * c",
    "z = x + y"
]

cse = LocalCSE()
optimized = cse.eliminate(basic_block)
print("原始代码:")
for line in basic_block:
    print(f"  {line}")
print("\n优化后代码:")
for line in optimized:
    print(f"  {line}")

可用表达式分析实现

class AvailableExpressions:
    def __init__(self, cfg):
        self.cfg = cfg  # 控制流图
        self.expressions = set()  # 程序中所有的表达式
        self.gen = {}  # 每个基本块生成的表达式
        self.kill = {}  # 每个基本块杀死的表达式
        self.in_set = {}  # 进入基本块时的可用表达式
        self.out_set = {}  # 离开基本块时的可用表达式
        self._initialize()
    
    def _extract_expressions(self, block):
        """从基本块中提取表达式"""
        exprs = set()
        vars = set()
        
        for instr in block:
            if '=' in instr:
                left, right = instr.split('=', 1)
                left = left.strip()
                right = right.strip()
                
                # 记录被修改的变量
                vars.add(left)
                
                # 提取表达式
                if any(op in right for op in ['+', '-', '*', '/']):
                    exprs.add(right)
        
        return exprs, vars
    
    def _initialize(self):
        """初始化gen和kill集合"""
        all_exprs = set()
        all_vars = set()
        
        # 首先收集所有表达式和变量
        for block_name, block in self.cfg.items():
            exprs, vars = self._extract_expressions(block)
            all_exprs.update(exprs)
            all_vars.update(vars)
        
        self.expressions = all_exprs
        
        # 初始化gen和kill
        for block_name, block in self.cfg.items():
            block_exprs, block_vars = self._extract_expressions(block)
            self.gen[block_name] = block_exprs
            
            # kill集合包含所有引用了被修改变量的表达式
            kill_set = set()
            for expr in all_exprs:
                if any(var in expr for var in block_vars):
                    kill_set.add(expr)
            self.kill[block_name] = kill_set
            
            # 初始化in和out集合
            self.in_set[block_name] = set()
            self.out_set[block_name] = set()
    
    def analyze(self):
        """执行可用表达式分析"""
        changed = True
        
        while changed:
            changed = False
            
            for block_name in self.cfg:
                # 计算新的out集合
                new_out = self.gen[block_name].union(self.in_set[block_name] - self.kill[block_name])
                
                if new_out != self.out_set[block_name]:
                    self.out_set[block_name] = new_out
                    changed = True
                
                # 计算新的in集合(处理前驱)
                predecessors = self._get_predecessors(block_name)
                if not predecessors:
                    # 入口块
                    new_in = set()
                else:
                    # 交汇运算:所有前驱的out集合的交集
                    new_in = self.out_set[predecessors[0]].copy()
                    for pred in predecessors[1:]:
                        new_in.intersection_update(self.out_set[pred])
                
                if new_in != self.in_set[block_name]:
                    self.in_set[block_name] = new_in
                    changed = True
    
    def _get_predecessors(self, block_name):
        """获取基本块的前驱"""
        # 简化实现,实际需要根据CFG结构确定
        # 这里假设每个基本块的前驱是前一个编号的块
        predecessors = []
        block_num = int(block_name[1:])  # 假设块名格式为B1, B2等
        if block_num > 1:
            predecessors.append(f"B{block_num - 1}")
        return predecessors

# 测试示例
cfg = {
    "B1": ["x = a + b", "y = c * d"],
    "B2": ["z = a + b", "w = y + z"],
    "B3": ["a = 5", "v = c * d"]
}

ae = AvailableExpressions(cfg)
ae.analyze()

print("可用表达式分析结果:")
for block_name in cfg:
    print(f"\n{block_name}:")
    print(f"  IN: {ae.in_set[block_name]}")
    print(f"  OUT: {ae.out_set[block_name]}")
    print(f"  GEN: {ae.gen[block_name]}")
    print(f"  KILL: {ae.kill[block_name]}")

全局公共子表达式消除实现

class GlobalCSE:
    def __init__(self, cfg):
        self.cfg = cfg
        self.ae = AvailableExpressions(cfg)
        self.ae.analyze()
        self.temp_count = 0
    
    def new_temp(self):
        """生成新的临时变量"""
        self.temp_count += 1
        return f"t{self.temp_count}"
    
    def eliminate(self):
        """执行全局公共子表达式消除"""
        optimized_cfg = {}
        expr_to_temp = {}  # 表达式到临时变量的映射
        
        for block_name, block in self.cfg.items():
            optimized_block = []
            available = self.ae.in_set[block_name].copy()
            
            for instr in block:
                if '=' in instr:
                    left, right = instr.split('=', 1)
                    left = left.strip()
                    right = right.strip()
                    
                    # 检查是否为表达式
                    if any(op in right for op in ['+', '-', '*', '/']):
                        # 检查是否是公共子表达式且可用
                        if right in available:
                            # 使用已有的临时变量
                            optimized_block.append(f"{left} = {expr_to_temp[right]}")
                        else:
                            # 生成新的临时变量
                            temp = self.new_temp()
                            expr_to_temp[right] = temp
                            optimized_block.append(f"{temp} = {right}")
                            optimized_block.append(f"{left} = {temp}")
                            # 将新表达式添加到可用集合
                            available.add(right)
                    else:
                        # 简单赋值,直接添加
                        optimized_block.append(instr)
                        # 如果是变量赋值,更新可用集合
                        if left in right:
                            # 杀死所有引用了left的表达式
                            to_remove = set()
                            for expr in available:
                                if left in expr:
                                    to_remove.add(expr)
                            available -= to_remove
                else:
                    # 非赋值语句,直接添加
                    optimized_block.append(instr)
            
            optimized_cfg[block_name] = optimized_block
        
        return optimized_cfg

# 测试示例
cfg = {
    "B1": ["x = a + b", "y = c * d"],
    "B2": ["z = a + b", "w = y + z"],
    "B3": ["a = 5", "v = c * d"]
}

global_cse = GlobalCSE(cfg)
optimized_cfg = global_cse.eliminate()

print("全局公共子表达式消除结果:")
for block_name, block in optimized_cfg.items():
    print(f"\n{block_name}:")
    for line in block:
        print(f"  {line}")

技术要点总结

  1. 公共子表达式消除是一种重要的代码优化技术,可以减少程序中的重复计算,提高执行效率。

  2. 局部CSE通过构建DAG来识别和消除单个基本块内的公共子表达式,实现简单且效果明显。

  3. 全局CSE需要使用可用表达式分析来识别跨基本块的公共子表达式,实现较为复杂但效果更全面。

  4. 可用表达式分析是一种数据流分析,通过计算每个基本块的IN和OUT集合,确定在程序每个点上哪些表达式是可用的。

  5. 实现考虑

    • 需要准确识别表达式,包括处理运算符优先级和括号
    • 需要高效的哈希和比较机制来识别相同的表达式
    • 需要合理选择插入临时变量的位置,避免引入新的依赖关系
  6. 优化效果

    • 减少计算次数,提高程序执行速度
    • 可能会增加代码大小(由于引入临时变量),但通常这种权衡是值得的
    • 对于计算密集型程序,效果尤为明显

通过公共子表达式消除技术,编译器可以自动识别和优化程序中的重复计算,使程序运行得更快、更高效。这是代码优化中一项基础而重要的技术,也是现代编译器必备的优化手段之一。

« 上一篇 死代码消除 下一篇 » 复写传播