Appearance
Phase 04: Better Errors
Compilerbook step 4. Keep the same accepted language as the tokenizer phase, but reject invalid input with a source location.
Source
armcc.h
h
#ifndef ARMCC_H
#define ARMCC_H
#include <stdbool.h>
#include <stddef.h>
#define PHASE_LEVEL 4
typedef enum {
TK_RESERVED,
TK_IDENT,
TK_NUM,
TK_EOF,
} TokenKind;
typedef struct Token Token;
struct Token {
TokenKind kind;
Token *next;
long val;
char *loc;
int len;
};
typedef struct Obj Obj;
struct Obj {
Obj *next;
Obj *param_next;
char *name;
int offset;
};
typedef enum {
ND_ADD,
ND_SUB,
ND_MUL,
ND_DIV,
ND_NEG,
ND_EQ,
ND_NE,
ND_LT,
ND_LE,
ND_ASSIGN,
ND_LVAR,
ND_NUM,
ND_RETURN,
ND_EXPR_STMT,
ND_BLOCK,
ND_IF,
ND_FOR,
ND_FUNCALL,
} NodeKind;
typedef struct Node Node;
struct Node {
NodeKind kind;
Node *next;
Node *lhs;
Node *rhs;
Node *body;
Node *cond;
Node *then;
Node *els;
Node *init;
Node *inc;
Node *args;
Obj *var;
char *funcname;
long val;
};
typedef struct Function Function;
struct Function {
Function *next;
char *name;
Obj *params;
Node *body;
Obj *locals;
int stack_size;
};
extern char *current_input;
extern Token *token;
void error(char *fmt, ...);
void error_at(char *loc, char *fmt, ...);
bool equal(Token *tok, char *op);
Token *skip(Token *tok, char *op);
char *strndup2(char *p, int len);
Token *tokenize(char *input);
Function *parse(Token *tok);
void codegen(Function *prog);
#endifmain.c
c
#include "armcc.h"
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
static char *read_file(char *path) {
FILE *fp = fopen(path, "r");
if (!fp) {
return NULL;
}
if (fseek(fp, 0, SEEK_END) == -1) {
error("cannot seek %s: %s", path, strerror(errno));
}
long size = ftell(fp);
if (size == -1) {
error("cannot tell size of %s: %s", path, strerror(errno));
}
rewind(fp);
char *buf = calloc((size_t)size + 2, 1);
fread(buf, 1, (size_t)size, fp);
fclose(fp);
if (size == 0 || buf[size - 1] != '\n') {
buf[size++] = '\n';
}
buf[size] = '\0';
return buf;
}
int main(int argc, char **argv) {
if (argc != 2) {
fprintf(stderr, "usage: %s <source-file-or-program>\n", argv[0]);
return 1;
}
char *file = read_file(argv[1]);
char *input = file ? file : argv[1];
token = tokenize(input);
Function *prog = parse(token);
codegen(prog);
free(file);
return 0;
}tokenize.c
c
#include "armcc.h"
#include <ctype.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
char *current_input;
Token *token;
void error(char *fmt, ...) {
va_list ap;
va_start(ap, fmt);
vfprintf(stderr, fmt, ap);
fprintf(stderr, "\n");
exit(1);
}
void error_at(char *loc, char *fmt, ...) {
char *line = loc;
while (current_input < line && line[-1] != '\n') {
line--;
}
char *end = loc;
while (*end && *end != '\n') {
end++;
}
int indent = (int)(loc - line);
fprintf(stderr, "%.*s\n", (int)(end - line), line);
fprintf(stderr, "%*s", indent, "");
fprintf(stderr, "^ ");
va_list ap;
va_start(ap, fmt);
vfprintf(stderr, fmt, ap);
fprintf(stderr, "\n");
exit(1);
}
static bool startswith(char *p, char *q) {
return strncmp(p, q, strlen(q)) == 0;
}
static bool is_ident1(char c) { return isalpha((unsigned char)c) || c == '_'; }
static bool is_ident2(char c) { return isalnum((unsigned char)c) || c == '_'; }
static bool is_keyword(Token *tok) {
#if PHASE_LEVEL < 11
(void)tok;
return false;
#else
static char *kw[] = {
"return",
#if PHASE_LEVEL >= 12
"if", "else", "while", "for",
#endif
};
for (int i = 0; i < (int)(sizeof(kw) / sizeof(*kw)); i++) {
if (tok->len == (int)strlen(kw[i]) && !strncmp(tok->loc, kw[i], tok->len)) {
return true;
}
}
return false;
#endif
}
static Token *new_token(TokenKind kind, char *start, char *end) {
Token *tok = calloc(1, sizeof(Token));
tok->kind = kind;
tok->loc = start;
tok->len = (int)(end - start);
return tok;
}
bool equal(Token *tok, char *op) {
return tok->kind == TK_RESERVED && strlen(op) == (size_t)tok->len &&
!strncmp(tok->loc, op, tok->len);
}
Token *skip(Token *tok, char *op) {
if (!equal(tok, op)) {
error_at(tok->loc, "expected '%s'", op);
}
return tok->next;
}
char *strndup2(char *p, int len) {
char *s = calloc((size_t)len + 1, 1);
strncpy(s, p, (size_t)len);
return s;
}
Token *tokenize(char *input) {
current_input = input;
Token head = {0};
Token *cur = &head;
char *p = input;
while (*p) {
if (isspace((unsigned char)*p)) {
p++;
continue;
}
if (startswith(p, "==") || startswith(p, "!=") || startswith(p, "<=") ||
startswith(p, ">=")) {
cur = cur->next = new_token(TK_RESERVED, p, p + 2);
p += 2;
continue;
}
if (strchr("+-*/(){}<>=;!,", *p)) {
cur = cur->next = new_token(TK_RESERVED, p, p + 1);
p++;
continue;
}
if (isdigit((unsigned char)*p)) {
cur = cur->next = new_token(TK_NUM, p, p);
char *q = p;
cur->val = strtol(p, &p, 10);
cur->len = (int)(p - q);
continue;
}
if (is_ident1(*p)) {
char *start = p++;
while (is_ident2(*p)) {
p++;
}
cur = cur->next = new_token(TK_IDENT, start, p);
if (is_keyword(cur)) {
cur->kind = TK_RESERVED;
}
continue;
}
error_at(p, "invalid token");
}
cur = cur->next = new_token(TK_EOF, p, p);
return head.next;
}parse.c
c
#include "armcc.h"
#include <stdlib.h>
#include <string.h>
static Function *current_fn;
static Node *expr(Token **rest, Token *tok);
static Node *stmt(Token **rest, Token *tok);
static Node *new_node(NodeKind kind) {
Node *node = calloc(1, sizeof(Node));
node->kind = kind;
return node;
}
static Node *new_binary(NodeKind kind, Node *lhs, Node *rhs) {
Node *node = new_node(kind);
node->lhs = lhs;
node->rhs = rhs;
return node;
}
static Node *new_unary(NodeKind kind, Node *expr) {
Node *node = new_node(kind);
node->lhs = expr;
return node;
}
static Node *new_num(long val) {
Node *node = new_node(ND_NUM);
node->val = val;
return node;
}
static bool consume(Token **rest, Token *tok, char *str) {
if (equal(tok, str)) {
*rest = tok->next;
return true;
}
*rest = tok;
return false;
}
static Obj *find_var(Token *tok) {
for (Obj *var = current_fn->locals; var; var = var->next) {
if (strlen(var->name) == (size_t)tok->len &&
!strncmp(tok->loc, var->name, tok->len)) {
return var;
}
}
return NULL;
}
static Obj *new_lvar(Token *tok) {
#if PHASE_LEVEL < 9
error_at(tok->loc, "local variables are not supported in this phase");
#endif
#if PHASE_LEVEL < 10
if (tok->len != 1) {
error_at(tok->loc,
"multi-letter local variables are not supported in this phase");
}
#endif
Obj *var = find_var(tok);
if (var) {
return var;
}
var = calloc(1, sizeof(Obj));
var->name = strndup2(tok->loc, tok->len);
var->next = current_fn->locals;
current_fn->locals = var;
return var;
}
static Node *primary(Token **rest, Token *tok) {
if (consume(&tok, tok, "(")) {
Node *node = expr(&tok, tok);
*rest = skip(tok, ")");
return node;
}
if (tok->kind == TK_IDENT) {
#if PHASE_LEVEL >= 14
if (equal(tok->next, "(")) {
Node *node = new_node(ND_FUNCALL);
node->funcname = strndup2(tok->loc, tok->len);
tok = tok->next->next;
Node head = {0};
Node *cur = &head;
if (!equal(tok, ")")) {
do {
cur = cur->next = expr(&tok, tok);
} while (consume(&tok, tok, ","));
}
node->args = head.next;
*rest = skip(tok, ")");
return node;
}
#else
if (equal(tok->next, "(")) {
error_at(tok->loc, "function calls are not supported in this phase");
}
#endif
Node *node = new_node(ND_LVAR);
node->var = new_lvar(tok);
*rest = tok->next;
return node;
}
if (tok->kind == TK_NUM) {
Node *node = new_num(tok->val);
*rest = tok->next;
return node;
}
error_at(tok->loc, "expected expression");
return NULL;
}
static Node *unary(Token **rest, Token *tok) {
#if PHASE_LEVEL < 6
return primary(rest, tok);
#else
if (consume(&tok, tok, "+")) {
return unary(rest, tok);
}
if (consume(&tok, tok, "-")) {
return new_unary(ND_NEG, unary(rest, tok));
}
return primary(rest, tok);
#endif
}
static Node *mul(Token **rest, Token *tok) {
Node *node = unary(&tok, tok);
for (;;) {
Token *start = tok;
#if PHASE_LEVEL >= 5
if (consume(&tok, tok, "*")) {
node = new_binary(ND_MUL, node, unary(&tok, tok));
continue;
}
if (consume(&tok, tok, "/")) {
node = new_binary(ND_DIV, node, unary(&tok, tok));
continue;
}
#endif
*rest = start;
return node;
}
}
static Node *add(Token **rest, Token *tok) {
Node *node = mul(&tok, tok);
for (;;) {
Token *start = tok;
#if PHASE_LEVEL >= 2
if (consume(&tok, tok, "+")) {
node = new_binary(ND_ADD, node, mul(&tok, tok));
continue;
}
if (consume(&tok, tok, "-")) {
node = new_binary(ND_SUB, node, mul(&tok, tok));
continue;
}
#endif
*rest = start;
return node;
}
}
static Node *relational(Token **rest, Token *tok) {
Node *node = add(&tok, tok);
for (;;) {
Token *start = tok;
#if PHASE_LEVEL >= 7
if (consume(&tok, tok, "<")) {
node = new_binary(ND_LT, node, add(&tok, tok));
continue;
}
if (consume(&tok, tok, "<=")) {
node = new_binary(ND_LE, node, add(&tok, tok));
continue;
}
if (consume(&tok, tok, ">")) {
node = new_binary(ND_LT, add(&tok, tok), node);
continue;
}
if (consume(&tok, tok, ">=")) {
node = new_binary(ND_LE, add(&tok, tok), node);
continue;
}
#endif
*rest = start;
return node;
}
}
static Node *equality(Token **rest, Token *tok) {
Node *node = relational(&tok, tok);
for (;;) {
Token *start = tok;
#if PHASE_LEVEL >= 7
if (consume(&tok, tok, "==")) {
node = new_binary(ND_EQ, node, relational(&tok, tok));
continue;
}
if (consume(&tok, tok, "!=")) {
node = new_binary(ND_NE, node, relational(&tok, tok));
continue;
}
#endif
*rest = start;
return node;
}
}
static Node *assign(Token **rest, Token *tok) {
Node *node = equality(&tok, tok);
#if PHASE_LEVEL >= 9
if (consume(&tok, tok, "=")) {
node = new_binary(ND_ASSIGN, node, assign(&tok, tok));
}
#endif
*rest = tok;
return node;
}
static Node *expr(Token **rest, Token *tok) { return assign(rest, tok); }
static Node *compound_stmt(Token **rest, Token *tok) {
Node *node = new_node(ND_BLOCK);
Node head = {0};
Node *cur = &head;
while (!equal(tok, "}")) {
cur = cur->next = stmt(&tok, tok);
}
node->body = head.next;
*rest = tok->next;
return node;
}
static Node *stmt(Token **rest, Token *tok) {
#if PHASE_LEVEL >= 11
if (equal(tok, "return")) {
Node *node = new_unary(ND_RETURN, expr(&tok, tok->next));
*rest = skip(tok, ";");
return node;
}
#endif
#if PHASE_LEVEL >= 12
if (equal(tok, "if")) {
Node *node = new_node(ND_IF);
tok = skip(tok->next, "(");
node->cond = expr(&tok, tok);
tok = skip(tok, ")");
node->then = stmt(&tok, tok);
if (equal(tok, "else")) {
node->els = stmt(&tok, tok->next);
}
*rest = tok;
return node;
}
if (equal(tok, "while")) {
Node *node = new_node(ND_FOR);
tok = skip(tok->next, "(");
node->cond = expr(&tok, tok);
tok = skip(tok, ")");
node->then = stmt(rest, tok);
return node;
}
if (equal(tok, "for")) {
Node *node = new_node(ND_FOR);
tok = skip(tok->next, "(");
if (!equal(tok, ";")) {
node->init = expr(&tok, tok);
}
tok = skip(tok, ";");
if (!equal(tok, ";")) {
node->cond = expr(&tok, tok);
}
tok = skip(tok, ";");
if (!equal(tok, ")")) {
node->inc = expr(&tok, tok);
}
tok = skip(tok, ")");
node->then = stmt(rest, tok);
return node;
}
#endif
#if PHASE_LEVEL >= 13
if (equal(tok, "{")) {
return compound_stmt(rest, tok->next);
}
#endif
#if PHASE_LEVEL >= 9
Node *node = new_unary(ND_EXPR_STMT, expr(&tok, tok));
*rest = skip(tok, ";");
return node;
#else
error_at(tok->loc, "statements are not supported in this phase");
return NULL;
#endif
}
static Obj *param(Token **rest, Token *tok) {
#if PHASE_LEVEL < 15
error_at(tok->loc, "function parameters are not supported in this phase");
#endif
if (tok->kind != TK_IDENT) {
error_at(tok->loc, "expected parameter name");
}
Obj *var = new_lvar(tok);
*rest = tok->next;
return var;
}
static Function *function(Token **rest, Token *tok) {
Function *fn = calloc(1, sizeof(Function));
current_fn = fn;
if (tok->kind != TK_IDENT) {
error_at(tok->loc, "expected function name");
}
fn->name = strndup2(tok->loc, tok->len);
tok = skip(tok->next, "(");
Obj head = {0};
Obj *cur = &head;
if (!equal(tok, ")")) {
#if PHASE_LEVEL < 15
error_at(tok->loc, "function parameters are not supported in this phase");
#endif
do {
cur = cur->param_next = param(&tok, tok);
} while (consume(&tok, tok, ","));
}
fn->params = head.param_next;
tok = skip(tok, ")");
tok = skip(tok, "{");
fn->body = compound_stmt(&tok, tok);
fn->locals = current_fn->locals;
int offset = 0;
for (Obj *var = fn->locals; var; var = var->next) {
offset += 8;
var->offset = offset;
}
fn->stack_size = (offset + 15) / 16 * 16;
*rest = tok;
return fn;
}
Function *parse(Token *tok) {
Function head = {0};
Function *cur = &head;
#if PHASE_LEVEL < 15
Function *fn = calloc(1, sizeof(Function));
current_fn = fn;
fn->name = "main";
if (PHASE_LEVEL < 9) {
fn->body = new_unary(ND_RETURN, expr(&tok, tok));
if (tok->kind != TK_EOF) {
error_at(tok->loc, "extra token");
}
} else {
Node *block = new_node(ND_BLOCK);
Node body = {0};
Node *body_cur = &body;
while (tok->kind != TK_EOF) {
body_cur = body_cur->next = stmt(&tok, tok);
}
block->body = body.next;
fn->body = block;
}
int offset = 0;
for (Obj *var = fn->locals; var; var = var->next) {
offset += 8;
var->offset = offset;
}
fn->stack_size = (offset + 15) / 16 * 16;
return fn;
#else
while (tok->kind != TK_EOF) {
cur = cur->next = function(&tok, tok);
}
return head.next;
#endif
}codegen.c
c
#include "armcc.h"
#include <stdio.h>
static int depth;
static int labelseq;
static char *argreg[] = {"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"};
static char *current_fn_name;
static void gen_expr(Node *node);
static void gen_stmt(Node *node);
static void push(void) {
printf(" sub sp, sp, #16\n");
printf(" str x0, [sp]\n");
depth++;
}
static void pop(char *reg) {
printf(" ldr %s, [sp]\n", reg);
printf(" add sp, sp, #16\n");
depth--;
}
static void gen_addr(Node *node) {
if (node->kind != ND_LVAR) {
error("not an lvalue");
}
printf(" sub x0, x29, #%d\n", node->var->offset);
}
static void gen_expr(Node *node) {
switch (node->kind) {
case ND_NUM:
printf(" mov x0, #%ld\n", node->val);
return;
case ND_NEG:
gen_expr(node->lhs);
printf(" neg x0, x0\n");
return;
case ND_LVAR:
gen_addr(node);
printf(" ldr x0, [x0]\n");
return;
case ND_ASSIGN:
gen_addr(node->lhs);
push();
gen_expr(node->rhs);
pop("x1");
printf(" str x0, [x1]\n");
return;
case ND_FUNCALL: {
int nargs = 0;
for (Node *arg = node->args; arg; arg = arg->next) {
gen_expr(arg);
push();
nargs++;
}
if (nargs > 8) {
error("function call with more than 8 arguments is not supported yet");
}
for (int i = nargs - 1; i >= 0; i--) {
pop(argreg[i]);
}
printf(" bl _%s\n", node->funcname);
return;
}
default:
break;
}
gen_expr(node->rhs);
push();
gen_expr(node->lhs);
pop("x1");
switch (node->kind) {
case ND_ADD:
printf(" add x0, x0, x1\n");
return;
case ND_SUB:
printf(" sub x0, x0, x1\n");
return;
case ND_MUL:
printf(" mul x0, x0, x1\n");
return;
case ND_DIV:
printf(" sdiv x0, x0, x1\n");
return;
case ND_EQ:
printf(" cmp x0, x1\n");
printf(" cset x0, eq\n");
return;
case ND_NE:
printf(" cmp x0, x1\n");
printf(" cset x0, ne\n");
return;
case ND_LT:
printf(" cmp x0, x1\n");
printf(" cset x0, lt\n");
return;
case ND_LE:
printf(" cmp x0, x1\n");
printf(" cset x0, le\n");
return;
default:
error("invalid expression");
}
}
static void gen_stmt(Node *node) {
switch (node->kind) {
case ND_RETURN:
gen_expr(node->lhs);
printf(" b .L.return.%s\n", current_fn_name);
return;
case ND_EXPR_STMT:
gen_expr(node->lhs);
return;
case ND_BLOCK:
for (Node *n = node->body; n; n = n->next) {
gen_stmt(n);
}
return;
case ND_IF: {
int c = labelseq++;
gen_expr(node->cond);
printf(" cmp x0, #0\n");
printf(" b.eq .L.else.%d\n", c);
gen_stmt(node->then);
printf(" b .L.end.%d\n", c);
printf(".L.else.%d:\n", c);
if (node->els) {
gen_stmt(node->els);
}
printf(".L.end.%d:\n", c);
return;
}
case ND_FOR: {
int c = labelseq++;
if (node->init) {
gen_expr(node->init);
}
printf(".L.begin.%d:\n", c);
if (node->cond) {
gen_expr(node->cond);
printf(" cmp x0, #0\n");
printf(" b.eq .L.end.%d\n", c);
}
gen_stmt(node->then);
if (node->inc) {
gen_expr(node->inc);
}
printf(" b .L.begin.%d\n", c);
printf(".L.end.%d:\n", c);
return;
}
default:
error("invalid statement");
}
}
static void assign_param_offsets(Function *fn) {
int i = 0;
for (Obj *var = fn->params; var; var = var->param_next) {
printf(" str %s, [x29, #-%d]\n", argreg[i++], var->offset);
}
}
void codegen(Function *prog) {
printf(".text\n");
for (Function *fn = prog; fn; fn = fn->next) {
current_fn_name = fn->name;
printf(".globl _%s\n", fn->name);
printf("_%s:\n", fn->name);
printf(" stp x29, x30, [sp, #-16]!\n");
printf(" mov x29, sp\n");
if (fn->stack_size) {
printf(" sub sp, sp, #%d\n", fn->stack_size);
}
assign_param_offsets(fn);
gen_stmt(fn->body);
printf(".L.return.%s:\n", fn->name);
printf(" mov sp, x29\n");
printf(" ldp x29, x30, [sp], #16\n");
printf(" ret\n");
}
}test.sh
sh
#!/bin/sh
set -eu
. ../test-lib.sh
try 21 '5+20-4'
reject '5+foo'
reject '1*2'
echo OK