diff --git a/scripts/update_copyright_headers.py b/scripts/update_copyright_headers.py index 1691b126c..95032f547 100644 --- a/scripts/update_copyright_headers.py +++ b/scripts/update_copyright_headers.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Script to insert legal header to files across the repo +Script to insert legal header in files in a folder """ import glob @@ -15,7 +15,7 @@ from typing import List from typing import Optional -REPO_ROOT_PATH = "../" +FOLDER_PATH = "../" EXCLUDE_DIR_NAMES = [ "data", @@ -24,11 +24,12 @@ "__pycache__", ] -COPYRIGHT_LINES = [ - "Copyright (c) Meta Platforms and its affiliates.", - "This source code is licensed under the MIT license found in the", - "LICENSE file in the root directory of this source tree.", -] +# Note: it must not contain blank lines +COPYRIGHT_TEXT = """ +Copyright (c) Meta Platforms and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" class ProcessingError(Exception): @@ -51,14 +52,17 @@ def _make_copyright_lines(ext: str) -> List[str]: Insert copyright as comment lines specific to file extension """ lines = [] + copyright_lines = COPYRIGHT_TEXT.split("\n") + # Copyright notice must not contain blank lines + copyright_lines = [line for line in copyright_lines if line] ext = ext.lstrip(".") if ext in ["py", "sh", "yml", "yaml"]: - lines = _add_prefix_suffix("# ", COPYRIGHT_LINES) + lines = _add_prefix_suffix("# ", copyright_lines) elif ext in ["js", "jsx", "ts", "tsx", "css", "scss"]: - lines = ["/*"] + _add_prefix_suffix(" * ", COPYRIGHT_LINES) + [" */"] + lines = ["/*"] + _add_prefix_suffix(" * ", copyright_lines) + [" */"] elif ext in ["md", "html"]: - lines = [""] + lines = [""] else: raise UnsupportedFile(f"Unsupported file extension `{ext}`") @@ -80,16 +84,18 @@ def _update_copyright_header(file_path: str, replace_existing: bool = False): raise UnsupportedFile("File has fewer than one line") # Check copyright presence at the top of the file - anchor_line_number = -2 + anchor_line_number = None + likelihood_score = 0 for i, line in enumerate(lines[:EXAMINED_LINES]): text = line[:50].lower() if "copyright" in text: - anchor_line_number = -1 - if anchor_line_number == -1 and ("meta" in text or "facebook" in text): + likelihood_score += 1 + if "meta" in text or "facebook" in text: + likelihood_score += 1 anchor_line_number = i new_lines = None - if anchor_line_number < 0: + if likelihood_score < 2: # Insert a new copyright notice print("Inserting new notice") @@ -120,7 +126,6 @@ def _update_copyright_header(file_path: str, replace_existing: bool = False): first_line_number = i + 1 else: last_line_number = i - 1 - break if last_line_number is None: raise ProcessingError( @@ -128,19 +133,10 @@ def _update_copyright_header(file_path: str, replace_existing: bool = False): "(empty line missing right after?)" ) - # Prevent removing code lines that are "glued" to the bottom of a copyright notice - had_empty_line_after_comment = True - is_comment_last_line = lambda line: bool(set(["/", "#", ">", "*"]) & set(line[:5])) - while last_line_number and not is_comment_last_line(lines[last_line_number]): - last_line_number -= 1 - had_empty_line_after_comment = False - - if last_line_number < anchor_line_number: - raise ProcessingError("Could not confirm last line of copyright note comment") - # Note that we're also replacing an empty line after copyright notice lines_before_copyright = lines[:first_line_number] - lines_after_copyright = lines[last_line_number + 1 + had_empty_line_after_comment :] + lines_after_copyright = lines[last_line_number + 2 :] + new_lines = ( lines_before_copyright + _make_copyright_lines(ext) + ["\n"] + lines_after_copyright ) @@ -160,8 +156,6 @@ def run( :param extension: When specified, we"ll only process files with this extension :param replace_existing: When True, we will replace existing copyright notice if found """ - assert "" not in COPYRIGHT_LINES, "Copyright notice must not contain empty lines" - # Filter files with glob glob_pattern = start_path.rstrip("/") + "/**" if extension: @@ -201,16 +195,16 @@ def run( failed.append(path) print( - f"\nFinished processing. " + f"\nProcessed {n_files} files in total. " f"Updated {len(updated)}, skipped {len(skipped)}, failed {len(failed)} files." ) if failed: - failed_files = "\n\t" + "\n\t".join(failed) + failed_files = "\n".join(failed) print(f"\nThe following files failed: {failed_files}") if __name__ == "__main__": run( - REPO_ROOT_PATH, + FOLDER_PATH, replace_existing=True, )