// Malloc verification

if objtype != "mips" then
       error("malloc checker for MIPS\n");

defn go()
{
       new();

       xmalloc = malloc + 4;           // Delay slot filler may pull in first
       xfree = free + 4;               // instruction

       bpset(xmalloc);
       bpset(xfree);
       bpset(_exits+4);

       blocklist = {};                 // List of allocated blocks
       framelist = {};                 // Frames they were allocated from
       stacks = {};                    // The stacks
       nstack = 0;

       while 1 do {
               cont();
               pc = *PC;
               if pc == xmalloc then
                       gotmalloc();
               else
               if pc == xfree then
                       gotfree();
               else {
                       pstop(pid);
                       return {};
               }
       }
}

defn stkcmp(s1, s2)                     // return true if the stack traces are the same
{
       local f1;
       local f2;

       while s1 do {
               f1 = head s1;
               f2 = head s2;

               if f2 == {} || f1[0] != f2[0] || f1[1] != f2[1] then
                       return 0;

               s1 = tail s1;
               s2 = tail s2;
       }
       if s2 != {} then
               return 0;

       return 1;
}

defn stkindex(s)
{
       local slist;
       local n;

       slist = stacks;
       n = 0;
       while slist do {
               if stkcmp(head slist, s) then
                       return n;

               slist = tail slist;
               n = n+1;
       }

       stacks = append stacks, s;
       n = nstack;
       nstack = nstack+1;
       return n;
}

defn gotmalloc()
{
       local raddr;
       local stk;

       raddr = *R31;
       bpset(raddr);
       stk = strace(*PC, *SP);         // Save a copy of the stack up to the malloc
       cont();                         // Continue until malloc returns
       bpdel(raddr);

       blocklist = append blocklist, *R1;
       framelist = append framelist, stkindex(stk);
}

defn pentry(s)
{
       local sno;
       local blst;
       local gotone;
       local size;

       gotone = 0;
       blst = blocklist;
       flst = framelist;

       while blst do {
               sno = head flst;
               if sno == s then {
                       addr = fmt(head blst, 'X');
                       size = 1<<*(addr-12);
//                      print(addr, "of ", fmt(size, 'D'), "bytes\n");
                       gotone = gotone + size;
               }
               blst = tail blst;
               flst = tail flst;
       }
       return gotone;
}

defn leak()
{
       local n;
       local sz;

       n = 0;
       while n < nstack do {
               sz = pentry(n);
               if sz then {
                       print("Lost a total of ", fmt(sz, 'D'), "bytes from:\n");
                       mstack(stacks[n]);
               }
               n = n+1;
       }
}

defn scanmem(start, end)
{
       local n;

       start = fmt(start, 'X');
       while start < end do {
               n = match(*start, blocklist);
               if n >= 0 then {
                       blocklist = delete blocklist, n;
                       framelist = delete framelist, n;
               }
               start++;
       }
}

defn refs()                     // Delete blocks with references
{
       print("data...");
       scanmem(etext, edata);

       print("bss...");
       scanmem(edata, *bloc);

       print("stack...");
       scanmem(*SP, 0x80000000);

       print("\n");
       leak();
}

defn mstack(stk)
{
       while stk do {
               frame = head stk;
               print("\t", fmt(frame[0], 'a'), "() ");
               print(pcfile(frame[0]), ":", pcline(frame[0]));
               print("called from ", fmt(frame[1], 'a'), " ");
               pfl(frame[1]);
               stk = tail stk;
       }
}

defn gotfree()
{
       local n;
       local addr;

       addr = *R1;                     // We know the first parameter is register R1
       if addr == 0 then
               return {};

       n = match(addr, blocklist);
       if n < 0 then {
               print("free(", addr, ") called with bad pointer:\n");
               stk();
               return {};
       }
       blocklist = delete blocklist, n;
       framelist = delete framelist, n;
}

defn stopped(pid)
{
       return {};
}

print("/lib/acid/malloc");