#!/usr/bin/python
#
# Copyright (C) 2005  Alexandre Boeglin <alex@boeglin.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import sys

f = open(sys.argv[1])

call_instr_tuple = ('ACALL', 'LCALL')
return_instr_tuple = ('RET', 'RETI')
always_jump_instr_tuple = ('AJMP', 'LJMP', 'SJMP', 'JMP')
condition_jump_instr_tuple = ('JZ', 'JNZ', 'JC', 'JNC', 'JB', 'JNB', 'JBC', 'CJNE', 'DJNZ')
jump_instr_tuple = always_jump_instr_tuple + condition_jump_instr_tuple
branch_instr_tuple = call_instr_tuple + return_instr_tuple + jump_instr_tuple
start_word = 'CSEG'
end_word = 'END'
start_label = '__start__'
dptr_jump_label = '@A+DPTR'

call_type = 'C'
return_type = 'R'
always_jump_type = 'AJ'
condition_jump_type = 'CJ'
jump_type = 'J'
entrypoint_type = 'EP'

return_suffix = '_RET'

class Jump :
  def __init__(self, label='', type='', to_label='') :
    self.label = label
    self.type = type
    self.to_label = to_label
  def __repr__(self) :
    return '%s -(%s)-> %s' % (self.label, self.type, self.to_label)

label_type_dict = {}

entrypoint_dict = {}
for i in sys.argv[2:] :
  entrypoint_dict['%04X' % int(i,16)] = '%s%04X' % (entrypoint_type, int(i,16))

def l_clean(line) :
  pos = line.find(';') # remove comments
  if pos != -1 :
    line = line[:pos]
  line = line.rstrip() # remove spaces and \n at end
  if not len(line) :
    return None
  if line[0] != ' ' : # label
    label = line[:-1]
    label_type_dict[label] = jump_type
    return line[:-1]
  if line.split()[2] in branch_instr_tuple :
    return line
  return None

# Stage 1 : parse lines, find labels and branch instruction

line_list = f.readlines()

# removes start_word line ('CSEG AT <ADDR>h') and set its entrypoint as start_label
item_list = line_list[0].split()
if item_list[0] == start_word :
  entrypoint_dict[item_list[-1][:-1]] = start_label
  line_list.pop(0)

# remove end_word line ('END')
if line_list[-1].strip() == end_word :
  line_list.pop()

# add entrypoints labels and dummy '@A+DPTR'
dptr_jump_found = 0
entrypoint_address_list = entrypoint_dict.keys()
entrypoint_label_list = entrypoint_dict.values()
new_line_list = []
for line in line_list :
  item_list = line.split()
  if len(item_list) and item_list[0] in entrypoint_address_list :
    new_line_list.append(entrypoint_dict[item_list[0]] + ':')
  new_line_list.append(line)
  if len(item_list) and item_list[-1] == dptr_jump_label :
    dptr_jump_found = 1
line_list = new_line_list

if dptr_jump_found :
  line_list.append('%s:' % dptr_jump_label)

# l_clean parses lines, finds labels and branch instruction
line_list = [l_clean(line) for line in line_list if l_clean(line)]

# Mark entrypoints
for label in label_type_dict.keys() :
  if label in entrypoint_label_list :
    label_type_dict[label] = entrypoint_type

#print '\n'.join(line_list)
#print '\n'.join(label_type_dict.keys())
#sys.exit()

# Stage 2 : for each label, find all their calls and jumps, and associate them with the label's function

# find the type of all labels (they are all jump_type by default. see l_clean)
for line in line_list :
  item_list = line.split()
  if line[0] != ' ' :
    current_label = item_list[0]
    continue
  if item_list[2] in call_instr_tuple :
    to_label = item_list[-1]
    label_type_dict[to_label] = call_type
  elif item_list[2] in return_instr_tuple :
    label_type_dict['%s%s' % (current_label, return_suffix)] = return_type

#print '\n'.join(['%s : %s' % (k, v) for (k, v) in label_type_dict.items()])
#sys.exit()

branch_dict = {}
for label in label_type_dict.keys() :
  branch_dict[label] = []

# find all the branch instructions and build the branch_dict
always_jump_met = 1 # do not try to jump from before start of code
for line in line_list :
  item_list = line.split()
  if line[0] != ' ' :
    # if we met a always_jump_instr in the previous label, we do not link to this one as they are not related
    if not always_jump_met :
      branch_dict[current_label].append(Jump(label=current_label, type=always_jump_type, to_label=item_list[0]))
    current_label = item_list[0]
    always_jump_met = 0
    continue
  to_label = item_list[-1]
  if item_list[2] in call_instr_tuple :
    branch_dict[current_label].append(Jump(label=current_label, type=call_type, to_label=to_label))
  elif item_list[2] in always_jump_instr_tuple :
    branch_dict[current_label].append(Jump(label=current_label, type=always_jump_type, to_label=to_label))
    always_jump_met = 1
  elif item_list[2] in condition_jump_instr_tuple :
    branch_dict[current_label].append(Jump(label=current_label, type=condition_jump_type, to_label=to_label))
  elif item_list[2] in return_instr_tuple :
    branch_dict[current_label].append(Jump(label=current_label, type=return_type, to_label='%s%s' % (current_label, return_suffix)))
    always_jump_met = 1

#print '\n'.join(['%s : %s' % (k, repr(v)) for (k, v) in branch_dict.items()])
#sys.exit()

# Stage 3 : generate an output in the graphviz.org dot language


label_list = label_type_dict.keys()
output_list = ['digraph G{',
               '  node [fontname="Vera",fontsize=8,height=0.1,width=0.1]',
               '  edge [arrowsize=0.5]',
               '  graph [nodesep=0.1,ranksep=0.4]']

for label, type in label_type_dict.items() :
  if type == entrypoint_type :
    output_list.append('  %s [label="%s",color=blue];' % (label_list.index(label), label))
  elif type == call_type :
    output_list.append('  %s [label="%s",color=red];' % (label_list.index(label), label))
  elif type == jump_type :
    output_list.append('  %s [label="%s",shape=box];' % (label_list.index(label), label))
  elif type == return_type :
    output_list.append('  %s [label="%s",color=grey];' % (label_list.index(label), label))

call_node_number = len(label_list)
for label, jump_list in branch_dict.items() :
  printed_jump_list = []
  for jump in jump_list :
    if repr(jump) in printed_jump_list:
      continue
    printed_jump_list.append(repr(jump))
    if jump.type == call_type :
      output_list.append('  %s -> %s [style=dashed,color=red];' % (label_list.index(jump.label), (call_node_number)))
      output_list.append('  %s [label="%s",color=red];' % (call_node_number, jump.to_label))
      call_node_number += 1
    elif jump.type == always_jump_type :
      output_list.append('  %s -> %s;' % (label_list.index(jump.label), (label_list.index(jump.to_label))))
    elif jump.type == condition_jump_type :
      output_list.append('  %s -> %s [style=dashed];' % (label_list.index(jump.label), (label_list.index(jump.to_label))))
    elif jump.type == return_type :
      output_list.append('  %s -> %s [color=grey];' % (label_list.index(jump.label), (label_list.index(jump.to_label))))

output_list.append('}')
print '\n'.join(output_list)
