第89集:构造抽象语法树(AST)

核心知识点讲解

为什么需要 AST?

在编译器前端,抽象语法树(AST)是一种重要的数据结构,它具有以下优点:

  1. 结构清晰:AST 以树状结构表示程序的语法结构,比原始的语法分析树更简洁,去除了无关的语法细节
  2. 易于遍历:树结构便于进行各种遍历操作,如前序、中序、后序遍历
  3. 语义分析:为语义分析、类型检查等提供基础
  4. 代码生成:为后续的中间代码生成和优化提供结构化的程序表示
  5. 语言无关:AST 可以表示不同编程语言的程序结构

AST 节点设计

一个典型的 AST 节点设计包含以下信息:

  1. 节点类型:表示节点的语法类别,如表达式、语句、声明等
  2. 节点数据:存储节点的具体数据,如数字值、标识符名称等
  3. 子节点指针:指向子节点的指针,如二元操作的左右操作数

示例节点结构:

typedef enum {
    AST_ADD,
    AST_SUB,
    AST_MUL,
    AST_DIV,
    AST_NUMBER,
    AST_IDENTIFIER,
    AST_ASSIGN,
    AST_DECLARATION,
    AST_IF,
    AST_WHILE,
    AST_PROGRAM,
    AST_STMT_LIST
} AstType;

typedef struct ast_node {
    AstType type;
    union {
        int number_val;
        char* identifier_val;
        struct {
            struct ast_node* left;
            struct ast_node* right;
        } binary_op;
        struct {
            char* identifier;
            struct ast_node* expression;
        } assign;
        struct {
            char* type;
            char* identifier;
        } declaration;
        struct {
            struct ast_node* condition;
            struct ast_node* then_body;
            struct ast_node* else_body;
        } if_stmt;
        struct {
            struct ast_node* condition;
            struct ast_node* body;
        } while_stmt;
        struct {
            struct ast_node* head;
            struct ast_node* tail;
        } list;
    } data;
} AstNode;

在动作中构建 AST

在 Yacc 中,我们可以在语义动作中构建 AST 节点:

expr: NUMBER { $$ = new_number($1); }
    | IDENTIFIER { $$ = new_identifier($1); }
    | expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
    | expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
    | expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
    | expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
    | '(' expr ')' { $$ = $2; }
    ;

stmt: IDENTIFIER ASSIGN expr SEMI { $$ = new_assign($1, $3); }
    | TYPE IDENTIFIER SEMI { $$ = new_declaration($1, $2); }
    | IF '(' expr ')' stmt ELSE stmt { $$ = new_if_stmt($3, $5, $7); }
    | WHILE '(' expr ')' stmt { $$ = new_while_stmt($3, $5); }
    ;

遍历 AST

AST 遍历是处理 AST 的核心操作,常见的遍历方式包括:

  1. 前序遍历:先访问节点,再访问子节点
  2. 中序遍历:先访问左子节点,再访问节点,最后访问右子节点
  3. 后序遍历:先访问子节点,再访问节点
  4. 广度优先遍历:按层次访问节点

前序遍历示例:

void preorder_traverse(AstNode* node) {
    if (!node) return;
    
    // 访问节点
    visit_node(node);
    
    // 递归访问子节点
    switch (node->type) {
        case AST_ADD:
        case AST_SUB:
        case AST_MUL:
        case AST_DIV:
            preorder_traverse(node->data.binary_op.left);
            preorder_traverse(node->data.binary_op.right);
            break;
        case AST_ASSIGN:
            preorder_traverse(node->data.assign.expression);
            break;
        case AST_IF:
            preorder_traverse(node->data.if_stmt.condition);
            preorder_traverse(node->data.if_stmt.then_body);
            preorder_traverse(node->data.if_stmt.else_body);
            break;
        case AST_WHILE:
            preorder_traverse(node->data.while_stmt.condition);
            preorder_traverse(node->data.while_stmt.body);
            break;
        case AST_STMT_LIST:
            preorder_traverse(node->data.list.head);
            preorder_traverse(node->data.list.tail);
            break;
        default:
            // 叶子节点,无需遍历子节点
            break;
    }
}

后序遍历示例:

void postorder_traverse(AstNode* node) {
    if (!node) return;
    
    // 递归访问子节点
    switch (node->type) {
        case AST_ADD:
        case AST_SUB:
        case AST_MUL:
        case AST_DIV:
            postorder_traverse(node->data.binary_op.left);
            postorder_traverse(node->data.binary_op.right);
            break;
        case AST_ASSIGN:
            postorder_traverse(node->data.assign.expression);
            break;
        case AST_IF:
            postorder_traverse(node->data.if_stmt.condition);
            postorder_traverse(node->data.if_stmt.then_body);
            postorder_traverse(node->data.if_stmt.else_body);
            break;
        case AST_WHILE:
            postorder_traverse(node->data.while_stmt.condition);
            postorder_traverse(node->data.while_stmt.body);
            break;
        case AST_STMT_LIST:
            postorder_traverse(node->data.list.head);
            postorder_traverse(node->data.list.tail);
            break;
        default:
            // 叶子节点,无需遍历子节点
            break;
    }
    
    // 访问节点
    visit_node(node);
}

实用案例分析

案例:简单表达式的 AST 构建

让我们创建一个简单的表达式解析器,构建并打印 AST:

AST 定义和操作:

/* ast.h */
#ifndef AST_H
#define AST_H

typedef enum {
    AST_ADD,
    AST_SUB,
    AST_MUL,
    AST_DIV,
    AST_NUMBER,
    AST_IDENTIFIER
} AstType;

typedef struct ast_node {
    AstType type;
    union {
        int number_val;
        char* identifier_val;
        struct {
            struct ast_node* left;
            struct ast_node* right;
        } binary_op;
    } data;
} AstNode;

AstNode* new_ast_node(AstType type);
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right);
AstNode* new_number(int value);
AstNode* new_identifier(const char* name);
void free_ast(AstNode* node);
void print_ast(AstNode* node, int indent);
int evaluate_ast(AstNode* node);

#endif
/* ast.c */
#include "ast.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

AstNode* new_ast_node(AstType type) {
    AstNode* node = (AstNode*)malloc(sizeof(AstNode));
    node->type = type;
    return node;
}

AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right) {
    AstNode* node = new_ast_node(type);
    node->data.binary_op.left = left;
    node->data.binary_op.right = right;
    return node;
}

AstNode* new_number(int value) {
    AstNode* node = new_ast_node(AST_NUMBER);
    node->data.number_val = value;
    return node;
}

AstNode* new_identifier(const char* name) {
    AstNode* node = new_ast_node(AST_IDENTIFIER);
    node->data.identifier_val = strdup(name);
    return node;
}

void free_ast(AstNode* node) {
    if (!node) return;
    
    switch (node->type) {
        case AST_ADD:
        case AST_SUB:
        case AST_MUL:
        case AST_DIV:
            free_ast(node->data.binary_op.left);
            free_ast(node->data.binary_op.right);
            break;
        case AST_IDENTIFIER:
            free(node->data.identifier_val);
            break;
        default:
            break;
    }
    
    free(node);
}

void print_ast(AstNode* node, int indent) {
    if (!node) return;
    
    for (int i = 0; i < indent; i++) {
        printf("  ");
    }
    
    switch (node->type) {
        case AST_ADD:
            printf("ADD\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_SUB:
            printf("SUB\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_MUL:
            printf("MUL\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_DIV:
            printf("DIV\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_NUMBER:
            printf("NUMBER: %d\n", node->data.number_val);
            break;
        case AST_IDENTIFIER:
            printf("IDENTIFIER: %s\n", node->data.identifier_val);
            break;
        default:
            printf("UNKNOWN\n");
            break;
    }
}

int evaluate_ast(AstNode* node) {
    if (!node) return 0;
    
    switch (node->type) {
        case AST_ADD:
            return evaluate_ast(node->data.binary_op.left) + evaluate_ast(node->data.binary_op.right);
        case AST_SUB:
            return evaluate_ast(node->data.binary_op.left) - evaluate_ast(node->data.binary_op.right);
        case AST_MUL:
            return evaluate_ast(node->data.binary_op.left) * evaluate_ast(node->data.binary_op.right);
        case AST_DIV:
            return evaluate_ast(node->data.binary_op.left) / evaluate_ast(node->data.binary_op.right);
        case AST_NUMBER:
            return node->data.number_val;
        case AST_IDENTIFIER:
            // 简单处理,返回0
            return 0;
        default:
            return 0;
    }
}

Yacc 文件:

/* parser.y */
%{
#include <stdio.h>
#include "ast.h"

int yylex();
void yyerror(const char* s);
%}

%union {
    int int_val;
    char* string_val;
    AstNode* ast_val;
}

%token <int_val> NUMBER
%token <string_val> IDENTIFIER
%token SEMI
%token EOL

%type <ast_val> expr

%%

calc: /* 空规则 */
    | calc expr EOL {
        printf("AST:\n");
        print_ast($2, 0);
        printf("计算结果: %d\n", evaluate_ast($2));
        free_ast($2);
    }
    | calc EOL
    ;

expr: NUMBER { $$ = new_number($1); }
    | IDENTIFIER { $$ = new_identifier($1); }
    | expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
    | expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
    | expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
    | expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
    | '(' expr ')' { $$ = $2; }
    ;

%%

void yyerror(const char* s) {
    fprintf(stderr, "错误: %s\n", s);
}

int main() {
    printf("表达式 AST 构建器\n");
    printf("输入表达式,例如: 5 + 3 * (4 - 2)\n");
    return yyparse();
}

Lex 文件:

/* lexer.l */
%{
#include "y.tab.h"
#include <string.h>
%}

%%

[0-9]+      { yylval.int_val = atoi(yytext); return NUMBER; }
[a-zA-Z_][a-zA-Z0-9_]* { yylval.string_val = strdup(yytext); return IDENTIFIER; }
"+"         { return '+'; }
"-"         { return '-'; }
"*"         { return '*'; }
"/"         { return '/'; }
"("         { return '('; }
")"         { return ')'; }
"\n"        { return EOL; }
[ \t]       { /* 忽略空白字符 */ }
.

%%

int yywrap() {
    return 1;
}

案例:完整语句的 AST 构建

让我们扩展上面的例子,支持变量声明和赋值语句:

扩展的 AST 定义:

/* ast.h */
#ifndef AST_H
#define AST_H

typedef enum {
    AST_ADD,
    AST_SUB,
    AST_MUL,
    AST_DIV,
    AST_NUMBER,
    AST_IDENTIFIER,
    AST_ASSIGN,
    AST_DECLARATION,
    AST_PROGRAM,
    AST_STMT_LIST
} AstType;

typedef struct ast_node {
    AstType type;
    union {
        int number_val;
        char* identifier_val;
        struct {
            struct ast_node* left;
            struct ast_node* right;
        } binary_op;
        struct {
            char* identifier;
            struct ast_node* expression;
        } assign;
        struct {
            char* type;
            char* identifier;
        } declaration;
        struct {
            struct ast_node* head;
            struct ast_node* tail;
        } list;
    } data;
} AstNode;

AstNode* new_ast_node(AstType type);
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right);
AstNode* new_number(int value);
AstNode* new_identifier(const char* name);
AstNode* new_assign(const char* identifier, AstNode* expression);
AstNode* new_declaration(const char* type, const char* identifier);
AstNode* new_program();
AstNode* new_stmt_list(AstNode* head, AstNode* tail);
void add_stmt_to_list(AstNode* list, AstNode* stmt);
void free_ast(AstNode* node);
void print_ast(AstNode* node, int indent);

#endif
/* ast.c */
#include "ast.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

AstNode* new_ast_node(AstType type) {
    AstNode* node = (AstNode*)malloc(sizeof(AstNode));
    node->type = type;
    return node;
}

AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right) {
    AstNode* node = new_ast_node(type);
    node->data.binary_op.left = left;
    node->data.binary_op.right = right;
    return node;
}

AstNode* new_number(int value) {
    AstNode* node = new_ast_node(AST_NUMBER);
    node->data.number_val = value;
    return node;
}

AstNode* new_identifier(const char* name) {
    AstNode* node = new_ast_node(AST_IDENTIFIER);
    node->data.identifier_val = strdup(name);
    return node;
}

AstNode* new_assign(const char* identifier, AstNode* expression) {
    AstNode* node = new_ast_node(AST_ASSIGN);
    node->data.assign.identifier = strdup(identifier);
    node->data.assign.expression = expression;
    return node;
}

AstNode* new_declaration(const char* type, const char* identifier) {
    AstNode* node = new_ast_node(AST_DECLARATION);
    node->data.declaration.type = strdup(type);
    node->data.declaration.identifier = strdup(identifier);
    return node;
}

AstNode* new_program() {
    return new_ast_node(AST_PROGRAM);
}

AstNode* new_stmt_list(AstNode* head, AstNode* tail) {
    AstNode* node = new_ast_node(AST_STMT_LIST);
    node->data.list.head = head;
    node->data.list.tail = tail;
    return node;
}

void add_stmt_to_list(AstNode* list, AstNode* stmt) {
    if (!list->data.list.head) {
        list->data.list.head = stmt;
        list->data.list.tail = stmt;
    } else {
        // 简单实现,实际应该使用链表
        list->data.list.tail = stmt;
    }
}

void free_ast(AstNode* node) {
    if (!node) return;
    
    switch (node->type) {
        case AST_ADD:
        case AST_SUB:
        case AST_MUL:
        case AST_DIV:
            free_ast(node->data.binary_op.left);
            free_ast(node->data.binary_op.right);
            break;
        case AST_IDENTIFIER:
            free(node->data.identifier_val);
            break;
        case AST_ASSIGN:
            free(node->data.assign.identifier);
            free_ast(node->data.assign.expression);
            break;
        case AST_DECLARATION:
            free(node->data.declaration.type);
            free(node->data.declaration.identifier);
            break;
        case AST_STMT_LIST:
            free_ast(node->data.list.head);
            free_ast(node->data.list.tail);
            break;
        default:
            break;
    }
    
    free(node);
}

void print_ast(AstNode* node, int indent) {
    if (!node) return;
    
    for (int i = 0; i < indent; i++) {
        printf("  ");
    }
    
    switch (node->type) {
        case AST_ADD:
            printf("ADD\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_SUB:
            printf("SUB\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_MUL:
            printf("MUL\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_DIV:
            printf("DIV\n");
            print_ast(node->data.binary_op.left, indent + 1);
            print_ast(node->data.binary_op.right, indent + 1);
            break;
        case AST_NUMBER:
            printf("NUMBER: %d\n", node->data.number_val);
            break;
        case AST_IDENTIFIER:
            printf("IDENTIFIER: %s\n", node->data.identifier_val);
            break;
        case AST_ASSIGN:
            printf("ASSIGN: %s\n", node->data.assign.identifier);
            print_ast(node->data.assign.expression, indent + 1);
            break;
        case AST_DECLARATION:
            printf("DECLARATION: %s %s\n", node->data.declaration.type, node->data.declaration.identifier);
            break;
        case AST_PROGRAM:
            printf("PROGRAM\n");
            break;
        case AST_STMT_LIST:
            printf("STMT_LIST\n");
            print_ast(node->data.list.head, indent + 1);
            print_ast(node->data.list.tail, indent + 1);
            break;
        default:
            printf("UNKNOWN\n");
            break;
    }
}

#endif

扩展的 Yacc 文件:

/* parser.y */
%{
#include <stdio.h>
#include "ast.h"

int yylex();
void yyerror(const char* s);
%}

%union {
    int int_val;
    char* string_val;
    AstNode* ast_val;
}

%token <int_val> NUMBER
%token <string_val> IDENTIFIER
%token <string_val> TYPE
%token ASSIGN
%token SEMI
%token EOL

%type <ast_val> program stmt expr declaration assignment

%%

program: /* 空规则 */ { $$ = new_program(); }
       | program stmt { /* 添加语句到程序 */ }
       ;

stmt: declaration
    | assignment
    | expr SEMI { $$ = $1; }
    ;

declaration: TYPE IDENTIFIER SEMI {
    $$ = new_declaration($1, $2);
    printf("声明变量: %s %s\n", $1, $2);
}
;

assignment: IDENTIFIER ASSIGN expr SEMI {
    $$ = new_assign($1, $3);
    printf("赋值: %s = <expression>\n", $1);
}
;

expr: NUMBER { $$ = new_number($1); }
    | IDENTIFIER { $$ = new_identifier($1); }
    | expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
    | expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
    | expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
    | expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
    | '(' expr ')' { $$ = $2; }
    ;

%%

void yyerror(const char* s) {
    fprintf(stderr, "错误: %s\n", s);
}

int main() {
    printf("语句 AST 构建器\n");
    printf("支持变量声明和赋值,例如: int x; x = 5 + 3;\n");
    return yyparse();
}

扩展的 Lex 文件:

/* lexer.l */
%{
#include "y.tab.h"
#include <string.h>
%}

%%

[0-9]+      { yylval.int_val = atoi(yytext); return NUMBER; }
"int"       { yylval.string_val = strdup(yytext); return TYPE; }
"double"    { yylval.string_val = strdup(yytext); return TYPE; }
"float"     { yylval.string_val = strdup(yytext); return TYPE; }
"char"      { yylval.string_val = strdup(yytext); return TYPE; }
[a-zA-Z_][a-zA-Z0-9_]* { yylval.string_val = strdup(yytext); return IDENTIFIER; }
"="         { return ASSIGN; }
";"         { return SEMI; }
"+"         { return '+'; }
"-"         { return '-'; }
"*"         { return '*'; }
"/"         { return '/'; }
"("         { return '('; }
")"         { return ')'; }
"\n"        { return EOL; }
[ \t]       { /* 忽略空白字符 */ }
.

%%

int yywrap() {
    return 1;
}

代码优化建议

  1. 内存管理优化

    • 使用内存池分配 AST 节点,减少内存分配开销
    • 实现引用计数,避免内存泄漏
    • 对于字符串常量,使用字符串池避免重复分配
  2. AST 遍历优化

    • 使用迭代遍历代替递归遍历,避免栈溢出
    • 实现惰性遍历,按需处理节点
    • 缓存遍历结果,避免重复计算
  3. 节点结构优化

    • 使用标签联合(tagged union)或多态设计,提高代码可读性
    • 为节点添加行号和列号信息,便于错误定位
    • 考虑使用不可变节点,简化某些优化操作
  4. 构建过程优化

    • 使用工厂方法创建节点,统一节点初始化逻辑
    • 实现节点复用,减少内存使用
    • 考虑增量构建,支持大型程序的处理
  5. 类型系统集成

    • 为 AST 节点添加类型信息,便于类型检查
    • 实现类型推导,在构建过程中推断表达式类型
    • 集成符号表,支持变量查找和作用域管理

总结

本集我们深入学习了抽象语法树(AST)的构造和使用,包括:

  1. AST 的基本概念和优点
  2. AST 节点的设计和实现
  3. 在 Yacc 语义动作中构建 AST
  4. AST 的遍历技术(前序、后序遍历)
  5. 实际案例:表达式和完整语句的 AST 构建
  6. AST 相关的代码优化建议

AST 是编译器前端的核心数据结构,它为后续的语义分析、中间代码生成和优化提供了结构化的程序表示。通过掌握 AST 的构建和处理技术,你已经具备了构建编译器前端的重要基础。在后续的课程中,我们将学习如何基于 AST 进行语义分析和中间代码生成。

« 上一篇 构造抽象语法树(AST) 下一篇 » 语法错误处理策略