第178集:复写传播

核心知识点讲解

什么是复写传播?

复写传播(Copy Propagation)是一种代码优化技术,用于消除程序中的临时变量引用,将变量的使用替换为其值的直接引用。例如:

t = x;
y = t + 1;

通过复写传播,可以优化为:

y = x + 1;

这里,我们将对变量t的引用替换为其值x,因为t只是x的一个副本,没有其他用途。

复写传播的类型

  1. 局部复写传播:在单个基本块内进行的复写传播
  2. 全局复写传播:跨多个基本块进行的复写传播

复写传播的目的

  1. 消除不必要的变量:减少程序中的临时变量数量
  2. 为其他优化创造机会:如死代码消除、公共子表达式消除等
  3. 提高程序执行效率:减少内存访问和寄存器使用
  4. 简化代码:使代码更易于理解和分析

复写传播的工作原理

复写传播基于以下观察:当程序中存在形如t = x的赋值语句,且t在后续代码中没有被修改时,所有对t的引用都可以替换为对x的引用。

复写传播的实现通常包括以下步骤:

  1. 识别复写语句:找出程序中的复写语句(如t = x
  2. 跟踪复写关系:记录变量之间的复写关系
  3. 传播复写:将变量的引用替换为其值的引用
  4. 清理无用变量:删除不再使用的变量定义

复写传播的数据流分析

全局复写传播需要使用数据流分析来跟踪变量之间的复写关系。主要涉及:

  1. 到达定值分析:确定在程序的每个点上,哪些变量的定值是到达的
  2. 复写关系分析:跟踪变量之间的复写关系
  3. 活跃变量分析:确定哪些变量在后续代码中仍然被使用

实用案例分析

局部复写传播示例

原始代码

t1 = a + b;
t2 = t1;
t3 = t2 * c;
y = t3;

优化过程

  1. 识别复写语句t2 = t1
  2. 传播复写:将t3 = t2 * c中的t2替换为t1,得到t3 = t1 * c
  3. 进一步优化:识别t3 = t1 * cy = t3,将y = t3替换为y = t1 * c
  4. 清理无用变量:删除t2t3的定义

优化后代码

t1 = a + b;
y = t1 * c;

全局复写传播示例

原始代码

if (condition) {
    t = x;
    // 其他操作,不修改t和x
    y = t + 1;
} else {
    t = x;
    // 其他操作,不修改t和x
    y = t + 2;
}

优化过程

  1. 识别复写语句:两个分支中的t = x
  2. 传播复写:在两个分支中,将对t的引用替换为对x的引用
  3. 清理无用变量:删除两个分支中的t = x语句

优化后代码

if (condition) {
    // 其他操作,不修改x
    y = x + 1;
} else {
    // 其他操作,不修改x
    y = x + 2;
}

复写传播与其他优化的结合

与死代码消除结合

// 原始代码
x = 5;
t = x;
y = t + 1;

// 复写传播后
x = 5;
y = x + 1;

// 死代码消除后(如果x不再被使用)
y = 5 + 1;
y = 6;

与公共子表达式消除结合

// 原始代码
a = b + c;
t = a;
d = t + e;
e = b + c;
f = e + g;

// 公共子表达式消除后
a = b + c;
t = a;
d = t + e;
e = a;
f = e + g;

// 复写传播后
a = b + c;
d = a + e;
f = a + g;

代码实现

局部复写传播实现

class LocalCopyPropagation:
    def __init__(self):
        self.copy_map = {}  # 变量到其值的映射
    
    def propagate(self, basic_block):
        """对基本块执行局部复写传播"""
        optimized_code = []
        copy_map = {}
        used_vars = set()
        
        # 第一遍:收集所有使用的变量
        for instr in basic_block:
            if '=' in instr:
                left, right = instr.split('=', 1)
                left = left.strip()
                right = right.strip()
                
                # 分析右侧使用的变量
                for var in right.split():
                    var = var.strip('+*-/()')
                    if var.isidentifier():
                        used_vars.add(var)
            else:
                # 分析非赋值语句中的变量使用
                for var in instr.split():
                    var = var.strip('+*-/()')
                    if var.isidentifier():
                        used_vars.add(var)
        
        # 第二遍:执行复写传播
        for instr in basic_block:
            if '=' in instr:
                left, right = instr.split('=', 1)
                left = left.strip()
                right = right.strip()
                
                # 检查是否是复写语句(形如 t = x)
                if right.isidentifier():
                    # 记录复写关系
                    copy_map[left] = right
                    # 检查这个复写是否有用
                    if left not in used_vars:
                        # 如果变量没有被使用,跳过这个赋值
                        continue
                else:
                    # 不是复写语句,尝试传播复写
                    new_right = right
                    for var, val in copy_map.items():
                        # 替换变量引用
                        new_right = new_right.replace(var, val)
                    
                    # 检查右侧是否发生了变化
                    if new_right != right:
                        # 添加优化后的语句
                        optimized_code.append(f"{left} = {new_right}")
                    else:
                        # 添加原始语句
                        optimized_code.append(instr)
                    
                    # 如果左侧变量在copy_map中,删除该映射
                    if left in copy_map:
                        del copy_map[left]
            else:
                # 非赋值语句,尝试传播复写
                new_instr = instr
                for var, val in copy_map.items():
                    # 替换变量引用
                    new_instr = new_instr.replace(var, val)
                
                # 添加优化后的语句
                optimized_code.append(new_instr)
        
        return optimized_code

# 测试示例
basic_block = [
    "t1 = a + b",
    "t2 = t1",
    "t3 = t2 * c",
    "y = t3",
    "z = t1 + t3"
]

lcp = LocalCopyPropagation()
optimized = lcp.propagate(basic_block)
print("原始代码:")
for line in basic_block:
    print(f"  {line}")
print("\n优化后代码:")
for line in optimized:
    print(f"  {line}")

全局复写传播实现

class GlobalCopyPropagation:
    def __init__(self, cfg):
        self.cfg = cfg  # 控制流图
        self.copy_map = {}  # 每个基本块的复写映射
        self.in_copy = {}  # 进入基本块时的复写映射
        self.out_copy = {}  # 离开基本块时的复写映射
        self._initialize()
    
    def _initialize(self):
        """初始化复写映射"""
        for block_name in self.cfg:
            self.copy_map[block_name] = {}
            self.in_copy[block_name] = {}
            self.out_copy[block_name] = {}
    
    def _extract_copies(self, block):
        """从基本块中提取复写语句"""
        copies = {}
        kill = set()
        
        for instr in block:
            if '=' in instr:
                left, right = instr.split('=', 1)
                left = left.strip()
                right = right.strip()
                
                # 记录被修改的变量
                kill.add(left)
                
                # 检查是否是复写语句
                if right.isidentifier():
                    copies[left] = right
        
        return copies, kill
    
    def analyze(self):
        """执行全局复写传播分析"""
        changed = True
        
        while changed:
            changed = False
            
            for block_name in self.cfg:
                # 提取当前块的复写语句和杀死的变量
                block_copies, block_kill = self._extract_copies(self.cfg[block_name])
                
                # 计算新的in_copy:合并所有前驱的out_copy
                new_in_copy = {}
                predecessors = self._get_predecessors(block_name)
                
                if predecessors:
                    # 初始化新的in_copy为第一个前驱的out_copy
                    new_in_copy.update(self.out_copy[predecessors[0]])
                    # 与其他前驱的out_copy取交集
                    for pred in predecessors[1:]:
                        pred_out = self.out_copy[pred]
                        # 只保留在所有前驱中都存在的复写关系
                        common_vars = set(new_in_copy.keys()) & set(pred_out.keys())
                        temp_copy = {}
                        for var in common_vars:
                            if new_in_copy[var] == pred_out[var]:
                                temp_copy[var] = new_in_copy[var]
                        new_in_copy = temp_copy
                
                # 应用当前块的kill操作
                for var in block_kill:
                    if var in new_in_copy:
                        del new_in_copy[var]
                
                # 应用当前块的复写语句
                new_out_copy = new_in_copy.copy()
                new_out_copy.update(block_copies)
                
                # 检查是否发生变化
                if new_in_copy != self.in_copy[block_name]:
                    self.in_copy[block_name] = new_in_copy
                    changed = True
                
                if new_out_copy != self.out_copy[block_name]:
                    self.out_copy[block_name] = new_out_copy
                    changed = True
    
    def _get_predecessors(self, block_name):
        """获取基本块的前驱"""
        # 简化实现,实际需要根据CFG结构确定
        # 这里假设每个基本块的前驱是前一个编号的块
        predecessors = []
        block_num = int(block_name[1:])  # 假设块名格式为B1, B2等
        if block_num > 1:
            predecessors.append(f"B{block_num - 1}")
        return predecessors
    
    def propagate(self):
        """执行全局复写传播"""
        # 首先执行分析
        self.analyze()
        
        optimized_cfg = {}
        
        for block_name, block in self.cfg.items():
            optimized_block = []
            current_copy = self.in_copy[block_name].copy()
            
            for instr in block:
                if '=' in instr:
                    left, right = instr.split('=', 1)
                    left = left.strip()
                    right = right.strip()
                    
                    # 尝试传播复写
                    new_right = right
                    for var, val in current_copy.items():
                        new_right = new_right.replace(var, val)
                    
                    # 检查是否是复写语句
                    if new_right.isidentifier():
                        # 更新当前复写映射
                        current_copy[left] = new_right
                        # 检查这个复写是否有用(简化实现,实际需要更复杂的分析)
                        optimized_block.append(f"{left} = {new_right}")
                    else:
                        # 添加优化后的语句
                        optimized_block.append(f"{left} = {new_right}")
                        # 如果左侧变量在current_copy中,删除该映射
                        if left in current_copy:
                            del current_copy[left]
                else:
                    # 非赋值语句,尝试传播复写
                    new_instr = instr
                    for var, val in current_copy.items():
                        new_instr = new_instr.replace(var, val)
                    
                    # 添加优化后的语句
                    optimized_block.append(new_instr)
            
            optimized_cfg[block_name] = optimized_block
        
        return optimized_cfg

# 测试示例
cfg = {
    "B1": ["t = x", "y = t + 1"],
    "B2": ["t = x", "y = t + 2"],
    "B3": ["z = y", "w = t + z"]
}

gcp = GlobalCopyPropagation(cfg)
optimized_cfg = gcp.propagate()

print("全局复写传播结果:")
for block_name, block in optimized_cfg.items():
    print(f"\n{block_name}:")
    for line in block:
        print(f"  {line}")

复写传播与死代码消除结合

class CopyPropagationWithDCE:
    def __init__(self):
        self.lcp = LocalCopyPropagation()
    
    def _find_used_vars(self, code):
        """找出代码中使用的变量"""
        used_vars = set()
        
        for instr in code:
            if '=' in instr:
                left, right = instr.split('=', 1)
                right = right.strip()
                
                # 分析右侧使用的变量
                for var in right.split():
                    var = var.strip('+*-/()')
                    if var.isidentifier():
                        used_vars.add(var)
            else:
                # 分析非赋值语句中的变量使用
                for var in instr.split():
                    var = var.strip('+*-/()')
                    if var.isidentifier():
                        used_vars.add(var)
        
        return used_vars
    
    def _eliminate_dead_code(self, code):
        """消除死代码"""
        used_vars = self._find_used_vars(code)
        optimized_code = []
        
        for instr in code:
            if '=' in instr:
                left, _ = instr.split('=', 1)
                left = left.strip()
                
                # 检查变量是否被使用
                if left in used_vars:
                    optimized_code.append(instr)
                else:
                    # 变量未被使用,删除这个赋值
                    pass
            else:
                optimized_code.append(instr)
        
        return optimized_code
    
    def optimize(self, code):
        """执行复写传播和死代码消除"""
        # 首先执行复写传播
        code_after_cp = self.lcp.propagate(code)
        # 然后执行死代码消除
        code_after_dce = self._eliminate_dead_code(code_after_cp)
        # 再次执行复写传播(可能会有新的机会)
        final_code = self.lcp.propagate(code_after_dce)
        
        return final_code

# 测试示例
code = [
    "x = 5",
    "t = x",
    "y = t + 1",
    "z = y",
    "w = z * 2"
]

cpdce = CopyPropagationWithDCE()
optimized = cpdce.optimize(code)
print("原始代码:")
for line in code:
    print(f"  {line}")
print("\n优化后代码:")
for line in optimized:
    print(f"  {line}")

技术要点总结

  1. 复写传播是一种重要的代码优化技术,可以消除程序中的临时变量引用,提高程序执行效率。

  2. 复写传播的核心是识别形如t = x的复写语句,并将对t的引用替换为对x的引用。

  3. 复写传播的类型包括局部复写传播和全局复写传播,后者需要使用数据流分析技术。

  4. 复写传播与其他优化的结合

    • 与死代码消除结合:可以删除不再使用的临时变量定义
    • 与公共子表达式消除结合:可以创造更多的公共子表达式消除机会
    • 与常量折叠结合:可以将复写的常量值直接传播到使用点
  5. 实现考虑

    • 需要准确识别复写语句
    • 需要跟踪变量之间的复写关系
    • 需要处理变量被修改的情况
    • 需要考虑控制流的影响(全局复写传播)
  6. 优化效果

    • 减少程序中的临时变量数量
    • 提高程序执行速度
    • 为其他优化创造机会
    • 简化代码结构
  7. 局限性

    • 复写传播可能会增加代码大小(如果传播的表达式较长)
    • 全局复写传播的实现较为复杂
    • 对于复杂的程序,复写传播的效果可能有限

通过复写传播技术,编译器可以自动识别和优化程序中的临时变量引用,使程序运行得更快、更高效。这是代码优化中一项基础而重要的技术,也是现代编译器必备的优化手段之一。

« 上一篇 公共子表达式消除 下一篇 » 循环优化概述