#!/usr/bin/env python
#
# Copyright (C) 2007 Brendan Cully <[email protected]>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2, incorporated herein by reference.

import os, sys

class ConfigError(Exception): pass

def usage():
   print """relink <source> <destination>
   Recreate hard links between source and destination repositories"""

class Config:
   def __init__(self, args):
       if len(args) != 3:
           raise ConfigError("wrong number of arguments")
       self.src = os.path.abspath(args[1])
       self.dst = os.path.abspath(args[2])
       for d in (self.src, self.dst):
           if not os.path.exists(os.path.join(d, '.hg')):
               raise ConfigError("%s: not a mercurial repository" % d)

def collect(src):
   seplen = len(os.path.sep)
   candidates = []
   for dirpath, dirnames, filenames in os.walk(src):
       relpath = dirpath[len(src) + seplen:]
       for filename in filenames:
           if not filename.endswith('.i'):
               continue
           st = os.stat(os.path.join(dirpath, filename))
           candidates.append((os.path.join(relpath, filename), st))

   return candidates

def prune(candidates, dst):
   def getdatafile(path):
       if not path.endswith('.i'):
           return None, None
       df = path[:-1] + 'd'
       try:
           st = os.stat(df)
       except OSError:
           return None, None
       return df, st

   def linkfilter(dst, st):
       try:
           ts = os.stat(dst)
       except OSError:
           # Destination doesn't have this file?
           return False
       if st.st_ino == ts.st_ino:
           return False
       if st.st_dev != ts.st_dev:
           # No point in continuing
           raise Exception('Source and destination are on different devices')
       if st.st_size != ts.st_size:
           # TODO: compare revlog heads
           return False
       return st

   targets = []
   for fn, st in candidates:
       tgt = os.path.join(dst, fn)
       ts = linkfilter(tgt, st)
       if not ts:
           continue
       targets.append((fn, ts.st_size))
       df, ts = getdatafile(tgt)
       if df:
           targets.append((fn[:-1] + 'd', ts.st_size))

   return targets

def relink(src, dst, files):
   def relinkfile(src, dst):
       bak = dst + '.bak'
       os.rename(dst, bak)
       try:
           os.link(src, dst)
       except OSError:
           os.rename(bak, dst)
           raise
       os.remove(bak)

   CHUNKLEN = 65536
   relinked = 0
   savedbytes = 0

   for f, sz in files:
       source = os.path.join(src, f)
       tgt = os.path.join(dst, f)
       sfp = file(source)
       dfp = file(tgt)
       sin = sfp.read(CHUNKLEN)
       while sin:
           din = dfp.read(CHUNKLEN)
           if sin != din:
               break
           sin = sfp.read(CHUNKLEN)
       if sin:
           continue
       try:
           relinkfile(source, tgt)
           print 'Relinked %s' % f
           relinked += 1
           savedbytes += sz
       except OSError, inst:
           print '%s: %s' % (tgt, str(inst))

   print 'Relinked %d files (%d bytes reclaimed)' % (relinked, savedbytes)

try:
   cfg = Config(sys.argv)
except ConfigError, inst:
   print str(inst)
   usage()
   sys.exit(1)

src = os.path.join(cfg.src, '.hg')
dst = os.path.join(cfg.dst, '.hg')
candidates = collect(src)
targets = prune(candidates, dst)
relink(src, dst, targets)