(defpackage #:compiler
  (:use #:cl))

(in-package #:compiler)

(defvar *label-counter* 0)

(defun genlabel (&optional (prefix "L"))
  (format nil "~A~D" prefix (incf *label-counter*)))

(defmacro format-inst (destination control-string &rest format-arguments)
  `(format ,destination "~C~A~%" #\Tab (format nil ,control-string ,@format-arguments)))

(defclass reference () ())

(defclass reference-constant (reference)
  ((%value :accessor ref-value :initarg :value)))

(defmethod print-object ((object reference-constant) stream)
  (print-unreadable-object (object stream :type t)
    (format stream "~D" (ref-value object))))

(defmethod dereference ((ref reference-constant))
  (format-inst t "LDA #~D" (ref-value ref)))

(defclass reference-variable (reference)
  ((%index :accessor ref-index :initarg :index)))

(defmethod print-object ((object reference-variable) stream)
  (print-unreadable-object (object stream :type t)
    (format stream "@~D" (ref-index object))))

(defmethod dereference ((ref reference-variable))
  (format-inst t "LDY #~D" (ref-index ref))
  (format-inst t "LDA VARVEC,Y"))


(defclass node ()
  ((%next :accessor next :accessor normal-next :initform nil)))

(defmethod generate-code :before ((node node))
  (format t ";; ~A~%" node))

(defmethod generate-code :after ((node node))
  (terpri))

(defclass node-call (node)
  ((%callee :accessor callee :initarg :callee)
   (%arguments :accessor arguments :initarg :arguments)))

(defmethod print-object ((object node-call) stream)
  (print-unreadable-object (object stream :type t :identity t)
    (format stream "~A~A" (callee object) (arguments object))))

(defmethod generate-code ((node node-call))
  (loop :for ref :in (arguments node)
        :for index :from 0
        :do (dereference ref)
        :do (format-inst t "STA ARGVEC+~D" index))
  (format-inst t "JSR ~A" (callee node)))

(defclass node-branch (node)
  ((%branch-next :accessor branch-next :initarg :branch-next)))

(defmethod generate-code ((node node-branch))
  (let ((else-label (genlabel "ELSE")))
    (format-inst t "LDA RESULT")
    (format-inst t "BNE ~A" else-label)
    ;; The THEN branch
    (generate-code (branch-next node))
    ;; The ELSE branch
    (format t "~%~A:~%" else-label)))

(defclass node-dotimes (node)
  ((%stop-ref :accessor stop-ref :initarg :stop-ref
              :documentation "A reference giving a value of how many times to run the loop.")
   (%loopee-node :accessor loopee-node :initarg :loopee-node)))

(defmethod generate-code ((node node-dotimes))
  (format-inst t "TXA")
  (format-inst t "PHA")

  (let ((loop-label (genlabel "LOOPBACK")))
    (dereference (stop-ref node))
    (format-inst t "TAX")
    (format t "~%~A:~%" loop-label)
    (generate-code (loopee-node node))
    (format-inst t "DEX")
    (format-inst t "BNE ~A" loop-label))

  (format-inst t "PLA")
  (format-inst t "TAX"))

(defmethod compile-starting-at ((node node))
  (generate-code node)
  (unless (null (next node))
    (compile-starting-at (next node))))

(defun make-call (callee args)
  (let ((arguments
          (loop :for (constp value) :in args
                :with index := -1
                :if constp
                  :collect (make-instance 'reference-constant :value value)
                :else
                  :collect (make-instance 'reference-variable :index (incf index)))))
    (make-instance 'node-call :callee callee
                              :arguments arguments)))