Skip to content

Phase 14: Function Calls

Compilerbook step 14. Add function calls. Function definitions are not part of this phase yet, so tests link the generated assembly with small helper functions compiled by Apple clang.

Source

armcc.h

h
#ifndef ARMCC_H
#define ARMCC_H

#include <stdbool.h>
#include <stddef.h>

#define PHASE_LEVEL 14

typedef enum {
  TK_RESERVED,
  TK_IDENT,
  TK_NUM,
  TK_EOF,
} TokenKind;

typedef struct Token Token;
struct Token {
  TokenKind kind;
  Token *next;
  long val;
  char *loc;
  int len;
};

typedef struct Obj Obj;
struct Obj {
  Obj *next;
  Obj *param_next;
  char *name;
  int offset;
};

typedef enum {
  ND_ADD,
  ND_SUB,
  ND_MUL,
  ND_DIV,
  ND_NEG,
  ND_EQ,
  ND_NE,
  ND_LT,
  ND_LE,
  ND_ASSIGN,
  ND_LVAR,
  ND_NUM,
  ND_RETURN,
  ND_EXPR_STMT,
  ND_BLOCK,
  ND_IF,
  ND_FOR,
  ND_FUNCALL,
} NodeKind;

typedef struct Node Node;
struct Node {
  NodeKind kind;
  Node *next;
  Node *lhs;
  Node *rhs;
  Node *body;
  Node *cond;
  Node *then;
  Node *els;
  Node *init;
  Node *inc;
  Node *args;
  Obj *var;
  char *funcname;
  long val;
};

typedef struct Function Function;
struct Function {
  Function *next;
  char *name;
  Obj *params;
  Node *body;
  Obj *locals;
  int stack_size;
};

extern char *current_input;
extern Token *token;

void error(char *fmt, ...);
void error_at(char *loc, char *fmt, ...);
bool equal(Token *tok, char *op);
Token *skip(Token *tok, char *op);
char *strndup2(char *p, int len);
Token *tokenize(char *input);
Function *parse(Token *tok);
void codegen(Function *prog);

#endif

main.c

c
#include "armcc.h"

#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static char *read_file(char *path) {
  FILE *fp = fopen(path, "r");
  if (!fp) {
    return NULL;
  }

  if (fseek(fp, 0, SEEK_END) == -1) {
    error("cannot seek %s: %s", path, strerror(errno));
  }
  long size = ftell(fp);
  if (size == -1) {
    error("cannot tell size of %s: %s", path, strerror(errno));
  }
  rewind(fp);

  char *buf = calloc((size_t)size + 2, 1);
  fread(buf, 1, (size_t)size, fp);
  fclose(fp);

  if (size == 0 || buf[size - 1] != '\n') {
    buf[size++] = '\n';
  }
  buf[size] = '\0';
  return buf;
}

int main(int argc, char **argv) {
  if (argc != 2) {
    fprintf(stderr, "usage: %s <source-file-or-program>\n", argv[0]);
    return 1;
  }

  char *file = read_file(argv[1]);
  char *input = file ? file : argv[1];
  token = tokenize(input);
  Function *prog = parse(token);
  codegen(prog);
  free(file);
  return 0;
}

tokenize.c

c
#include "armcc.h"

#include <ctype.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

char *current_input;
Token *token;

void error(char *fmt, ...) {
  va_list ap;
  va_start(ap, fmt);
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

void error_at(char *loc, char *fmt, ...) {
  char *line = loc;
  while (current_input < line && line[-1] != '\n') {
    line--;
  }

  char *end = loc;
  while (*end && *end != '\n') {
    end++;
  }

  int indent = (int)(loc - line);
  fprintf(stderr, "%.*s\n", (int)(end - line), line);
  fprintf(stderr, "%*s", indent, "");
  fprintf(stderr, "^ ");

  va_list ap;
  va_start(ap, fmt);
  vfprintf(stderr, fmt, ap);
  fprintf(stderr, "\n");
  exit(1);
}

static bool startswith(char *p, char *q) {
  return strncmp(p, q, strlen(q)) == 0;
}

static bool is_ident1(char c) { return isalpha((unsigned char)c) || c == '_'; }

static bool is_ident2(char c) { return isalnum((unsigned char)c) || c == '_'; }

static bool is_keyword(Token *tok) {
#if PHASE_LEVEL < 11
  (void)tok;
  return false;
#else
  static char *kw[] = {
      "return",
#if PHASE_LEVEL >= 12
      "if",     "else", "while", "for",
#endif
  };
  for (int i = 0; i < (int)(sizeof(kw) / sizeof(*kw)); i++) {
    if (tok->len == (int)strlen(kw[i]) && !strncmp(tok->loc, kw[i], tok->len)) {
      return true;
    }
  }
  return false;
#endif
}

static Token *new_token(TokenKind kind, char *start, char *end) {
  Token *tok = calloc(1, sizeof(Token));
  tok->kind = kind;
  tok->loc = start;
  tok->len = (int)(end - start);
  return tok;
}

bool equal(Token *tok, char *op) {
  return tok->kind == TK_RESERVED && strlen(op) == (size_t)tok->len &&
         !strncmp(tok->loc, op, tok->len);
}

Token *skip(Token *tok, char *op) {
  if (!equal(tok, op)) {
    error_at(tok->loc, "expected '%s'", op);
  }
  return tok->next;
}

char *strndup2(char *p, int len) {
  char *s = calloc((size_t)len + 1, 1);
  strncpy(s, p, (size_t)len);
  return s;
}

Token *tokenize(char *input) {
  current_input = input;

  Token head = {0};
  Token *cur = &head;
  char *p = input;

  while (*p) {
    if (isspace((unsigned char)*p)) {
      p++;
      continue;
    }

    if (startswith(p, "==") || startswith(p, "!=") || startswith(p, "<=") ||
        startswith(p, ">=")) {
      cur = cur->next = new_token(TK_RESERVED, p, p + 2);
      p += 2;
      continue;
    }

    if (strchr("+-*/(){}<>=;!,", *p)) {
      cur = cur->next = new_token(TK_RESERVED, p, p + 1);
      p++;
      continue;
    }

    if (isdigit((unsigned char)*p)) {
      cur = cur->next = new_token(TK_NUM, p, p);
      char *q = p;
      cur->val = strtol(p, &p, 10);
      cur->len = (int)(p - q);
      continue;
    }

    if (is_ident1(*p)) {
      char *start = p++;
      while (is_ident2(*p)) {
        p++;
      }
      cur = cur->next = new_token(TK_IDENT, start, p);
      if (is_keyword(cur)) {
        cur->kind = TK_RESERVED;
      }
      continue;
    }

    error_at(p, "invalid token");
  }

  cur = cur->next = new_token(TK_EOF, p, p);
  return head.next;
}

parse.c

c
#include "armcc.h"

#include <stdlib.h>
#include <string.h>

static Function *current_fn;

static Node *expr(Token **rest, Token *tok);
static Node *stmt(Token **rest, Token *tok);

static Node *new_node(NodeKind kind) {
  Node *node = calloc(1, sizeof(Node));
  node->kind = kind;
  return node;
}

static Node *new_binary(NodeKind kind, Node *lhs, Node *rhs) {
  Node *node = new_node(kind);
  node->lhs = lhs;
  node->rhs = rhs;
  return node;
}

static Node *new_unary(NodeKind kind, Node *expr) {
  Node *node = new_node(kind);
  node->lhs = expr;
  return node;
}

static Node *new_num(long val) {
  Node *node = new_node(ND_NUM);
  node->val = val;
  return node;
}

static bool consume(Token **rest, Token *tok, char *str) {
  if (equal(tok, str)) {
    *rest = tok->next;
    return true;
  }
  *rest = tok;
  return false;
}

static Obj *find_var(Token *tok) {
  for (Obj *var = current_fn->locals; var; var = var->next) {
    if (strlen(var->name) == (size_t)tok->len &&
        !strncmp(tok->loc, var->name, tok->len)) {
      return var;
    }
  }
  return NULL;
}

static Obj *new_lvar(Token *tok) {
#if PHASE_LEVEL < 9
  error_at(tok->loc, "local variables are not supported in this phase");
#endif
#if PHASE_LEVEL < 10
  if (tok->len != 1) {
    error_at(tok->loc,
             "multi-letter local variables are not supported in this phase");
  }
#endif
  Obj *var = find_var(tok);
  if (var) {
    return var;
  }

  var = calloc(1, sizeof(Obj));
  var->name = strndup2(tok->loc, tok->len);
  var->next = current_fn->locals;
  current_fn->locals = var;
  return var;
}

static Node *primary(Token **rest, Token *tok) {
  if (consume(&tok, tok, "(")) {
    Node *node = expr(&tok, tok);
    *rest = skip(tok, ")");
    return node;
  }

  if (tok->kind == TK_IDENT) {
#if PHASE_LEVEL >= 14
    if (equal(tok->next, "(")) {
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup2(tok->loc, tok->len);
      tok = tok->next->next;

      Node head = {0};
      Node *cur = &head;
      if (!equal(tok, ")")) {
        do {
          cur = cur->next = expr(&tok, tok);
        } while (consume(&tok, tok, ","));
      }
      node->args = head.next;
      *rest = skip(tok, ")");
      return node;
    }
#else
    if (equal(tok->next, "(")) {
      error_at(tok->loc, "function calls are not supported in this phase");
    }
#endif

    Node *node = new_node(ND_LVAR);
    node->var = new_lvar(tok);
    *rest = tok->next;
    return node;
  }

  if (tok->kind == TK_NUM) {
    Node *node = new_num(tok->val);
    *rest = tok->next;
    return node;
  }

  error_at(tok->loc, "expected expression");
  return NULL;
}

static Node *unary(Token **rest, Token *tok) {
#if PHASE_LEVEL < 6
  return primary(rest, tok);
#else
  if (consume(&tok, tok, "+")) {
    return unary(rest, tok);
  }
  if (consume(&tok, tok, "-")) {
    return new_unary(ND_NEG, unary(rest, tok));
  }
  return primary(rest, tok);
#endif
}

static Node *mul(Token **rest, Token *tok) {
  Node *node = unary(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 5
    if (consume(&tok, tok, "*")) {
      node = new_binary(ND_MUL, node, unary(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "/")) {
      node = new_binary(ND_DIV, node, unary(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *add(Token **rest, Token *tok) {
  Node *node = mul(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 2
    if (consume(&tok, tok, "+")) {
      node = new_binary(ND_ADD, node, mul(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "-")) {
      node = new_binary(ND_SUB, node, mul(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *relational(Token **rest, Token *tok) {
  Node *node = add(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 7
    if (consume(&tok, tok, "<")) {
      node = new_binary(ND_LT, node, add(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "<=")) {
      node = new_binary(ND_LE, node, add(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, ">")) {
      node = new_binary(ND_LT, add(&tok, tok), node);
      continue;
    }
    if (consume(&tok, tok, ">=")) {
      node = new_binary(ND_LE, add(&tok, tok), node);
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *equality(Token **rest, Token *tok) {
  Node *node = relational(&tok, tok);

  for (;;) {
    Token *start = tok;
#if PHASE_LEVEL >= 7
    if (consume(&tok, tok, "==")) {
      node = new_binary(ND_EQ, node, relational(&tok, tok));
      continue;
    }
    if (consume(&tok, tok, "!=")) {
      node = new_binary(ND_NE, node, relational(&tok, tok));
      continue;
    }
#endif
    *rest = start;
    return node;
  }
}

static Node *assign(Token **rest, Token *tok) {
  Node *node = equality(&tok, tok);
#if PHASE_LEVEL >= 9
  if (consume(&tok, tok, "=")) {
    node = new_binary(ND_ASSIGN, node, assign(&tok, tok));
  }
#endif
  *rest = tok;
  return node;
}

static Node *expr(Token **rest, Token *tok) { return assign(rest, tok); }

static Node *compound_stmt(Token **rest, Token *tok) {
  Node *node = new_node(ND_BLOCK);
  Node head = {0};
  Node *cur = &head;

  while (!equal(tok, "}")) {
    cur = cur->next = stmt(&tok, tok);
  }

  node->body = head.next;
  *rest = tok->next;
  return node;
}

static Node *stmt(Token **rest, Token *tok) {
#if PHASE_LEVEL >= 11
  if (equal(tok, "return")) {
    Node *node = new_unary(ND_RETURN, expr(&tok, tok->next));
    *rest = skip(tok, ";");
    return node;
  }
#endif

#if PHASE_LEVEL >= 12
  if (equal(tok, "if")) {
    Node *node = new_node(ND_IF);
    tok = skip(tok->next, "(");
    node->cond = expr(&tok, tok);
    tok = skip(tok, ")");
    node->then = stmt(&tok, tok);
    if (equal(tok, "else")) {
      node->els = stmt(&tok, tok->next);
    }
    *rest = tok;
    return node;
  }

  if (equal(tok, "while")) {
    Node *node = new_node(ND_FOR);
    tok = skip(tok->next, "(");
    node->cond = expr(&tok, tok);
    tok = skip(tok, ")");
    node->then = stmt(rest, tok);
    return node;
  }

  if (equal(tok, "for")) {
    Node *node = new_node(ND_FOR);
    tok = skip(tok->next, "(");
    if (!equal(tok, ";")) {
      node->init = expr(&tok, tok);
    }
    tok = skip(tok, ";");
    if (!equal(tok, ";")) {
      node->cond = expr(&tok, tok);
    }
    tok = skip(tok, ";");
    if (!equal(tok, ")")) {
      node->inc = expr(&tok, tok);
    }
    tok = skip(tok, ")");
    node->then = stmt(rest, tok);
    return node;
  }
#endif

#if PHASE_LEVEL >= 13
  if (equal(tok, "{")) {
    return compound_stmt(rest, tok->next);
  }
#endif

#if PHASE_LEVEL >= 9
  Node *node = new_unary(ND_EXPR_STMT, expr(&tok, tok));
  *rest = skip(tok, ";");
  return node;
#else
  error_at(tok->loc, "statements are not supported in this phase");
  return NULL;
#endif
}

static Obj *param(Token **rest, Token *tok) {
#if PHASE_LEVEL < 15
  error_at(tok->loc, "function parameters are not supported in this phase");
#endif
  if (tok->kind != TK_IDENT) {
    error_at(tok->loc, "expected parameter name");
  }
  Obj *var = new_lvar(tok);
  *rest = tok->next;
  return var;
}

static Function *function(Token **rest, Token *tok) {
  Function *fn = calloc(1, sizeof(Function));
  current_fn = fn;

  if (tok->kind != TK_IDENT) {
    error_at(tok->loc, "expected function name");
  }
  fn->name = strndup2(tok->loc, tok->len);
  tok = skip(tok->next, "(");

  Obj head = {0};
  Obj *cur = &head;
  if (!equal(tok, ")")) {
#if PHASE_LEVEL < 15
    error_at(tok->loc, "function parameters are not supported in this phase");
#endif
    do {
      cur = cur->param_next = param(&tok, tok);
    } while (consume(&tok, tok, ","));
  }
  fn->params = head.param_next;

  tok = skip(tok, ")");
  tok = skip(tok, "{");
  fn->body = compound_stmt(&tok, tok);
  fn->locals = current_fn->locals;

  int offset = 0;
  for (Obj *var = fn->locals; var; var = var->next) {
    offset += 8;
    var->offset = offset;
  }
  fn->stack_size = (offset + 15) / 16 * 16;

  *rest = tok;
  return fn;
}

Function *parse(Token *tok) {
  Function head = {0};
  Function *cur = &head;

#if PHASE_LEVEL < 15
  Function *fn = calloc(1, sizeof(Function));
  current_fn = fn;
  fn->name = "main";

  if (PHASE_LEVEL < 9) {
    fn->body = new_unary(ND_RETURN, expr(&tok, tok));
    if (tok->kind != TK_EOF) {
      error_at(tok->loc, "extra token");
    }
  } else {
    Node *block = new_node(ND_BLOCK);
    Node body = {0};
    Node *body_cur = &body;
    while (tok->kind != TK_EOF) {
      body_cur = body_cur->next = stmt(&tok, tok);
    }
    block->body = body.next;
    fn->body = block;
  }

  int offset = 0;
  for (Obj *var = fn->locals; var; var = var->next) {
    offset += 8;
    var->offset = offset;
  }
  fn->stack_size = (offset + 15) / 16 * 16;
  return fn;
#else
  while (tok->kind != TK_EOF) {
    cur = cur->next = function(&tok, tok);
  }
  return head.next;
#endif
}

codegen.c

c
#include "armcc.h"

#include <stdio.h>

static int depth;
static int labelseq;
static char *argreg[] = {"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"};
static char *current_fn_name;

static void gen_expr(Node *node);
static void gen_stmt(Node *node);

static void push(void) {
  printf("    sub sp, sp, #16\n");
  printf("    str x0, [sp]\n");
  depth++;
}

static void pop(char *reg) {
  printf("    ldr %s, [sp]\n", reg);
  printf("    add sp, sp, #16\n");
  depth--;
}

static void gen_addr(Node *node) {
  if (node->kind != ND_LVAR) {
    error("not an lvalue");
  }
  printf("    sub x0, x29, #%d\n", node->var->offset);
}

static void gen_expr(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("    mov x0, #%ld\n", node->val);
    return;
  case ND_NEG:
    gen_expr(node->lhs);
    printf("    neg x0, x0\n");
    return;
  case ND_LVAR:
    gen_addr(node);
    printf("    ldr x0, [x0]\n");
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    push();
    gen_expr(node->rhs);
    pop("x1");
    printf("    str x0, [x1]\n");
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen_expr(arg);
      push();
      nargs++;
    }
    if (nargs > 8) {
      error("function call with more than 8 arguments is not supported yet");
    }
    for (int i = nargs - 1; i >= 0; i--) {
      pop(argreg[i]);
    }
    printf("    bl _%s\n", node->funcname);
    return;
  }
  default:
    break;
  }

  gen_expr(node->rhs);
  push();
  gen_expr(node->lhs);
  pop("x1");

  switch (node->kind) {
  case ND_ADD:
    printf("    add x0, x0, x1\n");
    return;
  case ND_SUB:
    printf("    sub x0, x0, x1\n");
    return;
  case ND_MUL:
    printf("    mul x0, x0, x1\n");
    return;
  case ND_DIV:
    printf("    sdiv x0, x0, x1\n");
    return;
  case ND_EQ:
    printf("    cmp x0, x1\n");
    printf("    cset x0, eq\n");
    return;
  case ND_NE:
    printf("    cmp x0, x1\n");
    printf("    cset x0, ne\n");
    return;
  case ND_LT:
    printf("    cmp x0, x1\n");
    printf("    cset x0, lt\n");
    return;
  case ND_LE:
    printf("    cmp x0, x1\n");
    printf("    cset x0, le\n");
    return;
  default:
    error("invalid expression");
  }
}

static void gen_stmt(Node *node) {
  switch (node->kind) {
  case ND_RETURN:
    gen_expr(node->lhs);
    printf("    b .L.return.%s\n", current_fn_name);
    return;
  case ND_EXPR_STMT:
    gen_expr(node->lhs);
    return;
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next) {
      gen_stmt(n);
    }
    return;
  case ND_IF: {
    int c = labelseq++;
    gen_expr(node->cond);
    printf("    cmp x0, #0\n");
    printf("    b.eq .L.else.%d\n", c);
    gen_stmt(node->then);
    printf("    b .L.end.%d\n", c);
    printf(".L.else.%d:\n", c);
    if (node->els) {
      gen_stmt(node->els);
    }
    printf(".L.end.%d:\n", c);
    return;
  }
  case ND_FOR: {
    int c = labelseq++;
    if (node->init) {
      gen_expr(node->init);
    }
    printf(".L.begin.%d:\n", c);
    if (node->cond) {
      gen_expr(node->cond);
      printf("    cmp x0, #0\n");
      printf("    b.eq .L.end.%d\n", c);
    }
    gen_stmt(node->then);
    if (node->inc) {
      gen_expr(node->inc);
    }
    printf("    b .L.begin.%d\n", c);
    printf(".L.end.%d:\n", c);
    return;
  }
  default:
    error("invalid statement");
  }
}

static void assign_param_offsets(Function *fn) {
  int i = 0;
  for (Obj *var = fn->params; var; var = var->param_next) {
    printf("    str %s, [x29, #-%d]\n", argreg[i++], var->offset);
  }
}

void codegen(Function *prog) {
  printf(".text\n");

  for (Function *fn = prog; fn; fn = fn->next) {
    current_fn_name = fn->name;
    printf(".globl _%s\n", fn->name);
    printf("_%s:\n", fn->name);

    printf("    stp x29, x30, [sp, #-16]!\n");
    printf("    mov x29, sp\n");
    if (fn->stack_size) {
      printf("    sub sp, sp, #%d\n", fn->stack_size);
    }

    assign_param_offsets(fn);
    gen_stmt(fn->body);

    printf(".L.return.%s:\n", fn->name);
    printf("    mov sp, x29\n");
    printf("    ldp x29, x30, [sp], #16\n");
    printf("    ret\n");
  }
}

test.sh

sh
#!/bin/sh
set -eu

try() {
  expected="$1"
  input="$2"

  cat > helper.c <<EOF
int ret3(void) { return 3; }
int add2(int x, int y) { return x + y; }
EOF

  ./armcc "$input" > tmp.s
  clang tmp.s helper.c -o tmp
  set +e
  ./tmp
  actual="$?"
  set -e

  if [ "$actual" = "$expected" ]; then
    echo "$input => $actual"
  else
    echo "$input => expected $expected, got $actual"
    exit 1
  fi
}

reject() {
  input="$1"
  if ./armcc "$input" > tmp.s 2> tmp.err; then
    echo "$input => expected rejection"
    exit 1
  fi
  echo "$input => rejected"
}

try 3 'return ret3();'
try 7 'return add2(3,4);'
reject 'add(x,y){ return x+y; }'

echo OK