Coverage for src \ nuremics \ core \ utils.py: 92%
74 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 18:48 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 18:48 +0100
1import ast
2import inspect
3import textwrap
4from typing import Any, Callable, Optional, Type, Union
6import attrs
7import numpy as np
10def convert_value(
11 value: object,
12) -> Optional[Union[bool, int, float, str, object]]:
14 if value == "NA":
15 return None
16 if isinstance(value, (bool, np.bool_)):
17 return bool(value)
18 if isinstance(value, (int, np.int64)):
19 return int(value)
20 if isinstance(value, (float, np.float64)):
21 return float(value)
22 if isinstance(value, str):
23 return str(value)
25 return value
28def concat_lists_unique(
29 list1: list,
30 list2: list,
31) -> list:
33 return list(dict.fromkeys(list1 + list2))
36def get_self_method_calls(
37 cls: Type,
38 method_name: str = "__call__",
39) -> list:
41 method = getattr(cls, method_name, None)
42 if method is None:
43 return []
45 source = inspect.getsource(method)
46 source = textwrap.dedent(source)
47 tree = ast.parse(source)
49 called_methods = []
51 class SelfCallVisitor(ast.NodeVisitor):
52 def visit_Call(self,
53 node: object,
54 ) -> list:
55 if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
56 if node.func.value.id == "self":
57 called_methods.append(node.func.attr)
58 self.generic_visit(node)
60 SelfCallVisitor().visit(tree)
61 return called_methods
64# From ChatGPT
65def only_function_calls(
66 method: Callable[..., Any],
67 allowed_methods: list[str],
68) -> bool:
69 """
70 Checks that the method contains only function calls,
71 and that all calls are either super().__call__() or self.<allowed_method>().
72 """
74 # Get and dedent source code
75 source = inspect.getsource(method)
76 source = textwrap.dedent(source)
78 # Parse the AST
79 tree = ast.parse(source)
81 # Expect a FunctionDef node at top level
82 func_def = tree.body[0]
83 if not isinstance(func_def, ast.FunctionDef):
84 return False
86 for stmt in func_def.body:
87 # Each statement must be a simple expression (Expr) containing a Call
88 if not isinstance(stmt, ast.Expr) or not isinstance(stmt.value, ast.Call):
89 return False
91 call = stmt.value
92 func = call.func
94 # Allow super().__call__()
95 if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Call):
96 if (
97 isinstance(func.value.func, ast.Name)
98 and func.value.func.id == 'super'
99 and func.attr == '__call__'
100 ):
101 continue
103 # Allow self.<allowed_method>()
104 if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
105 if func.value.id == 'self' and func.attr in allowed_methods:
106 continue
108 # If it's neither of the above, reject
109 return False
111 return True
114def extract_inputs_and_types(
115 obj: object,
116) -> dict:
118 params = {}
119 for field in attrs.fields(obj.__class__):
120 if field.metadata.get("input", False):
121 params[field.name] = field.type
123 return params
126def extract_analysis(
127 obj: object,
128) -> list:
130 analysis = []
131 for field in attrs.fields(obj.__class__):
132 if field.metadata.get("analysis", False):
133 analysis.append(field.name)
135 return analysis
138def extract_outputs(
139 obj: object,
140) -> list:
142 outputs = []
143 for field in attrs.fields(obj.__class__):
144 if field.metadata.get("output", False):
145 outputs.append(field.name)
147 return outputs