summaryrefslogtreecommitdiff
path: root/yaksh/file_utils.py
blob: 4b8640e8e711fabd53a9c6431c2b791d09175349 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import shutil
import os
import zipfile
import tempfile
import csv

def copy_files(file_paths):
    """ Copy Files to current directory, takes
    tuple with file paths and extract status"""

    files = []
    for src in file_paths:
        file_path, extract = src
        file_name = os.path.basename(file_path)
        files.append(file_name)
        shutil.copy(file_path, os.getcwd())
        if extract:
            z_files, path = extract_files(file_name, os.getcwd())
            for file in z_files:
                files.append(file)
    return files


def delete_files(files, file_path=None):
    """ Delete Files from directory """
    for file_name in files:
        if file_path:
            file = os.path.join(file_path, file_name)
        else:
            file = file_name
        if os.path.exists(file):
            if os.path.isfile(file):
                os.remove(file)
            else:
                shutil.rmtree(file)


def extract_files(zip_file, path=None):
    """ extract files from zip """
    zfiles = []
    if zipfile.is_zipfile(zip_file):
        zip_file = zipfile.ZipFile(zip_file, 'r')
        for z_file in zip_file.namelist():
            zfiles.append(z_file)
        if path:
            extract_path = path
        else:
            extract_path = tempfile.gettempdir()
        zip_file.extractall(extract_path)
        zip_file.close()
        return zfiles, extract_path


def is_csv(document):
    try:
        content = document.read(1024).decode('utf-8')
    except AttributeError:
        document.seek(0)
        content = document.read(1024)
        sniffer = csv.Sniffer()
        dialect = sniffer.sniff(content)
        document.seek(0)
    except (csv.Error, UnicodeDecodeError):
        return False
    return True


def headers_present(dict_reader, headers):
    fields = dict_reader.fieldnames
    header_fields = set()
    for field in fields:
        if field.strip() in headers.keys():
            headers[field.strip()] = field
            header_fields.add(field.strip())
    if  header_fields != set(headers.keys()):
        return False
    return True