python文件下载、解压及数据读取


文件下载

机器学习数据集下载

DATA_HUB = {}

def download(url, folder='../data', sha1_hash=None):
    """Download a file to folder and return the local filepath.

    Defined in :numref:`sec_utils`"""
    if not url.startswith('http'):
        # For back compatability
        url, sha1_hash = DATA_HUB[url]
    os.makedirs(folder, exist_ok=True)
    fname = os.path.join(folder, url.split('/')[-1])
    # Check if hit cache
    if os.path.exists(fname) and sha1_hash:
        sha1 = hashlib.sha1()
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == sha1_hash:
            return fname
    # Download
    print(f'Downloading {fname} from {url}...')
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname

文件解压

zip/tar/gz的文件解压

def extract(filename, folder=None):
    """Extract a zip/tar file into folder.

    Defined in :numref:`sec_utils`"""
    base_dir = os.path.dirname(filename)
    _, ext = os.path.splitext(filename)
    assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'
    if ext == '.zip':
        fp = zipfile.ZipFile(filename, 'r')
    else:
        fp = tarfile.open(filename, 'r')
    if folder is None:
        folder = base_dir
    fp.extractall(folder)

文件读取

常用的数据读取

def read_data(file_path, **kwargs):
    """
    功能: 
    1. 数据读取函数,读取csv/excel/txt/json/pkl/feather文件
    2. 函数根据文件后缀名,使用相应的数据读取函数,如,csv: pd.read_csv()
    3. txt文件sep默认为\t,其它设置与pd.read_读取方式一致
    
    输入:
    file_path: str, 文件路径,如,"./data.csv"
    **kwargs: dict, 传递给pandas中,对应文件读取函数的关键字参数,如{header:None, sep:'\t'}
              基本参数包括: encoding、sep等。
    
    输出:
    dt: DataFrame, 读取完成的数据
    """
 
    # 输入合法性检查
    assert isinstance(file_path, str), 'file_path must be str type'
    # 获取文件后缀名
    tail_name = file_path.split('.')[-1]
    # 按照文件后缀名指定文件读取函数
    if tail_name == 'csv':
        func = pd.read_csv
        if 'sep' in kwargs and kwargs['sep'] is None:
            kwargs.pop('sep')
    elif tail_name == 'txt':
        func = pd.read_csv
        if 'sep' in kwargs and kwargs['sep'] is None:
            kwargs['sep'] = '\t'
    elif tail_name in ['xlsx','xls']:
        func = pd.read_excel
        if 'encoding' in kwargs:
            kwargs.pop('encoding')
        if 'sep' in kwargs:
            kwargs.pop('sep')
    elif tail_name == 'json':
        func = pd.read_json
        if 'sep' in kwargs:
            kwargs.pop('sep')
    elif tail_name == 'feather' or tail_name == 'ftr':
        func = pd.read_feather
        if 'encoding' in kwargs:
            kwargs.pop('encoding')
        if 'sep' in kwargs:
            kwargs.pop('sep')
    elif tail_name == 'pkl':
        with open(file_path, 'rb') as f:
            dt = pickle.load(f)
        return dt
    else:
        raise ValueError('The file_path must end with csv, txt, json, xlxs, xls, pkl')

    dt = func(file_path, **kwargs)

    return dt

文件保存

常用的文件保存

def save_data(dt, save_path, index=False, encoding='utf-8', sep='\t'):
    """
    功能: 
    1. DataFrame格式数据保存,根据文件后缀名保存为相应格式的文件
    2. 文件格式可保存为,[csv, txt, xlsx, xls, json, pkl]
    
    输入:
    dt: DataFrame, 数据集
    save_path: str, 文件保存路径
    index: bool, 该参数在保存文件格式为[csv, txt, xlsx, xls]时生效,默认为False
    encoding: str, 该参数在保存文件格式为csv时生效,默认为'utf-8'
    sep: str, 该参数在保存文件格式为txt时生效,默认为'\t'
    """
    # 输入合法性检查
    assert isinstance(dt, pd.DataFrame), 'df_data must be DataFrame type'
    assert isinstance(save_path, str), 'save_path must be str type'

    # 按照文件保存格式,保存数据,默认为csv
    if save_path:
        # 获取文件路径
        file_dir, _ = os.path.split(save_path)
        if file_dir != '' and not os.path.exists(file_dir):
            os.makedirs(file_dir)

        # 获取文件后缀名
        tail_name = save_path.split('.')[-1]

        if tail_name == 'csv':
            dt.to_csv(save_path, index=index, encoding=encoding)
        elif tail_name == 'txt':
            dt.to_csv(save_path, index=index, sep=sep)
        elif tail_name in ['xlsx', 'xls']:
            dt.to_excel(save_path, index=index)
        elif tail_name == 'json':
            dt.to_json(save_path)
        elif tail_name == 'pkl':
            with open(save_path, 'wb') as f:
                pickle.dump(dt, f)
        else:
            raise ValueError("save_path extension must be in ['csv', 'txt', 'json', 'xlsx', 'xls', 'pkl', 'feather', 'ftr']")

文章作者: lilso
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 lilso !
  目录