diff --git a/podman_compose.py b/podman_compose.py index ef21b5d..97fb60e 100755 --- a/podman_compose.py +++ b/podman_compose.py @@ -63,6 +63,8 @@ def is_list(list_object): def is_list_of_str(list_of_str_object): if is_list(list_of_str_object): + if len(list_of_str_object) == 0: + return False for element in list_of_str_object: if not is_str(element): return False @@ -1300,6 +1302,20 @@ def clone(value): return value.copy() if is_list(value) or is_dict(value) else value +def clone_shell_value(target, key, value): + if is_str(value): + target[key] = shlex.split(value) + else: + target[key] = clone(value) + + +def check_shell_value_type(key, value): + if not is_str(value) and not is_list_of_str(value): + raise ValueError( + f"can't merge value of [{key}]: must be a string or a list of strings" + ) + + def rec_merge_one(target, source): """ update target from source recursively @@ -1308,28 +1324,26 @@ def rec_merge_one(target, source): for key, value in source.items(): if key in target: continue - target[key] = clone(value) + if key in ("command", "entrypoint"): + check_shell_value_type(key, value) + clone_shell_value(target, key, value) + else: + target[key] = clone(value) done.add(key) for key, value in target.items(): if key in done: continue + if key in ("command", "entrypoint"): + if key not in source: + check_shell_value_type(key, value) + clone_shell_value(target, key, value) + else: + check_shell_value_type(key, source[key]) + clone_shell_value(target, key, source[key]) + continue if key not in source: continue value2 = source[key] - if key in ("command", "entrypoint"): - if not is_str(value) and not is_list_of_str(value): - raise ValueError( - f"can't merge value of [{key}]: must be a string or a list of strings" - ) - if not is_str(value2) and not is_list_of_str(value2): - raise ValueError( - f"can't merge value of [{key}]: must be a string or a list of strings" - ) - if is_str(value2): - target[key] = value2.split(" ") - else: - target[key] = clone(value2) - continue if not isinstance(value2, type(value)): value_type = type(value) value2_type = type(value2) diff --git a/pytests/test_is_list_of_str.py b/pytests/test_is_list_of_str.py new file mode 100644 index 0000000..e5b8d7c --- /dev/null +++ b/pytests/test_is_list_of_str.py @@ -0,0 +1,10 @@ +from podman_compose import is_list_of_str + + +def test_is_list_of_str(): + assert is_list_of_str(["foo", "bar"]) + assert not is_list_of_str(["foo", 1]) + assert not is_list_of_str("foo") + assert not is_list_of_str([]) + assert not is_list_of_str(1) + assert not is_list_of_str(None) diff --git a/pytests/test_rec_merge_one_cmd_ent.py b/pytests/test_rec_merge_one_cmd_ent.py new file mode 100644 index 0000000..0ef3717 --- /dev/null +++ b/pytests/test_rec_merge_one_cmd_ent.py @@ -0,0 +1,82 @@ +import copy + +import pytest +from podman_compose import rec_merge_one + + +test_keys = ["command", "entrypoint"] +test_cases = [ + ({}, {"$$$": "sh"}, {"$$$": ["sh"]}), + ({"$$$": "sh"}, {}, {"$$$": ["sh"]}), + ({"$$$": "sh-1"}, {"$$$": "sh-2"}, {"$$$": ["sh-2"]}), + ({"$$$": ["sh-1"]}, {"$$$": "sh-2"}, {"$$$": ["sh-2"]}), + ({"$$$": "sh-1"}, {"$$$": ["sh-2"]}, {"$$$": ["sh-2"]}), + ({"$$$": "sh-1"}, {"$$$": ["sh-2", "sh-3"]}, {"$$$": ["sh-2", "sh-3"]}), + ({"$$$": ["sh-1"]}, {"$$$": ["sh-2", "sh-3"]}, {"$$$": ["sh-2", "sh-3"]}), + ({"$$$": ["sh-1", "sh-2"]}, {"$$$": ["sh-3", "sh-4"]}, {"$$$": ["sh-3", "sh-4"]}), + ({}, {"$$$": ["sh-3", "sh 4"]}, {"$$$": ["sh-3", "sh 4"]}), + ({"$$$": "sleep infinity"}, {"$$$": "sh"}, {"$$$": ["sh"]}), + ({"$$$": "sh"}, {"$$$": "sleep infinity"}, {"$$$": ["sleep", "infinity"]}), + ( + {}, + {"$$$": "bash -c 'sleep infinity'"}, + {"$$$": ["bash", "-c", "sleep infinity"]}, + ), +] +test_cases_with_exceptions = [ + ({}, {"$$$": 1234}, ValueError), + ({"$$$": 1234}, {}, ValueError), + ({"$$$": 1234}, {"$$$": 1234}, ValueError), + ({"$$$": {}}, {}, ValueError), + ({}, {"$$$": {}}, ValueError), + ({"$$$": {}}, {"$$$": {}}, ValueError), + ({"$$$": []}, {}, ValueError), + ({}, {"$$$": []}, ValueError), + ({"$$$": []}, {"$$$": []}, ValueError), +] + + +def template_to_expression(base, override, expected, key): + base_copy = copy.deepcopy(base) + override_copy = copy.deepcopy(override) + expected_copy = copy.deepcopy(expected) + + expected_copy[key] = expected_copy.pop("$$$") + if "$$$" in base: + base_copy[key] = base_copy.pop("$$$") + if "$$$" in override: + override_copy[key] = override_copy.pop("$$$") + return base_copy, override_copy, expected_copy + + +def test_rec_merge_one_for_command_and_entrypoint(): + for base_template, override_template, expected_template in test_cases: + for key in test_keys: + base, override, expected = template_to_expression( + base_template, override_template, expected_template, key + ) + + base = rec_merge_one(base, override) + test_result = expected == base + if not test_result: + print("base_template: ", base_template) + print("override_template: ", override_template) + print("expected: ", expected) + print("actual: ", base) + assert test_result + + for ( + base_template, + override_template, + expected_exception, + ) in test_cases_with_exceptions: + for key in test_keys: + base, override, expected = template_to_expression( + base_template, override_template, {"$$$": ""}, key + ) + + with pytest.raises(expected_exception): + base = rec_merge_one(base, override) + print("base_template: ", base_template) + print("override_template: ", override_template) + print("expected: ", expected_exception)