Skip to content

Phase 15: Function Definitions

Compilerbook step 15. Add function definitions and parameters.

Source

armcc.h

h
#ifndef ARMCC_H
#define ARMCC_H

#include <stdbool.h>
#include <stddef.h>

#define PHASE_LEVEL 15

typedef enum {
  TK_RESERVED,
  TK_IDENT,
  TK_NUM,
  TK_EOF,
} TokenKind;

typedef struct Token Token;
struct Token {
  TokenKind kind;
  Token *next;
  long val;
  char *loc;
  int len;
};

typedef struct Obj Obj;
struct Obj {
  Obj *next;
  Obj *param_next;
  char *name;
  int offset;
};

typedef enum {
  ND_ADD,
  ND_SUB,
  ND_MUL,
  ND_DIV,
  ND_NEG,
  ND_EQ,
  ND_NE,
  ND_LT,
  ND_LE,
  ND_ASSIGN,
  ND_LVAR,
  ND_NUM,
  ND_RETURN,
  ND_EXPR_STMT,
  ND_BLOCK,
  ND_IF,
  ND_FOR,
  ND_FUNCALL,
} NodeKind;

typedef struct Node Node;
struct Node {
  NodeKind kind;
  Node *next;
  Node *lhs;
  Node *rhs;
  Node *body;
  Node *cond;
  Node *then;
  Node *els;
  Node *init;
  Node *inc;
  Node *args;
  Obj *var;
  char *funcname;
  long val;
};

typedef struct Function Function;
struct Function {
  Function *next;
  char *name;
  Obj *params;
  Node *body;
  Obj *locals;
  int stack_size;
};

extern char *current_input;
extern Token *token;

void error(char *fmt, ...);
void error_at(char *loc, char *fmt, ...);
bool equal(Token *tok, char *op);
Token *skip(Token *tok, char *op);
char *strndup2(char *p, int len);
Token *tokenize(char *input);
Function *parse(Token *tok);
void codegen(Function *prog);

#endif

main.c

c
#include "armcc.h"

#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static char *read_file(char *path) {
  FILE *fp = fopen(path, "r");
  if (!fp) {
    return NULL;
  }

  if (fseek(fp, 0, SEEK_END) == -1) {
    error("cannot seek %s: %s", path, strerror(errno));
  }
  long size = ftell(fp);
  if (size == -1) {
    error("cannot tell size of %s: %s", path, strerror(errno));
  }
  rewind(fp);

  char *buf = calloc((size_t)size + 2, 1);
  fread(buf, 1, (size_t)size, fp);
  fclose(fp);

  if (size == 0 || buf[size - 1] != '\n') {
    buf[size++] = '\n';
  }
  buf[size] = '\0';
  return buf;
}

int main(int argc, char **argv) {
  if (argc != 2) {
    fprintf(stderr, "usage: %s <source-file-or-program>\n", argv[0]);
    return 1;
  }

  char *file = read_file(argv[1]);
  char *input = file ? file : argv[1];
  token = tokenize(input);
  Function *prog = parse(token);
  codegen(prog);
  free(file);
  return 0;
}

tokenize.c

c
#include "armcc.h"

#include <ctype.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

char *current_input;
Token *token;

void error(char *fmt, ...) {
  va_list ap;
  va_start(ap, fmt);
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

void error_at(char *loc, char *fmt, ...) {
  char *line = loc;
  while (current_input < line && line[-1] != '\n') {
    line--;
  }

  char *end = loc;
  while (*end && *end != '\n') {
    end++;
  }

  int indent = (int)(loc - line);
  fprintf(stderr, "%.*s\n", (int)(end - line), line);
  fprintf(stderr, "%*s", indent, "");
  fprintf(stderr, "^ ");

  va_list ap;
  va_start(ap, fmt);
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

static bool startswith(char *p, char *q) {
  return strncmp(p, q, strlen(q)) == 0;
}

static bool is_ident1(char c) { return isalpha((unsigned char)c) || c == '_'; }

static bool is_ident2(char c) { return isalnum((unsigned char)c) || c == '_'; }

static bool is_keyword(Token *tok) {
#if PHASE_LEVEL < 11
  (void)tok;
  return false;
#else
  static char *kw[] = {
      "return",
#if PHASE_LEVEL >= 12
      "if",     "else", "while", "for",
#endif
  };
  for (int i = 0; i < (int)(sizeof(kw) / sizeof(*kw)); i++) {
    if (tok->len == (int)strlen(kw[i]) && !strncmp(tok->loc, kw[i], tok->len)) {
      return true;
    }
  }
  return false;
#endif
}

static Token *new_token(TokenKind kind, char *start, char *end) {
  Token *tok = calloc(1, sizeof(Token));
  tok->kind = kind;
  tok->loc = start;
  tok->len = (int)(end - start);
  return tok;
}

bool equal(Token *tok, char *op) {
  return tok->kind == TK_RESERVED && strlen(op) == (size_t)tok->len &&
         !strncmp(tok->loc, op, tok->len);
}

Token *skip(Token *tok, char *op) {
  if (!equal(tok, op)) {
    error_at(tok->loc, "expected '%s'", op);
  }
  return tok->next;
}

char *strndup2(char *p, int len) {
  char *s = calloc((size_t)len + 1, 1);
  strncpy(s, p, (size_t)len);
  return s;
}

Token *tokenize(char *input) {
  current_input = input;

  Token head = {0};
  Token *cur = &head;
  char *p = input;

  while (*p) {
    if (isspace((unsigned char)*p)) {
      p++;
      continue;
    }

    if (startswith(p, "==") || startswith(p, "!=") || startswith(p, "<=") ||
        startswith(p, ">=")) {
      cur = cur->next = new_token(TK_RESERVED, p, p + 2);
      p += 2;
      continue;
    }

    if (strchr("+-*/(){}<>=;!,", *p)) {
      cur = cur->next = new_token(TK_RESERVED, p, p + 1);
      p++;
      continue;
    }

    if (isdigit((unsigned char)*p)) {
      cur = cur->next = new_token(TK_NUM, p, p);
      char *q = p;
      cur->val = strtol(p, &p, 10);
      cur->len = (int)(p - q);
      continue;
    }

    if (is_ident1(*p)) {
      char *start = p++;
      while (is_ident2(*p)) {
        p++;
      }
      cur = cur->next = new_token(TK_IDENT, start, p);
      if (is_keyword(cur)) {
        cur->kind = TK_RESERVED;
      }
      continue;
    }

    error_at(p, "invalid token");
  }

  cur = cur->next = new_token(TK_EOF, p, p);
  return head.next;
}

parse.c

c
#include "armcc.h"

#include <stdlib.h>
#include <string.h>

static Function *current_fn;

static Node *expr(Token **rest, Token *tok);
static Node *stmt(Token **rest, Token *tok);

static Node *new_node(NodeKind kind) {
  Node *node = calloc(1, sizeof(Node));
  node->kind = kind;
  return node;
}

static Node *new_binary(NodeKind kind, Node *lhs, Node *rhs) {
  Node *node = new_node(kind);
  node->lhs = lhs;
  node->rhs = rhs;
  return node;
}

static Node *new_unary(NodeKind kind, Node *expr) {
  Node *node = new_node(kind);
  node->lhs = expr;
  return node;
}

static Node *new_num(long val) {
  Node *node = new_node(ND_NUM);
  node->val = val;
  return node;
}

static bool consume(Token **rest, Token *tok, char *str) {
  if (equal(tok, str)) {
    *rest = tok->next;
    return true;
  }
  *rest = tok;
  return false;
}

static Obj *find_var(Token *tok) {
  for (Obj *var = current_fn->locals; var; var = var->next) {
    if (strlen(var->name) == (size_t)tok->len &&
        !strncmp(tok->loc, var->name, tok->len)) {
      return var;
    }
  }
  return NULL;
}

static Obj *new_lvar(Token *tok) {
#if PHASE_LEVEL < 9
  error_at(tok->loc, "local variables are not supported in this phase");
#endif
#if PHASE_LEVEL < 10
  if (tok->len != 1) {
    error_at(tok->loc,
             "multi-letter local variables are not supported in this phase");
  }
#endif
  Obj *var = find_var(tok);
  if (var) {
    return var;
  }

  var = calloc(1, sizeof(Obj));
  var->name = strndup2(tok->loc, tok->len);
  var->next = current_fn->locals;
  current_fn->locals = var;
  return var;
}

static Node *primary(Token **rest, Token *tok) {
  if (consume(&tok, tok, "(")) {
    Node *node = expr(&tok, tok);
    *rest = skip(tok, ")");
    return node;
  }

  if (tok->kind == TK_IDENT) {
#if PHASE_LEVEL >= 14
    if (equal(tok->next, "(")) {
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup2(tok->loc, tok->len);
      tok = tok->next->next;

      Node head = {0};
      Node *cur = &head;
      if (!equal(tok, ")")) {
        do {
          cur = cur->next = expr(&tok, tok);
        } while (consume(&tok, tok, ","));
      }
      node->args = head.next;
      *rest = skip(tok, ")");
      return node;
    }
#else
    if (equal(tok->next, "(")) {
      error_at(tok->loc, "function calls are not supported in this phase");
    }
#endif

    Node *node = new_node(ND_LVAR);
    node->var = new_lvar(tok);
    *rest = tok->next;
    return node;
  }

  if (tok->kind == TK_NUM) {
    Node *node = new_num(tok->val);
    *rest = tok->next;
    return node;
  }

  error_at(tok->loc, "expected expression");
  return NULL;
}

static Node *unary(Token **rest, Token *tok) {
#if PHASE_LEVEL < 6
  return primary(rest, tok);
#else
  if (consume(&tok, tok, "+")) {
    return unary(rest, tok);
  }
  if (consume(&tok, tok, "-")) {
    return new_unary(ND_NEG, unary(rest, tok));
  }
  return primary(rest, tok);
#endif
}

static Node *mul(Token **rest, Token *tok) {
  Node *node = unary(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 5
    if (consume(&tok, tok, "*")) {
      node = new_binary(ND_MUL, node, unary(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "/")) {
      node = new_binary(ND_DIV, node, unary(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *add(Token **rest, Token *tok) {
  Node *node = mul(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 2
    if (consume(&tok, tok, "+")) {
      node = new_binary(ND_ADD, node, mul(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "-")) {
      node = new_binary(ND_SUB, node, mul(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *relational(Token **rest, Token *tok) {
  Node *node = add(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 7
    if (consume(&tok, tok, "<")) {
      node = new_binary(ND_LT, node, add(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "<=")) {
      node = new_binary(ND_LE, node, add(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, ">")) {
      node = new_binary(ND_LT, add(&tok, tok), node);
      continue;
    }
    if (consume(&tok, tok, ">=")) {
      node = new_binary(ND_LE, add(&tok, tok), node);
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *equality(Token **rest, Token *tok) {
  Node *node = relational(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 7
    if (consume(&tok, tok, "==")) {
      node = new_binary(ND_EQ, node, relational(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "!=")) {
      node = new_binary(ND_NE, node, relational(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *assign(Token **rest, Token *tok) {
  Node *node = equality(&tok, tok);
#if PHASE_LEVEL >= 9
  if (consume(&tok, tok, "=")) {
    node = new_binary(ND_ASSIGN, node, assign(&tok, tok));
  }
#endif
  *rest = tok;
  return node;
}

static Node *expr(Token **rest, Token *tok) { return assign(rest, tok); }

static Node *compound_stmt(Token **rest, Token *tok) {
  Node *node = new_node(ND_BLOCK);
  Node head = {0};
  Node *cur = &head;

  while (!equal(tok, "}")) {
    cur = cur->next = stmt(&tok, tok);
  }

  node->body = head.next;
  *rest = tok->next;
  return node;
}

static Node *stmt(Token **rest, Token *tok) {
#if PHASE_LEVEL >= 11
  if (equal(tok, "return")) {
    Node *node = new_unary(ND_RETURN, expr(&tok, tok->next));
    *rest = skip(tok, ";");
    return node;
  }
#endif

#if PHASE_LEVEL >= 12
  if (equal(tok, "if")) {
    Node *node = new_node(ND_IF);
    tok = skip(tok->next, "(");
    node->cond = expr(&tok, tok);
    tok = skip(tok, ")");
    node->then = stmt(&tok, tok);
    if (equal(tok, "else")) {
      node->els = stmt(&tok, tok->next);
    }
    *rest = tok;
    return node;
  }

  if (equal(tok, "while")) {
    Node *node = new_node(ND_FOR);
    tok = skip(tok->next, "(");
    node->cond = expr(&tok, tok);
    tok = skip(tok, ")");
    node->then = stmt(rest, tok);
    return node;
  }

  if (equal(tok, "for")) {
    Node *node = new_node(ND_FOR);
    tok = skip(tok->next, "(");
    if (!equal(tok, ";")) {
      node->init = expr(&tok, tok);
    }
    tok = skip(tok, ";");
    if (!equal(tok, ";")) {
      node->cond = expr(&tok, tok);
    }
    tok = skip(tok, ";");
    if (!equal(tok, ")")) {
      node->inc = expr(&tok, tok);
    }
    tok = skip(tok, ")");
    node->then = stmt(rest, tok);
    return node;
  }
#endif

#if PHASE_LEVEL >= 13
  if (equal(tok, "{")) {
    return compound_stmt(rest, tok->next);
  }
#endif

#if PHASE_LEVEL >= 9
  Node *node = new_unary(ND_EXPR_STMT, expr(&tok, tok));
  *rest = skip(tok, ";");
  return node;
#else
  error_at(tok->loc, "statements are not supported in this phase");
  return NULL;
#endif
}

static Obj *param(Token **rest, Token *tok) {
#if PHASE_LEVEL < 15
  error_at(tok->loc, "function parameters are not supported in this phase");
#endif
  if (tok->kind != TK_IDENT) {
    error_at(tok->loc, "expected parameter name");
  }
  Obj *var = new_lvar(tok);
  *rest = tok->next;
  return var;
}

static Function *function(Token **rest, Token *tok) {
  Function *fn = calloc(1, sizeof(Function));
  current_fn = fn;

  if (tok->kind != TK_IDENT) {
    error_at(tok->loc, "expected function name");
  }
  fn->name = strndup2(tok->loc, tok->len);
  tok = skip(tok->next, "(");

  Obj head = {0};
  Obj *cur = &head;
  if (!equal(tok, ")")) {
#if PHASE_LEVEL < 15
    error_at(tok->loc, "function parameters are not supported in this phase");
#endif
    do {
      cur = cur->param_next = param(&tok, tok);
    } while (consume(&tok, tok, ","));
  }
  fn->params = head.param_next;

  tok = skip(tok, ")");
  tok = skip(tok, "{");
  fn->body = compound_stmt(&tok, tok);
  fn->locals = current_fn->locals;

  int offset = 0;
  for (Obj *var = fn->locals; var; var = var->next) {
    offset += 8;
    var->offset = offset;
  }
  fn->stack_size = (offset + 15) / 16 * 16;

  *rest = tok;
  return fn;
}

Function *parse(Token *tok) {
  Function head = {0};
  Function *cur = &head;

#if PHASE_LEVEL < 15
  Function *fn = calloc(1, sizeof(Function));
  current_fn = fn;
  fn->name = "main";

  if (PHASE_LEVEL < 9) {
    fn->body = new_unary(ND_RETURN, expr(&tok, tok));
    if (tok->kind != TK_EOF) {
      error_at(tok->loc, "extra token");
    }
  } else {
    Node *block = new_node(ND_BLOCK);
    Node body = {0};
    Node *body_cur = &body;
    while (tok->kind != TK_EOF) {
      body_cur = body_cur->next = stmt(&tok, tok);
    }
    block->body = body.next;
    fn->body = block;
  }

  int offset = 0;
  for (Obj *var = fn->locals; var; var = var->next) {
    offset += 8;
    var->offset = offset;
  }
  fn->stack_size = (offset + 15) / 16 * 16;
  return fn;
#else
  while (tok->kind != TK_EOF) {
    cur = cur->next = function(&tok, tok);
  }
  return head.next;
#endif
}

codegen.c

c
#include "armcc.h"

#include <stdio.h>

static int depth;
static int labelseq;
static char *argreg[] = {"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"};
static char *current_fn_name;

static void gen_expr(Node *node);
static void gen_stmt(Node *node);

static void push(void) {
  printf("    sub sp, sp, #16\n");
  printf("    str x0, [sp]\n");
  depth++;
}

static void pop(char *reg) {
  printf("    ldr %s, [sp]\n", reg);
  printf("    add sp, sp, #16\n");
  depth--;
}

static void gen_addr(Node *node) {
  if (node->kind != ND_LVAR) {
    error("not an lvalue");
  }
  printf("    sub x0, x29, #%d\n", node->var->offset);
}

static void gen_expr(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("    mov x0, #%ld\n", node->val);
    return;
  case ND_NEG:
    gen_expr(node->lhs);
    printf("    neg x0, x0\n");
    return;
  case ND_LVAR:
    gen_addr(node);
    printf("    ldr x0, [x0]\n");
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    push();
    gen_expr(node->rhs);
    pop("x1");
    printf("    str x0, [x1]\n");
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen_expr(arg);
      push();
      nargs++;
    }
    if (nargs > 8) {
      error("function call with more than 8 arguments is not supported yet");
    }
    for (int i = nargs - 1; i >= 0; i--) {
      pop(argreg[i]);
    }
    printf("    bl _%s\n", node->funcname);
    return;
  }
  default:
    break;
  }

  gen_expr(node->rhs);
  push();
  gen_expr(node->lhs);
  pop("x1");

  switch (node->kind) {
  case ND_ADD:
    printf("    add x0, x0, x1\n");
    return;
  case ND_SUB:
    printf("    sub x0, x0, x1\n");
    return;
  case ND_MUL:
    printf("    mul x0, x0, x1\n");
    return;
  case ND_DIV:
    printf("    sdiv x0, x0, x1\n");
    return;
  case ND_EQ:
    printf("    cmp x0, x1\n");
    printf("    cset x0, eq\n");
    return;
  case ND_NE:
    printf("    cmp x0, x1\n");
    printf("    cset x0, ne\n");
    return;
  case ND_LT:
    printf("    cmp x0, x1\n");
    printf("    cset x0, lt\n");
    return;
  case ND_LE:
    printf("    cmp x0, x1\n");
    printf("    cset x0, le\n");
    return;
  default:
    error("invalid expression");
  }
}

static void gen_stmt(Node *node) {
  switch (node->kind) {
  case ND_RETURN:
    gen_expr(node->lhs);
    printf("    b .L.return.%s\n", current_fn_name);
    return;
  case ND_EXPR_STMT:
    gen_expr(node->lhs);
    return;
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next) {
      gen_stmt(n);
    }
    return;
  case ND_IF: {
    int c = labelseq++;
    gen_expr(node->cond);
    printf("    cmp x0, #0\n");
    printf("    b.eq .L.else.%d\n", c);
    gen_stmt(node->then);
    printf("    b .L.end.%d\n", c);
    printf(".L.else.%d:\n", c);
    if (node->els) {
      gen_stmt(node->els);
    }
    printf(".L.end.%d:\n", c);
    return;
  }
  case ND_FOR: {
    int c = labelseq++;
    if (node->init) {
      gen_expr(node->init);
    }
    printf(".L.begin.%d:\n", c);
    if (node->cond) {
      gen_expr(node->cond);
      printf("    cmp x0, #0\n");
      printf("    b.eq .L.end.%d\n", c);
    }
    gen_stmt(node->then);
    if (node->inc) {
      gen_expr(node->inc);
    }
    printf("    b .L.begin.%d\n", c);
    printf(".L.end.%d:\n", c);
    return;
  }
  default:
    error("invalid statement");
  }
}

static void assign_param_offsets(Function *fn) {
  int i = 0;
  for (Obj *var = fn->params; var; var = var->param_next) {
    printf("    str %s, [x29, #-%d]\n", argreg[i++], var->offset);
  }
}

void codegen(Function *prog) {
  printf(".text\n");

  for (Function *fn = prog; fn; fn = fn->next) {
    current_fn_name = fn->name;
    printf(".globl _%s\n", fn->name);
    printf("_%s:\n", fn->name);

    printf("    stp x29, x30, [sp, #-16]!\n");
    printf("    mov x29, sp\n");
    if (fn->stack_size) {
      printf("    sub sp, sp, #%d\n", fn->stack_size);
    }

    assign_param_offsets(fn);
    gen_stmt(fn->body);

    printf(".L.return.%s:\n", fn->name);
    printf("    mov sp, x29\n");
    printf("    ldp x29, x30, [sp], #16\n");
    printf("    ret\n");
  }
}

test.sh

sh
#!/bin/sh
set -eu

try() {
  expected="$1"
  input="$2"

  ./armcc "$input" > tmp.s
  clang tmp.s -o tmp
  set +e
  ./tmp
  actual="$?"
  set -e

  if [ "$actual" = "$expected" ]; then
    echo "$input => $actual"
  else
    echo "$input => expected $expected, got $actual"
    exit 1
  fi
}

try 7 'add(x,y){ return x+y; } main(){ return add(3,4); }'
try 21 'triple(x){ return x*3; } main(){ return triple(7); }'
try 55 'sum_to(n){ sum=0; for(i=0; i<=n; i=i+1) sum=sum+i; return sum; } main(){ return sum_to(10); }'

echo OK