第89集:构造抽象语法树(AST)

核心知识点讲解

为什么需要 AST?

抽象语法树(Abstract Syntax Tree,简称AST)是编译器前端的重要数据结构,它在语法分析和语义分析之间起到了桥梁作用。与具体语法树相比,AST更加简洁,去除了语法分析中的冗余信息,更适合后续的语义分析和代码生成。

主要优势:

  • 结构清晰:更接近语言的语义结构
  • 易于遍历:便于进行语义分析和代码生成
  • 信息丰富:可以携带更多的语义信息
  • 模块化:使编译器前端各阶段职责分明

AST 节点设计

设计良好的AST节点结构是构建高效编译器的关键。一个典型的AST节点设计应考虑以下因素:

  1. 节点类型:表示不同的语法结构(如表达式、语句、声明等)
  2. 子节点引用:指向子节点的指针或引用
  3. 属性信息:节点的语义属性(如类型、作用域等)
  4. 位置信息:源代码中的位置,用于错误报告

常见的AST节点类型:

  • 表达式节点:二元表达式、一元表达式、常量表达式、标识符等
  • 语句节点:赋值语句、条件语句、循环语句、函数定义等
  • 声明节点:变量声明、函数声明、类型声明等

在动作中构建 AST

在Yacc/Bison中,可以在语义动作中构建AST。具体步骤如下:

  1. 定义节点结构:创建表示不同语法结构的节点类型
  2. 实现节点创建函数:用于创建不同类型的AST节点
  3. 在语义动作中构建:为每个产生式编写构建AST的语义动作
  4. 管理内存:确保节点的正确分配和释放

遍历 AST

AST构建完成后,需要对其进行遍历以执行各种操作:

  • 前序遍历:先访问父节点,再访问子节点
  • 中序遍历:先访问左子节点,再访问父节点,最后访问右子节点
  • 后序遍历:先访问子节点,再访问父节点
  • 层次遍历:按层次从上到下、从左到右访问节点

实用案例分析

完整的 AST 实现示例

下面是一个完整的AST实现示例,用于处理简单的表达式和语句:

// ast.h - AST节点定义
#ifndef AST_H
#define AST_H

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// 节点类型枚举
typedef enum {
    NODE_PROGRAM,
    NODE_EXPR,
    NODE_BINOP,
    NODE_UNOP,
    NODE_CONSTANT,
    NODE_IDENTIFIER,
    NODE_ASSIGN,
    NODE_IF,
    NODE_WHILE,
    NODE_PRINT
} NodeType;

// 二元运算符类型
typedef enum {
    OP_ADD,
    OP_SUB,
    OP_MUL,
    OP_DIV,
    OP_EQ,
    OP_NE,
    OP_LT,
    OP_LE,
    OP_GT,
    OP_GE
} BinOpType;

// 一元运算符类型
typedef enum {
    OP_NEG,
    OP_NOT
} UnOpType;

// 常量类型
typedef enum {
    CONST_INT,
    CONST_DOUBLE,
    CONST_STRING
} ConstType;

// 位置信息
typedef struct {
    int line;
    int column;
} Position;

// AST节点基类
typedef struct ASTNode {
    NodeType type;
    Position pos;
    struct ASTNode *next; // 用于链表结构
} ASTNode;

// 程序节点
typedef struct {
    ASTNode base;
    ASTNode *statements;
} ProgramNode;

// 表达式节点
typedef struct {
    ASTNode base;
} ExprNode;

// 二元表达式节点
typedef struct {
    ExprNode base;
    BinOpType op;
    ExprNode *left;
    ExprNode *right;
} BinOpNode;

// 一元表达式节点
typedef struct {
    ExprNode base;
    UnOpType op;
    ExprNode *expr;
} UnOpNode;

// 常量节点
typedef struct {
    ExprNode base;
    ConstType const_type;
    union {
        int int_val;
        double double_val;
        char *string_val;
    } value;
} ConstantNode;

// 标识符节点
typedef struct {
    ExprNode base;
    char *name;
} IdentifierNode;

// 赋值语句节点
typedef struct {
    ASTNode base;
    IdentifierNode *target;
    ExprNode *value;
} AssignNode;

// if语句节点
typedef struct {
    ASTNode base;
    ExprNode *condition;
    ASTNode *then_branch;
    ASTNode *else_branch;
} IfNode;

// while语句节点
typedef struct {
    ASTNode base;
    ExprNode *condition;
    ASTNode *body;
} WhileNode;

// print语句节点
typedef struct {
    ASTNode base;
    ExprNode *expr;
} PrintNode;

// 节点创建函数
ProgramNode *create_program();
BinOpNode *create_binop(BinOpType op, ExprNode *left, ExprNode *right);
UnOpNode *create_unop(UnOpType op, ExprNode *expr);
ConstantNode *create_constant_int(int value);
ConstantNode *create_constant_double(double value);
ConstantNode *create_constant_string(const char *value);
IdentifierNode *create_identifier(const char *name);
AssignNode *create_assign(IdentifierNode *target, ExprNode *value);
IfNode *create_if(ExprNode *condition, ASTNode *then_branch, ASTNode *else_branch);
WhileNode *create_while(ExprNode *condition, ASTNode *body);
PrintNode *create_print(ExprNode *expr);

// 节点添加函数
void add_statement(ProgramNode *program, ASTNode *statement);

// 遍历函数
void print_ast(ASTNode *node, int indent);
void free_ast(ASTNode *node);

#endif // AST_H
// ast.c - AST实现
#include "ast.h"

// 创建程序节点
ProgramNode *create_program() {
    ProgramNode *node = (ProgramNode *)malloc(sizeof(ProgramNode));
    node->base.type = NODE_PROGRAM;
    node->base.next = NULL;
    node->statements = NULL;
    return node;
}

// 创建二元表达式节点
BinOpNode *create_binop(BinOpType op, ExprNode *left, ExprNode *right) {
    BinOpNode *node = (BinOpNode *)malloc(sizeof(BinOpNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->op = op;
    node->left = left;
    node->right = right;
    return node;
}

// 创建一元表达式节点
UnOpNode *create_unop(UnOpType op, ExprNode *expr) {
    UnOpNode *node = (UnOpNode *)malloc(sizeof(UnOpNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->op = op;
    node->expr = expr;
    return node;
}

// 创建整型常量节点
ConstantNode *create_constant_int(int value) {
    ConstantNode *node = (ConstantNode *)malloc(sizeof(ConstantNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->const_type = CONST_INT;
    node->value.int_val = value;
    return node;
}

// 创建浮点型常量节点
ConstantNode *create_constant_double(double value) {
    ConstantNode *node = (ConstantNode *)malloc(sizeof(ConstantNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->const_type = CONST_DOUBLE;
    node->value.double_val = value;
    return node;
}

// 创建字符串常量节点
ConstantNode *create_constant_string(const char *value) {
    ConstantNode *node = (ConstantNode *)malloc(sizeof(ConstantNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->const_type = CONST_STRING;
    node->value.string_val = strdup(value);
    return node;
}

// 创建标识符节点
IdentifierNode *create_identifier(const char *name) {
    IdentifierNode *node = (IdentifierNode *)malloc(sizeof(IdentifierNode));
    node->base.base.type = NODE_EXPR;
    node->base.base.next = NULL;
    node->name = strdup(name);
    return node;
}

// 创建赋值语句节点
AssignNode *create_assign(IdentifierNode *target, ExprNode *value) {
    AssignNode *node = (AssignNode *)malloc(sizeof(AssignNode));
    node->base.type = NODE_ASSIGN;
    node->base.next = NULL;
    node->target = target;
    node->value = value;
    return node;
}

// 创建if语句节点
IfNode *create_if(ExprNode *condition, ASTNode *then_branch, ASTNode *else_branch) {
    IfNode *node = (IfNode *)malloc(sizeof(IfNode));
    node->base.type = NODE_IF;
    node->base.next = NULL;
    node->condition = condition;
    node->then_branch = then_branch;
    node->else_branch = else_branch;
    return node;
}

// 创建while语句节点
WhileNode *create_while(ExprNode *condition, ASTNode *body) {
    WhileNode *node = (WhileNode *)malloc(sizeof(WhileNode));
    node->base.type = NODE_WHILE;
    node->base.next = NULL;
    node->condition = condition;
    node->body = body;
    return node;
}

// 创建print语句节点
PrintNode *create_print(ExprNode *expr) {
    PrintNode *node = (PrintNode *)malloc(sizeof(PrintNode));
    node->base.type = NODE_PRINT;
    node->base.next = NULL;
    node->expr = expr;
    return node;
}

// 添加语句到程序节点
void add_statement(ProgramNode *program, ASTNode *statement) {
    if (!program->statements) {
        program->statements = statement;
    } else {
        ASTNode *current = program->statements;
        while (current->next) {
            current = current->next;
        }
        current->next = statement;
    }
}

// 获取运算符字符串
const char *get_op_string(BinOpType op) {
    switch (op) {
        case OP_ADD: return "+";
        case OP_SUB: return "-";
        case OP_MUL: return "*";
        case OP_DIV: return "/";
        case OP_EQ: return "==";
        case OP_NE: return "!=";
        case OP_LT: return "<";
        case OP_LE: return "<=";
        case OP_GT: return ">";
        case OP_GE: return ">=";
        default: return "?";
    }
}

// 获取一元运算符字符串
const char *get_unop_string(UnOpType op) {
    switch (op) {
        case OP_NEG: return "-";
        case OP_NOT: return "!";
        default: return "?";
    }
}

// 打印AST
void print_ast(ASTNode *node, int indent) {
    if (!node) return;
    
    // 打印缩进
    for (int i = 0; i < indent; i++) {
        printf("  ");
    }
    
    switch (node->type) {
        case NODE_PROGRAM:
            {
                ProgramNode *prog = (ProgramNode *)node;
                printf("Program\n");
                print_ast(prog->statements, indent + 1);
            }
            break;
        
        case NODE_EXPR:
            {
                ExprNode *expr = (ExprNode *)node;
                // 进一步判断表达式类型
                if (((ASTNode *)expr)->type == NODE_EXPR) {
                    // 检查具体的表达式类型
                    if (((char *)expr)[sizeof(ASTNode)] == NODE_BINOP) {
                        BinOpNode *binop = (BinOpNode *)expr;
                        printf("BinOp: %s\n", get_op_string(binop->op));
                        print_ast((ASTNode *)binop->left, indent + 1);
                        print_ast((ASTNode *)binop->right, indent + 1);
                    } else if (((char *)expr)[sizeof(ASTNode)] == NODE_UNOP) {
                        UnOpNode *unop = (UnOpNode *)expr;
                        printf("UnOp: %s\n", get_unop_string(unop->op));
                        print_ast((ASTNode *)unop->expr, indent + 1);
                    } else if (((char *)expr)[sizeof(ASTNode)] == NODE_CONSTANT) {
                        ConstantNode *constant = (ConstantNode *)expr;
                        switch (constant->const_type) {
                            case CONST_INT:
                                printf("Constant: %d\n", constant->value.int_val);
                                break;
                            case CONST_DOUBLE:
                                printf("Constant: %g\n", constant->value.double_val);
                                break;
                            case CONST_STRING:
                                printf("Constant: \"%s\"\n", constant->value.string_val);
                                break;
                        }
                    } else if (((char *)expr)[sizeof(ASTNode)] == NODE_IDENTIFIER) {
                        IdentifierNode *ident = (IdentifierNode *)expr;
                        printf("Identifier: %s\n", ident->name);
                    }
                }
            }
            break;
        
        case NODE_ASSIGN:
            {
                AssignNode *assign = (AssignNode *)node;
                printf("Assign\n");
                print_ast((ASTNode *)assign->target, indent + 1);
                print_ast((ASTNode *)assign->value, indent + 1);
            }
            break;
        
        case NODE_IF:
            {
                IfNode *if_node = (IfNode *)node;
                printf("If\n");
                print_ast((ASTNode *)if_node->condition, indent + 1);
                print_ast(if_node->then_branch, indent + 1);
                if (if_node->else_branch) {
                    for (int i = 0; i < indent; i++) {
                        printf("  ");
                    }
                    printf("Else\n");
                    print_ast(if_node->else_branch, indent + 1);
                }
            }
            break;
        
        case NODE_WHILE:
            {
                WhileNode *while_node = (WhileNode *)node;
                printf("While\n");
                print_ast((ASTNode *)while_node->condition, indent + 1);
                print_ast(while_node->body, indent + 1);
            }
            break;
        
        case NODE_PRINT:
            {
                PrintNode *print_node = (PrintNode *)node;
                printf("Print\n");
                print_ast((ASTNode *)print_node->expr, indent + 1);
            }
            break;
        
        default:
            printf("Unknown node type\n");
            break;
    }
    
    // 打印链表中的下一个节点
    if (node->next) {
        print_ast(node->next, indent);
    }
}

// 释放AST
void free_ast(ASTNode *node) {
    if (!node) return;
    
    switch (node->type) {
        case NODE_PROGRAM:
            {
                ProgramNode *prog = (ProgramNode *)node;
                free_ast(prog->statements);
            }
            break;
        
        case NODE_EXPR:
            {
                ExprNode *expr = (ExprNode *)node;
                // 进一步判断表达式类型并释放
                if (((char *)expr)[sizeof(ASTNode)] == NODE_BINOP) {
                    BinOpNode *binop = (BinOpNode *)expr;
                    free_ast((ASTNode *)binop->left);
                    free_ast((ASTNode *)binop->right);
                } else if (((char *)expr)[sizeof(ASTNode)] == NODE_UNOP) {
                    UnOpNode *unop = (UnOpNode *)expr;
                    free_ast((ASTNode *)unop->expr);
                } else if (((char *)expr)[sizeof(ASTNode)] == NODE_CONSTANT) {
                    ConstantNode *constant = (ConstantNode *)expr;
                    if (constant->const_type == CONST_STRING) {
                        free(constant->value.string_val);
                    }
                } else if (((char *)expr)[sizeof(ASTNode)] == NODE_IDENTIFIER) {
                    IdentifierNode *ident = (IdentifierNode *)expr;
                    free(ident->name);
                }
            }
            break;
        
        case NODE_ASSIGN:
            {
                AssignNode *assign = (AssignNode *)node;
                free_ast((ASTNode *)assign->target);
                free_ast((ASTNode *)assign->value);
            }
            break;
        
        case NODE_IF:
            {
                IfNode *if_node = (IfNode *)node;
                free_ast((ASTNode *)if_node->condition);
                free_ast(if_node->then_branch);
                free_ast(if_node->else_branch);
            }
            break;
        
        case NODE_WHILE:
            {
                WhileNode *while_node = (WhileNode *)node;
                free_ast((ASTNode *)while_node->condition);
                free_ast(while_node->body);
            }
            break;
        
        case NODE_PRINT:
            {
                PrintNode *print_node = (PrintNode *)node;
                free_ast((ASTNode *)print_node->expr);
            }
            break;
    }
    
    // 释放链表中的下一个节点
    ASTNode *next = node->next;
    free(node);
    free_ast(next);
}

在 Yacc/Bison 中使用 AST

%{
#include <stdio.h>
#include <stdlib.h>
#include "ast.h"

ProgramNode *program;

int yylex();
void yyerror(const char *s);
%}

%union {
    int int_val;
    double double_val;
    char *string_val;
    ExprNode *expr;
    ASTNode *stmt;
    IdentifierNode *ident;
}

%token <int_val> INT
%token <double_val> DOUBLE
%token <string_val> STRING ID
%token ASSIGN
%token IF ELSE WHILE PRINT
%token LPAREN RPAREN LBRACE RBRACE
%token SEMICOLON
%token PLUS MINUS MULT DIV
%token EQ NE LT LE GT GE
%token NOT

%type <expr> expr term factor primary
%type <stmt> stmt assign_stmt if_stmt while_stmt print_stmt block

%%

program: /* 空 */ {
        program = create_program();
    }
    | program stmt {
        add_statement(program, $2);
    }
    ;

stmt: assign_stmt
    | if_stmt
    | while_stmt
    | print_stmt
    ;

assign_stmt: ID ASSIGN expr SEMICOLON {
        IdentifierNode *ident = create_identifier($1);
        $$ = (ASTNode *)create_assign(ident, $3);
        free($1);
    }
    ;

if_stmt: IF LPAREN expr RPAREN stmt {
        $$ = (ASTNode *)create_if($3, $5, NULL);
    }
    | IF LPAREN expr RPAREN stmt ELSE stmt {
        $$ = (ASTNode *)create_if($3, $5, $7);
    }
    ;

while_stmt: WHILE LPAREN expr RPAREN stmt {
        $$ = (ASTNode *)create_while($3, $5);
    }
    ;

print_stmt: PRINT LPAREN expr RPAREN SEMICOLON {
        $$ = (ASTNode *)create_print($3);
    }
    ;

block: LBRACE program RBRACE {
        $$ = program->statements;
    }
    ;

expr: expr PLUS term {
        $$ = (ExprNode *)create_binop(OP_ADD, $1, $3);
    }
    | expr MINUS term {
        $$ = (ExprNode *)create_binop(OP_SUB, $1, $3);
    }
    | term {
        $$ = $1;
    }
    ;

term: term MULT factor {
        $$ = (ExprNode *)create_binop(OP_MUL, $1, $3);
    }
    | term DIV factor {
        $$ = (ExprNode *)create_binop(OP_DIV, $1, $3);
    }
    | factor {
        $$ = $1;
    }
    ;

factor: primary
    | MINUS primary {
        $$ = (ExprNode *)create_unop(OP_NEG, $2);
    }
    | NOT primary {
        $$ = (ExprNode *)create_unop(OP_NOT, $2);
    }
    ;

primary: INT {
        $$ = (ExprNode *)create_constant_int($1);
    }
    | DOUBLE {
        $$ = (ExprNode *)create_constant_double($1);
    }
    | STRING {
        $$ = (ExprNode *)create_constant_string($1);
        free($1);
    }
    | ID {
        $$ = (ExprNode *)create_identifier($1);
        free($1);
    }
    | LPAREN expr RPAREN {
        $$ = $2;
    }
    ;

%%

int main() {
    printf("输入程序,按Ctrl+D结束\n");
    yyparse();
    
    printf("\n抽象语法树:\n");
    print_ast((ASTNode *)program, 0);
    
    free_ast((ASTNode *)program);
    return 0;
}

void yyerror(const char *s) {
    fprintf(stderr, "错误: %s\n", s);
}

技术要点总结

  1. AST 设计原则

    • 节点类型应覆盖所有语法结构
    • 节点结构应简洁明了
    • 应包含足够的语义信息
    • 应考虑内存管理和遍历效率
  2. AST 构建技巧

    • 在语义动作中逐步构建
    • 使用工厂函数创建节点
    • 正确处理节点间的引用关系
    • 管理好内存分配和释放
  3. AST 遍历策略

    • 前序遍历:适合代码生成
    • 后序遍历:适合表达式求值
    • 层次遍历:适合整体分析
    • 根据具体需求选择合适的遍历方式
  4. 常见问题与解决方案

    • 内存泄漏:实现完善的释放函数
    • 遍历效率:优化遍历算法
    • 节点类型过多:使用继承或联合体减少代码冗余
    • 错误定位:在节点中保存位置信息
  5. AST 在编译器中的应用

    • 语义分析:类型检查、作用域分析
    • 代码生成:生成中间代码或目标代码
    • 代码优化:进行各种优化变换
    • 语言工具:静态分析、代码格式化

代码优化建议

  1. 内存管理优化

    • 使用内存池减少内存分配开销
    • 实现引用计数或垃圾回收
    • 延迟释放策略,提高性能
  2. 节点设计优化

    • 使用 tagged union 减少内存占用
    • 考虑使用结构体继承(在支持的语言中)
    • 为频繁访问的属性提供直接访问方式
  3. 遍历优化

    • 实现迭代式遍历,避免递归栈溢出
    • 使用访问者模式,分离遍历逻辑和操作逻辑
    • 缓存遍历结果,避免重复计算
  4. 错误处理增强

    • 在节点中保存更详细的位置信息
    • 实现错误恢复机制,提高编译器健壮性
    • 提供更友好的错误提示
  5. 调试支持

    • 实现可视化工具,直观展示AST结构
    • 添加调试信息,便于定位问题
    • 提供序列化和反序列化功能,方便持久化

通过本集的学习,你已经掌握了抽象语法树的设计、构建和遍历技术。AST作为编译器前端的核心数据结构,将在后续的语义分析和代码生成阶段发挥重要作用。在实际编译器开发中,应根据具体语言的特点和编译器的需求,设计适合的AST结构,以提高编译器的性能和可维护性。

« 上一篇 Yacc 高级特性 下一篇 » 构造抽象语法树(AST)