Coverage for src \ nuremics \ core \ utils.py: 92%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 18:48 +0100

1import ast 

2import inspect 

3import textwrap 

4from typing import Any, Callable, Optional, Type, Union 

5 

6import attrs 

7import numpy as np 

8 

9 

10def convert_value( 

11 value: object, 

12) -> Optional[Union[bool, int, float, str, object]]: 

13 

14 if value == "NA": 

15 return None 

16 if isinstance(value, (bool, np.bool_)): 

17 return bool(value) 

18 if isinstance(value, (int, np.int64)): 

19 return int(value) 

20 if isinstance(value, (float, np.float64)): 

21 return float(value) 

22 if isinstance(value, str): 

23 return str(value) 

24 

25 return value 

26 

27 

28def concat_lists_unique( 

29 list1: list, 

30 list2: list, 

31) -> list: 

32 

33 return list(dict.fromkeys(list1 + list2)) 

34 

35 

36def get_self_method_calls( 

37 cls: Type, 

38 method_name: str = "__call__", 

39) -> list: 

40 

41 method = getattr(cls, method_name, None) 

42 if method is None: 

43 return [] 

44 

45 source = inspect.getsource(method) 

46 source = textwrap.dedent(source) 

47 tree = ast.parse(source) 

48 

49 called_methods = [] 

50 

51 class SelfCallVisitor(ast.NodeVisitor): 

52 def visit_Call(self, 

53 node: object, 

54 ) -> list: 

55 if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): 

56 if node.func.value.id == "self": 

57 called_methods.append(node.func.attr) 

58 self.generic_visit(node) 

59 

60 SelfCallVisitor().visit(tree) 

61 return called_methods 

62 

63 

64# From ChatGPT 

65def only_function_calls( 

66 method: Callable[..., Any], 

67 allowed_methods: list[str], 

68) -> bool: 

69 """ 

70 Checks that the method contains only function calls, 

71 and that all calls are either super().__call__() or self.<allowed_method>(). 

72 """ 

73 

74 # Get and dedent source code 

75 source = inspect.getsource(method) 

76 source = textwrap.dedent(source) 

77 

78 # Parse the AST 

79 tree = ast.parse(source) 

80 

81 # Expect a FunctionDef node at top level 

82 func_def = tree.body[0] 

83 if not isinstance(func_def, ast.FunctionDef): 

84 return False 

85 

86 for stmt in func_def.body: 

87 # Each statement must be a simple expression (Expr) containing a Call 

88 if not isinstance(stmt, ast.Expr) or not isinstance(stmt.value, ast.Call): 

89 return False 

90 

91 call = stmt.value 

92 func = call.func 

93 

94 # Allow super().__call__() 

95 if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Call): 

96 if ( 

97 isinstance(func.value.func, ast.Name) 

98 and func.value.func.id == 'super' 

99 and func.attr == '__call__' 

100 ): 

101 continue 

102 

103 # Allow self.<allowed_method>() 

104 if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): 

105 if func.value.id == 'self' and func.attr in allowed_methods: 

106 continue 

107 

108 # If it's neither of the above, reject 

109 return False 

110 

111 return True 

112 

113 

114def extract_inputs_and_types( 

115 obj: object, 

116) -> dict: 

117 

118 params = {} 

119 for field in attrs.fields(obj.__class__): 

120 if field.metadata.get("input", False): 

121 params[field.name] = field.type 

122 

123 return params 

124 

125 

126def extract_analysis( 

127 obj: object, 

128) -> list: 

129 

130 analysis = [] 

131 for field in attrs.fields(obj.__class__): 

132 if field.metadata.get("analysis", False): 

133 analysis.append(field.name) 

134 

135 return analysis 

136 

137 

138def extract_outputs( 

139 obj: object, 

140) -> list: 

141 

142 outputs = [] 

143 for field in attrs.fields(obj.__class__): 

144 if field.metadata.get("output", False): 

145 outputs.append(field.name) 

146 

147 return outputs