Skip to content

Commit 56a6465

Browse files
authored
fix: Harden resolve_import with new changes (#803)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed
1 parent f5ec923 commit 56a6465

File tree

2 files changed

+159
-151
lines changed

2 files changed

+159
-151
lines changed

src/codegen/sdk/python/import_resolution.py

Lines changed: 99 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -85,73 +85,87 @@ def imported_exports(self) -> list[Exportable]:
8585
@noapidoc
8686
@reader
8787
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None:
88-
base_path = base_path or self.ctx.projects[0].base_path or ""
89-
module_source = self.module.source if self.module else ""
90-
symbol_name = self.symbol_name.source if self.symbol_name else ""
91-
if add_module_name:
92-
module_source += f".{symbol_name}"
93-
symbol_name = add_module_name
94-
# If import is relative, convert to absolute path
95-
if module_source.startswith("."):
96-
module_source = self._relative_to_absolute_import(module_source)
97-
98-
# =====[ Check if we are importing an entire file ]=====
99-
if self.is_module_import():
100-
# covers `import a.b.c` case and `from a.b.c import *` case
101-
filepath = os.path.join(base_path, module_source.replace(".", "/") + ".py")
102-
else:
103-
# This is the case where you do:
104-
# `from a.b.c import foo`
105-
filepath = os.path.join(
106-
base_path,
107-
module_source.replace(".", "/") + "/" + symbol_name + ".py",
108-
)
109-
110-
# =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]=====
111-
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
112-
# Handle resolve overrides first if both is set
113-
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
114-
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
88+
try:
89+
base_path = base_path or self.ctx.projects[0].base_path or ""
90+
module_source = self.module.source if self.module else ""
91+
symbol_name = self.symbol_name.source if self.symbol_name else ""
92+
if add_module_name:
93+
module_source += f".{symbol_name}"
94+
symbol_name = add_module_name
95+
# If import is relative, convert to absolute path
96+
if module_source.startswith("."):
97+
module_source = self._relative_to_absolute_import(module_source)
98+
99+
# =====[ Check if we are importing an entire file ]=====
100+
if self.is_module_import():
101+
# covers `import a.b.c` case and `from a.b.c import *` case
102+
filepath = os.path.join(base_path, module_source.replace(".", "/") + ".py")
103+
else:
104+
# This is the case where you do:
105+
# `from a.b.c import foo`
106+
filepath = os.path.join(
107+
base_path,
108+
module_source.replace(".", "/") + "/" + symbol_name + ".py",
109+
)
110+
111+
# =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]=====
112+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
113+
# Handle resolve overrides first if both is set
114+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
115+
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
116+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
117+
118+
# =====[ Default path ]=====
119+
if file := self.ctx.get_file(filepath):
120+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
121+
122+
filepath = filepath.replace(".py", "/__init__.py")
123+
if file := self.ctx.get_file(filepath):
124+
# TODO - I think this is another edge case, due to `dao/__init__.py` etc.
125+
# You can't do `from a.b.c import foo` => `foo.utils.x` right now since `foo` is just a file...
115126
return ImportResolution(from_file=file, symbol=None, imports_file=True)
116127

117-
# =====[ Default path ]=====
118-
if file := self.ctx.get_file(filepath):
119-
return ImportResolution(from_file=file, symbol=None, imports_file=True)
120-
121-
filepath = filepath.replace(".py", "/__init__.py")
122-
if file := self.ctx.get_file(filepath):
123-
# TODO - I think this is another edge case, due to `dao/__init__.py` etc.
124-
# You can't do `from a.b.c import foo` => `foo.utils.x` right now since `foo` is just a file...
125-
return ImportResolution(from_file=file, symbol=None, imports_file=True)
126-
127-
# =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
128-
filepath = module_source.replace(".", "/") + ".py"
129-
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
130-
# Handle resolve overrides first if both is set
131-
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
132-
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
128+
# =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
129+
filepath = module_source.replace(".", "/") + ".py"
130+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
131+
# Handle resolve overrides first if both is set
132+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
133+
if file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
134+
symbol = file.get_node_by_name(symbol_name)
135+
return ImportResolution(from_file=file, symbol=symbol)
136+
137+
# =====[ Check if `module.py` file exists in the graph ]=====
138+
filepath = os.path.join(base_path, filepath)
139+
if file := self.ctx.get_file(filepath):
133140
symbol = file.get_node_by_name(symbol_name)
134-
return ImportResolution(from_file=file, symbol=symbol)
135-
136-
# =====[ Check if `module.py` file exists in the graph ]=====
137-
filepath = os.path.join(base_path, filepath)
138-
if file := self.ctx.get_file(filepath):
139-
symbol = file.get_node_by_name(symbol_name)
140-
if symbol is None:
141-
if file.get_node_from_wildcard_chain(symbol_name):
142-
return ImportResolution(from_file=file, symbol=None, imports_file=True)
141+
if symbol is None:
142+
if file.get_node_from_wildcard_chain(symbol_name):
143+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
144+
else:
145+
# This is most likely a broken import
146+
return ImportResolution(from_file=file, symbol=None)
143147
else:
144-
# This is most likely a broken import
145-
return ImportResolution(from_file=file, symbol=None)
146-
else:
147-
return ImportResolution(from_file=file, symbol=symbol)
148-
149-
# =====[ Check if `module/__init__.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
150-
filepath = filepath.replace(".py", "/__init__.py")
151-
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
152-
# Handle resolve overrides first if both is set
153-
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
154-
if from_file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
148+
return ImportResolution(from_file=file, symbol=symbol)
149+
150+
# =====[ Check if `module/__init__.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
151+
filepath = filepath.replace(".py", "/__init__.py")
152+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
153+
# Handle resolve overrides first if both is set
154+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
155+
if from_file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
156+
symbol = from_file.get_node_by_name(symbol_name)
157+
if symbol is None:
158+
if from_file.get_node_from_wildcard_chain(symbol_name):
159+
return ImportResolution(from_file=from_file, symbol=None, imports_file=True)
160+
else:
161+
# This is most likely a broken import
162+
return ImportResolution(from_file=from_file, symbol=None)
163+
164+
else:
165+
return ImportResolution(from_file=from_file, symbol=symbol)
166+
167+
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
168+
if from_file := self.ctx.get_file(filepath):
155169
symbol = from_file.get_node_by_name(symbol_name)
156170
if symbol is None:
157171
if from_file.get_node_from_wildcard_chain(symbol_name):
@@ -163,40 +177,30 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
163177
else:
164178
return ImportResolution(from_file=from_file, symbol=symbol)
165179

166-
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
167-
if from_file := self.ctx.get_file(filepath):
168-
symbol = from_file.get_node_by_name(symbol_name)
169-
if symbol is None:
170-
if from_file.get_node_from_wildcard_chain(symbol_name):
171-
return ImportResolution(from_file=from_file, symbol=None, imports_file=True)
172-
else:
173-
# This is most likely a broken import
174-
return ImportResolution(from_file=from_file, symbol=None)
175-
176-
else:
177-
return ImportResolution(from_file=from_file, symbol=symbol)
178-
179-
# =====[ Case: Can't resolve the import ]=====
180-
if base_path == "":
181-
# Try to resolve with "src" as the base path
182-
return self.resolve_import(base_path="src", add_module_name=add_module_name)
183-
if base_path == "src":
184-
# Try "test" next
185-
return self.resolve_import(base_path="test", add_module_name=add_module_name)
180+
# =====[ Case: Can't resolve the import ]=====
181+
if base_path == "":
182+
# Try to resolve with "src" as the base path
183+
return self.resolve_import(base_path="src", add_module_name=add_module_name)
184+
if base_path == "src":
185+
# Try "test" next
186+
return self.resolve_import(base_path="test", add_module_name=add_module_name)
186187

187-
# if not G_override:
188-
# for resolver in ctx.import_resolvers:
189-
# if imp := resolver.resolve(self):
190-
# return imp
188+
# if not G_override:
189+
# for resolver in ctx.import_resolvers:
190+
# if imp := resolver.resolve(self):
191+
# return imp
191192

192-
return None
193-
# # =====[ Check if we are importing an external module in the graph ]=====
194-
# if ext := self.ctx.get_external_module(self.source, self._unique_node.source):
195-
# return ImportResolution(symbol=ext)
196-
# # Implies we are not importing the symbol from the current repo.
197-
# # In these cases, consider the import as an ExternalModule and add to graph
198-
# ext = ExternalModule.from_import(self)
199-
# return ImportResolution(symbol=ext)
193+
return None
194+
# # =====[ Check if we are importing an external module in the graph ]=====
195+
# if ext := self.ctx.get_external_module(self.source, self._unique_node.source):
196+
# return ImportResolution(symbol=ext)
197+
# # Implies we are not importing the symbol from the current repo.
198+
# # In these cases, consider the import as an ExternalModule and add to graph
199+
# ext = ExternalModule.from_import(self)
200+
# return ImportResolution(symbol=ext)
201+
except AssertionError:
202+
# Codebase is probably trying to import file from outside repo
203+
return None
200204

201205
@noapidoc
202206
@reader

src/codegen/sdk/typescript/import_resolution.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -215,63 +215,67 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
215215
- symbol: The specific symbol being imported (None for module imports)
216216
- imports_file: True if importing the entire file/module
217217
"""
218-
self.file: TSFile # Type cast ts_file
219-
base_path = base_path or self.ctx.projects[0].base_path or ""
220-
221-
# Get the import source path
222-
import_source = self.module.source.strip('"').strip("'") if self.module else ""
223-
224-
# Try to resolve the import using the tsconfig paths
225-
if self.file.ts_config:
226-
import_source = self.file.ts_config.translate_import_path(import_source)
227-
228-
# Check if need to resolve relative import path to absolute path
229-
relative_import = False
230-
if import_source.startswith("."):
231-
relative_import = True
232-
233-
# Insert base path
234-
# This has the happen before the relative path resolution
235-
if not import_source.startswith(base_path):
236-
import_source = os.path.join(base_path, import_source)
237-
238-
# If the import is relative, convert it to an absolute path
239-
if relative_import:
240-
import_source = self._relative_to_absolute_import(import_source)
241-
else:
242-
import_source = os.path.normpath(import_source)
243-
244-
# covers the case where the import is from a directory ex: "import { postExtract } from './post'"
245-
import_name = import_source.split("/")[-1]
246-
if "." not in import_name:
247-
possible_paths = ["index.ts", "index.js", "index.tsx", "index.jsx"]
248-
for p_path in possible_paths:
249-
if self.ctx.to_absolute(os.path.join(import_source, p_path)).exists():
250-
import_source = os.path.join(import_source, p_path)
251-
break
252-
253-
# Loop through all extensions and try to find the file
254-
extensions = ["", ".ts", ".d.ts", ".tsx", ".d.tsx", ".js", ".jsx"]
255-
# Try both filename with and without extension
256-
for import_source_base in (import_source, os.path.splitext(import_source)[0]):
257-
for extension in extensions:
258-
import_source_ext = import_source_base + extension
259-
if file := self.ctx.get_file(import_source_ext):
260-
if self.is_module_import():
261-
return ImportResolution(from_file=file, symbol=None, imports_file=True)
262-
else:
263-
# If the import is a named import, resolve to the named export in the file
264-
if self.symbol_name is None:
265-
return ImportResolution(from_file=file, symbol=None, imports_file=True)
266-
export_symbol = file.get_export(export_name=self.symbol_name.source)
267-
if export_symbol is None:
268-
# If the named export is not found, it is importing a module re-export.
269-
# In this case, resolve to the file itself and dynamically resolve the symbol later.
218+
try:
219+
self.file: TSFile # Type cast ts_file
220+
base_path = base_path or self.ctx.projects[0].base_path or ""
221+
222+
# Get the import source path
223+
import_source = self.module.source.strip('"').strip("'") if self.module else ""
224+
225+
# Try to resolve the import using the tsconfig paths
226+
if self.file.ts_config:
227+
import_source = self.file.ts_config.translate_import_path(import_source)
228+
229+
# Check if need to resolve relative import path to absolute path
230+
relative_import = False
231+
if import_source.startswith("."):
232+
relative_import = True
233+
234+
# Insert base path
235+
# This has the happen before the relative path resolution
236+
if not import_source.startswith(base_path):
237+
import_source = os.path.join(base_path, import_source)
238+
239+
# If the import is relative, convert it to an absolute path
240+
if relative_import:
241+
import_source = self._relative_to_absolute_import(import_source)
242+
else:
243+
import_source = os.path.normpath(import_source)
244+
245+
# covers the case where the import is from a directory ex: "import { postExtract } from './post'"
246+
import_name = import_source.split("/")[-1]
247+
if "." not in import_name:
248+
possible_paths = ["index.ts", "index.js", "index.tsx", "index.jsx"]
249+
for p_path in possible_paths:
250+
if self.ctx.to_absolute(os.path.join(import_source, p_path)).exists():
251+
import_source = os.path.join(import_source, p_path)
252+
break
253+
254+
# Loop through all extensions and try to find the file
255+
extensions = ["", ".ts", ".d.ts", ".tsx", ".d.tsx", ".js", ".jsx"]
256+
# Try both filename with and without extension
257+
for import_source_base in (import_source, os.path.splitext(import_source)[0]):
258+
for extension in extensions:
259+
import_source_ext = import_source_base + extension
260+
if file := self.ctx.get_file(import_source_ext):
261+
if self.is_module_import():
270262
return ImportResolution(from_file=file, symbol=None, imports_file=True)
271-
return ImportResolution(from_file=file, symbol=export_symbol)
272-
273-
# If the imported file is not found, treat it as an external module
274-
return None
263+
else:
264+
# If the import is a named import, resolve to the named export in the file
265+
if self.symbol_name is None:
266+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
267+
export_symbol = file.get_export(export_name=self.symbol_name.source)
268+
if export_symbol is None:
269+
# If the named export is not found, it is importing a module re-export.
270+
# In this case, resolve to the file itself and dynamically resolve the symbol later.
271+
return ImportResolution(from_file=file, symbol=None, imports_file=True)
272+
return ImportResolution(from_file=file, symbol=export_symbol)
273+
274+
# If the imported file is not found, treat it as an external module
275+
return None
276+
except AssertionError:
277+
# Codebase is probably trying to import file from outside repo
278+
return None
275279

276280
@noapidoc
277281
@reader

0 commit comments

Comments
 (0)