1
0
Fork 0
mirror of https://github.com/ytdl-org/youtube-dl.git synced 2024-05-12 08:09:31 +00:00

[utils] Add {expected_type} and Iterable support to traverse_obj()

This commit is contained in:
dirkf 2023-07-06 15:46:22 +01:00
parent b6dff4073d
commit f47fdb9564
2 changed files with 265 additions and 99 deletions

View file

@ -79,10 +79,12 @@ from youtube_dl.utils import (
rot47, rot47,
shell_quote, shell_quote,
smuggle_url, smuggle_url,
str_or_none,
str_to_int, str_to_int,
strip_jsonp, strip_jsonp,
strip_or_none, strip_or_none,
subtitles_filename, subtitles_filename,
T,
timeconvert, timeconvert,
traverse_obj, traverse_obj,
try_call, try_call,
@ -1566,6 +1568,7 @@ Line 1
self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam') self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
def test_traverse_obj(self): def test_traverse_obj(self):
str = compat_str
_TEST_DATA = { _TEST_DATA = {
100: 100, 100: 100,
1.2: 1.2, 1.2: 1.2,
@ -1598,8 +1601,8 @@ Line 1
# Test Ellipsis behavior # Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
(item for item in _TEST_DATA.values() if item is not None), (item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all values except `None`') msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values') msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
@ -1607,13 +1610,51 @@ Line 1
msg='nested `...` queries should work') msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
msg='`...` query result should be flattened') msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
msg='`...` should accept iterables')
# Test function as key # Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']], [_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), ('str',), self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be caught') msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: Ellipsis)
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a, b, c: Ellipsis)
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'],
msg='Function in set should be a transformation')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'],
msg='Type in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'],
msg='Transformation function should not raise')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
[item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, {str.upper, str})
# Test `slice` as a key
_SLICE_DATA = [0, 1, 2, 3, 4]
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
msg='slice on a dictionary should not throw')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
msg='slice key should apply slice to sequence')
# Test alternative paths # Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@ -1659,15 +1700,23 @@ Line 1
{0: ['https://www.example.com/1', 'https://www.example.com/0']}, {0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='triple nesting in dict path should be treated as branches') msg='triple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
msg='remove `None` values when dict key') msg='remove `None` values when top level dict key fails')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='do not remove `None` values if `default`') msg='use `default` if key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='do not remove empty values when dict key') msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
msg='do not remove empty values when dict key and a default') msg='use `default` when dict key and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []}, self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='if branch in dict key not successful, return `[]`') msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=Ellipsis), {0: {0: Ellipsis}},
msg='use nested `default` when nested dict key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {},
msg='remove key if branch in dict key not successful')
# Testing default parameter behavior # Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@ -1691,20 +1740,55 @@ Line 1
msg='if branched but not successful return `[]`, not `default`') msg='if branched but not successful return `[]`, not `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [], self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
msg='if branched but object is empty return `[]`, not `default`') msg='if branched but object is empty return `[]`, not `default`')
self.assertEqual(traverse_obj(None, Ellipsis), [],
msg='if branched but object is `None` return `[]`, not `default`')
self.assertEqual(traverse_obj({0: None}, (0, Ellipsis)), [],
msg='if branched but state is `None` return `[]`, not `default`')
branching_paths = [
('fail', Ellipsis),
(Ellipsis, 'fail'),
100 * ('fail',) + (Ellipsis,),
(Ellipsis,) + 100 * ('fail',),
]
for branching_path in branching_paths:
self.assertEqual(traverse_obj({}, branching_path), [],
msg='if branched but state is `None`, return `[]` (not `default`)')
self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
msg='if branching in last alternative and previous did match, return single value')
self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
msg='if branching in first alternative and non-branching path does match, return single value')
self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
msg='if branching in first alternative and non-branching path does not match, return `default`')
# Testing expected_type behavior # Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
msg='accept matching `expected_type` type') 'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
msg='reject non matching `expected_type` type') None, msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
msg='transform type using type function') '0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
expected_type=lambda _: 1 / 0), None, None, msg='wrap expected_type function in try_call')
msg='wrap expected_type function in try_call') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=str),
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'], ['str'], msg='eliminate items that expected_type fails on')
msg='eliminate items that expected_type fails on') self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
{0: 100}, msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int),
1, msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
[{0: Ellipsis}, {0: Ellipsis}], msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
[4], msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', Ellipsis], expected_type=int),
[], msg='expected_type regression for type matching in dict result')
# Test get_all behavior # Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]} _GET_ALL_DATA = {'key': [0, 1, 2]}
@ -1749,14 +1833,23 @@ Line 1
_traverse_string=True), '.', _traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`') msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
_traverse_string=True), list('str'), _traverse_string=True), 'str',
msg='`...` branching into string should result in list') msg='`...` should result in string (same value) if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
_traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
_traverse_string=True), 'str',
msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
_traverse_string=True), ['s', 'r'], _traverse_string=True), ['s', 'r'],
msg='branching into string should result in list') msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), self.assertEqual(traverse_obj({}, (0, Ellipsis), _traverse_string=True), [],
_traverse_string=True), list('str'), msg='branching should result in list if `traverse_string`')
msg='function branching into string should result in list') self.assertEqual(traverse_obj({}, (0, lambda x, y: True), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
# Test is_user_input behavior # Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))} _IS_USER_INPUT_DATA = {'range8': list(range(8))}
@ -1793,6 +1886,8 @@ Line 1
msg='failing str key on a `re.Match` should return `default`') msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None, self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`') msg='failing int key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
def test_get_first(self): def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')

View file

@ -16,6 +16,7 @@ import email.header
import errno import errno
import functools import functools
import gzip import gzip
import inspect
import io import io
import itertools import itertools
import json import json
@ -3881,7 +3882,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
return unrecognized return unrecognized
class LazyList(compat_collections_abc.Sequence): class LazyList(compat_collections_abc.Iterable):
"""Lazy immutable list from an iterable """Lazy immutable list from an iterable
Note that slices of a LazyList are lists and not LazyList""" Note that slices of a LazyList are lists and not LazyList"""
@ -4223,10 +4224,16 @@ def multipart_encode(data, boundary=None):
return out, content_type return out, content_type
def variadic(x, allowed_types=(compat_str, bytes, dict)): def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT):
if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable): if blocked_types is NO_DEFAULT:
blocked_types = (compat_str, bytes, compat_collections_abc.Mapping)
return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
def variadic(x, allowed_types=NO_DEFAULT):
if isinstance(allowed_types, compat_collections_abc.Iterable):
allowed_types = tuple(allowed_types) allowed_types = tuple(allowed_types)
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
def dict_get(d, key_or_keys, default=None, skip_false_values=True): def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@ -5993,7 +6000,7 @@ def clean_podcast_url(url):
def traverse_obj(obj, *paths, **kwargs): def traverse_obj(obj, *paths, **kwargs):
""" """
Safely traverse nested `dict`s and `Sequence`s Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}] >>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key")) >>> traverse_obj(obj, (1, "key"))
@ -6001,14 +6008,17 @@ def traverse_obj(obj, *paths, **kwargs):
Each of the provided `paths` is tested and the first producing a valid result will be returned. Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found. The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
A value of None is treated as the absence of a value. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of: The keys in the path can be one of:
- `None`: Return the current object. - `None`: Return the current object.
- `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. - `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`. - `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values. - `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values. - `tuple`/`list`: Branch out and return a list of all matching values.
@ -6016,6 +6026,9 @@ def traverse_obj(obj, *paths, **kwargs):
- `function`: Branch out and return values filtered by the function. - `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`. Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value. For `Sequence`s, `key` is the index of the value.
For `Iterable`s, `key` is the enumeration count of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict. - `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
@ -6024,8 +6037,12 @@ def traverse_obj(obj, *paths, **kwargs):
@params paths Paths which to traverse by. @params paths Paths which to traverse by.
Keyword arguments: Keyword arguments:
@param default Value to return if the paths do not match. @param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type. @param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result. If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones. @param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive. @param casesense If `False`, consider string dictionary keys as case insensitive.
@ -6036,12 +6053,15 @@ def traverse_obj(obj, *paths, **kwargs):
@param _traverse_string Whether to traverse into objects as strings. @param _traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be If `True`, any non-compatible object will first be
converted into a string and then traversed into. converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal. @returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once, If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead. then a list of results is returned instead.
A list is always returned if the last path branches and no `default` is given. A list is always returned if the last path branches and no `default` is given.
If a path ends on a `dict` that result will always be a `dict`.
""" """
# parameter defaults # parameter defaults
@ -6055,7 +6075,6 @@ def traverse_obj(obj, *paths, **kwargs):
# instant compat # instant compat
str = compat_str str = compat_str
is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
if isinstance(expected_type, type): if isinstance(expected_type, type):
@ -6063,128 +6082,180 @@ def traverse_obj(obj, *paths, **kwargs):
else: else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def lookup_or_none(v, k, getter=None):
try:
return getter(v, k) if getter else v[k]
except IndexError:
return None
def from_iterable(iterables): def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables: for it in iterables:
for item in it: for item in it:
yield item yield item
def apply_key(key, obj): def apply_key(key, obj, is_last):
if obj is None: branching = False
return
if obj is None and _traverse_string:
if key is Ellipsis or callable(key) or isinstance(key, slice):
branching = True
result = ()
else:
result = None
elif key is None: elif key is None:
yield obj result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
result = obj if isinstance(obj, item) else None
else:
result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)): elif isinstance(key, (list, tuple)):
for branch in key: branching = True
_, result = apply_path(obj, branch) result = from_iterable(
for item in result: apply_path(obj, branch, is_last)[0] for branch in key)
yield item
elif key is Ellipsis: elif key is Ellipsis:
result = [] branching = True
if isinstance(obj, compat_collections_abc.Mapping): if isinstance(obj, compat_collections_abc.Mapping):
result = obj.values() result = obj.values()
elif is_sequence(obj): elif is_iterable_like(obj):
result = obj result = obj
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
result = obj.groups() result = obj.groups()
elif _traverse_string: elif _traverse_string:
branching = False
result = str(obj) result = str(obj)
for item in result: else:
yield item result = ()
elif callable(key): elif callable(key):
if is_sequence(obj): branching = True
iter_obj = enumerate(obj) if isinstance(obj, compat_collections_abc.Mapping):
elif isinstance(obj, compat_collections_abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif is_iterable_like(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) iter_obj = itertools.chain(
enumerate(itertools.chain((obj.group(),), obj.groups())),
obj.groupdict().items())
elif _traverse_string: elif _traverse_string:
branching = False
iter_obj = enumerate(str(obj)) iter_obj = enumerate(str(obj))
else: else:
return iter_obj = ()
for item in (v for k, v in iter_obj if try_call(key, args=(k, v))):
yield item result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict): elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
yield dict((k, v if v is not None else default) for k, v in iter_obj result = dict((k, v if v is not None else default) for k, v in iter_obj
if v is not None or default is not NO_DEFAULT) if v is not None or default is not NO_DEFAULT) or None
elif isinstance(obj, compat_collections_abc.Mapping): elif isinstance(obj, compat_collections_abc.Mapping):
yield (obj.get(key) if casesense or (key in obj) result = (try_call(obj.get, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None)) if casesense or try_call(obj.__contains__, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
result = None
if isinstance(key, int) or casesense: if isinstance(key, int) or casesense:
try: result = lookup_or_none(obj, key, getter=compat_re_Match.group)
yield obj.group(key)
return
except IndexError:
pass
if not isinstance(key, str):
return
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, str):
result = next((v for k, v in obj.groupdict().items()
if casefold(k) == key), None)
else: else:
if _is_user_input: result = None
key = (int_or_none(key) if ':' not in key if isinstance(key, (int, slice)):
else slice(*map(int_or_none, key.split(':')))) if is_iterable_like(obj, compat_collections_abc.Sequence):
branching = isinstance(key, slice)
result = lookup_or_none(obj, key)
elif _traverse_string:
result = lookup_or_none(str(obj), key)
if not isinstance(key, (int, slice)): return branching, result if branching else (result,)
return
if not is_sequence(obj): def lazy_last(iterable):
if not _traverse_string: iterator = iter(iterable)
return prev = next(iterator, NO_DEFAULT)
obj = str(obj) if prev is NO_DEFAULT:
return
try: for item in iterator:
yield obj[key] yield False, prev
except IndexError: prev = item
pass
def apply_path(start_obj, path): yield True, prev
def apply_path(start_obj, path, test_type):
objs = (start_obj,) objs = (start_obj,)
has_branched = False has_branched = False
for key in variadic(path): key = None
if _is_user_input and key == ':': for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
key = Ellipsis if _is_user_input and isinstance(key, str):
if key == ':':
key = Ellipsis
elif ':' in key:
key = slice(*map(int_or_none, key.split(':')))
elif int_or_none(key) is not None:
key = int(key)
if not casesense and isinstance(key, str): if not casesense and isinstance(key, str):
key = compat_casefold(key) key = compat_casefold(key)
if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): if __debug__ and callable(key):
has_branched = True # Verify function signature
inspect.getcallargs(key, None, None)
key_func = functools.partial(apply_key, key) new_objs = []
objs = from_iterable(map(key_func, objs)) for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
return has_branched, objs objs = from_iterable(new_objs)
def _traverse_obj(obj, path, use_list=True): if test_type and not isinstance(key, (dict, list, tuple)):
has_branched, results = apply_path(obj, path) objs = map(type_test, objs)
results = LazyList(x for x in map(type_test, results) if x is not None)
return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(x for x in results if x not in (None, {}))
if get_all and has_branched: if get_all and has_branched:
return results.exhaust() if results or use_list else None if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else None return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1): for index, path in enumerate(paths, 1):
use_list = default is NO_DEFAULT and index == len(paths) result = _traverse_obj(obj, path, index == len(paths), True)
result = _traverse_obj(obj, path, use_list)
if result is not None: if result is not None:
return result return result
return None if default is NO_DEFAULT else default return None if default is NO_DEFAULT else default
def T(x):
""" For use in yt-dl instead of {type} or set((type,)) """
return set((x,))
def get_first(obj, keys, **kwargs): def get_first(obj, keys, **kwargs):
return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)