From f47fdb9564d3ca1c0fa70ed6031148ec908fdc7b Mon Sep 17 00:00:00 2001 From: dirkf Date: Thu, 6 Jul 2023 15:46:22 +0100 Subject: [PATCH] [utils] Add {expected_type} and Iterable support to traverse_obj() --- test/test_utils.py | 153 ++++++++++++++++++++++++++------ youtube_dl/utils.py | 211 +++++++++++++++++++++++++++++--------------- 2 files changed, 265 insertions(+), 99 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 5fab05f7c..1fc16ed05 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -79,10 +79,12 @@ from youtube_dl.utils import ( rot47, shell_quote, smuggle_url, + str_or_none, str_to_int, strip_jsonp, strip_or_none, subtitles_filename, + T, timeconvert, traverse_obj, try_call, @@ -1566,6 +1568,7 @@ Line 1 self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam') def test_traverse_obj(self): + str = compat_str _TEST_DATA = { 100: 100, 1.2: 1.2, @@ -1598,8 +1601,8 @@ Line 1 # Test Ellipsis behavior self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), - (item for item in _TEST_DATA.values() if item is not None), - msg='`...` should give all values except `None`') + (item for item in _TEST_DATA.values() if item not in (None, {})), + msg='`...` should give all non discarded values') self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), msg='`...` selection for dicts should select all values') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), @@ -1607,13 +1610,51 @@ Line 1 msg='nested `...` queries should work') self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), msg='`...` query result should be flattened') + self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)), + msg='`...` should accept iterables') # Test function as key self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), [_TEST_DATA['urls']], msg='function as query key should perform a filter based on (key, value)') - self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), ('str',), - msg='exceptions in the query function should be caught') + self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, + msg='exceptions in the query function should be catched') + self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], + msg='function key should accept iterables') + if __debug__: + with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): + traverse_obj(_TEST_DATA, lambda a: Ellipsis) + with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): + traverse_obj(_TEST_DATA, lambda a, b, c: Ellipsis) + + # Test set as key (transformation/type, like `expected_type`) + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'], + msg='Function in set should be a transformation') + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'], + msg='Type in set should be a type filter') + self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA, + msg='A single set should be wrapped into a path') + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'], + msg='Transformation function should not raise') + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))), + [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], + msg='Function in set should be a transformation') + if __debug__: + with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): + traverse_obj(_TEST_DATA, set()) + with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): + traverse_obj(_TEST_DATA, {str.upper, str}) + + # Test `slice` as a key + _SLICE_DATA = [0, 1, 2, 3, 4] + self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None, + msg='slice on a dictionary should not throw') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1], + msg='slice key should apply slice to sequence') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2], + msg='slice key should apply slice to sequence') + self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2], + msg='slice key should apply slice to sequence') # Test alternative paths self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', @@ -1659,15 +1700,23 @@ Line 1 {0: ['https://www.example.com/1', 'https://www.example.com/0']}, msg='triple nesting in dict path should be treated as branches') self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, - msg='remove `None` values when dict key') + msg='remove `None` values when top level dict key fails') self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, - msg='do not remove `None` values if `default`') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, - msg='do not remove empty values when dict key') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}}, - msg='do not remove empty values when dict key and a default') - self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []}, - msg='if branch in dict key not successful, return `[]`') + msg='use `default` if key fails and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, + msg='remove empty values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, + msg='use `default` when dict key and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, + msg='remove empty values when nested dict key fails') + self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, + msg='default to dict if pruned') + self.assertEqual(traverse_obj(None, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, + msg='default to dict if pruned and default is given') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=Ellipsis), {0: {0: Ellipsis}}, + msg='use nested `default` when nested dict key fails and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {}, + msg='remove key if branch in dict key not successful') # Testing default parameter behavior _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} @@ -1691,20 +1740,55 @@ Line 1 msg='if branched but not successful return `[]`, not `default`') self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [], msg='if branched but object is empty return `[]`, not `default`') + self.assertEqual(traverse_obj(None, Ellipsis), [], + msg='if branched but object is `None` return `[]`, not `default`') + self.assertEqual(traverse_obj({0: None}, (0, Ellipsis)), [], + msg='if branched but state is `None` return `[]`, not `default`') + + branching_paths = [ + ('fail', Ellipsis), + (Ellipsis, 'fail'), + 100 * ('fail',) + (Ellipsis,), + (Ellipsis,) + 100 * ('fail',), + ] + for branching_path in branching_paths: + self.assertEqual(traverse_obj({}, branching_path), [], + msg='if branched but state is `None`, return `[]` (not `default`)') + self.assertEqual(traverse_obj({}, 'fail', branching_path), [], + msg='if branching in last alternative and previous did not match, return `[]` (not `default`)') + self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x', + msg='if branching in last alternative and previous did match, return single value') + self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x', + msg='if branching in first alternative and non-branching path does match, return single value') + self.assertEqual(traverse_obj({}, branching_path, 'fail'), None, + msg='if branching in first alternative and non-branching path does not match, return `default`') # Testing expected_type behavior _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str', - msg='accept matching `expected_type` type') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, - msg='reject non matching `expected_type` type') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0', - msg='transform type using type function') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', - expected_type=lambda _: 1 / 0), None, - msg='wrap expected_type function in try_call') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'], - msg='eliminate items that expected_type fails on') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), + 'str', msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), + None, msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), + '0', msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), + None, msg='wrap expected_type function in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=str), + ['str'], msg='eliminate items that expected_type fails on') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), + {0: 100}, msg='type as expected_type should filter dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), + {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') + self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), + 1, msg='expected_type should not filter non final dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), + {0: {0: 100}}, msg='expected_type should transform deep dict values') + self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), + [{0: Ellipsis}, {0: Ellipsis}], msg='expected_type should transform branched dict values') + self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), + [4], msg='expected_type regression for type matching in tuple branching') + self.assertEqual(traverse_obj(_TEST_DATA, ['data', Ellipsis], expected_type=int), + [], msg='expected_type regression for type matching in dict result') # Test get_all behavior _GET_ALL_DATA = {'key': [0, 1, 2]} @@ -1749,14 +1833,23 @@ Line 1 _traverse_string=True), '.', msg='traverse into converted data if `traverse_string`') self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis), - _traverse_string=True), list('str'), - msg='`...` branching into string should result in list') + _traverse_string=True), 'str', + msg='`...` should result in string (same value) if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), + _traverse_string=True), 'sr', + msg='`slice` should result in string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), + _traverse_string=True), 'str', + msg='function should result in string if `traverse_string`') self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), _traverse_string=True), ['s', 'r'], - msg='branching into string should result in list') - self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), - _traverse_string=True), list('str'), - msg='function branching into string should result in list') + msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, Ellipsis), _traverse_string=True), [], + msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, lambda x, y: True), _traverse_string=True), [], + msg='branching should result in list if `traverse_string`') + self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [], + msg='branching should result in list if `traverse_string`') # Test is_user_input behavior _IS_USER_INPUT_DATA = {'range8': list(range(8))} @@ -1793,6 +1886,8 @@ Line 1 msg='failing str key on a `re.Match` should return `default`') self.assertEqual(traverse_obj(mobj, 8), None, msg='failing int key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], + msg='function on a `re.Match` should give group name as well') def test_get_first(self): self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 83f67bd95..dbdbe5f59 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -16,6 +16,7 @@ import email.header import errno import functools import gzip +import inspect import io import itertools import json @@ -3881,7 +3882,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'): return unrecognized -class LazyList(compat_collections_abc.Sequence): +class LazyList(compat_collections_abc.Iterable): """Lazy immutable list from an iterable Note that slices of a LazyList are lists and not LazyList""" @@ -4223,10 +4224,16 @@ def multipart_encode(data, boundary=None): return out, content_type -def variadic(x, allowed_types=(compat_str, bytes, dict)): - if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable): +def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT): + if blocked_types is NO_DEFAULT: + blocked_types = (compat_str, bytes, compat_collections_abc.Mapping) + return isinstance(x, allowed_types) and not isinstance(x, blocked_types) + + +def variadic(x, allowed_types=NO_DEFAULT): + if isinstance(allowed_types, compat_collections_abc.Iterable): allowed_types = tuple(allowed_types) - return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) + return x if is_iterable_like(x, blocked_types=allowed_types) else (x,) def dict_get(d, key_or_keys, default=None, skip_false_values=True): @@ -5993,7 +6000,7 @@ def clean_podcast_url(url): def traverse_obj(obj, *paths, **kwargs): """ - Safely traverse nested `dict`s and `Sequence`s + Safely traverse nested `dict`s and `Iterable`s >>> obj = [{}, {"key": "value"}] >>> traverse_obj(obj, (1, "key")) @@ -6001,14 +6008,17 @@ def traverse_obj(obj, *paths, **kwargs): Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. - A value of None is treated as the absence of a value. + Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The keys in the path can be one of: - `None`: Return the current object. - - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. + - `set`: Requires the only item in the set to be a type or function, + like `{type}`/`{func}`. If a `type`, returns only values + of this type. If a function, returns `func(obj)`. + - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. - `tuple`/`list`: Branch out and return a list of all matching values. @@ -6016,6 +6026,9 @@ def traverse_obj(obj, *paths, **kwargs): - `function`: Branch out and return values filtered by the function. Read as: `[value for key, value in obj if function(key, value)]`. For `Sequence`s, `key` is the index of the value. + For `Iterable`s, `key` is the enumeration count of the value. + For `re.Match`es, `key` is the group number (0 = full match) + as well as additionally any group names, if given. - `dict` Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. @@ -6024,8 +6037,12 @@ def traverse_obj(obj, *paths, **kwargs): @params paths Paths which to traverse by. Keyword arguments: @param default Value to return if the paths do not match. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, depth first. Try to avoid if using nested `dict` keys. @param expected_type If a `type`, only accept final values of this type. If any other callable, try to call the function on each result. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, recursively. This does respect branching paths. @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. @@ -6036,12 +6053,15 @@ def traverse_obj(obj, *paths, **kwargs): @param _traverse_string Whether to traverse into objects as strings. If `True`, any non-compatible object will first be converted into a string and then traversed into. + The return value of that path will be a string instead, + not respecting any further branching. @returns The result of the object traversal. If successful, `get_all=True`, and the path branches at least once, then a list of results is returned instead. A list is always returned if the last path branches and no `default` is given. + If a path ends on a `dict` that result will always be a `dict`. """ # parameter defaults @@ -6055,7 +6075,6 @@ def traverse_obj(obj, *paths, **kwargs): # instant compat str = compat_str - is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes)) casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k if isinstance(expected_type, type): @@ -6063,128 +6082,180 @@ def traverse_obj(obj, *paths, **kwargs): else: type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + def lookup_or_none(v, k, getter=None): + try: + return getter(v, k) if getter else v[k] + except IndexError: + return None + def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F for it in iterables: for item in it: yield item - def apply_key(key, obj): - if obj is None: - return + def apply_key(key, obj, is_last): + branching = False + + if obj is None and _traverse_string: + if key is Ellipsis or callable(key) or isinstance(key, slice): + branching = True + result = () + else: + result = None elif key is None: - yield obj + result = obj + + elif isinstance(key, set): + assert len(key) == 1, 'Set should only be used to wrap a single item' + item = next(iter(key)) + if isinstance(item, type): + result = obj if isinstance(obj, item) else None + else: + result = try_call(item, args=(obj,)) elif isinstance(key, (list, tuple)): - for branch in key: - _, result = apply_path(obj, branch) - for item in result: - yield item + branching = True + result = from_iterable( + apply_path(obj, branch, is_last)[0] for branch in key) elif key is Ellipsis: - result = [] + branching = True if isinstance(obj, compat_collections_abc.Mapping): result = obj.values() - elif is_sequence(obj): + elif is_iterable_like(obj): result = obj elif isinstance(obj, compat_re_Match): result = obj.groups() elif _traverse_string: + branching = False result = str(obj) - for item in result: - yield item + else: + result = () elif callable(key): - if is_sequence(obj): - iter_obj = enumerate(obj) - elif isinstance(obj, compat_collections_abc.Mapping): + branching = True + if isinstance(obj, compat_collections_abc.Mapping): iter_obj = obj.items() + elif is_iterable_like(obj): + iter_obj = enumerate(obj) elif isinstance(obj, compat_re_Match): - iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) + iter_obj = itertools.chain( + enumerate(itertools.chain((obj.group(),), obj.groups())), + obj.groupdict().items()) elif _traverse_string: + branching = False iter_obj = enumerate(str(obj)) else: - return - for item in (v for k, v in iter_obj if try_call(key, args=(k, v))): - yield item + iter_obj = () + + result = (v for k, v in iter_obj if try_call(key, args=(k, v))) + if not branching: # string traversal + result = ''.join(result) elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) - yield dict((k, v if v is not None else default) for k, v in iter_obj - if v is not None or default is not NO_DEFAULT) + iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) + result = dict((k, v if v is not None else default) for k, v in iter_obj + if v is not None or default is not NO_DEFAULT) or None elif isinstance(obj, compat_collections_abc.Mapping): - yield (obj.get(key) if casesense or (key in obj) - else next((v for k, v in obj.items() if casefold(k) == key), None)) + result = (try_call(obj.get, args=(key,)) + if casesense or try_call(obj.__contains__, args=(key,)) + else next((v for k, v in obj.items() if casefold(k) == key), None)) elif isinstance(obj, compat_re_Match): + result = None if isinstance(key, int) or casesense: - try: - yield obj.group(key) - return - except IndexError: - pass - if not isinstance(key, str): - return + result = lookup_or_none(obj, key, getter=compat_re_Match.group) - yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + elif isinstance(key, str): + result = next((v for k, v in obj.groupdict().items() + if casefold(k) == key), None) else: - if _is_user_input: - key = (int_or_none(key) if ':' not in key - else slice(*map(int_or_none, key.split(':')))) + result = None + if isinstance(key, (int, slice)): + if is_iterable_like(obj, compat_collections_abc.Sequence): + branching = isinstance(key, slice) + result = lookup_or_none(obj, key) + elif _traverse_string: + result = lookup_or_none(str(obj), key) - if not isinstance(key, (int, slice)): - return + return branching, result if branching else (result,) - if not is_sequence(obj): - if not _traverse_string: - return - obj = str(obj) + def lazy_last(iterable): + iterator = iter(iterable) + prev = next(iterator, NO_DEFAULT) + if prev is NO_DEFAULT: + return - try: - yield obj[key] - except IndexError: - pass + for item in iterator: + yield False, prev + prev = item - def apply_path(start_obj, path): + yield True, prev + + def apply_path(start_obj, path, test_type): objs = (start_obj,) has_branched = False - for key in variadic(path): - if _is_user_input and key == ':': - key = Ellipsis + key = None + for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): + if _is_user_input and isinstance(key, str): + if key == ':': + key = Ellipsis + elif ':' in key: + key = slice(*map(int_or_none, key.split(':'))) + elif int_or_none(key) is not None: + key = int(key) if not casesense and isinstance(key, str): key = compat_casefold(key) - if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): - has_branched = True + if __debug__ and callable(key): + # Verify function signature + inspect.getcallargs(key, None, None) - key_func = functools.partial(apply_key, key) - objs = from_iterable(map(key_func, objs)) + new_objs = [] + for obj in objs: + branching, results = apply_key(key, obj, last) + has_branched |= branching + new_objs.append(results) - return has_branched, objs + objs = from_iterable(new_objs) - def _traverse_obj(obj, path, use_list=True): - has_branched, results = apply_path(obj, path) - results = LazyList(x for x in map(type_test, results) if x is not None) + if test_type and not isinstance(key, (dict, list, tuple)): + objs = map(type_test, objs) + + return objs, has_branched, isinstance(key, dict) + + def _traverse_obj(obj, path, allow_empty, test_type): + results, has_branched, is_dict = apply_path(obj, path, test_type) + results = LazyList(x for x in results if x not in (None, {})) if get_all and has_branched: - return results.exhaust() if results or use_list else None + if results: + return results.exhaust() + if allow_empty: + return [] if default is NO_DEFAULT else default + return None - return results[0] if results else None + return results[0] if results else {} if allow_empty and is_dict else None for index, path in enumerate(paths, 1): - use_list = default is NO_DEFAULT and index == len(paths) - result = _traverse_obj(obj, path, use_list) + result = _traverse_obj(obj, path, index == len(paths), True) if result is not None: return result return None if default is NO_DEFAULT else default +def T(x): + """ For use in yt-dl instead of {type} or set((type,)) """ + return set((x,)) + + def get_first(obj, keys, **kwargs): return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)