第178集:复写传播
核心知识点讲解
什么是复写传播?
复写传播(Copy Propagation)是一种代码优化技术,用于消除程序中的临时变量引用,将变量的使用替换为其值的直接引用。例如:
t = x;
y = t + 1;通过复写传播,可以优化为:
y = x + 1;这里,我们将对变量t的引用替换为其值x,因为t只是x的一个副本,没有其他用途。
复写传播的类型
- 局部复写传播:在单个基本块内进行的复写传播
- 全局复写传播:跨多个基本块进行的复写传播
复写传播的目的
- 消除不必要的变量:减少程序中的临时变量数量
- 为其他优化创造机会:如死代码消除、公共子表达式消除等
- 提高程序执行效率:减少内存访问和寄存器使用
- 简化代码:使代码更易于理解和分析
复写传播的工作原理
复写传播基于以下观察:当程序中存在形如t = x的赋值语句,且t在后续代码中没有被修改时,所有对t的引用都可以替换为对x的引用。
复写传播的实现通常包括以下步骤:
- 识别复写语句:找出程序中的复写语句(如
t = x) - 跟踪复写关系:记录变量之间的复写关系
- 传播复写:将变量的引用替换为其值的引用
- 清理无用变量:删除不再使用的变量定义
复写传播的数据流分析
全局复写传播需要使用数据流分析来跟踪变量之间的复写关系。主要涉及:
- 到达定值分析:确定在程序的每个点上,哪些变量的定值是到达的
- 复写关系分析:跟踪变量之间的复写关系
- 活跃变量分析:确定哪些变量在后续代码中仍然被使用
实用案例分析
局部复写传播示例
原始代码
t1 = a + b;
t2 = t1;
t3 = t2 * c;
y = t3;优化过程
- 识别复写语句:
t2 = t1 - 传播复写:将
t3 = t2 * c中的t2替换为t1,得到t3 = t1 * c - 进一步优化:识别
t3 = t1 * c和y = t3,将y = t3替换为y = t1 * c - 清理无用变量:删除
t2和t3的定义
优化后代码
t1 = a + b;
y = t1 * c;全局复写传播示例
原始代码
if (condition) {
t = x;
// 其他操作,不修改t和x
y = t + 1;
} else {
t = x;
// 其他操作,不修改t和x
y = t + 2;
}优化过程
- 识别复写语句:两个分支中的
t = x - 传播复写:在两个分支中,将对
t的引用替换为对x的引用 - 清理无用变量:删除两个分支中的
t = x语句
优化后代码
if (condition) {
// 其他操作,不修改x
y = x + 1;
} else {
// 其他操作,不修改x
y = x + 2;
}复写传播与其他优化的结合
与死代码消除结合
// 原始代码
x = 5;
t = x;
y = t + 1;
// 复写传播后
x = 5;
y = x + 1;
// 死代码消除后(如果x不再被使用)
y = 5 + 1;
y = 6;与公共子表达式消除结合
// 原始代码
a = b + c;
t = a;
d = t + e;
e = b + c;
f = e + g;
// 公共子表达式消除后
a = b + c;
t = a;
d = t + e;
e = a;
f = e + g;
// 复写传播后
a = b + c;
d = a + e;
f = a + g;代码实现
局部复写传播实现
class LocalCopyPropagation:
def __init__(self):
self.copy_map = {} # 变量到其值的映射
def propagate(self, basic_block):
"""对基本块执行局部复写传播"""
optimized_code = []
copy_map = {}
used_vars = set()
# 第一遍:收集所有使用的变量
for instr in basic_block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 分析右侧使用的变量
for var in right.split():
var = var.strip('+*-/()')
if var.isidentifier():
used_vars.add(var)
else:
# 分析非赋值语句中的变量使用
for var in instr.split():
var = var.strip('+*-/()')
if var.isidentifier():
used_vars.add(var)
# 第二遍:执行复写传播
for instr in basic_block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 检查是否是复写语句(形如 t = x)
if right.isidentifier():
# 记录复写关系
copy_map[left] = right
# 检查这个复写是否有用
if left not in used_vars:
# 如果变量没有被使用,跳过这个赋值
continue
else:
# 不是复写语句,尝试传播复写
new_right = right
for var, val in copy_map.items():
# 替换变量引用
new_right = new_right.replace(var, val)
# 检查右侧是否发生了变化
if new_right != right:
# 添加优化后的语句
optimized_code.append(f"{left} = {new_right}")
else:
# 添加原始语句
optimized_code.append(instr)
# 如果左侧变量在copy_map中,删除该映射
if left in copy_map:
del copy_map[left]
else:
# 非赋值语句,尝试传播复写
new_instr = instr
for var, val in copy_map.items():
# 替换变量引用
new_instr = new_instr.replace(var, val)
# 添加优化后的语句
optimized_code.append(new_instr)
return optimized_code
# 测试示例
basic_block = [
"t1 = a + b",
"t2 = t1",
"t3 = t2 * c",
"y = t3",
"z = t1 + t3"
]
lcp = LocalCopyPropagation()
optimized = lcp.propagate(basic_block)
print("原始代码:")
for line in basic_block:
print(f" {line}")
print("\n优化后代码:")
for line in optimized:
print(f" {line}")全局复写传播实现
class GlobalCopyPropagation:
def __init__(self, cfg):
self.cfg = cfg # 控制流图
self.copy_map = {} # 每个基本块的复写映射
self.in_copy = {} # 进入基本块时的复写映射
self.out_copy = {} # 离开基本块时的复写映射
self._initialize()
def _initialize(self):
"""初始化复写映射"""
for block_name in self.cfg:
self.copy_map[block_name] = {}
self.in_copy[block_name] = {}
self.out_copy[block_name] = {}
def _extract_copies(self, block):
"""从基本块中提取复写语句"""
copies = {}
kill = set()
for instr in block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 记录被修改的变量
kill.add(left)
# 检查是否是复写语句
if right.isidentifier():
copies[left] = right
return copies, kill
def analyze(self):
"""执行全局复写传播分析"""
changed = True
while changed:
changed = False
for block_name in self.cfg:
# 提取当前块的复写语句和杀死的变量
block_copies, block_kill = self._extract_copies(self.cfg[block_name])
# 计算新的in_copy:合并所有前驱的out_copy
new_in_copy = {}
predecessors = self._get_predecessors(block_name)
if predecessors:
# 初始化新的in_copy为第一个前驱的out_copy
new_in_copy.update(self.out_copy[predecessors[0]])
# 与其他前驱的out_copy取交集
for pred in predecessors[1:]:
pred_out = self.out_copy[pred]
# 只保留在所有前驱中都存在的复写关系
common_vars = set(new_in_copy.keys()) & set(pred_out.keys())
temp_copy = {}
for var in common_vars:
if new_in_copy[var] == pred_out[var]:
temp_copy[var] = new_in_copy[var]
new_in_copy = temp_copy
# 应用当前块的kill操作
for var in block_kill:
if var in new_in_copy:
del new_in_copy[var]
# 应用当前块的复写语句
new_out_copy = new_in_copy.copy()
new_out_copy.update(block_copies)
# 检查是否发生变化
if new_in_copy != self.in_copy[block_name]:
self.in_copy[block_name] = new_in_copy
changed = True
if new_out_copy != self.out_copy[block_name]:
self.out_copy[block_name] = new_out_copy
changed = True
def _get_predecessors(self, block_name):
"""获取基本块的前驱"""
# 简化实现,实际需要根据CFG结构确定
# 这里假设每个基本块的前驱是前一个编号的块
predecessors = []
block_num = int(block_name[1:]) # 假设块名格式为B1, B2等
if block_num > 1:
predecessors.append(f"B{block_num - 1}")
return predecessors
def propagate(self):
"""执行全局复写传播"""
# 首先执行分析
self.analyze()
optimized_cfg = {}
for block_name, block in self.cfg.items():
optimized_block = []
current_copy = self.in_copy[block_name].copy()
for instr in block:
if '=' in instr:
left, right = instr.split('=', 1)
left = left.strip()
right = right.strip()
# 尝试传播复写
new_right = right
for var, val in current_copy.items():
new_right = new_right.replace(var, val)
# 检查是否是复写语句
if new_right.isidentifier():
# 更新当前复写映射
current_copy[left] = new_right
# 检查这个复写是否有用(简化实现,实际需要更复杂的分析)
optimized_block.append(f"{left} = {new_right}")
else:
# 添加优化后的语句
optimized_block.append(f"{left} = {new_right}")
# 如果左侧变量在current_copy中,删除该映射
if left in current_copy:
del current_copy[left]
else:
# 非赋值语句,尝试传播复写
new_instr = instr
for var, val in current_copy.items():
new_instr = new_instr.replace(var, val)
# 添加优化后的语句
optimized_block.append(new_instr)
optimized_cfg[block_name] = optimized_block
return optimized_cfg
# 测试示例
cfg = {
"B1": ["t = x", "y = t + 1"],
"B2": ["t = x", "y = t + 2"],
"B3": ["z = y", "w = t + z"]
}
gcp = GlobalCopyPropagation(cfg)
optimized_cfg = gcp.propagate()
print("全局复写传播结果:")
for block_name, block in optimized_cfg.items():
print(f"\n{block_name}:")
for line in block:
print(f" {line}")复写传播与死代码消除结合
class CopyPropagationWithDCE:
def __init__(self):
self.lcp = LocalCopyPropagation()
def _find_used_vars(self, code):
"""找出代码中使用的变量"""
used_vars = set()
for instr in code:
if '=' in instr:
left, right = instr.split('=', 1)
right = right.strip()
# 分析右侧使用的变量
for var in right.split():
var = var.strip('+*-/()')
if var.isidentifier():
used_vars.add(var)
else:
# 分析非赋值语句中的变量使用
for var in instr.split():
var = var.strip('+*-/()')
if var.isidentifier():
used_vars.add(var)
return used_vars
def _eliminate_dead_code(self, code):
"""消除死代码"""
used_vars = self._find_used_vars(code)
optimized_code = []
for instr in code:
if '=' in instr:
left, _ = instr.split('=', 1)
left = left.strip()
# 检查变量是否被使用
if left in used_vars:
optimized_code.append(instr)
else:
# 变量未被使用,删除这个赋值
pass
else:
optimized_code.append(instr)
return optimized_code
def optimize(self, code):
"""执行复写传播和死代码消除"""
# 首先执行复写传播
code_after_cp = self.lcp.propagate(code)
# 然后执行死代码消除
code_after_dce = self._eliminate_dead_code(code_after_cp)
# 再次执行复写传播(可能会有新的机会)
final_code = self.lcp.propagate(code_after_dce)
return final_code
# 测试示例
code = [
"x = 5",
"t = x",
"y = t + 1",
"z = y",
"w = z * 2"
]
cpdce = CopyPropagationWithDCE()
optimized = cpdce.optimize(code)
print("原始代码:")
for line in code:
print(f" {line}")
print("\n优化后代码:")
for line in optimized:
print(f" {line}")技术要点总结
复写传播是一种重要的代码优化技术,可以消除程序中的临时变量引用,提高程序执行效率。
复写传播的核心是识别形如
t = x的复写语句,并将对t的引用替换为对x的引用。复写传播的类型包括局部复写传播和全局复写传播,后者需要使用数据流分析技术。
复写传播与其他优化的结合:
- 与死代码消除结合:可以删除不再使用的临时变量定义
- 与公共子表达式消除结合:可以创造更多的公共子表达式消除机会
- 与常量折叠结合:可以将复写的常量值直接传播到使用点
实现考虑:
- 需要准确识别复写语句
- 需要跟踪变量之间的复写关系
- 需要处理变量被修改的情况
- 需要考虑控制流的影响(全局复写传播)
优化效果:
- 减少程序中的临时变量数量
- 提高程序执行速度
- 为其他优化创造机会
- 简化代码结构
局限性:
- 复写传播可能会增加代码大小(如果传播的表达式较长)
- 全局复写传播的实现较为复杂
- 对于复杂的程序,复写传播的效果可能有限
通过复写传播技术,编译器可以自动识别和优化程序中的临时变量引用,使程序运行得更快、更高效。这是代码优化中一项基础而重要的技术,也是现代编译器必备的优化手段之一。