Newer
Older
zweic / sources / zweic / Code.scala
/*  zweic -- a compiler for zwei
 *
 *  Stephane Micheloud & LAMP
 *
 *  $Id$
 */

package zweic;

import java.io.PrintWriter;

class Code {
  import RISC._;

  //########################################################################
  // Constants

  /** Index of first general-use register. */
  private val RC_MIN =  1;
  /** Index of last general-use register. */
  private val RC_MAX = 29;

  //########################################################################
  // Instance variables

  /** Free registers */
  private val freeRegisters = new Array[Boolean](32);

  /** Already emitted instructions */
  private var code = List[Instruction]();

  /** Stack frame size */
  private var frameSize = 0;

  //########################################################################
  // Initialization

  Iterator.range(RC_MIN, RC_MAX+1) foreach { r => freeRegisters(r) = true };

  //########################################################################
  // Methods - register allocation

  /** Allocate a register. */
  def getRegister(): Int =
    Iterator.range(RC_MIN, RC_MAX+1) find { r => freeRegisters(r) } match {
      case Some(r) => getRegister(r)
      case None    => throw new Error("no more free registers")
    }

  /**
   * Allocate the given register, which has to be free (otherwise an
   * exception is thrown).
   */
  def getRegister(r: Int): Int = {
    if (! freeRegisters(r))
      throw new Error("attempt to allocate non-free register R" + r);
    freeRegisters(r) = false;
    r
  }

  /** Free a register. */
  def freeRegister(r: Int): Unit = {
    if (freeRegisters(r))
      throw new Error("attempt to free already-free register R" + r);
    if (r < RC_MIN || r > RC_MAX)
      throw new Error("attempt to free special register R" + r);
    freeRegisters(r) = true
  }

  /**
   * Return the set of used registers as an array. If element at
   * index I is true, this means that register R[i] is in use.
   */
  def usedRegisters(): Array[Boolean] = {
    val used = new Array[Boolean](freeRegisters.length);
    for (val i <- Iterator.range(0, used.length))
      used(i) = (i >= RC_MIN) && (i <= RC_MAX) && !freeRegisters(i);
    used
  }

  //########################################################################
  // Methods - address handling

  /** Return address of next instruction. */
  def pc(): Int = WORD_SIZE * code.length;

  //########################################################################
  // Methods - frame handling

  /** Increment frame size by given value. */
  def incFrameSize(bytes: Int) =
    frameSize = frameSize + bytes;

  /** Decrement frame size by given value. */
  def decFrameSize(bytes: Int): Unit = {
    frameSize = frameSize - bytes;
    assert(frameSize >= 0);
  }

  /** Return current frame size. */
  def getFrameSize(): Int = frameSize;

  //########################################################################
  // Methods - code emitting

  def emit(opcode: Int): Unit =
    emit(opcode, Integer.MIN_VALUE);

  def emit(opcode: Int, c: Int): Unit =
    emit(opcode, Integer.MIN_VALUE, c);

  def emit(opcode: Int, c: Int, s: String): Unit =
    emit(opcode, Integer.MIN_VALUE, Integer.MIN_VALUE, c, s);

  def emit(opcode: Int, l: Label): Unit =
    emit(opcode, Integer.MIN_VALUE, l);

  def emit(opcode: Int, a: Int, c: Int): Unit =
    emit(opcode, a, Integer.MIN_VALUE, c);

  def emit(opcode: Int, a: Int, l: Label): Unit =
    if (l.isAnchored())
      emit(opcode, a, (l.getAnchor() - pc()) / WORD_SIZE);
    else {
      l.recordInstructionToPatch(pc());
      emit(opcode, a, Integer.MIN_VALUE, Integer.MIN_VALUE);
    };

  private def syscall2string(c: Int): String = c match {
    case SYS_IO_RD_CHR => "SYS_IO_RD_CHR"
    case SYS_IO_RD_INT => "SYS_IO_RD_INT"
    case SYS_IO_WR_CHR => "SYS_IO_WR_CHR"
    case SYS_IO_WR_INT => "SYS_IO_WR_INT"
    case SYS_GC_INIT   => "SYS_GC_INIT"
    case SYS_GC_ALLOC  => "SYS_GC_ALLOC"
    case SYS_GET_TOTAL_MEM_SIZE => "SYS_GET_TOTAL_MEM_SIZE"
    case SYS_EXIT      => "SYS_EXIT"
    case _             => Integer.toString(c)
  };

  def emit(opcode: Int, a: Int, b: Int, c: Int): Unit = {
    val s = if (opcode == SYSCALL) syscall2string(c) else null;
    emit(opcode, a, b, c, s)
  }

  def emit(opcode: Int, a: Int, b: Int, c: Int, s: String): Unit =
    code = new Instruction(opcode, a, b, c, s) :: code;

  //########################################################################
  // Methods - program printing.

  /** Print program. */
  def write(out: PrintWriter): Unit = {
    val n = code.length;
    var i:Int = 0;
    while(i < n) {
      var label = Integer.toString(WORD_SIZE * i);
      while (label.length() < 4) label = '0' + label;
      out.print("/* " + label + " */ ");
      out.print(code(n - i - 1));
      out.println();
      i=i+1;
    }
    out.flush()
  }

    //########################################################################
    // Labels

    def getLabel(): Label = new Label();

    def anchorLabel(l: Label): Unit = l.setAnchor(pc());

    class Label {
      private var toPatch = List[Int]();
      private var pc = -1;

      def setAnchor(pc: Int): Unit = {
        this.pc = pc;
        toPatch foreach { instrPC =>
          code((pc - instrPC) / WORD_SIZE - 1).c = (pc - instrPC) / WORD_SIZE; }
      }
      def getAnchor(): Int = pc;
      def isAnchored(): Boolean = pc >= 0;
      def recordInstructionToPatch(pc: Int): Unit =
        toPatch = pc :: toPatch;
    }

  //########################################################################
  // Instructions

  private case class Instruction(opcode: Int, a: Int, b: Int, _c: Int, s: String) {
    var c = _c;
    override def toString(): String = {
      val buffer = new StringBuffer();
      buffer.append(mnemonics(opcode));
      if (a != Integer.MIN_VALUE) buffer.append(' ').append(a);
      if (b != Integer.MIN_VALUE) buffer.append(' ').append(b);
      if (c != Integer.MIN_VALUE) buffer.append(' ').append(c);
      if (s != null) {
        var i = buffer.length();
         while (i < 17) { buffer.append(' '); i = i + 1; }
         buffer.append("/* ").append(s).append(" */");
      }
      buffer.toString()
    }
  }

  //########################################################################
}