文件下载
机器学习数据集下载
DATA_HUB = {}
def download(url, folder='../data', sha1_hash=None):
"""Download a file to folder and return the local filepath.
Defined in :numref:`sec_utils`"""
if not url.startswith('http'):
# For back compatability
url, sha1_hash = DATA_HUB[url]
os.makedirs(folder, exist_ok=True)
fname = os.path.join(folder, url.split('/')[-1])
# Check if hit cache
if os.path.exists(fname) and sha1_hash:
sha1 = hashlib.sha1()
with open(fname, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
if sha1.hexdigest() == sha1_hash:
return fname
# Download
print(f'Downloading {fname} from {url}...')
r = requests.get(url, stream=True, verify=True)
with open(fname, 'wb') as f:
f.write(r.content)
return fname
文件解压
zip/tar/gz的文件解压
def extract(filename, folder=None):
"""Extract a zip/tar file into folder.
Defined in :numref:`sec_utils`"""
base_dir = os.path.dirname(filename)
_, ext = os.path.splitext(filename)
assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'
if ext == '.zip':
fp = zipfile.ZipFile(filename, 'r')
else:
fp = tarfile.open(filename, 'r')
if folder is None:
folder = base_dir
fp.extractall(folder)
文件读取
常用的数据读取
def read_data(file_path, **kwargs):
"""
功能:
1. 数据读取函数,读取csv/excel/txt/json/pkl/feather文件
2. 函数根据文件后缀名,使用相应的数据读取函数,如,csv: pd.read_csv()
3. txt文件sep默认为\t,其它设置与pd.read_读取方式一致
输入:
file_path: str, 文件路径,如,"./data.csv"
**kwargs: dict, 传递给pandas中,对应文件读取函数的关键字参数,如{header:None, sep:'\t'}
基本参数包括: encoding、sep等。
输出:
dt: DataFrame, 读取完成的数据
"""
# 输入合法性检查
assert isinstance(file_path, str), 'file_path must be str type'
# 获取文件后缀名
tail_name = file_path.split('.')[-1]
# 按照文件后缀名指定文件读取函数
if tail_name == 'csv':
func = pd.read_csv
if 'sep' in kwargs and kwargs['sep'] is None:
kwargs.pop('sep')
elif tail_name == 'txt':
func = pd.read_csv
if 'sep' in kwargs and kwargs['sep'] is None:
kwargs['sep'] = '\t'
elif tail_name in ['xlsx','xls']:
func = pd.read_excel
if 'encoding' in kwargs:
kwargs.pop('encoding')
if 'sep' in kwargs:
kwargs.pop('sep')
elif tail_name == 'json':
func = pd.read_json
if 'sep' in kwargs:
kwargs.pop('sep')
elif tail_name == 'feather' or tail_name == 'ftr':
func = pd.read_feather
if 'encoding' in kwargs:
kwargs.pop('encoding')
if 'sep' in kwargs:
kwargs.pop('sep')
elif tail_name == 'pkl':
with open(file_path, 'rb') as f:
dt = pickle.load(f)
return dt
else:
raise ValueError('The file_path must end with csv, txt, json, xlxs, xls, pkl')
dt = func(file_path, **kwargs)
return dt
文件保存
常用的文件保存
def save_data(dt, save_path, index=False, encoding='utf-8', sep='\t'):
"""
功能:
1. DataFrame格式数据保存,根据文件后缀名保存为相应格式的文件
2. 文件格式可保存为,[csv, txt, xlsx, xls, json, pkl]
输入:
dt: DataFrame, 数据集
save_path: str, 文件保存路径
index: bool, 该参数在保存文件格式为[csv, txt, xlsx, xls]时生效,默认为False
encoding: str, 该参数在保存文件格式为csv时生效,默认为'utf-8'
sep: str, 该参数在保存文件格式为txt时生效,默认为'\t'
"""
# 输入合法性检查
assert isinstance(dt, pd.DataFrame), 'df_data must be DataFrame type'
assert isinstance(save_path, str), 'save_path must be str type'
# 按照文件保存格式,保存数据,默认为csv
if save_path:
# 获取文件路径
file_dir, _ = os.path.split(save_path)
if file_dir != '' and not os.path.exists(file_dir):
os.makedirs(file_dir)
# 获取文件后缀名
tail_name = save_path.split('.')[-1]
if tail_name == 'csv':
dt.to_csv(save_path, index=index, encoding=encoding)
elif tail_name == 'txt':
dt.to_csv(save_path, index=index, sep=sep)
elif tail_name in ['xlsx', 'xls']:
dt.to_excel(save_path, index=index)
elif tail_name == 'json':
dt.to_json(save_path)
elif tail_name == 'pkl':
with open(save_path, 'wb') as f:
pickle.dump(dt, f)
else:
raise ValueError("save_path extension must be in ['csv', 'txt', 'json', 'xlsx', 'xls', 'pkl', 'feather', 'ftr']")