第177集:公共子表达式消除
核心知识点讲解
什么是公共子表达式?
公共子表达式(Common Subexpression)是指在程序中多次出现的、计算结果相同的表达式。例如:
x = a + b * c;
y = d + a + b * c;这里的 a + b * c 就是一个公共子表达式,它在两个不同的赋值语句中出现。如果我们能够识别并消除这种重复计算,就可以提高程序的执行效率。
公共子表达式消除的类型
- 局部公共子表达式消除(Local CSE):在单个基本块内识别和消除公共子表达式
- 全局公共子表达式消除(Global CSE):跨多个基本块识别和消除公共子表达式
局部公共子表达式消除
局部CSE的实现相对简单,主要通过以下步骤:
- 构建基本块的DAG(有向无环图):将基本块中的表达式表示为DAG节点
- 识别公共子表达式:在DAG中,相同的子表达式会被表示为同一个节点
- 重构代码:从DAG生成优化后的代码,确保每个公共子表达式只计算一次
全局公共子表达式消除
全局CSE需要考虑整个程序的控制流,使用数据流分析技术来识别在不同基本块中出现的公共子表达式。主要涉及:
- 可用表达式分析:确定在程序的每个点上哪些表达式的值是可用的(即已经计算过且值未改变)
- 表达式哈希:为每个表达式计算哈希值,以便快速识别相同的表达式
- 插入计算点:在适当的位置插入公共子表达式的计算,并在其他位置引用该计算结果
可用表达式分析
可用表达式分析是一种数据流分析,用于确定在程序的每个点上哪些表达式是可用的。它基于以下两个集合:
- **IN[B]**:进入基本块B时可用的表达式集合
- **OUT[B]**:离开基本块B时可用的表达式集合
分析规则如下:
- 边界条件:对于程序的第一个基本块,IN[B]为空集
- 传递函数:OUT[B] = gen[B] ∪ (IN[B] - kill[B]),其中:
- gen[B]:基本块B中生成的新可用表达式
- kill[B]:基本块B中杀死的可用表达式(即修改了表达式中变量的值)
- 交汇运算:对于有多个前驱的基本块B,IN[B]是所有前驱基本块的OUT集合的交集
实用案例分析
局部公共子表达式消除示例
原始代码
x = (a + b) * c;
y = (a + b) * d;
z = x + y;优化过程
构建DAG:
- 节点1:a + b
- 节点2:节点1 * c (x)
- 节点3:节点1 * d (y)
- 节点4:节点2 + 节点3 (z)
生成优化后代码:
t1 = a + b;
x = t1 * c;
y = t1 * d;
z = x + y;全局公共子表达式消除示例
原始代码
if (condition) {
x = a + b * c;
// 其他操作,不修改a、b、c
} else {
y = d + e;
// 其他操作,不修改a、b、c
}
z = a + b * c + f;优化过程
可用表达式分析:
- 在if分支中计算了
a + b * c - 在else分支中没有修改a、b、c
- 因此,在if和else分支之后,
a + b * c都是可用的
- 在if分支中计算了
生成优化后代码:
t1 = a + b * c;
if (condition) {
x = t1;
// 其他操作
} else {
y = d + e;
// 其他操作
}
z = t1 + f;代码实现
局部公共子表达式消除实现
class LocalCSE:
def __init__(self):
self.temp_count = 0
def new_temp(self):
"""生成新的临时变量"""
self.temp_count += 1
return f"t{self.temp_count}"
def eliminate(self, basic_block):
"""对基本块执行局部公共子表达式消除"""
# 表达式到临时变量的映射
expr_map = {}
optimized_code = []
for instr in basic_block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 检查是否为表达式
if '+' in right or '-' in right or '*' in right or '/' in right:
# 检查是否是公共子表达式
if right in expr_map:
# 使用已有的临时变量
optimized_code.append(f"{left} = {expr_map[right]}")
else:
# 生成新的临时变量
temp = self.new_temp()
expr_map[right] = temp
optimized_code.append(f"{temp} = {right}")
optimized_code.append(f"{left} = {temp}")
else:
# 简单赋值,直接添加
optimized_code.append(instr)
else:
# 非赋值语句,直接添加
optimized_code.append(instr)
return optimized_code
# 测试示例
basic_block = [
"x = a + b * c",
"y = d + a + b * c",
"z = x + y"
]
cse = LocalCSE()
optimized = cse.eliminate(basic_block)
print("原始代码:")
for line in basic_block:
print(f" {line}")
print("\n优化后代码:")
for line in optimized:
print(f" {line}")可用表达式分析实现
class AvailableExpressions:
def __init__(self, cfg):
self.cfg = cfg # 控制流图
self.expressions = set() # 程序中所有的表达式
self.gen = {} # 每个基本块生成的表达式
self.kill = {} # 每个基本块杀死的表达式
self.in_set = {} # 进入基本块时的可用表达式
self.out_set = {} # 离开基本块时的可用表达式
self._initialize()
def _extract_expressions(self, block):
"""从基本块中提取表达式"""
exprs = set()
vars = set()
for instr in block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 记录被修改的变量
vars.add(left)
# 提取表达式
if any(op in right for op in ['+', '-', '*', '/']):
exprs.add(right)
return exprs, vars
def _initialize(self):
"""初始化gen和kill集合"""
all_exprs = set()
all_vars = set()
# 首先收集所有表达式和变量
for block_name, block in self.cfg.items():
exprs, vars = self._extract_expressions(block)
all_exprs.update(exprs)
all_vars.update(vars)
self.expressions = all_exprs
# 初始化gen和kill
for block_name, block in self.cfg.items():
block_exprs, block_vars = self._extract_expressions(block)
self.gen[block_name] = block_exprs
# kill集合包含所有引用了被修改变量的表达式
kill_set = set()
for expr in all_exprs:
if any(var in expr for var in block_vars):
kill_set.add(expr)
self.kill[block_name] = kill_set
# 初始化in和out集合
self.in_set[block_name] = set()
self.out_set[block_name] = set()
def analyze(self):
"""执行可用表达式分析"""
changed = True
while changed:
changed = False
for block_name in self.cfg:
# 计算新的out集合
new_out = self.gen[block_name].union(self.in_set[block_name] - self.kill[block_name])
if new_out != self.out_set[block_name]:
self.out_set[block_name] = new_out
changed = True
# 计算新的in集合(处理前驱)
predecessors = self._get_predecessors(block_name)
if not predecessors:
# 入口块
new_in = set()
else:
# 交汇运算:所有前驱的out集合的交集
new_in = self.out_set[predecessors[0]].copy()
for pred in predecessors[1:]:
new_in.intersection_update(self.out_set[pred])
if new_in != self.in_set[block_name]:
self.in_set[block_name] = new_in
changed = True
def _get_predecessors(self, block_name):
"""获取基本块的前驱"""
# 简化实现,实际需要根据CFG结构确定
# 这里假设每个基本块的前驱是前一个编号的块
predecessors = []
block_num = int(block_name[1:]) # 假设块名格式为B1, B2等
if block_num > 1:
predecessors.append(f"B{block_num - 1}")
return predecessors
# 测试示例
cfg = {
"B1": ["x = a + b", "y = c * d"],
"B2": ["z = a + b", "w = y + z"],
"B3": ["a = 5", "v = c * d"]
}
ae = AvailableExpressions(cfg)
ae.analyze()
print("可用表达式分析结果:")
for block_name in cfg:
print(f"\n{block_name}:")
print(f" IN: {ae.in_set[block_name]}")
print(f" OUT: {ae.out_set[block_name]}")
print(f" GEN: {ae.gen[block_name]}")
print(f" KILL: {ae.kill[block_name]}")全局公共子表达式消除实现
class GlobalCSE:
def __init__(self, cfg):
self.cfg = cfg
self.ae = AvailableExpressions(cfg)
self.ae.analyze()
self.temp_count = 0
def new_temp(self):
"""生成新的临时变量"""
self.temp_count += 1
return f"t{self.temp_count}"
def eliminate(self):
"""执行全局公共子表达式消除"""
optimized_cfg = {}
expr_to_temp = {} # 表达式到临时变量的映射
for block_name, block in self.cfg.items():
optimized_block = []
available = self.ae.in_set[block_name].copy()
for instr in block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 检查是否为表达式
if any(op in right for op in ['+', '-', '*', '/']):
# 检查是否是公共子表达式且可用
if right in available:
# 使用已有的临时变量
optimized_block.append(f"{left} = {expr_to_temp[right]}")
else:
# 生成新的临时变量
temp = self.new_temp()
expr_to_temp[right] = temp
optimized_block.append(f"{temp} = {right}")
optimized_block.append(f"{left} = {temp}")
# 将新表达式添加到可用集合
available.add(right)
else:
# 简单赋值,直接添加
optimized_block.append(instr)
# 如果是变量赋值,更新可用集合
if left in right:
# 杀死所有引用了left的表达式
to_remove = set()
for expr in available:
if left in expr:
to_remove.add(expr)
available -= to_remove
else:
# 非赋值语句,直接添加
optimized_block.append(instr)
optimized_cfg[block_name] = optimized_block
return optimized_cfg
# 测试示例
cfg = {
"B1": ["x = a + b", "y = c * d"],
"B2": ["z = a + b", "w = y + z"],
"B3": ["a = 5", "v = c * d"]
}
global_cse = GlobalCSE(cfg)
optimized_cfg = global_cse.eliminate()
print("全局公共子表达式消除结果:")
for block_name, block in optimized_cfg.items():
print(f"\n{block_name}:")
for line in block:
print(f" {line}")技术要点总结
公共子表达式消除是一种重要的代码优化技术,可以减少程序中的重复计算,提高执行效率。
局部CSE通过构建DAG来识别和消除单个基本块内的公共子表达式,实现简单且效果明显。
全局CSE需要使用可用表达式分析来识别跨基本块的公共子表达式,实现较为复杂但效果更全面。
可用表达式分析是一种数据流分析,通过计算每个基本块的IN和OUT集合,确定在程序每个点上哪些表达式是可用的。
实现考虑:
- 需要准确识别表达式,包括处理运算符优先级和括号
- 需要高效的哈希和比较机制来识别相同的表达式
- 需要合理选择插入临时变量的位置,避免引入新的依赖关系
优化效果:
- 减少计算次数,提高程序执行速度
- 可能会增加代码大小(由于引入临时变量),但通常这种权衡是值得的
- 对于计算密集型程序,效果尤为明显
通过公共子表达式消除技术,编译器可以自动识别和优化程序中的重复计算,使程序运行得更快、更高效。这是代码优化中一项基础而重要的技术,也是现代编译器必备的优化手段之一。