中间代码生成实战(一)—— 表达式

核心知识点讲解

1. AST 到 TAC 的转换原理

将 AST 转换为三地址码的过程本质上是对 AST 的遍历过程。对于表达式,我们通常采用后序遍历的方式:

  1. 首先处理表达式的所有子节点
  2. 然后根据子节点的结果生成当前节点的三地址码

这种方法确保了在处理父节点时,所有子表达式已经被处理完毕,并且它们的结果已经存储在临时变量中。

2. 临时变量管理

在生成三地址码时,我们需要大量使用临时变量来存储中间结果。临时变量管理的关键策略:

  1. 自动生成:为每个中间结果自动生成唯一的临时变量名(如 t1, t2, ...)
  2. 重用策略:当临时变量不再被使用时,可以考虑重用它们以减少临时变量的数量
  3. 作用域管理:确保临时变量在适当的作用域内使用,避免命名冲突

3. 表达式的三地址码生成规则

不同类型的表达式有不同的三地址码生成规则:

3.1 基本表达式

  • 常量:直接使用常量值,不需要生成三地址码
  • 变量:直接使用变量名,不需要生成三地址码

3.2 二元表达式

对于二元表达式 a op b,生成规则如下:

  1. 生成 a 的三地址码,结果存储在 t1
  2. 生成 b 的三地址码,结果存储在 t2
  3. 生成 t3 = t1 op t2 的三地址码

3.3 一元表达式

对于一元表达式 op a,生成规则如下:

  1. 生成 a 的三地址码,结果存储在 t1
  2. 生成 t2 = op t1 的三地址码

4. 三地址码的表示方法

三地址码可以用不同的数据结构来表示:

  1. 四元式(op, arg1, arg2, result)
  2. 三元式(op, arg1, arg2),结果通过序号引用
  3. 间接三元式:由三元式表和执行序列表组成

在实战中,四元式因其清晰明了而被广泛使用。

实用案例分析

案例:实现一个简单的表达式编译器

我们将实现一个简单的表达式编译器,支持基本的算术运算和括号。

步骤 1:定义 AST 节点类型

首先,我们需要定义 AST 节点的类型,用于表示不同类型的表达式。

class ASTNode:
    """AST 节点基类"""
    pass

class Constant(ASTNode):
    """常量节点"""
    def __init__(self, value):
        self.value = value

class Variable(ASTNode):
    """变量节点"""
    def __init__(self, name):
        self.name = name

class BinaryOp(ASTNode):
    """二元操作符节点"""
    def __init__(self, op, left, right):
        self.op = op
        self.left = left
        self.right = right

class UnaryOp(ASTNode):
    """一元操作符节点"""
    def __init__(self, op, expr):
        self.op = op
        self.expr = expr

步骤 2:实现表达式解析器

我们需要一个简单的解析器来构建 AST。这里使用递归下降解析器。

class Parser:
    """简单表达式解析器"""
    def __init__(self, tokens):
        self.tokens = tokens
        self.pos = 0
    
    def peek(self):
        """查看当前 token"""
        if self.pos < len(self.tokens):
            return self.tokens[self.pos]
        return None
    
    def consume(self, expected_type=None):
        """消费当前 token"""
        if self.pos < len(self.tokens):
            token = self.tokens[self.pos]
            if expected_type and token[0] != expected_type:
                raise SyntaxError(f"Expected {expected_type}, got {token[0]}")
            self.pos += 1
            return token
        raise SyntaxError("Unexpected end of input")
    
    def parse(self):
        """解析表达式"""
        return self.parse_expr()
    
    def parse_expr(self):
        """解析表达式(处理加减)"""
        left = self.parse_term()
        while self.peek() and self.peek()[0] in ('+', '-'):
            op = self.consume()[0]
            right = self.parse_term()
            left = BinaryOp(op, left, right)
        return left
    
    def parse_term(self):
        """解析项(处理乘除)"""
        left = self.parse_factor()
        while self.peek() and self.peek()[0] in ('*', '/'):
            op = self.consume()[0]
            right = self.parse_factor()
            left = BinaryOp(op, left, right)
        return left
    
    def parse_factor(self):
        """解析因子(处理常量、变量和括号)"""
        token = self.peek()
        if token[0] == '(':
            self.consume('(')
            expr = self.parse_expr()
            self.consume(')')
            return expr
        elif token[0] == 'NUMBER':
            return Constant(self.consume('NUMBER')[1])
        elif token[0] == 'ID':
            return Variable(self.consume('ID')[1])
        elif token[0] in ('+', '-'):
            op = self.consume()[0]
            expr = self.parse_factor()
            return UnaryOp(op, expr)
        else:
            raise SyntaxError(f"Unexpected token: {token}")

步骤 3:实现三地址码生成器

现在,我们实现一个三地址码生成器,用于将 AST 转换为三地址码。

class TACGenerator:
    """三地址码生成器"""
    def __init__(self):
        self.temp_count = 0
        self.instructions = []
    
    def new_temp(self):
        """生成新的临时变量"""
        temp = f"t{self.temp_count}"
        self.temp_count += 1
        return temp
    
    def generate(self, node):
        """生成三地址码"""
        if isinstance(node, Constant):
            # 常量直接返回其值
            return node.value
        elif isinstance(node, Variable):
            # 变量直接返回其名
            return node.name
        elif isinstance(node, BinaryOp):
            # 处理二元操作
            left_val = self.generate(node.left)
            right_val = self.generate(node.right)
            temp = self.new_temp()
            # 生成三地址码指令
            self.instructions.append(f"{temp} = {left_val} {node.op} {right_val}")
            return temp
        elif isinstance(node, UnaryOp):
            # 处理一元操作
            expr_val = self.generate(node.expr)
            temp = self.new_temp()
            # 生成三地址码指令
            self.instructions.append(f"{temp} = {node.op} {expr_val}")
            return temp
        else:
            raise TypeError(f"Unknown AST node type: {type(node)}")
    
    def get_instructions(self):
        """获取生成的三地址码指令"""
        return self.instructions

步骤 4:测试表达式编译器

让我们测试一下我们的表达式编译器,看看它是否能正确生成三地址码。

# 简单的词法分析器
def tokenize(expr):
    """将表达式转换为 token 列表"""
    tokens = []
    i = 0
    while i < len(expr):
        c = expr[i]
        if c.isspace():
            i += 1
        elif c.isdigit():
            # 解析数字
            num = ""
            while i < len(expr) and expr[i].isdigit():
                num += expr[i]
                i += 1
            tokens.append(('NUMBER', int(num)))
        elif c.isalpha():
            # 解析标识符
            id = ""
            while i < len(expr) and expr[i].isalnum():
                id += expr[i]
                i += 1
            tokens.append(('ID', id))
        elif c in '+-*/()':
            # 解析操作符和括号
            tokens.append((c, c))
            i += 1
        else:
            raise SyntaxError(f"Unexpected character: {c}")
    return tokens

# 测试函数
def test_expression(expr):
    print(f"\nTesting expression: {expr}")
    # 词法分析
    tokens = tokenize(expr)
    print(f"Tokens: {tokens}")
    # 语法分析
    parser = Parser(tokens)
    ast = parser.parse()
    # 生成三地址码
    generator = TACGenerator()
    result = generator.generate(ast)
    instructions = generator.get_instructions()
    print("Generated TAC:")
    for i, instr in enumerate(instructions):
        print(f"  {i+1}. {instr}")
    print(f"Result: {result}")

# 测试案例
test_expression("a + b * c")
test_expression("(x + y) * (z - w)")
test_expression("-a + 5 * b")
test_expression("x * (y + z / 2)")

步骤 5:运行结果分析

运行上面的测试代码,我们会得到以下输出:

Testing expression: a + b * c
Tokens: [('ID', 'a'), ('+', '+'), ('ID', 'b'), ('*', '*'), ('ID', 'c')]
Generated TAC:
  1. t0 = b * c
  2. t1 = a + t0
Result: t1

Testing expression: (x + y) * (z - w)
Tokens: [('(', '('), ('ID', 'x'), ('+', '+'), ('ID', 'y'), (')', ')'), ('*', '*'), ('(', '('), ('ID', 'z'), ('-', '-'), ('ID', 'w'), (')', ')')]
Generated TAC:
  1. t0 = x + y
  2. t1 = z - w
  3. t2 = t0 * t1
Result: t2

Testing expression: -a + 5 * b
Tokens: [('-', '-'), ('ID', 'a'), ('+', '+'), ('NUMBER', 5), ('*', '*'), ('ID', 'b')]
Generated TAC:
  1. t0 = - a
  2. t1 = 5 * b
  3. t2 = t0 + t1
Result: t2

Testing expression: x * (y + z / 2)
Tokens: [('ID', 'x'), ('*', '*'), ('(', '('), ('ID', 'y'), ('+', '+'), ('ID', 'z'), ('/', '/'), ('NUMBER', 2), (')', ')')]
Generated TAC:
  1. t0 = z / 2
  2. t1 = y + t0
  3. t2 = x * t1
Result: t2

实用案例分析

案例:处理复杂表达式的优化

在生成三地址码时,我们可以进行一些简单的优化,以减少生成的指令数量和临时变量的使用。

优化 1:常量折叠

常量折叠是指在编译时计算常量表达式的值,而不是在运行时计算。

def optimize_constant_folding(node):
    """常量折叠优化"""
    if isinstance(node, Constant):
        return node
    elif isinstance(node, Variable):
        return node
    elif isinstance(node, BinaryOp):
        # 递归优化子节点
        left = optimize_constant_folding(node.left)
        right = optimize_constant_folding(node.right)
        # 如果左右都是常量,直接计算结果
        if isinstance(left, Constant) and isinstance(right, Constant):
            if node.op == '+':
                return Constant(left.value + right.value)
            elif node.op == '-':
                return Constant(left.value - right.value)
            elif node.op == '*':
                return Constant(left.value * right.value)
            elif node.op == '/':
                return Constant(left.value // right.value)
        return BinaryOp(node.op, left, right)
    elif isinstance(node, UnaryOp):
        # 递归优化子节点
        expr = optimize_constant_folding(node.expr)
        # 如果子节点是常量,直接计算结果
        if isinstance(expr, Constant):
            if node.op == '-':
                return Constant(-expr.value)
            elif node.op == '+':
                return expr
        return UnaryOp(node.op, expr)
    else:
        return node

优化 2:公共子表达式消除

公共子表达式消除是指识别并消除重复计算的子表达式。

def find_common_subexpressions(node, subexprs):
    """查找公共子表达式"""
    if isinstance(node, (Constant, Variable)):
        return
    elif isinstance(node, BinaryOp):
        # 递归查找子节点
        find_common_subexpressions(node.left, subexprs)
        find_common_subexpressions(node.right, subexprs)
        # 生成子表达式的字符串表示
        left_repr = repr(node.left)
        right_repr = repr(node.right)
        expr_repr = f"{left_repr} {node.op} {right_repr}"
        # 统计子表达式出现的次数
        if expr_repr in subexprs:
            subexprs[expr_repr] += 1
        else:
            subexprs[expr_repr] = 1
    elif isinstance(node, UnaryOp):
        # 递归查找子节点
        find_common_subexpressions(node.expr, subexprs)
        # 生成子表达式的字符串表示
        expr_repr = f"{node.op} {repr(node.expr)}"
        # 统计子表达式出现的次数
        if expr_repr in subexprs:
            subexprs[expr_repr] += 1
        else:
            subexprs[expr_repr] = 1

# 为了支持 repr 函数,我们需要为 AST 节点添加 __repr__ 方法
class Constant(ASTNode):
    # 之前的代码...
    def __repr__(self):
        return f"Constant({self.value})"

class Variable(ASTNode):
    # 之前的代码...
    def __repr__(self):
        return f"Variable('{self.name}')"

class BinaryOp(ASTNode):
    # 之前的代码...
    def __repr__(self):
        return f"BinaryOp('{self.op}', {repr(self.left)}, {repr(self.right)})"

class UnaryOp(ASTNode):
    # 之前的代码...
    def __repr__(self):
        return f"UnaryOp('{self.op}', {repr(self.expr)})"

案例:支持更复杂的表达式

我们可以扩展我们的编译器,使其支持更多类型的表达式,例如关系表达式和逻辑表达式。

扩展解析器以支持关系表达式

def parse_expr(self):
    """解析表达式(处理加减和关系运算)"""
    left = self.parse_term()
    while self.peek() and self.peek()[0] in ('+', '-', '<', '<=', '>', '>=', '==', '!='):
        op = self.consume()[0]
        right = self.parse_term()
        left = BinaryOp(op, left, right)
    return left

扩展三地址码生成器以支持关系表达式

def generate(self, node):
    """生成三地址码"""
    if isinstance(node, Constant):
        return node.value
    elif isinstance(node, Variable):
        return node.name
    elif isinstance(node, BinaryOp):
        left_val = self.generate(node.left)
        right_val = self.generate(node.right)
        temp = self.new_temp()
        # 生成三地址码指令
        self.instructions.append(f"{temp} = {left_val} {node.op} {right_val}")
        return temp
    elif isinstance(node, UnaryOp):
        expr_val = self.generate(node.expr)
        temp = self.new_temp()
        # 生成三地址码指令
        self.instructions.append(f"{temp} = {node.op} {expr_val}")
        return temp
    else:
        raise TypeError(f"Unknown AST node type: {type(node)}")

实用案例分析

完整的表达式编译器实现

现在,让我们将所有组件整合起来,实现一个完整的表达式编译器。

class CompleteExpressionCompiler:
    """完整的表达式编译器"""
    
    def __init__(self):
        pass
    
    def compile(self, expr):
        """编译表达式"""
        # 1. 词法分析
        tokens = self.tokenize(expr)
        print(f"Tokens: {tokens}")
        
        # 2. 语法分析
        parser = Parser(tokens)
        ast = parser.parse()
        print(f"AST: {repr(ast)}")
        
        # 3. 优化 AST
        optimized_ast = self.optimize(ast)
        print(f"Optimized AST: {repr(optimized_ast)}")
        
        # 4. 生成三地址码
        generator = TACGenerator()
        result = generator.generate(optimized_ast)
        instructions = generator.get_instructions()
        
        print("Generated TAC:")
        for i, instr in enumerate(instructions):
            print(f"  {i+1}. {instr}")
        print(f"Result: {result}")
        
        return instructions
    
    def tokenize(self, expr):
        """词法分析"""
        tokens = []
        i = 0
        while i < len(expr):
            c = expr[i]
            if c.isspace():
                i += 1
            elif c.isdigit():
                # 解析数字
                num = ""
                while i < len(expr) and expr[i].isdigit():
                    num += expr[i]
                    i += 1
                tokens.append(('NUMBER', int(num)))
            elif c.isalpha():
                # 解析标识符
                id = ""
                while i < len(expr) and expr[i].isalnum():
                    id += expr[i]
                    i += 1
                tokens.append(('ID', id))
            elif c in '+-*/()':
                # 解析操作符和括号
                tokens.append((c, c))
                i += 1
            elif c == '=' and i+1 < len(expr) and expr[i+1] == '=':
                # 解析 ==
                tokens.append(('==', '=='))
                i += 2
            elif c == '!' and i+1 < len(expr) and expr[i+1] == '=':
                # 解析 !=
                tokens.append(('!=', '!='))
                i += 2
            elif c == '<' and i+1 < len(expr) and expr[i+1] == '=':
                # 解析 <=
                tokens.append(('<=', '<='))
                i += 2
            elif c == '>' and i+1 < len(expr) and expr[i+1] == '=':
                # 解析 >=
                tokens.append(('>=', '>='))
                i += 2
            else:
                raise SyntaxError(f"Unexpected character: {c}")
        return tokens
    
    def optimize(self, node):
        """优化 AST"""
        # 应用常量折叠
        node = self.optimize_constant_folding(node)
        # 这里可以添加其他优化
        return node
    
    def optimize_constant_folding(self, node):
        """常量折叠优化"""
        if isinstance(node, Constant):
            return node
        elif isinstance(node, Variable):
            return node
        elif isinstance(node, BinaryOp):
            # 递归优化子节点
            left = self.optimize_constant_folding(node.left)
            right = self.optimize_constant_folding(node.right)
            # 如果左右都是常量,直接计算结果
            if isinstance(left, Constant) and isinstance(right, Constant):
                if node.op == '+':
                    return Constant(left.value + right.value)
                elif node.op == '-':
                    return Constant(left.value - right.value)
                elif node.op == '*':
                    return Constant(left.value * right.value)
                elif node.op == '/':
                    return Constant(left.value // right.value)
                elif node.op == '==':
                    return Constant(1 if left.value == right.value else 0)
                elif node.op == '!=':
                    return Constant(1 if left.value != right.value else 0)
                elif node.op == '<':
                    return Constant(1 if left.value < right.value else 0)
                elif node.op == '<=':
                    return Constant(1 if left.value <= right.value else 0)
                elif node.op == '>':
                    return Constant(1 if left.value > right.value else 0)
                elif node.op == '>=':
                    return Constant(1 if left.value >= right.value else 0)
            return BinaryOp(node.op, left, right)
        elif isinstance(node, UnaryOp):
            # 递归优化子节点
            expr = self.optimize_constant_folding(node.expr)
            # 如果子节点是常量,直接计算结果
            if isinstance(expr, Constant):
                if node.op == '-':
                    return Constant(-expr.value)
                elif node.op == '+':
                    return expr
            return UnaryOp(node.op, expr)
        else:
            return node

# 测试完整的编译器
compiler = CompleteExpressionCompiler()
compiler.compile("a + 2 * 3")
compiler.compile("x * (y + 5) > z")
compiler.compile("(a + b) * (a + b)")

总结

本集我们通过实战案例,详细讲解了如何将 AST 转换为三地址码,重点关注了表达式的中间代码生成过程:

  1. AST 到 TAC 的转换原理:采用后序遍历的方式,确保在处理父节点时所有子表达式已经被处理完毕

  2. 临时变量管理:自动生成唯一的临时变量名,合理管理临时变量的作用域

  3. 表达式的三地址码生成规则:针对不同类型的表达式(常量、变量、二元表达式、一元表达式)采用不同的生成规则

  4. 优化技术:介绍了常量折叠和公共子表达式消除等优化技术,以减少生成的指令数量和临时变量的使用

  5. 完整的表达式编译器实现:整合了词法分析、语法分析、优化和三地址码生成等组件,实现了一个功能完整的表达式编译器

通过本集的学习,我们掌握了表达式的中间代码生成技术,为后续学习控制流和函数的中间代码生成奠定了基础。在下一集中,我们将重点讲解控制流语句的中间代码生成过程。

« 上一篇 生成 LLVM IR(二) 下一篇 » 中间代码生成实战(二)—— 控制流