#!/usr/bin/python2.7
'''
" urpm-downloader for URPM-based linux
" A tool for downloading RPMs from URPM-based linux repositories.
"
" Copyright (C) 2011 ROSA Laboratory.
" Written by Anton Kirilenko <anton.kirilenko@rosalab.ru>
"
" PLATFORMS
" =========
"  Linux
"
" REQUIREMENTS
" ============
"  - python 2.7
"  - python-rpm 5.3
"  - urpmi 6.68
"
" This program is free software: you can redistribute it and/or modify
" it under the terms of the GNU General Public License or the GNU Lesser
" General Public License as published by the Free Software Foundation,
" either version 2 of the Licenses, or (at your option) any later version.
"
" This program is distributed in the hope that it will be useful,
" but WITHOUT ANY WARRANTY; without even the implied warranty of
" MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
" GNU General Public License for more details.
"
" You should have received a copy of the GNU General Public License
" and the GNU Lesser General Public License along with this program.
" If not, see <http://www.gnu.org/licenses/>.
'''

import argparse
import sys
import subprocess
import os
import re
from urllib import urlretrieve
import rpm 
from urllib2 import urlopen, HTTPError, URLError
import shutil

def vprint(text):
    '''Print the message only if verbose mode is on'''
    if(command_line_arguments.verbose):
        print(text)

def qprint(text):
    '''Print the message only if quiet mode is off'''
    if(not command_line_arguments.quiet):
        print(text)
        
        
def eprint(text, fatal=False, code=1):
    '''Print the message to stderr. Exit if fatal'''
    print >> sys.stderr, text
    if (fatal):
        exit(code)


def url_exists(url):
    '''Return True if the given url or local path exists. Otherwise, return False.'''
    if(url.startswith("file://") or url.startswith("/")):
        return os.path.isfile(url)
    
    #try to open file
    try:
        r = urlopen(url)
        return True
    except (HTTPError,URLError):
        return False


def run_self_test():
    '''Stupid piece of code, don't look at that =) '''
    def clear():
        if(os.path.exists("test_data")):
            shutil.rmtree("test_data")
        os.mkdir("test_data")
    
    def run_test(cmd):
        clear()
        saved_path = os.getcwd()
        os.chdir("test_data")
    
        sys.argv = [sys.argv[0], '-q'] + cmd
        parse_command_line()
        Main()  
        
        for (p1,p2,p3) in os.walk("."):
            if(len(p3) == 0):
                eprint("ERROR: nothing downloaded\ncmd = " + str(cmd))
            
        os.chdir(saved_path)
        
    run_test(['psi'])
    run_test(['-s', 'psi'])
    
      
def parse_command_line():
    ''' Parse command line, adjust some flags and warn in some cases'''
    global command_line_arguments
    arg_parser = argparse.ArgumentParser(description='A tool for downloading RPMs and SRPMs from URPM-based linux repositories',
                epilog="If none of the options -b, -s, -d turned on, it will be treated as -b") 
    
    arg_parser.add_argument('packages', action='store',nargs = '+', help="Package name(s) to download. It can contain not only package names, but (S)RPM files too. In this case package name extracted from this file will be used")
    arg_parser.add_argument('-u', '--urls', action='store_true', help="Instead of downloading files, list the URLs that would be processed")
    arg_parser.add_argument('-r', '--resolve', action='store_true', help="When downloading RPMs, resolve dependencies and also download the required packages, if they are not already installed")
    arg_parser.add_argument('-a', '--resolve-all', action='store_true', help="When downloading RPMs, resolve dependencies and also download the required packages, even if they are already installed")
    arg_parser.add_argument('-b', '--binary', action='store_true', help="Download binary RPMs")
    arg_parser.add_argument('-s', '--source', action='store_true', help="Download the source RPMs (SRPMs)")
    arg_parser.add_argument('-d', '--debug-info', action='store_true', help="Download debug RPMs")
    arg_parser.add_argument('-D', '--debug-info-install', action='store_true', help="Download debug RPMs and install")
    arg_parser.add_argument('--version', action='version', version=VERSION)
    arg_parser.add_argument('-v', '--verbose', action='store_true', help="Verbose (print additional info)")
    arg_parser.add_argument('-q', '--quiet', action='store_true', help="Quiet operation.")
    arg_parser.add_argument('--include-media', '--media', action='append',nargs = '+', help="Use only selected URPM media")
    arg_parser.add_argument('--exclude-media', action='append',nargs = '+', help="Do not use selected URPM media")
    arg_parser.add_argument('-x', '--exclude-packages', action='store',nargs = '+', help="Exclude package(s) by regex")
    arg_parser.add_argument('-i', '--ignore-errors', action='store_true', help="Try to continue when error occurs")
    arg_parser.add_argument('-o', '--overwrite', action='store_true', help="If the file already exists, download it again and overwrite the old one")
    arg_parser.add_argument('--all-alternatives', action='store_true', help="If package dependency can be satisfied by several packages, download all of them (by default, only the first one is downloaded)")
    arg_parser.add_argument('--all-versions', action='store_true', help="If different versions of package present in repository, process them all")
    #arg_parser.add_argument('--self-test', action='store_true', help="Test urpm-downloader end exit")
    arg_parser.add_argument('--dest-dir', action='store', help="Specify a destination directory for the download")
    
    command_line_arguments  = arg_parser.parse_args(sys.argv[1:])
    
    if(command_line_arguments.debug_info_install):
        command_line_arguments.debug_info = True 
        
    if(not command_line_arguments.debug_info and not command_line_arguments.source):
        command_line_arguments.binary = True
    
    if(command_line_arguments.resolve_all):
        command_line_arguments.resolve = True  
 
    if(command_line_arguments.exclude_packages is None):
        command_line_arguments.exclude_packages = []
    
    if(command_line_arguments.verbose and command_line_arguments.quiet):
        eprint("Use of --verbose with --quiet is senseless. Turning verbose mode off.")
        command_line_arguments.verbose = False

    if(command_line_arguments.resolve and command_line_arguments.source and command_line_arguments.urls):
        eprint("Note that resolving of SRPM dependencies is not possible until SRPM downloaded. So, it will be done despite --urls")
        
    if(command_line_arguments.dest_dir is not None):
        if(not os.path.exists(command_line_arguments.dest_dir) or not os.path.isdir(command_line_arguments.dest_dir)):
            os.mkdir(command_line_arguments.dest_dir)
    else:
        command_line_arguments.dest_dir = os.getcwd()
    

def get_command_output(command, fatal_fails=True):
    '''Execute command using subprocess.Popen and return its stdout output string. If 
    return code is not 0, print error message end exit'''
    vprint("Executing command: " + str(command))
    res = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output = res.communicate()
    if(res.returncode != 0 and fatal_fails):  # if not fatal_fails, do nothing. Caller have to deal with that himself
        eprint("Error calling command '" + " ".join(command) + "'")
        if(output[1] != None or output[0] != None):
            eprint("Error message: \n"+  ((output[0].strip() + "\n") if output[0]!=None else "") +
                  (output[1].strip() if output[1]!=None else "") )
        exit(1)
    return [output[0], output[1], res.returncode]


def parse_packages(pkgs_list, toresolve):
    ''' Takes a list of package names, some of that are alternative (like 'pkg1|pkg2') 
    and returns a list of package names without '|' '''
    output = []
    for pkg in pkgs_list:
        pkgs = pkg.split("|")
        if(len(pkgs)>1):
            vprint("Aternatives found: " + str(pkgs)) 
        if(command_line_arguments.all_alternatives):  # download all the alternatives
            for p in pkgs:
                output.append(p)
        else:  # download only the firsl package(first in alphabetical order)
            #check if one of the packages already ion the 'toresolve' list
            already_presents = False
            for p in pkgs:
                if(p in toresolve or p in output):
                    already_presents = True
                    break
            #if not - add the first package
            if(not already_presents):
                output.append(sorted(pkgs)[0])
            if(len(pkgs)>1):
                vprint("Selected: " + sorted(pkgs)[0])
    return output


def get_installed_packages():
    '''Makes 'installed_packages' be filled with installed packages data and look like
    {pkg_namei:[[version1,relese1], [version2,relese2], ...], ...} '''
    global installed_packages, installed_loaded
    if(installed_loaded):
        return
    installed_loaded = True
    installed_packages = {}

    ts = rpm.TransactionSet() 
    mi = ts.dbMatch() 
    for h in mi: 
        if(h['name'] not in installed_packages):
            installed_packages[h['name']] = []
        installed_packages[h['name']].append( [h['version'], h['release']] )
    vprint("The list of installed packages loaded")
    
def check_what_to_skip(package_names):
    ''' Get the list of package names and return a list of packages from it, that don't have to be downloaded '''
    
    def should_be_excluded(pkg):
        for line in command_line_arguments.exclude_packages:
            if(re.search(line, pkg) is not None):
                return True
        return False
      
    vprint("Check package to skip...")
    pkgs = package_names[:]
    to_skip = []
    # remove packages that have to be excluded dew to command line arguments
    for pkg in pkgs[:]:  
        if(should_be_excluded(pkg)):
            pkgs.remove(pkg)
            to_skip.append(pkg)
    
    if(command_line_arguments.resolve_all):
        return to_skip
    
    # Skip packages, that are already installed and have the same version
    get_installed_packages()
    
    #remove from to_skip candidates all the packages, which are not installed
    for pkg in pkgs[:]:
        if(pkg not in installed_packages):
            pkgs.remove(pkg)
        
    vprint("Retrieving possible downloading package versions...")
    res = get_command_output(cmd + ['--sources'] + pkgs)
    urls = res[0].strip().split('\n')
    vprint("A list of urls retrieved: " + str(urls))
    to_download = {}    
    rpms = {}
    for url in urls:  # collect data
        res = get_package_fields(url)
        if(res[0] not in rpms):
            rpms[res[0]] = []
        rpms[res[0]].append(res[1:4])
        
    
    if(not command_line_arguments.all_versions):
        vprint("Removing urls of the older versions...")
        for pkg in rpms.keys()[:]:  # filter
            L = rpms[pkg]
            while(len(L) > 1):
                if(rpm.evrCompare(L[0][0], L[1][0]) == 1):
                    del L[1]
                else:
                    del L[0]    

    # regroup data: to_download[pkg_name] = [ver-rel1, ver-rel2, ...]
    for pkg in rpms:
        if(pkg not in to_download):
            to_download[pkg] = []
        for item in rpms[pkg]:
            to_download[pkg].append(item[0])  # item[0] == version
               
    vprint("Checking what to skip...")
    
    for pkg in pkgs:
        installed_versions = ['-'.join(i) for i in installed_packages[pkg]]
        #print pkg, str(installed_versions)
        vprint("Name: " + pkg + "; Versions intalled: " + str(installed_versions) + "; Versions to download: " + str(to_download[pkg]))
        for ver in to_download[pkg][:]:
            if (ver in installed_versions):
                to_download[pkg].remove(ver)
                if(len(to_download[pkg]) == 0):
                    to_download.pop(pkg)
                    to_skip.append(pkg)
                vprint("Skipping " + pkg)
    return to_skip   
        
        
def resolve_packages(package_names):
    '''Returns a list of packages recursively resoled from given list'''
    global installed_packages
    
    resolved_packages = []
    def _resolve_packages(pkg_names):
        toresolve = []
        pkgs = parse_packages(pkg_names, toresolve)
        to_skip = check_what_to_skip(pkgs)
        for pkg in pkgs[:]:
            if(pkg in resolved_packages or (pkg in to_skip and (pkg not in package_names or resolve_source))):
                # don't resolve its dependencies.
                pkgs.remove(pkg)
            else:
                resolved_packages.append(pkg)
                toresolve.append(pkg)

        if (len(toresolve) == 0):
            return
        vprint ("resolving " + str(toresolve))
        names = get_command_output(['urpmq', "--requires-recursive"] + toresolve)[0].strip().split("\n")
        _resolve_packages(names)
    
    _resolve_packages(package_names)
    return resolved_packages

def get_srpm_names(pkgs):
    '''get a list of srpms names for every given package name. Returns a dictionary {pakage_name_1:[srpm_name_1, srpm_name_2,...], ...}'''
    srpms = {}
    cmd_tmp = cmd[:] + ['--sourcerpm'] + pkgs
    names = get_command_output(cmd_tmp)[0]

    for line in names.split("\n"):
        line = line.strip()
        if(line == ''):
            continue
        n = line.split(":")[0].strip()
        v = ":".join((line.split(":")[1:])).strip()
        if(n not in srpms):
            srpms[n] = []
        srpms[n].append(v)     
    return srpms


def get_srpm_url(url):
    if(url.startswith("file://") or url.startswith("/")):
        return url
    tmp = url.split("/")
    tmp[-4] = "SRPMS"
    del tmp[-3]
    return "/".join(tmp)


def list_srpm_urls():
    global cmd, srpm_urls_loaded
    try:
        srpm_urls_loaded
        return srpm_urls
    except:
        srpm_urls_loaded = True
        vprint("Loading list of SRPM URLs...")
        re_slash = re.compile("/")
        lines = get_command_output(cmd + ["--list-url"])[0].strip().split("\n")
        srpm_urls = []
        for line in lines:
            parts = line.split(" ")
            if(parts[-1].endswith("/")):
                parts[-1] = parts[-1][:-1]
            if(re_slash.search(parts[-1]) is not None):
                srpm_urls.append(get_srpm_url(parts[-1]))
        return srpm_urls    

def try_download(url):
    ''' Try to download file and return True if success, else return False '''
    path = os.path.join(command_line_arguments.dest_dir, os.path.basename(url))
    vprint("Trying to download file " + url)
    try:
        if(not os.path.exists(path) or command_line_arguments.overwrite):
            #(path, msg) = urlretrieve(url, path)
            if(url.startswith('/')): # local file
                shutil.copyfile(url, path)
            else:
                fd = urlopen(url)
                file = open(path, 'w')
                file.write(fd.read())
                file.close()
                fd.close()
            qprint ("Downloaded: " +  url)
        else:
            qprint ("File exists, skipping: " + url)
        return None
    except IOError, e:
        return e

def get_package_fields(rpmname):
        ''' Return [name, version, suffix, path(prefix)] for given rpm file or package name '''
        suffix = ""
        path = os.path.dirname(rpmname)
        if(path):
            path += "/"
        
        filename = False
        rpmname = os.path.basename(rpmname)
        if(rpmname.endswith(".rpm")):
            suffix = ".rpm"
            rpmname = rpmname[:-4]
            filename = True
        
        if(rpmname.endswith(".src")):
            suffix = ".src" + suffix
            rpmname = rpmname[:-4]
            name = rpmname.split("-")[:-2]
            version = rpmname.split("-")[-2:]
        else:
            re_version =  re.compile("(\.)?((alpha)|(cvs)|(svn)|(r))?\d+((mdv)|(mdk)|(mnb))")
            if(filename):
                parts = rpmname.split('.')
                suffix = "." + parts[-1] + suffix
                rpmname = '.'.join(parts[:-1])  # remove the architecture part
            sections = rpmname.split("-")
            if(re_version.search(sections[-1]) == None):
                name = sections[:-3]
                version = sections[-3:-1]
                suffix = "-" + sections[-1] + suffix
            else:
                name = sections[:-2]
                version = sections[-2:]
        return ["-".join(name), "-".join(version), suffix, path]


#url = 'ftp://ftp.sunet.se/pub/Linux/distributions/mandrakelinux/official/2011/x86_64/media/contrib/release/lib64oil0.3_0-0.3.17-2mdv2011.0.x86_64.rpm'
#url = 'ftp://ftp.sunet.se/pub/Linux/distributions/mandrakelinux/official/2011/x86_64/media/contrib/release/liboil-tools-0.3.17-2mdv2011.0.x86_64.rpm'
#res = get_package_fields(url)
#print res
#exit()
    

def filter_versions(rpm_list):
    ''' When different versions of one package given, remove older version and returns only the newest one for every package. '''
    if(command_line_arguments.all_versions):
        return rpm_list
    
    rpms = {}
    vprint("Filtering input: " + str(rpm_list))
    for srpm in rpm_list:  # collect data
        res = get_package_fields(srpm)
        if(res[0] not in rpms):
            rpms[res[0]] = []
        rpms[res[0]].append(res[1:4])
        
    for pkg in rpms.keys()[:]:  # filter
        L = rpms[pkg]
        while(len(L)> 1):
            if(rpm.evrCompare(L[0][0], L[1][0]) == 1):
                del L[1]
            else:
                del L[0]
                     
    output = []
    for pkg in rpms:  # assembling package names
        output.append ( rpms[pkg][0][2] + pkg + "-" + rpms[pkg][0][0] + rpms[pkg][0][1])
    vprint ("Filtering output: " + str(output))
    return output

def download_srpm(package, srpms):
    '''download the srpm with a given name. Try to find it in the repository. Returns a list of downloaded file names'''
    vprint("downloading srpm(s) for package " + package)
    
    srpm_urls = list_srpm_urls()
    downloaded = []
    for srpm in filter_versions(srpms[package]): 
        count = 0
        for srpm_url in srpm_urls:  
            url = srpm_url + "/" + srpm
            if(command_line_arguments.urls):  # a correct url have to be printed!
                if(not url_exists(url)):
                    continue
                qprint(url)
                if(not command_line_arguments.resolve):
                    count += 1
                    break
                
            if(try_download(url) == None):
                count += 1
                downloaded.append(os.path.join(command_line_arguments.dest_dir, os.path.basename(url)))
                break 

        if(count == 0):
            eprint("Can not download SRPM " + srpm +" for package " + package)
            if(not command_line_arguments.ignore_errors):
                exit(2)

    return downloaded


def download_rpm(pkgs_to_download):
    global resolve_source, downloaded_debug_pkgs
    vprint("downloading packages " + ", ".join (pkgs_to_download))
    cmd_bin = cmd[:] + ['--sources'] + pkgs_to_download
    urls = get_command_output(cmd_bin)[0].strip().split("\n")

    urls = filter_versions(urls)
    
    if(command_line_arguments.binary or resolve_source):
        for url in urls:
            if(command_line_arguments.urls):
                qprint(url)
                continue
            
            res = try_download(url)
            if(res != None):
                eprint("Can not download RPM %s\n(%s)" % (url, res) )
                if(not command_line_arguments.ignore_errors):
                    exit(3)
    if(command_line_arguments.debug_info):    
        pkgs_to_download_debug = [p+"-debug" for p in pkgs_to_download[:]]
        qprint("Resolving debug-info packages...")
        cmd_debug = ['urpmq', '--media', 'debug', '--sources'] + pkgs_to_download_debug 
        res = get_command_output(cmd_debug, fatal_fails=False)
        
        text = "No package named "
        vprint("Removing missed debug packages from query...")
        removed = []
        if(res[2] != 0):   # return code is not 0
            
            for line in res[1].split("\n"):
                if line.startswith(text):
                    pkg = line[len(text):]
                    pkgs_to_download_debug.remove(pkg)
                    removed.append(pkg)
        
        vprint("Removed " + str(len(removed)) + " packages")
        vprint(removed)
        
        cmd_debug = ['urpmq', '--media', 'debug', '--sources'] + pkgs_to_download_debug 
        urls = get_command_output(cmd_debug)[0].strip().split("\n")  
        urls = filter_versions(urls) 
        for url in urls:
            if(command_line_arguments.urls):
                qprint(url)
                continue
            res = try_download(url)
            if(res != None):
                eprint("Can not download RPM " + os.path.basename(url) +"\n(%s)\nMaybe you need to update urpmi database (urpmi.update -a)?" % (res))
                if(not command_line_arguments.ignore_errors):
                    exit(2)
            else:
                path = os.path.join(command_line_arguments.dest_dir, os.path.basename(url))
                downloaded_debug_pkgs.append(path)
        
        if(command_line_arguments.debug_info_install):
            for pkg in downloaded_debug_pkgs:
                qprint('Installing ' + os.path.basename(str(pkg)) + "...")
                command = ['rpm', '-i', pkg]
                res = get_command_output(command,fatal_fails=False)
                if(res[2] != 0):  # rpm return code is not 0
                    qprint('Error while calling command "' + ' '.join(command) + '":\n' + res[1].strip())


def filter_debug_rpm_urls(input_urls):
    command = ['urpmq', '--media', 'debug', '--sources', pkg_name + "-debug"]
    res = get_command_output(command, fatal_fails=False)
    if(res[2] != 0):  # return code is not 0
        qprint("Debug package for '" + pkg_name + "' not found")
        return []
    names = res[0].strip().split("\n")
    if(command_line_arguments.all_versions):
        return names
    
    get_installed_packages()
    #print names
    #print installed_packages[pkg_name]
    urls = []
    for n in names:
        res = get_package_fields(os.path.basename(n))
        version = "-".join(res[1].split("-")[0:2] )
        if(pkg_name not in installed_packages):
            break
        for inst_pkg in installed_packages[pkg_name]:
            if(version == inst_pkg[0] + "-" + inst_pkg[1]):
                urls.append(n)
                break
    return urls


def Main():
    global cmd, resolve_source
    resolve_source = False  # variable that makes download_rpm to download resolved build-deps
    cmd = ['urpmq']
    if(command_line_arguments.include_media != None):
        media = ''
        for i in command_line_arguments.include_media:
            media = ",".join([media]+i)
        cmd = cmd + ['--media', media[1:]] 
        
    if(command_line_arguments.exclude_media != None):
        media = ''
        for i in command_line_arguments.exclude_media:
            media = ",".join([media]+i)
        cmd = cmd + ['--excludemedia', media[1:]]
        
    missing_files = []
    for pkg in command_line_arguments.packages[:]:
        if(pkg.endswith(".rpm")):
            if(not os.path.exists(pkg) or not os.path.isfile(pkg)):
                missing_files.append(pkg)
                continue
            name = get_rpm_tag_from_file("name", pkg)
            command_line_arguments.packages.remove(pkg)
            command_line_arguments.packages.append(name)
            
    if(missing_files):
        eprint("Files that end with '.rpm' seem to be local files, but the folowing files do not exist: " + ", ".join(missing_files))
        if(not command_line_arguments.ignore_errors):
            exit(4)
    
    if(command_line_arguments.source):
        download(command_line_arguments.packages, True)
    
    if(command_line_arguments.binary or (not command_line_arguments.source and command_line_arguments.debug_info)):
        download(command_line_arguments.packages, False)
        
    
def get_rpm_tag_from_file(tag, file):
    rpm_ts = rpm.TransactionSet() 
    fd = os.open(file, os.O_RDONLY) 
    rpm_hdr = rpm_ts.hdrFromFdno(fd) 
    os.close(fd) 
    return rpm_hdr.sprintf("%{" + tag + "}").strip()

        
def download(packages, src):
    global resolve_source
    pkgs_to_download = packages
    
    if(src):
        if(command_line_arguments.urls):
            qprint("Searching src.rpm file(s) in repository...")
        else:
            qprint("Downloading src.rpm file(s)...")
        srpms = get_srpm_names(packages)
        #for pkg in packages[:]:
            #if (pkg not in srpms:
                #eprint("Package " + pkg + " not fond!")
                #if(not command_line_arguments.ignore_errors):
                #    exit(1)
                #else:
                #    eprint ("Package is dequeued.")
                #packages.remove(pkg)
            
        srpms_list= []
        for package in packages:
            srpms_list = srpms_list + download_srpm(package, srpms)

        if(len(srpms_list) == 0):
            return
        
        if(command_line_arguments.resolve):  
            resolve_source = True
            pkgs = []
            lines = get_command_output(cmd + ['--requires-recursive'] + srpms_list)[0].strip().split("\n")
            pkgs = parse_packages(lines, [])
            download(pkgs, False)
            resolve_source = False
            
    else:
        pkgs_to_download = packages
        if(command_line_arguments.resolve):
            if(resolve_source):
                qprint("Resolving build dependencies...")
            else:
                qprint("Resolving dependencies...")
            pkgs_to_download = resolve_packages(packages)
            qprint ("Resolved " + str(len(pkgs_to_download)) + " packages")
        if(len(pkgs_to_download) == 0):
            qprint("Nothing to download")
            return
        download_rpm(pkgs_to_download)   


downloaded_debug_pkgs = []
installed_loaded=False
VERSION = "urpm-downloader 2.2.4"
if __name__ == '__main__':
    parse_command_line()
    Main()
