第89集:构造抽象语法树(AST)
核心知识点讲解
为什么需要 AST?
在编译器前端,抽象语法树(AST)是一种重要的数据结构,它具有以下优点:
- 结构清晰:AST 以树状结构表示程序的语法结构,比原始的语法分析树更简洁,去除了无关的语法细节
- 易于遍历:树结构便于进行各种遍历操作,如前序、中序、后序遍历
- 语义分析:为语义分析、类型检查等提供基础
- 代码生成:为后续的中间代码生成和优化提供结构化的程序表示
- 语言无关:AST 可以表示不同编程语言的程序结构
AST 节点设计
一个典型的 AST 节点设计包含以下信息:
- 节点类型:表示节点的语法类别,如表达式、语句、声明等
- 节点数据:存储节点的具体数据,如数字值、标识符名称等
- 子节点指针:指向子节点的指针,如二元操作的左右操作数
示例节点结构:
typedef enum {
AST_ADD,
AST_SUB,
AST_MUL,
AST_DIV,
AST_NUMBER,
AST_IDENTIFIER,
AST_ASSIGN,
AST_DECLARATION,
AST_IF,
AST_WHILE,
AST_PROGRAM,
AST_STMT_LIST
} AstType;
typedef struct ast_node {
AstType type;
union {
int number_val;
char* identifier_val;
struct {
struct ast_node* left;
struct ast_node* right;
} binary_op;
struct {
char* identifier;
struct ast_node* expression;
} assign;
struct {
char* type;
char* identifier;
} declaration;
struct {
struct ast_node* condition;
struct ast_node* then_body;
struct ast_node* else_body;
} if_stmt;
struct {
struct ast_node* condition;
struct ast_node* body;
} while_stmt;
struct {
struct ast_node* head;
struct ast_node* tail;
} list;
} data;
} AstNode;在动作中构建 AST
在 Yacc 中,我们可以在语义动作中构建 AST 节点:
expr: NUMBER { $$ = new_number($1); }
| IDENTIFIER { $$ = new_identifier($1); }
| expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
| expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
| expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
| expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
| '(' expr ')' { $$ = $2; }
;
stmt: IDENTIFIER ASSIGN expr SEMI { $$ = new_assign($1, $3); }
| TYPE IDENTIFIER SEMI { $$ = new_declaration($1, $2); }
| IF '(' expr ')' stmt ELSE stmt { $$ = new_if_stmt($3, $5, $7); }
| WHILE '(' expr ')' stmt { $$ = new_while_stmt($3, $5); }
;遍历 AST
AST 遍历是处理 AST 的核心操作,常见的遍历方式包括:
- 前序遍历:先访问节点,再访问子节点
- 中序遍历:先访问左子节点,再访问节点,最后访问右子节点
- 后序遍历:先访问子节点,再访问节点
- 广度优先遍历:按层次访问节点
前序遍历示例:
void preorder_traverse(AstNode* node) {
if (!node) return;
// 访问节点
visit_node(node);
// 递归访问子节点
switch (node->type) {
case AST_ADD:
case AST_SUB:
case AST_MUL:
case AST_DIV:
preorder_traverse(node->data.binary_op.left);
preorder_traverse(node->data.binary_op.right);
break;
case AST_ASSIGN:
preorder_traverse(node->data.assign.expression);
break;
case AST_IF:
preorder_traverse(node->data.if_stmt.condition);
preorder_traverse(node->data.if_stmt.then_body);
preorder_traverse(node->data.if_stmt.else_body);
break;
case AST_WHILE:
preorder_traverse(node->data.while_stmt.condition);
preorder_traverse(node->data.while_stmt.body);
break;
case AST_STMT_LIST:
preorder_traverse(node->data.list.head);
preorder_traverse(node->data.list.tail);
break;
default:
// 叶子节点,无需遍历子节点
break;
}
}后序遍历示例:
void postorder_traverse(AstNode* node) {
if (!node) return;
// 递归访问子节点
switch (node->type) {
case AST_ADD:
case AST_SUB:
case AST_MUL:
case AST_DIV:
postorder_traverse(node->data.binary_op.left);
postorder_traverse(node->data.binary_op.right);
break;
case AST_ASSIGN:
postorder_traverse(node->data.assign.expression);
break;
case AST_IF:
postorder_traverse(node->data.if_stmt.condition);
postorder_traverse(node->data.if_stmt.then_body);
postorder_traverse(node->data.if_stmt.else_body);
break;
case AST_WHILE:
postorder_traverse(node->data.while_stmt.condition);
postorder_traverse(node->data.while_stmt.body);
break;
case AST_STMT_LIST:
postorder_traverse(node->data.list.head);
postorder_traverse(node->data.list.tail);
break;
default:
// 叶子节点,无需遍历子节点
break;
}
// 访问节点
visit_node(node);
}实用案例分析
案例:简单表达式的 AST 构建
让我们创建一个简单的表达式解析器,构建并打印 AST:
AST 定义和操作:
/* ast.h */
#ifndef AST_H
#define AST_H
typedef enum {
AST_ADD,
AST_SUB,
AST_MUL,
AST_DIV,
AST_NUMBER,
AST_IDENTIFIER
} AstType;
typedef struct ast_node {
AstType type;
union {
int number_val;
char* identifier_val;
struct {
struct ast_node* left;
struct ast_node* right;
} binary_op;
} data;
} AstNode;
AstNode* new_ast_node(AstType type);
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right);
AstNode* new_number(int value);
AstNode* new_identifier(const char* name);
void free_ast(AstNode* node);
void print_ast(AstNode* node, int indent);
int evaluate_ast(AstNode* node);
#endif/* ast.c */
#include "ast.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
AstNode* new_ast_node(AstType type) {
AstNode* node = (AstNode*)malloc(sizeof(AstNode));
node->type = type;
return node;
}
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right) {
AstNode* node = new_ast_node(type);
node->data.binary_op.left = left;
node->data.binary_op.right = right;
return node;
}
AstNode* new_number(int value) {
AstNode* node = new_ast_node(AST_NUMBER);
node->data.number_val = value;
return node;
}
AstNode* new_identifier(const char* name) {
AstNode* node = new_ast_node(AST_IDENTIFIER);
node->data.identifier_val = strdup(name);
return node;
}
void free_ast(AstNode* node) {
if (!node) return;
switch (node->type) {
case AST_ADD:
case AST_SUB:
case AST_MUL:
case AST_DIV:
free_ast(node->data.binary_op.left);
free_ast(node->data.binary_op.right);
break;
case AST_IDENTIFIER:
free(node->data.identifier_val);
break;
default:
break;
}
free(node);
}
void print_ast(AstNode* node, int indent) {
if (!node) return;
for (int i = 0; i < indent; i++) {
printf(" ");
}
switch (node->type) {
case AST_ADD:
printf("ADD\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_SUB:
printf("SUB\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_MUL:
printf("MUL\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_DIV:
printf("DIV\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_NUMBER:
printf("NUMBER: %d\n", node->data.number_val);
break;
case AST_IDENTIFIER:
printf("IDENTIFIER: %s\n", node->data.identifier_val);
break;
default:
printf("UNKNOWN\n");
break;
}
}
int evaluate_ast(AstNode* node) {
if (!node) return 0;
switch (node->type) {
case AST_ADD:
return evaluate_ast(node->data.binary_op.left) + evaluate_ast(node->data.binary_op.right);
case AST_SUB:
return evaluate_ast(node->data.binary_op.left) - evaluate_ast(node->data.binary_op.right);
case AST_MUL:
return evaluate_ast(node->data.binary_op.left) * evaluate_ast(node->data.binary_op.right);
case AST_DIV:
return evaluate_ast(node->data.binary_op.left) / evaluate_ast(node->data.binary_op.right);
case AST_NUMBER:
return node->data.number_val;
case AST_IDENTIFIER:
// 简单处理,返回0
return 0;
default:
return 0;
}
}Yacc 文件:
/* parser.y */
%{
#include <stdio.h>
#include "ast.h"
int yylex();
void yyerror(const char* s);
%}
%union {
int int_val;
char* string_val;
AstNode* ast_val;
}
%token <int_val> NUMBER
%token <string_val> IDENTIFIER
%token SEMI
%token EOL
%type <ast_val> expr
%%
calc: /* 空规则 */
| calc expr EOL {
printf("AST:\n");
print_ast($2, 0);
printf("计算结果: %d\n", evaluate_ast($2));
free_ast($2);
}
| calc EOL
;
expr: NUMBER { $$ = new_number($1); }
| IDENTIFIER { $$ = new_identifier($1); }
| expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
| expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
| expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
| expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
| '(' expr ')' { $$ = $2; }
;
%%
void yyerror(const char* s) {
fprintf(stderr, "错误: %s\n", s);
}
int main() {
printf("表达式 AST 构建器\n");
printf("输入表达式,例如: 5 + 3 * (4 - 2)\n");
return yyparse();
}Lex 文件:
/* lexer.l */
%{
#include "y.tab.h"
#include <string.h>
%}
%%
[0-9]+ { yylval.int_val = atoi(yytext); return NUMBER; }
[a-zA-Z_][a-zA-Z0-9_]* { yylval.string_val = strdup(yytext); return IDENTIFIER; }
"+" { return '+'; }
"-" { return '-'; }
"*" { return '*'; }
"/" { return '/'; }
"(" { return '('; }
")" { return ')'; }
"\n" { return EOL; }
[ \t] { /* 忽略空白字符 */ }
.
%%
int yywrap() {
return 1;
}案例:完整语句的 AST 构建
让我们扩展上面的例子,支持变量声明和赋值语句:
扩展的 AST 定义:
/* ast.h */
#ifndef AST_H
#define AST_H
typedef enum {
AST_ADD,
AST_SUB,
AST_MUL,
AST_DIV,
AST_NUMBER,
AST_IDENTIFIER,
AST_ASSIGN,
AST_DECLARATION,
AST_PROGRAM,
AST_STMT_LIST
} AstType;
typedef struct ast_node {
AstType type;
union {
int number_val;
char* identifier_val;
struct {
struct ast_node* left;
struct ast_node* right;
} binary_op;
struct {
char* identifier;
struct ast_node* expression;
} assign;
struct {
char* type;
char* identifier;
} declaration;
struct {
struct ast_node* head;
struct ast_node* tail;
} list;
} data;
} AstNode;
AstNode* new_ast_node(AstType type);
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right);
AstNode* new_number(int value);
AstNode* new_identifier(const char* name);
AstNode* new_assign(const char* identifier, AstNode* expression);
AstNode* new_declaration(const char* type, const char* identifier);
AstNode* new_program();
AstNode* new_stmt_list(AstNode* head, AstNode* tail);
void add_stmt_to_list(AstNode* list, AstNode* stmt);
void free_ast(AstNode* node);
void print_ast(AstNode* node, int indent);
#endif/* ast.c */
#include "ast.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
AstNode* new_ast_node(AstType type) {
AstNode* node = (AstNode*)malloc(sizeof(AstNode));
node->type = type;
return node;
}
AstNode* new_binary_op(AstType type, AstNode* left, AstNode* right) {
AstNode* node = new_ast_node(type);
node->data.binary_op.left = left;
node->data.binary_op.right = right;
return node;
}
AstNode* new_number(int value) {
AstNode* node = new_ast_node(AST_NUMBER);
node->data.number_val = value;
return node;
}
AstNode* new_identifier(const char* name) {
AstNode* node = new_ast_node(AST_IDENTIFIER);
node->data.identifier_val = strdup(name);
return node;
}
AstNode* new_assign(const char* identifier, AstNode* expression) {
AstNode* node = new_ast_node(AST_ASSIGN);
node->data.assign.identifier = strdup(identifier);
node->data.assign.expression = expression;
return node;
}
AstNode* new_declaration(const char* type, const char* identifier) {
AstNode* node = new_ast_node(AST_DECLARATION);
node->data.declaration.type = strdup(type);
node->data.declaration.identifier = strdup(identifier);
return node;
}
AstNode* new_program() {
return new_ast_node(AST_PROGRAM);
}
AstNode* new_stmt_list(AstNode* head, AstNode* tail) {
AstNode* node = new_ast_node(AST_STMT_LIST);
node->data.list.head = head;
node->data.list.tail = tail;
return node;
}
void add_stmt_to_list(AstNode* list, AstNode* stmt) {
if (!list->data.list.head) {
list->data.list.head = stmt;
list->data.list.tail = stmt;
} else {
// 简单实现,实际应该使用链表
list->data.list.tail = stmt;
}
}
void free_ast(AstNode* node) {
if (!node) return;
switch (node->type) {
case AST_ADD:
case AST_SUB:
case AST_MUL:
case AST_DIV:
free_ast(node->data.binary_op.left);
free_ast(node->data.binary_op.right);
break;
case AST_IDENTIFIER:
free(node->data.identifier_val);
break;
case AST_ASSIGN:
free(node->data.assign.identifier);
free_ast(node->data.assign.expression);
break;
case AST_DECLARATION:
free(node->data.declaration.type);
free(node->data.declaration.identifier);
break;
case AST_STMT_LIST:
free_ast(node->data.list.head);
free_ast(node->data.list.tail);
break;
default:
break;
}
free(node);
}
void print_ast(AstNode* node, int indent) {
if (!node) return;
for (int i = 0; i < indent; i++) {
printf(" ");
}
switch (node->type) {
case AST_ADD:
printf("ADD\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_SUB:
printf("SUB\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_MUL:
printf("MUL\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_DIV:
printf("DIV\n");
print_ast(node->data.binary_op.left, indent + 1);
print_ast(node->data.binary_op.right, indent + 1);
break;
case AST_NUMBER:
printf("NUMBER: %d\n", node->data.number_val);
break;
case AST_IDENTIFIER:
printf("IDENTIFIER: %s\n", node->data.identifier_val);
break;
case AST_ASSIGN:
printf("ASSIGN: %s\n", node->data.assign.identifier);
print_ast(node->data.assign.expression, indent + 1);
break;
case AST_DECLARATION:
printf("DECLARATION: %s %s\n", node->data.declaration.type, node->data.declaration.identifier);
break;
case AST_PROGRAM:
printf("PROGRAM\n");
break;
case AST_STMT_LIST:
printf("STMT_LIST\n");
print_ast(node->data.list.head, indent + 1);
print_ast(node->data.list.tail, indent + 1);
break;
default:
printf("UNKNOWN\n");
break;
}
}
#endif扩展的 Yacc 文件:
/* parser.y */
%{
#include <stdio.h>
#include "ast.h"
int yylex();
void yyerror(const char* s);
%}
%union {
int int_val;
char* string_val;
AstNode* ast_val;
}
%token <int_val> NUMBER
%token <string_val> IDENTIFIER
%token <string_val> TYPE
%token ASSIGN
%token SEMI
%token EOL
%type <ast_val> program stmt expr declaration assignment
%%
program: /* 空规则 */ { $$ = new_program(); }
| program stmt { /* 添加语句到程序 */ }
;
stmt: declaration
| assignment
| expr SEMI { $$ = $1; }
;
declaration: TYPE IDENTIFIER SEMI {
$$ = new_declaration($1, $2);
printf("声明变量: %s %s\n", $1, $2);
}
;
assignment: IDENTIFIER ASSIGN expr SEMI {
$$ = new_assign($1, $3);
printf("赋值: %s = <expression>\n", $1);
}
;
expr: NUMBER { $$ = new_number($1); }
| IDENTIFIER { $$ = new_identifier($1); }
| expr '+' expr { $$ = new_binary_op(AST_ADD, $1, $3); }
| expr '-' expr { $$ = new_binary_op(AST_SUB, $1, $3); }
| expr '*' expr { $$ = new_binary_op(AST_MUL, $1, $3); }
| expr '/' expr { $$ = new_binary_op(AST_DIV, $1, $3); }
| '(' expr ')' { $$ = $2; }
;
%%
void yyerror(const char* s) {
fprintf(stderr, "错误: %s\n", s);
}
int main() {
printf("语句 AST 构建器\n");
printf("支持变量声明和赋值,例如: int x; x = 5 + 3;\n");
return yyparse();
}扩展的 Lex 文件:
/* lexer.l */
%{
#include "y.tab.h"
#include <string.h>
%}
%%
[0-9]+ { yylval.int_val = atoi(yytext); return NUMBER; }
"int" { yylval.string_val = strdup(yytext); return TYPE; }
"double" { yylval.string_val = strdup(yytext); return TYPE; }
"float" { yylval.string_val = strdup(yytext); return TYPE; }
"char" { yylval.string_val = strdup(yytext); return TYPE; }
[a-zA-Z_][a-zA-Z0-9_]* { yylval.string_val = strdup(yytext); return IDENTIFIER; }
"=" { return ASSIGN; }
";" { return SEMI; }
"+" { return '+'; }
"-" { return '-'; }
"*" { return '*'; }
"/" { return '/'; }
"(" { return '('; }
")" { return ')'; }
"\n" { return EOL; }
[ \t] { /* 忽略空白字符 */ }
.
%%
int yywrap() {
return 1;
}代码优化建议
内存管理优化:
- 使用内存池分配 AST 节点,减少内存分配开销
- 实现引用计数,避免内存泄漏
- 对于字符串常量,使用字符串池避免重复分配
AST 遍历优化:
- 使用迭代遍历代替递归遍历,避免栈溢出
- 实现惰性遍历,按需处理节点
- 缓存遍历结果,避免重复计算
节点结构优化:
- 使用标签联合(tagged union)或多态设计,提高代码可读性
- 为节点添加行号和列号信息,便于错误定位
- 考虑使用不可变节点,简化某些优化操作
构建过程优化:
- 使用工厂方法创建节点,统一节点初始化逻辑
- 实现节点复用,减少内存使用
- 考虑增量构建,支持大型程序的处理
类型系统集成:
- 为 AST 节点添加类型信息,便于类型检查
- 实现类型推导,在构建过程中推断表达式类型
- 集成符号表,支持变量查找和作用域管理
总结
本集我们深入学习了抽象语法树(AST)的构造和使用,包括:
- AST 的基本概念和优点
- AST 节点的设计和实现
- 在 Yacc 语义动作中构建 AST
- AST 的遍历技术(前序、后序遍历)
- 实际案例:表达式和完整语句的 AST 构建
- AST 相关的代码优化建议
AST 是编译器前端的核心数据结构,它为后续的语义分析、中间代码生成和优化提供了结构化的程序表示。通过掌握 AST 的构建和处理技术,你已经具备了构建编译器前端的重要基础。在后续的课程中,我们将学习如何基于 AST 进行语义分析和中间代码生成。