import json
import os
import tarfile
import tempfile
import warnings
from itertools import zip_longest
from pathlib import Path
import pandas as pd
import woodwork as ww
from woodwork.exceptions import OutdatedSchemaWarning, UpgradeSchemaWarning
from woodwork.s3_utils import get_transport_params, use_smartopen
from woodwork.serialize import FORMATS, SCHEMA_VERSION
from woodwork.utils import _is_s3, _is_url, import_or_raise
def _typing_information_to_woodwork_table(table_typing_info, validate, **kwargs):
"""Deserialize Woodwork table from table description.
Args:
table_typing_info (dict) : Woodwork typing information. Likely generated using :meth:`.serialize.typing_info_to_dict`
validate (bool): Whether parameter and data validation should occur during table initialization
kwargs (keywords): Additional keyword arguments to pass as keywords arguments to the underlying deserialization method.
Returns:
DataFrame: DataFrame with Woodwork typing information initialized.
"""
_check_schema_version(table_typing_info['schema_version'])
path = table_typing_info['path']
loading_info = table_typing_info['loading_info']
file = os.path.join(path, loading_info['location'])
load_format = loading_info['type']
assert load_format in FORMATS
kwargs = loading_info.get('params', {})
table_type = loading_info.get('table_type', 'pandas')
logical_types = {}
semantic_tags = {}
column_descriptions = {}
column_origins = {}
column_metadata = {}
use_standard_tags = {}
column_dtypes = {}
for col in table_typing_info['column_typing_info']:
col_name = col['name']
ltype_metadata = col['logical_type']
ltype = ww.type_system.str_to_logical_type(ltype_metadata['type'], params=ltype_metadata['parameters'])
tags = col['semantic_tags']
if 'index' in tags:
tags.remove('index')
elif 'time_index' in tags:
tags.remove('time_index')
logical_types[col_name] = ltype
semantic_tags[col_name] = tags
column_descriptions[col_name] = col['description']
column_origins[col_name] = col['origin']
column_metadata[col_name] = col['metadata']
use_standard_tags[col_name] = col['use_standard_tags']
col_type = col['physical_type']['type']
if col_type == 'category':
# Make sure categories are recreated properly
cat_values = col['physical_type']['cat_values']
cat_dtype = col['physical_type']['cat_dtype']
if table_type == 'pandas':
cat_object = pd.CategoricalDtype(pd.Index(cat_values, dtype=cat_dtype))
else:
cat_object = pd.CategoricalDtype(pd.Series(cat_values))
col_type = cat_object
elif table_type == 'koalas' and col_type == 'object':
col_type = 'string'
column_dtypes[col_name] = col_type
if table_type == 'dask':
DASK_ERR_MSG = (
'Cannot load Dask DataFrame - unable to import Dask.\n\n'
'Please install with pip or conda:\n\n'
'python -m pip install "woodwork[dask]"\n\n'
'conda install dask'
)
lib = import_or_raise('dask.dataframe', DASK_ERR_MSG)
elif table_type == 'koalas':
KOALAS_ERR_MSG = (
'Cannot load Koalas DataFrame - unable to import Koalas.\n\n'
'Please install with pip or conda:\n\n'
'python -m pip install "woodwork[koalas]"\n\n'
'conda install koalas\n\n'
'conda install pyspark'
)
lib = import_or_raise('databricks.koalas', KOALAS_ERR_MSG)
if 'compression' in kwargs.keys():
kwargs['compression'] = str(kwargs['compression'])
else:
lib = pd
if 'index' in kwargs.keys():
del kwargs['index']
if load_format == 'csv':
dataframe = lib.read_csv(
file,
dtype=column_dtypes,
**kwargs
)
elif load_format == 'pickle':
dataframe = pd.read_pickle(file, **kwargs)
elif load_format == 'parquet':
dataframe = lib.read_parquet(file, engine=kwargs['engine'])
elif load_format in ['arrow', 'feather']:
dataframe = lib.read_feather(file)
elif load_format == 'orc':
dataframe = lib.read_orc(file)
dataframe.ww.init(
name=table_typing_info.get('name'),
index=table_typing_info.get('index'),
time_index=table_typing_info.get('time_index'),
logical_types=logical_types,
semantic_tags=semantic_tags,
use_standard_tags=use_standard_tags,
table_metadata=table_typing_info.get('table_metadata'),
column_metadata=column_metadata,
column_descriptions=column_descriptions,
column_origins=column_origins,
validate=validate)
return dataframe
[docs]def read_woodwork_table(path, profile_name=None, validate=False, **kwargs):
"""Read Woodwork table from disk, S3 path, or URL.
Args:
path (str): Directory on disk, S3 path, or URL to read `woodwork_typing_info.json`.
profile_name (str, bool): The AWS profile specified to write to S3. Will default to None and search for AWS credentials.
Set to False to use an anonymous profile.
validate (bool, optional): Whether parameter and data validation should occur when initializing Woodwork dataframe
during deserialization. Defaults to False. Note: If serialized data was modified outside of Woodwork and you
are unsure of the validity of the data or typing information, `validate` should be set to True.
kwargs (keywords): Additional keyword arguments to pass as keyword arguments to the underlying deserialization method.
Returns:
DataFrame: DataFrame with Woodwork typing information initialized.
"""
if _is_url(path) or _is_s3(path):
with tempfile.TemporaryDirectory() as tmpdir:
file_name = Path(path).name
file_path = os.path.join(tmpdir, file_name)
transport_params = None
if _is_s3(path):
transport_params = get_transport_params(profile_name)
use_smartopen(file_path, path, transport_params)
with tarfile.open(str(file_path)) as tar:
tar.extractall(path=tmpdir)
table_typing_info = read_table_typing_information(tmpdir)
return _typing_information_to_woodwork_table(table_typing_info, validate, **kwargs)
else:
table_typing_info = read_table_typing_information(path)
return _typing_information_to_woodwork_table(table_typing_info, validate, **kwargs)
def _check_schema_version(saved_version_str):
"""Warns users if the schema used to save their data is greater than the latest
supported schema or if it is an outdated schema that is no longer supported."""
saved = saved_version_str.split('.')
current = SCHEMA_VERSION.split('.')
for c_num, s_num in zip_longest(current, saved, fillvalue=0):
if int(c_num) > int(s_num):
break
elif int(c_num) < int(s_num):
warnings.warn(UpgradeSchemaWarning().get_warning_message(saved_version_str, SCHEMA_VERSION),
UpgradeSchemaWarning)
break
# Check if saved has older major version.
if int(current[0]) > int(saved[0]):
warnings.warn(OutdatedSchemaWarning().get_warning_message(saved_version_str),
OutdatedSchemaWarning)