diff --git a/script/lean-bisect b/script/lean-bisect new file mode 100755 index 0000000000..faaf4d7b55 --- /dev/null +++ b/script/lean-bisect @@ -0,0 +1,1290 @@ +#!/usr/bin/env python3 +""" +lean-bisect: Bisect Lean toolchain versions to find where behavior changes. + +Usage: + lean-bisect path/to/file.lean # auto-find regression + lean-bisect path/to/file.lean ..nightly-2024-06-01 # bisect to a nightly + lean-bisect path/to/file.lean nightly-2024-01-01..nightly-2024-06-01 + lean-bisect path/to/file.lean nightly-2024-01-01..nightly-2024-06-01 --timeout 30 + lean-bisect path/to/file.lean abc1234..def5678 # bisect commits + lean-bisect --selftest + lean-bisect --clear-cache + +For SHA-based bisection, the script will first try to download pre-built CI artifacts +from GitHub Actions (fast: ~30s) before falling back to building from source (slow: 2-5min). +CI artifacts are cached in ~/.cache/lean_build_artifact/ for reuse. +""" + +import argparse +import json +import os +import platform +import re +import shutil +import subprocess +import sys +import tempfile +import urllib.request +import urllib.error +from pathlib import Path +from dataclasses import dataclass +from typing import Optional, Tuple, List + +# Import shared artifact download functionality +sys.path.insert(0, str(Path(__file__).parent)) +import build_artifact + +# Constants +NIGHTLY_PATTERN = re.compile(r'^nightly-(\d{4})-(\d{2})-(\d{2})$') +VERSION_PATTERN = re.compile(r'^v4\.(\d+)\.(\d+)(-rc\d+)?$') +# Accept short SHAs (7+ chars) - we'll resolve to full SHA later +SHA_PATTERN = re.compile(r'^[0-9a-f]{7,40}$') + +GITHUB_API_BASE = "https://api.github.com" +NIGHTLY_REPO = "leanprover/lean4-nightly" +LEAN4_REPO = "leanprover/lean4" + +# Re-export from build_artifact for local use +ARTIFACT_CACHE = build_artifact.ARTIFACT_CACHE +CI_FAILED = build_artifact.CI_FAILED + +# ANSI colors for terminal output +class Colors: + RED = '\033[91m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + BLUE = '\033[94m' + BOLD = '\033[1m' + RESET = '\033[0m' + +def color(text: str, c: str) -> str: + """Apply color to text if stdout is a tty.""" + if sys.stdout.isatty(): + return f"{c}{text}{Colors.RESET}" + return text + +@dataclass +class BuildResult: + exit_code: int + stdout: str + stderr: str + timed_out: bool + + def signature(self, ignore_messages: bool) -> Tuple: + """Return a comparable signature for this result.""" + if ignore_messages: + # Treat timeout as a distinct exit code + return (self.exit_code if not self.timed_out else -124,) + return (self.exit_code if not self.timed_out else -124, self.stdout, self.stderr) + +def error(msg: str) -> None: + """Print error message and exit.""" + print(color(f"Error: {msg}", Colors.RED), file=sys.stderr) + sys.exit(1) + +def warn(msg: str) -> None: + """Print warning message.""" + print(color(f"Warning: {msg}", Colors.YELLOW), file=sys.stderr) + +def info(msg: str) -> None: + """Print info message to stdout.""" + print(color(msg, Colors.BLUE)) + +def success(msg: str) -> None: + """Print success message to stdout.""" + print(color(msg, Colors.GREEN)) + +# ----------------------------------------------------------------------------- +# Import sanity check +# ----------------------------------------------------------------------------- + +def check_imports(file_path: Path) -> None: + """Check that the file only imports from Lean.* or Std.*""" + allowed_prefixes = ('Lean', 'Std') + + with open(file_path, 'r') as f: + content = f.read() + + import_pattern = re.compile(r'^\s*import\s+(\S+)', re.MULTILINE) + + for match in import_pattern.finditer(content): + module = match.group(1) + # Get the first component of the module path + first_component = module.split('.')[0] + + if first_component not in allowed_prefixes: + error(f"File imports '{module}' which is outside Lean.*/Std.*\n" + f"lean-bisect only supports files that import from Lean or Std.\n" + f"This is because we test against bare toolchains without lake dependencies.") + +# ----------------------------------------------------------------------------- +# Identifier type detection +# ----------------------------------------------------------------------------- + +def resolve_sha(short_sha: str) -> str: + """Resolve a (possibly short) SHA to full 40-character SHA using git rev-parse.""" + if len(short_sha) == 40: + return short_sha + try: + result = subprocess.run( + ['git', 'rev-parse', short_sha], + capture_output=True, + text=True, + timeout=5 + ) + if result.returncode == 0: + full_sha = result.stdout.strip() + if len(full_sha) == 40: + return full_sha + error(f"Cannot resolve SHA '{short_sha}': {result.stderr.strip() or 'not found in repository'}") + except subprocess.TimeoutExpired: + error(f"Timeout resolving SHA '{short_sha}'") + except FileNotFoundError: + error("git not found - required for SHA resolution") + +def parse_identifier(s: str) -> Tuple[str, str]: + """ + Parse an identifier and return (type, value). + Types: 'nightly', 'version', 'sha' + For SHAs, resolves short SHAs to full 40-character SHAs. + """ + if NIGHTLY_PATTERN.match(s): + return ('nightly', s) + if VERSION_PATTERN.match(s): + return ('version', s) + if SHA_PATTERN.match(s): + full_sha = resolve_sha(s) + return ('sha', full_sha) + error(f"Invalid identifier format: '{s}'\n" + f"Expected one of:\n" + f" - nightly-YYYY-MM-DD (e.g., nightly-2024-06-15)\n" + f" - v4.X.Y or v4.X.Y-rcK (e.g., v4.8.0, v4.9.0-rc1)\n" + f" - commit SHA (short or full)") + + +def parse_range(range_str: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + """ + Parse a range string and return (from_id, to_id). + + Syntax: + FROM...TO or FROM..TO → (FROM, TO) + FROM → (FROM, None) + ...TO or ..TO → (None, TO) + None → (None, None) + """ + if range_str is None: + return (None, None) + + # Check for range separator (... or ..) + if '...' in range_str: + parts = range_str.split('...', 1) + from_id = parts[0] if parts[0] else None + to_id = parts[1] if parts[1] else None + return (from_id, to_id) + elif '..' in range_str: + parts = range_str.split('..', 1) + from_id = parts[0] if parts[0] else None + to_id = parts[1] if parts[1] else None + return (from_id, to_id) + else: + # Single identifier = FROM + return (range_str, None) + + +# ----------------------------------------------------------------------------- +# GitHub API helpers +# ----------------------------------------------------------------------------- + +def github_api_request(url: str) -> dict: + """Make a GitHub API request and return JSON response.""" + headers = { + 'Accept': 'application/vnd.github.v3+json', + 'User-Agent': 'lean-bisect' + } + + token = build_artifact.get_github_token() + if token: + headers['Authorization'] = f'token {token}' + + req = urllib.request.Request(url, headers=headers) + try: + with urllib.request.urlopen(req, timeout=30) as response: + return json.loads(response.read().decode()) + except urllib.error.HTTPError as e: + if e.code == 403: + error(f"GitHub API rate limit exceeded. Set GITHUB_TOKEN environment variable to increase limit.") + elif e.code == 404: + error(f"GitHub resource not found: {url}") + else: + error(f"GitHub API error: {e.code} {e.reason}") + except urllib.error.URLError as e: + error(f"Network error accessing GitHub API: {e.reason}") + +def fetch_nightly_tags() -> List[str]: + """Fetch all nightly tags from GitHub, sorted by date (oldest first).""" + tags = [] + page = 1 + + while True: + url = f"{GITHUB_API_BASE}/repos/{NIGHTLY_REPO}/tags?per_page=100&page={page}" + data = github_api_request(url) + + if not data: + break + + for tag in data: + name = tag['name'] + if NIGHTLY_PATTERN.match(name): + tags.append(name) + + page += 1 + + # GitHub returns empty list when no more pages + if len(data) < 100: + break + + # Sort by date (nightly-YYYY-MM-DD format sorts lexicographically) + tags.sort() + return tags + +def get_commit_for_nightly(nightly: str) -> str: + """Get the commit SHA for a nightly tag.""" + url = f"{GITHUB_API_BASE}/repos/{NIGHTLY_REPO}/git/refs/tags/{nightly}" + data = github_api_request(url) + + # The ref might point to a tag object or directly to a commit + sha = data['object']['sha'] + obj_type = data['object']['type'] + + if obj_type == 'tag': + # Need to dereference the tag object + tag_url = f"{GITHUB_API_BASE}/repos/{NIGHTLY_REPO}/git/tags/{sha}" + tag_data = github_api_request(tag_url) + sha = tag_data['object']['sha'] + + return sha + +def get_commit_for_version(version: str) -> str: + """Get the commit SHA for a version tag in lean4 repo.""" + url = f"{GITHUB_API_BASE}/repos/{LEAN4_REPO}/git/refs/tags/{version}" + data = github_api_request(url) + + sha = data['object']['sha'] + obj_type = data['object']['type'] + + if obj_type == 'tag': + tag_url = f"{GITHUB_API_BASE}/repos/{LEAN4_REPO}/git/tags/{sha}" + tag_data = github_api_request(tag_url) + sha = tag_data['object']['sha'] + + return sha + +# ----------------------------------------------------------------------------- +# Build functions +# ----------------------------------------------------------------------------- + +def ensure_toolchain_installed(toolchain: str, work_dir: Path) -> None: + """Ensure the toolchain is installed (let elan handle it).""" + # Write lean-toolchain file + (work_dir / 'lean-toolchain').write_text(f'leanprover/lean4:{toolchain}\n') + + # Run a simple command to trigger toolchain download + # We use 'lake --version' as it's quick and triggers elan + # Don't capture output so user can see download progress + try: + subprocess.run( + ['lake', '--version'], + cwd=work_dir, + stdout=subprocess.DEVNULL, # Hide "Lake version ..." output + # stderr passes through to show elan download progress + timeout=600 # 10 minutes for toolchain download + ) + except subprocess.TimeoutExpired: + error(f"Timeout waiting for toolchain {toolchain} to download") + except FileNotFoundError: + error("lake not found. Is elan installed and in PATH?") + +def build_with_toolchain( + toolchain: str, + file_path: Path, + work_dir: Path, + timeout: Optional[int] = None +) -> BuildResult: + """ + Build file with given toolchain. + Returns BuildResult with exit_code, stdout, stderr, and timed_out flag. + """ + # Ensure toolchain is installed first (not subject to timeout) + ensure_toolchain_installed(toolchain, work_dir) + + # Copy file to work directory + target_file = work_dir / file_path.name + shutil.copy(file_path, target_file) + + # Run lean on the file + try: + result = subprocess.run( + ['lake', 'env', 'lean', target_file.name], + cwd=work_dir, + capture_output=True, + timeout=timeout, + text=True + ) + return BuildResult( + exit_code=result.returncode, + stdout=result.stdout, + stderr=result.stderr, + timed_out=False + ) + except subprocess.TimeoutExpired: + return BuildResult( + exit_code=-124, + stdout='', + stderr=f'Process timed out after {timeout} seconds', + timed_out=True + ) + except FileNotFoundError: + error("lake not found. Is elan installed and in PATH?") + +# ----------------------------------------------------------------------------- +# Bisection +# ----------------------------------------------------------------------------- + +def format_result(result: BuildResult) -> str: + """Format a build result for display.""" + if result.timed_out: + return color("TIMEOUT", Colors.YELLOW) + elif result.exit_code == 0: + return color("OK", Colors.GREEN) + else: + return color(f"FAIL (exit {result.exit_code})", Colors.RED) + +def print_verbose_result(result: BuildResult) -> None: + """Print detailed output from a build result.""" + if result.stdout.strip(): + print(color(" stdout:", Colors.BOLD)) + for line in result.stdout.strip().split('\n'): + print(f" {line}") + if result.stderr.strip(): + print(color(" stderr:", Colors.BOLD)) + for line in result.stderr.strip().split('\n'): + print(f" {line}") + +def bisect_nightlies( + file_path: Path, + nightlies: List[str], + work_dir: Path, + timeout: Optional[int], + ignore_messages: bool, + verbose: bool = False, + file_display: str = None +) -> Tuple[str, str]: + """ + Bisect through nightlies to find where behavior changes. + Returns the pair of adjacent nightlies where behavior changes. + """ + results = {} # Cache of results + + def test_nightly(nightly: str) -> BuildResult: + if nightly not in results: + print(f" Testing {nightly}... ", end='', flush=True) + result = build_with_toolchain(nightly, file_path, work_dir, timeout) + results[nightly] = result + print(format_result(result)) + if verbose: + print_verbose_result(result) + return results[nightly] + + lo, hi = 0, len(nightlies) - 1 + + # Test endpoints + lo_result = test_nightly(nightlies[lo]) + hi_result = test_nightly(nightlies[hi]) + + lo_sig = lo_result.signature(ignore_messages) + hi_sig = hi_result.signature(ignore_messages) + + if lo_sig == hi_sig: + return None, None # No change detected + + # Binary search + while hi - lo > 1: + mid = (lo + hi) // 2 + mid_result = test_nightly(nightlies[mid]) + mid_sig = mid_result.signature(ignore_messages) + + # Update bounds first + if mid_sig == lo_sig: + lo = mid + else: + hi = mid + + # Then print the updated range + print(f"\n Current range: {nightlies[lo]} .. {nightlies[hi]} ({hi - lo} nightlies)") + print(f" Resume command: {sys.argv[0]} {file_display or file_path} {nightlies[lo]}...{nightlies[hi]}", end='') + if timeout: + print(f" --timeout {timeout}", end='') + if ignore_messages: + print(" --ignore-messages", end='') + print("\n") + + return nightlies[lo], nightlies[hi] + +def find_regression_range( + file_path: Path, + to_nightly: str, + all_nightlies: List[str], + work_dir: Path, + timeout: Optional[int], + ignore_messages: bool +) -> Tuple[Optional[str], Optional[str]]: + """ + Find a nightly range that brackets a behavior change using exponential search. + Starts from to_nightly and goes back in time, doubling the step each iteration. + Returns (from_nightly, to_nightly) bracketing the change, or (None, None) if not found. + """ + results = {} # Cache of results + + def test_nightly(nightly: str) -> BuildResult: + if nightly not in results: + print(f" Testing {nightly}... ", end='', flush=True) + result = build_with_toolchain(nightly, file_path, work_dir, timeout) + results[nightly] = result + print(format_result(result)) + return results[nightly] + + # Find position of to_nightly in the list + if to_nightly not in all_nightlies: + error(f"Nightly {to_nightly} not found") + to_idx = all_nightlies.index(to_nightly) + + # Test the target nightly first + info(f"Testing target nightly {to_nightly}...") + to_result = test_nightly(to_nightly) + to_sig = to_result.signature(ignore_messages) + + # Exponential search backwards + step = 1 + prev_idx = to_idx + + while True: + # Calculate target index, going back by 'step' nightlies + target_idx = to_idx - step + + if target_idx < 0: + # We've gone past the beginning of available nightlies + target_idx = 0 + if target_idx == prev_idx: + warn("Reached earliest available nightly without finding behavior change") + return None, None + + from_nightly = all_nightlies[target_idx] + print() + info(f"Testing {from_nightly} (step={step})...") + from_result = test_nightly(from_nightly) + from_sig = from_result.signature(ignore_messages) + + if from_sig != to_sig: + # Found a difference! The regression is between target_idx and prev_idx + success(f"Found behavior change between {from_nightly} and {all_nightlies[prev_idx]}") + return from_nightly, all_nightlies[prev_idx] + + if target_idx == 0: + warn("Reached earliest available nightly without finding behavior change") + return None, None + + prev_idx = target_idx + step *= 2 + +# ----------------------------------------------------------------------------- +# Version tag handling +# ----------------------------------------------------------------------------- + +def find_closest_nightlies_for_version(version: str, all_nightlies: List[str]) -> Tuple[Optional[str], Optional[str]]: + """ + Find the closest nightlies before and after a version's branch point from master. + + For v4.X.Y, finds where releases/v4.X.0 diverged from master, since nightlies + are built from master. This means v4.25.0, v4.25.1, v4.25.2 all map to the + same nightly range. + + Returns (nightly_before, nightly_after). + """ + # Parse version to get the release branch name (v4.X.Y -> releases/v4.X.0) + # VERSION_PATTERN is r'^v4\.(\d+)\.(\d+)(-rc\d+)?$' + # So group(1) is the minor version (X), group(2) is patch (Y) + match = VERSION_PATTERN.match(version) + if not match: + error(f"Invalid version format: {version}") + + minor_version = match.group(1) # e.g., "25" from v4.25.2 + release_branch = f"releases/v4.{minor_version}.0" + + info(f"Version {version} is on branch {release_branch}") + + # Find merge-base between release branch and master + # This is where the release branch diverged from master + try: + url = f"{GITHUB_API_BASE}/repos/{LEAN4_REPO}/compare/master...{release_branch}" + compare_data = github_api_request(url) + merge_base_sha = compare_data['merge_base_commit']['sha'] + merge_base_date = compare_data['merge_base_commit']['commit']['committer']['date'][:10] + + info(f"Branch {release_branch} diverged from master at {merge_base_sha[:12]} ({merge_base_date})") + except Exception as e: + # Fallback: try to get the commit date of the version tag directly + warn(f"Could not find merge base for {release_branch}, falling back to tag date") + version_sha = get_commit_for_version(version) + url = f"{GITHUB_API_BASE}/repos/{LEAN4_REPO}/commits/{version_sha}" + commit_data = github_api_request(url) + merge_base_date = commit_data['commit']['committer']['date'][:10] + + # Find nightlies around the merge base date + nightly_before = None + nightly_after = None + + for nightly in all_nightlies: + nightly_match = NIGHTLY_PATTERN.match(nightly) + if nightly_match: + nightly_date = f"{nightly_match.group(1)}-{nightly_match.group(2)}-{nightly_match.group(3)}" + if nightly_date <= merge_base_date: + nightly_before = nightly + elif nightly_date > merge_base_date and nightly_after is None: + nightly_after = nightly + + return nightly_before, nightly_after + +def handle_version_endpoints( + file_path: Path, + from_id: str, + to_id: str, + from_type: str, + to_type: str, + timeout: Optional[int], + ignore_messages: bool +) -> Tuple[str, str, List[str]]: + """ + Handle the case where one or both endpoints are version tags. + Converts version tags to equivalent nightlies. + Returns (from_nightly, to_nightly, all_nightlies). + """ + all_nightlies = fetch_nightly_tags() + + suggested_from = from_id + suggested_to = to_id + + if from_type == 'version': + before, after = find_closest_nightlies_for_version(from_id, all_nightlies) + if before: + info(f"For FROM={from_id}, using nightly {before} (latest nightly before the version)") + suggested_from = before + else: + error(f"Could not find a nightly before version {from_id}") + + if to_type == 'version': + before, after = find_closest_nightlies_for_version(to_id, all_nightlies) + if after: + info(f"For TO={to_id}, using nightly {after} (earliest nightly after the version)") + suggested_to = after + else: + error(f"Could not find a nightly after version {to_id}") + + return suggested_from, suggested_to, all_nightlies + +# ----------------------------------------------------------------------------- +# Commit SHA bisection +# ----------------------------------------------------------------------------- + +def find_lean4_repo() -> Optional[Path]: + """Try to find a local lean4 repository to use as reference.""" + cwd = Path.cwd() + + # If we're in a script/ directory, step up one level + if cwd.name == 'script': + cwd = cwd.parent + + # Check if this directory is a lean4 repo + try: + result = subprocess.run( + ['git', 'remote', 'get-url', 'origin'], + cwd=cwd, + capture_output=True, + text=True + ) + if result.returncode == 0 and 'lean4' in result.stdout.lower(): + return cwd + except: + pass + + return None + +def bisect_commits( + file_path: Path, + from_sha: str, + to_sha: str, + work_dir: Path, + timeout: Optional[int], + ignore_messages: bool, + file_display: str = None +) -> None: + """Bisect through commits on master.""" + build_dir = Path(tempfile.mkdtemp(prefix='lean-bisect-build-')) + + # Try to find a local lean4 repo to use as reference (speeds up clone significantly) + local_repo = find_lean4_repo() + + try: + if local_repo: + info(f"Found local lean4 repo at {local_repo}, using as reference...") + info(f"Cloning to {build_dir}...") + subprocess.run( + ['git', 'clone', '--reference', str(local_repo), + f'https://github.com/{LEAN4_REPO}.git', str(build_dir)], + check=True + ) + else: + info(f"No local lean4 repo found, cloning from scratch to {build_dir}...") + info("(Tip: run from within a lean4 checkout to speed this up)") + subprocess.run( + ['git', 'clone', f'https://github.com/{LEAN4_REPO}.git', str(build_dir)], + check=True + ) + + # Fetch latest from origin + subprocess.run( + ['git', 'fetch', 'origin'], + cwd=build_dir, + check=True + ) + + # Verify commits are on master + for sha in [from_sha, to_sha]: + result = subprocess.run( + ['git', 'merge-base', '--is-ancestor', sha, 'origin/master'], + cwd=build_dir, + capture_output=True + ) + if result.returncode != 0: + error(f"Commit {sha} is not an ancestor of master") + + # Get list of commits between from and to (need full SHAs for CI artifact lookup) + result = subprocess.run( + ['git', 'log', '--format=%H', '--reverse', f'{from_sha}..{to_sha}'], + cwd=build_dir, + capture_output=True, + text=True, + check=True + ) + + commits = [from_sha] # Include the from commit + for line in result.stdout.strip().split('\n'): + if line: + commits.append(line.strip()) + + info(f"Found {len(commits)} commits to bisect") + print() + + results = {} # sha -> BuildResult + failed_builds = set() # commits that failed to build + no_artifact = set() # commits with no CI artifact available + + # Check if we can use CI artifacts + use_artifacts = (build_artifact.check_gh_available() and + build_artifact.check_zstd_support() and + build_artifact.get_artifact_name() is not None) + if use_artifacts: + info("CI artifact download available (will try before building from source)") + else: + if not build_artifact.check_gh_available(): + warn("gh CLI not available or not authenticated; will build from source") + elif not build_artifact.check_zstd_support(): + warn("tar does not support zstd; will build from source") + elif build_artifact.get_artifact_name() is None: + warn("No CI artifacts available for this platform; will build from source") + print() + + def get_toolchain_for_commit(sha: str) -> Optional[Path]: + """ + Get a toolchain path for testing a commit. + Tries CI artifact first, falls back to source build. + Returns toolchain path or None if unavailable. + """ + # Try CI artifact first (unless we know there isn't one) + if use_artifacts and sha not in no_artifact: + print(f" Checking {sha[:12]}... ", end='', flush=True) + + # Check if already cached + if build_artifact.is_cached(sha): + print(color("cached", Colors.GREEN)) + return build_artifact.get_cache_path(sha) + + artifact_result = build_artifact.download_ci_artifact(sha) + if artifact_result is CI_FAILED: + print(color("CI failed, skipping", Colors.YELLOW)) + return None # Don't bother building locally + elif artifact_result: + print(color("using CI artifact", Colors.GREEN)) + return artifact_result + else: + no_artifact.add(sha) + print("no artifact, ", end='', flush=True) + else: + print(f" Checking out {sha[:12]}... ", end='', flush=True) + + # Fall back to building from source + return build_lean_at_commit_internal(sha) + + def build_lean_at_commit_internal(sha: str) -> Optional[Path]: + """Build lean at a specific commit. Returns stage1 path if successful, None otherwise.""" + # Only print "Checking out" if we didn't already (when artifact path printed it) + if sha not in no_artifact or not use_artifacts: + pass # Already printed in get_toolchain_for_commit + subprocess.run( + ['git', 'checkout', '-q', sha], + cwd=build_dir, + check=True + ) + + # Configure cmake if build/release doesn't exist + build_release = build_dir / 'build' / 'release' + if not build_release.exists(): + print("configuring... ", end='', flush=True) + build_release.mkdir(parents=True, exist_ok=True) + result = subprocess.run( + ['cmake', '../..', '-DCMAKE_BUILD_TYPE=Release'], + cwd=build_release, + capture_output=True, + text=True + ) + if result.returncode != 0: + print(color("CMAKE FAILED", Colors.YELLOW)) + failed_builds.add(sha) + return None + + print("building... ", end='', flush=True) + + # Build lean + result = subprocess.run( + ['make', '-j', '-C', 'build/release'], + cwd=build_dir, + capture_output=True, + text=True + ) + + if result.returncode != 0: + print(color("BUILD FAILED", Colors.YELLOW)) + failed_builds.add(sha) + return None + + print(color("built", Colors.GREEN)) + return build_dir / 'build' / 'release' / 'stage1' + + def test_commit(sha: str) -> Optional[BuildResult]: + """Test the file against a commit. Returns None if toolchain unavailable.""" + if sha in results: + return results[sha] + + if sha in failed_builds: + return None + + toolchain_path = get_toolchain_for_commit(sha) + if toolchain_path is None: + return None + + # Link as elan toolchain + toolchain_name = f'lean-bisect-{sha[:12]}' + + subprocess.run( + ['elan', 'toolchain', 'link', toolchain_name, str(toolchain_path)], + check=True, + capture_output=True + ) + + try: + # Write lean-toolchain and test + (work_dir / 'lean-toolchain').write_text(f'{toolchain_name}\n') + shutil.copy(file_path, work_dir / file_path.name) + + print(f" Testing {sha[:12]}... ", end='', flush=True) + + try: + result = subprocess.run( + ['lake', 'env', 'lean', file_path.name], + cwd=work_dir, + capture_output=True, + timeout=timeout, + text=True + ) + build_result = BuildResult( + exit_code=result.returncode, + stdout=result.stdout, + stderr=result.stderr, + timed_out=False + ) + except subprocess.TimeoutExpired: + build_result = BuildResult( + exit_code=-124, + stdout='', + stderr=f'Process timed out after {timeout} seconds', + timed_out=True + ) + + results[sha] = build_result + print(format_result(build_result)) + return build_result + + finally: + # Unlink toolchain + subprocess.run( + ['elan', 'toolchain', 'uninstall', toolchain_name], + capture_output=True + ) + + # Test endpoints + lo, hi = 0, len(commits) - 1 + + lo_result = test_commit(commits[lo]) + if lo_result is None: + error(f"Could not build at starting commit {commits[lo][:12]}") + + hi_result = test_commit(commits[hi]) + if hi_result is None: + error(f"Could not build at ending commit {commits[hi][:12]}") + + lo_sig = lo_result.signature(ignore_messages) + hi_sig = hi_result.signature(ignore_messages) + + if lo_sig == hi_sig: + info("No behavior change detected between the endpoints.") + return + + # Binary search + while hi - lo > 1: + mid = (lo + hi) // 2 + mid_result = test_commit(commits[mid]) + + if mid_result is None: + # Build failed, try to find another commit nearby + # Search outward from mid for a buildable commit + found = False + for offset in range(1, hi - lo): + for candidate in [mid - offset, mid + offset]: + if lo < candidate < hi and candidate not in failed_builds: + mid_result = test_commit(commits[candidate]) + if mid_result is not None: + mid = candidate + found = True + break + if found: + break + + if not found: + warn("No buildable commits found in range, cannot narrow further") + break + + mid_sig = mid_result.signature(ignore_messages) + + if mid_sig == lo_sig: + lo = mid + else: + hi = mid + + print(f"\n Current range: {commits[lo][:12]} .. {commits[hi][:12]} ({hi - lo} commits)") + print(f" Resume command: {sys.argv[0]} {file_display or file_path} {commits[lo]}...{commits[hi]}", end='') + if timeout: + print(f" --timeout {timeout}", end='') + if ignore_messages: + print(" --ignore-messages", end='') + print("\n") + + # Report result + print() + if hi - lo == 1: + success(f"Behavior change introduced in commit {commits[hi][:12]}") + print() + # Show commit info + result = subprocess.run( + ['git', 'log', '-1', '--oneline', commits[hi]], + cwd=build_dir, + capture_output=True, + text=True + ) + if result.returncode == 0: + print(f" {result.stdout.strip()}") + print() + print(f" Full SHA: {commits[hi]}") + print(f" View on GitHub: https://github.com/{LEAN4_REPO}/commit/{commits[hi]}") + else: + warn(f"Narrowed to range {commits[lo][:12]} .. {commits[hi][:12]} ({hi - lo} commits)") + if failed_builds: + warn(f"Note: {len(failed_builds)} commits failed to build and were skipped") + + finally: + # Clean up + shutil.rmtree(build_dir, ignore_errors=True) + +# ----------------------------------------------------------------------------- +# Selftest +# ----------------------------------------------------------------------------- + +def run_selftest() -> None: + """Run the built-in selftest by shelling out to lean-bisect.""" + script_path = Path(__file__).resolve() + script_dir = script_path.parent + test_file_abs = script_dir / 'lean-bisect-test.lean' + + if not test_file_abs.exists(): + error(f"Selftest file not found: {test_file_abs}") + + # Use relative paths for nicer output + try: + test_file = os.path.relpath(test_file_abs) + script_rel = os.path.relpath(script_path) + except ValueError: + # On Windows, relpath can fail across drives + test_file = str(test_file_abs) + script_rel = str(script_path) + + print(color("Running lean-bisect selftest...", Colors.BOLD)) + print(f"Test file: {test_file}") + print(f"Running: {script_rel} {test_file} nightly-2025-11-01...nightly-2025-11-15 --timeout 10 --ignore-messages -v") + print() + + # Run lean-bisect as a subprocess, streaming output + cmd = [ + sys.executable, str(script_path), + str(test_file), + 'nightly-2025-11-01...nightly-2025-11-15', + '--timeout', '10', + '--ignore-messages', + '-v' + ] + + # Capture last N lines for validation + output_lines = [] + max_lines = 50 + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1 + ) + + # Stream output to user and capture it + for line in process.stdout: + print(line, end='') + output_lines.append(line.rstrip()) + if len(output_lines) > max_lines: + output_lines.pop(0) + + process.wait() + + print() + print(color("=" * 60, Colors.BOLD)) + print(color("Selftest validation:", Colors.BOLD)) + print() + + # Check for expected output + output_text = '\n'.join(output_lines) + + expected_change = "Behavior change detected between nightly-2025-11-06 and nightly-2025-11-07" + + if expected_change in output_text: + success("Selftest PASSED!") + success(f"Found expected: {expected_change}") + sys.exit(0) + else: + print(color("Selftest FAILED!", Colors.RED)) + print(f"Expected to find: {expected_change}") + print(f"Last {len(output_lines)} lines of output:") + for line in output_lines[-20:]: + print(f" {line}") + sys.exit(1) + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description='Bisect Lean toolchain versions to find where behavior changes.', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Range Syntax: + + FROM..TO Bisect between FROM and TO + FROM Start from FROM, bisect to latest nightly + ..TO Bisect to TO, search backwards for regression start + + If no range given, searches backwards from latest nightly to find regression. + +Identifier Formats: + + nightly-YYYY-MM-DD Nightly build date (e.g., nightly-2024-06-15) + Uses pre-built toolchains from leanprover/lean4-nightly. + Fast: downloads via elan (~30s each). + + v4.X.Y or v4.X.Y-rcN Version tag (e.g., v4.8.0, v4.9.0-rc1) + Converts to equivalent nightly range. + + Commit SHA Git commit hash (short or full, e.g., abc123def) + Bisects individual commits between two points. + Tries CI artifacts first (~30s), falls back to building (~2-5min). + Commits with failed CI builds are automatically skipped. + Artifacts cached in ~/.cache/lean_build_artifact/ + +Bisection Modes: + + Nightly mode: Both endpoints are nightly dates. + Binary search through nightlies to find the day behavior changed. + Then automatically continues to bisect individual commits. + Use --nightly-only to stop after finding the nightly range. + + Version mode: Either endpoint is a version tag. + Converts to equivalent nightly range and bisects. + + Commit mode: Both endpoints are commit SHAs. + Binary search through individual commits on master. + Output: "Behavior change introduced in commit abc123" + +Examples: + + # Simplest: just provide the file, finds the regression automatically + lean-bisect test.lean + + # Specify an endpoint if you know roughly when it broke + lean-bisect test.lean ..nightly-2024-06-01 + + # Full manual control over the range + lean-bisect test.lean nightly-2024-01-01..nightly-2024-06-01 + + # Only find the nightly range, don't continue to commit bisection + lean-bisect test.lean nightly-2024-01-01..nightly-2024-06-01 --nightly-only + + # Add a timeout (kills slow/hanging tests) + lean-bisect test.lean --timeout 30 + + # Bisect commits directly (if you already know the commit range) + lean-bisect test.lean abc1234..def5678 + + # Only compare exit codes, ignore output differences + lean-bisect test.lean --ignore-messages + + # Clear downloaded CI artifacts to free disk space + lean-bisect --clear-cache +""" + ) + + parser.add_argument('file', nargs='?', help='Lean file to test (must only import Lean.* or Std.*)') + parser.add_argument('range', nargs='?', metavar='RANGE', + help='Range to bisect: FROM..TO, FROM, or ..TO') + parser.add_argument('--timeout', type=int, metavar='SEC', + help='Timeout in seconds for each test run') + parser.add_argument('--ignore-messages', action='store_true', + help='Compare only exit codes, ignore stdout/stderr differences') + parser.add_argument('--verbose', '-v', action='store_true', + help='Show stdout/stderr from each test') + parser.add_argument('--selftest', action='store_true', + help='Run built-in selftest to verify lean-bisect works') + parser.add_argument('--clear-cache', action='store_true', + help='Clear CI artifact cache (~600MB per commit) and exit') + parser.add_argument('--nightly-only', action='store_true', + help='Stop after finding nightly range (don\'t bisect individual commits)') + + args = parser.parse_args() + + # Show help if no arguments provided + if len(sys.argv) == 1: + parser.print_help() + sys.exit(0) + + # Handle cache clearing + if args.clear_cache: + if ARTIFACT_CACHE.exists(): + size = sum(f.stat().st_size for f in ARTIFACT_CACHE.rglob('*') if f.is_file()) + shutil.rmtree(ARTIFACT_CACHE) + info(f"Cleared cache at {ARTIFACT_CACHE} ({size / 1024 / 1024:.1f} MB)") + else: + info(f"Cache directory does not exist: {ARTIFACT_CACHE}") + return + + # Handle selftest + if args.selftest: + run_selftest() + return + + # Validate arguments + if not args.file: + parser.error("file is required (unless using --selftest)") + + file_arg = args.file # Preserve original for resume commands + file_path = Path(args.file).resolve() + if not file_path.exists(): + error(f"File not found: {file_path}") + + # Check imports + check_imports(file_path) + + # Parse range syntax + all_nightlies = None # Lazy load + from_id, to_id = parse_range(args.range) + + if to_id: + to_type, to_val = parse_identifier(to_id) + else: + # Default to most recent nightly + info("No endpoint specified, fetching latest nightly...") + all_nightlies = fetch_nightly_tags() + to_val = all_nightlies[-1] + to_type = 'nightly' + info(f"Using latest nightly: {to_val}") + + if from_id: + from_type, from_val = parse_identifier(from_id) + else: + # Will use exponential search to find regression range + from_type, from_val = None, None + + # Validate --nightly-only + if args.nightly_only: + if from_val is not None and from_type != 'nightly': + error("--nightly-only requires FROM to be a nightly identifier (nightly-YYYY-MM-DD)") + if to_type != 'nightly': + error("--nightly-only requires TO to be a nightly identifier (nightly-YYYY-MM-DD)") + + if from_val: + info(f"From: {from_val} ({from_type})") + else: + info("From: (will search backwards to find regression)") + info(f"To: {to_val} ({to_type})") + print() + + # Handle different combinations + if from_type == 'version' or to_type == 'version': + # Version tag handling - convert to nightlies and continue + from_val, to_val, all_nightlies = handle_version_endpoints( + file_path, from_val, to_val, + from_type, to_type, + args.timeout, args.ignore_messages + ) + from_type, to_type = 'nightly', 'nightly' + print() + # Fall through to nightly bisection below + + elif from_type == 'sha' and to_type == 'sha': + # Commit SHA bisection + work_dir = Path(tempfile.mkdtemp(prefix='lean-bisect-')) + try: + bisect_commits( + file_path, from_val, to_val, + work_dir, args.timeout, args.ignore_messages, + file_display=file_arg + ) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + return + + # Nightly bisection (with optional exponential search for FROM) + if from_type is not None and from_type != 'nightly': + error("Mixed identifier types not supported. Both FROM and TO must be nightlies, versions, or SHAs.") + if to_type != 'nightly': + error("Mixed identifier types not supported. Both FROM and TO must be nightlies, versions, or SHAs.") + + # Fetch all nightly tags (if not already fetched) + if all_nightlies is None: + info("Fetching nightly tags from GitHub...") + all_nightlies = fetch_nightly_tags() + + # Validate to_val exists + if to_val not in all_nightlies: + error(f"Nightly {to_val} not found. Check https://github.com/{NIGHTLY_REPO}/tags") + + # Create temp directory + work_dir = Path(tempfile.mkdtemp(prefix='lean-bisect-')) + + try: + # If FROM not specified, use exponential search to find regression range + if from_val is None: + info("Searching backwards to find where behavior changed...") + print() + from_val, to_val = find_regression_range( + file_path, to_val, all_nightlies, + work_dir, args.timeout, args.ignore_messages + ) + if from_val is None: + info("Could not find a behavior change in available nightlies.") + return + print() + info(f"Narrowed to range: {from_val} .. {to_val}") + print() + + # Validate from_val exists + if from_val not in all_nightlies: + error(f"Nightly {from_val} not found. Check https://github.com/{NIGHTLY_REPO}/tags") + + # Get range + from_idx = all_nightlies.index(from_val) + to_idx = all_nightlies.index(to_val) + + if from_idx > to_idx: + error(f"FROM ({from_val}) must be before TO ({to_val})") + + nightlies = all_nightlies[from_idx:to_idx + 1] + info(f"Bisecting range: {from_val} .. {to_val} ({len(nightlies)} nightlies)") + print() + + lo, hi = bisect_nightlies( + file_path, nightlies, work_dir, + args.timeout, args.ignore_messages, + verbose=args.verbose, + file_display=file_arg + ) + + print() + + if lo is None: + info("No behavior change detected between the endpoints.") + info("Both toolchains produce the same output.") + return + + success(f"Behavior change detected between {lo} and {hi}") + print() + + # Get commit SHAs for the nightlies + try: + info("Fetching commit SHAs for these nightlies...") + lo_sha = get_commit_for_nightly(lo) + hi_sha = get_commit_for_nightly(hi) + print(f" {lo} -> {lo_sha[:12]}") + print(f" {hi} -> {hi_sha[:12]}") + print() + + if args.nightly_only: + # Just show the command to run manually + print("To identify the exact commit that introduced the change, run:") + cmd = f" {sys.argv[0]} {file_arg} {lo_sha}...{hi_sha}" + if args.timeout: + cmd += f" --timeout {args.timeout}" + if args.ignore_messages: + cmd += " --ignore-messages" + print(color(cmd, Colors.BOLD)) + else: + # Automatically continue to commit bisection + info("Continuing to bisect individual commits...") + print() + bisect_commits( + file_path, lo_sha, hi_sha, + work_dir, args.timeout, args.ignore_messages, + file_display=file_arg + ) + except Exception as e: + warn(f"Could not fetch commit SHAs: {e}") + + finally: + shutil.rmtree(work_dir, ignore_errors=True) + +if __name__ == '__main__': + main() diff --git a/script/lean-bisect-test.lean b/script/lean-bisect-test.lean new file mode 100644 index 0000000000..ca536f8a7c --- /dev/null +++ b/script/lean-bisect-test.lean @@ -0,0 +1,307 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +namespace Strata +namespace Python + +/- +Parser and translator for some basic regular expression patterns supported by +Python's `re` library +Ref.: https://docs.python.org/3/library/re.html + +Also see +https://github.com/python/cpython/blob/759a048d4bea522fda2fe929be0fba1650c62b0e/Lib/re/_parser.py +for a reference implementation. +-/ + +------------------------------------------------------------------------------- + +inductive ParseError where + /-- + `patternError` is raised when Python's `re.patternError` exception is + raised. + [Reference: Python's re exceptions](https://docs.python.org/3/library/re.html#exceptions): + + "Exception raised when a string passed to one of the functions here is not a + valid regular expression (for example, it might contain unmatched + parentheses) or when some other error occurs during compilation or matching. + It is never an error if a string contains no match for a pattern." + -/ + | patternError (message : String) (pattern : String) (pos : String.Pos.Raw) + /-- + `unimplemented` is raised whenever we don't support some regex operations + (e.g., lookahead assertions). + -/ + | unimplemented (message : String) (pattern : String) (pos : String.Pos.Raw) + deriving Repr + +def ParseError.toString : ParseError → String + | .patternError msg pat pos => s!"Pattern error at position {pos.byteIdx}: {msg} in pattern '{pat}'" + | .unimplemented msg pat pos => s!"Unimplemented at position {pos.byteIdx}: {msg} in pattern '{pat}'" + +instance : ToString ParseError where + toString := ParseError.toString + +------------------------------------------------------------------------------- + +/-- +Regular Expression Nodes +-/ +inductive RegexAST where + /-- Single literal character: `a` -/ + | char : Char → RegexAST + /-- Character range: `[a-z]` -/ + | range : Char → Char → RegexAST + /-- Alternation: `a|b` -/ + | union : RegexAST → RegexAST → RegexAST + /-- Concatenation: `ab` -/ + | concat : RegexAST → RegexAST → RegexAST + /-- Any character: `.` -/ + | anychar : RegexAST + /-- Zero or more: `a*` -/ + | star : RegexAST → RegexAST + /-- One or more: `a+` -/ + | plus : RegexAST → RegexAST + /-- Zero or one: `a?` -/ + | optional : RegexAST → RegexAST + /-- Bounded repetition: `a{n,m}` -/ + | loop : RegexAST → Nat → Nat → RegexAST + /-- Start of string: `^` -/ + | anchor_start : RegexAST + /-- End of string: `$` -/ + | anchor_end : RegexAST + /-- Grouping: `(abc)` -/ + | group : RegexAST → RegexAST + /-- Empty string: `()` or `""` -/ + | empty : RegexAST + /-- Complement: `[^a-z]` -/ + | complement : RegexAST → RegexAST + deriving Inhabited, Repr + +------------------------------------------------------------------------------- + +/-- Parse character class like [a-z], [0-9], etc. into union of ranges and + chars. Note that this parses `|` as a character. -/ +def parseCharClass (s : String) (pos : String.Pos.Raw) : Except ParseError (RegexAST × String.Pos.Raw) := do + if pos.get? s != some '[' then throw (.patternError "Expected '[' at start of character class" s pos) + let mut i := pos.next s + + -- Check for complement (negation) with leading ^ + let isComplement := !i.atEnd s && i.get? s == some '^' + if isComplement then + i := i.next s + + let mut result : Option RegexAST := none + + -- Process each element in the character class. + while !i.atEnd s && i.get? s != some ']' do + -- Uncommenting this makes the code stop + --dbg_trace "Working" (pure ()) + let some c1 := i.get? s | throw (.patternError "Invalid character in class" s i) + let i1 := i.next s + -- Check for range pattern: c1-c2. + if !i1.atEnd s && i1.get? s == some '-' then + let i2 := i1.next s + if !i2.atEnd s && i2.get? s != some ']' then + let some c2 := i2.get? s | throw (.patternError "Invalid character in range" s i2) + if c1 > c2 then + throw (.patternError s!"Invalid character range [{c1}-{c2}]: \ + start character '{c1}' is greater than end character '{c2}'" s i) + let r := RegexAST.range c1 c2 + -- Union with previous elements. + result := some (match result with | none => r | some prev => RegexAST.union prev r) + i := i2.next s + continue + -- Single character. + let r := RegexAST.char c1 + result := some (match result with | none => r | some prev => RegexAST.union prev r) + i := i.next s + + let some ast := result | throw (.patternError "Unterminated character set" s pos) + let finalAst := if isComplement then RegexAST.complement ast else ast + pure (finalAst, i.next s) + +------------------------------------------------------------------------------- + +/-- Parse numeric repeats like `{10}` or `{1,10}` into min and max bounds. -/ +def parseBounds (s : String) (pos : String.Pos.Raw) : Except ParseError (Nat × Nat × String.Pos.Raw) := do + if pos.get? s != some '{' then throw (.patternError "Expected '{' at start of bounds" s pos) + let mut i := pos.next s + let mut numStr := "" + + -- Parse first number. + while !i.atEnd s && (i.get? s).any Char.isDigit do + numStr := numStr.push ((i.get? s).get!) + i := i.next s + + let some n := numStr.toNat? | throw (.patternError "Invalid minimum bound" s pos) + + -- Check for comma (range) or closing brace (exact count). + match i.get? s with + | some '}' => pure (n, n, i.next s) -- {n} means exactly n times. + | some ',' => + i := i.next s + -- Parse maximum bound + numStr := "" + while !i.atEnd s && (i.get? s).any Char.isDigit do + numStr := numStr.push ((i.get? s).get!) + i := i.next s + let some max := numStr.toNat? | throw (.patternError "Invalid maximum bound" s i) + if i.get? s != some '}' then throw (.patternError "Expected '}' at end of bounds" s i) + -- Validate bounds order + if max < n then + throw (.patternError s!"Invalid repeat bounds \{{n},{max}}: \ + maximum {max} is less than minimum {n}" s pos) + pure (n, max, i.next s) + | _ => throw (.patternError "Invalid bounds syntax" s i) + +------------------------------------------------------------------------------- + +mutual +/-- +Parse atom: single element (char, class, anchor, group) with optional +quantifier. Stops at the first `|`. +-/ +partial def parseAtom (s : String) (pos : String.Pos.Raw) : Except ParseError (RegexAST × String.Pos.Raw) := do + if pos.atEnd s then throw (.patternError "Unexpected end of regex" s pos) + + let some c := pos.get? s | throw (.patternError "Invalid position" s pos) + + -- Detect invalid quantifier at start + if c == '*' || c == '+' || c == '{' || c == '?' then + throw (.patternError s!"Quantifier '{c}' at position {pos} has nothing to quantify" s pos) + + -- Detect unbalanced closing parenthesis + if c == ')' then + throw (.patternError "Unbalanced parenthesis" s pos) + + -- Parse base element (anchor, char class, group, anychar, escape, or single char). + let (base, nextPos) ← match c with + | '^' => pure (RegexAST.anchor_start, pos.next s) + | '$' => pure (RegexAST.anchor_end, pos.next s) + | '[' => parseCharClass s pos + | '(' => parseExplicitGroup s pos + | '.' => pure (RegexAST.anychar, pos.next s) + | '\\' => + -- Handle escape sequence. + -- Note: Python uses a single backslash as an escape character, but Lean + -- strings need to escape that. After DDMification, we will see two + -- backslashes in Strata for every Python backslash. + let nextPos := pos.next s + if nextPos.atEnd s then throw (.patternError "Incomplete escape sequence at end of regex" s pos) + let some escapedChar := nextPos.get? s | throw (.patternError "Invalid escape position" s nextPos) + -- Check for special sequences (unsupported right now). + match escapedChar with + | 'A' | 'b' | 'B' | 'd' | 'D' | 's' | 'S' | 'w' | 'W' | 'z' | 'Z' => + throw (.unimplemented s!"Special sequence \\{escapedChar} is not supported" s pos) + | 'a' | 'f' | 'n' | 'N' | 'r' | 't' | 'u' | 'U' | 'v' | 'x' => + throw (.unimplemented s!"Escape sequence \\{escapedChar} is not supported" s pos) + | c => + if c.isDigit then + throw (.unimplemented s!"Backreference \\{c} is not supported" s pos) + else + pure (RegexAST.char escapedChar, nextPos.next s) + | _ => pure (RegexAST.char c, pos.next s) + + -- Check for numeric repeat suffix on base element (but not on anchors) + match base with + | .anchor_start | .anchor_end => pure (base, nextPos) + | _ => + if !nextPos.atEnd s then + match nextPos.get? s with + | some '{' => + let (min, max, finalPos) ← parseBounds s nextPos + pure (RegexAST.loop base min max, finalPos) + | some '*' => + let afterStar := nextPos.next s + if !afterStar.atEnd s then + match afterStar.get? s with + | some '?' => throw (.unimplemented "Non-greedy quantifier *? is not supported" s nextPos) + | some '+' => throw (.unimplemented "Possessive quantifier *+ is not supported" s nextPos) + | _ => pure (RegexAST.star base, afterStar) + else pure (RegexAST.star base, afterStar) + | some '+' => + let afterPlus := nextPos.next s + if !afterPlus.atEnd s then + match afterPlus.get? s with + | some '?' => throw (.unimplemented "Non-greedy quantifier +? is not supported" s nextPos) + | some '+' => throw (.unimplemented "Possessive quantifier ++ is not supported" s nextPos) + | _ => pure (RegexAST.plus base, afterPlus) + else pure (RegexAST.plus base, afterPlus) + | some '?' => + let afterQuestion := nextPos.next s + if !afterQuestion.atEnd s then + match afterQuestion.get? s with + | some '?' => throw (.unimplemented "Non-greedy quantifier ?? is not supported" s nextPos) + | some '+' => throw (.unimplemented "Possessive quantifier ?+ is not supported" s nextPos) + | _ => pure (RegexAST.optional base, afterQuestion) + else pure (RegexAST.optional base, afterQuestion) + | _ => pure (base, nextPos) + else + pure (base, nextPos) + +/-- Parse explicit group with parentheses. -/ +partial def parseExplicitGroup (s : String) (pos : String.Pos.Raw) : Except ParseError (RegexAST × String.Pos.Raw) := do + if pos.get? s != some '(' then throw (.patternError "Expected '(' at start of group" s pos) + let mut i := pos.next s + + -- Check for extension notation (?... + if !i.atEnd s && i.get? s == some '?' then + let i1 := i.next s + if !i1.atEnd s then + match i1.get? s with + | some '=' => throw (.unimplemented "Positive lookahead (?=...) is not supported" s pos) + | some '!' => throw (.unimplemented "Negative lookahead (?!...) is not supported" s pos) + | _ => throw (.unimplemented "Extension notation (?...) is not supported" s pos) + + let (inner, finalPos) ← parseGroup s i (some ')') + pure (.group inner, finalPos) + +/-- Parse group: handles alternation and concatenation at current scope. -/ +partial def parseGroup (s : String) (pos : String.Pos.Raw) (endChar : Option Char) : + Except ParseError (RegexAST × String.Pos.Raw) := do + let mut alternatives : List (List RegexAST) := [[]] + let mut i := pos + + -- Parse until end of string or `endChar`. + while !i.atEnd s && (endChar.isNone || i.get? s != endChar) do + if i.get? s == some '|' then + -- Push a new scope to `alternatives`. + alternatives := [] :: alternatives + i := i.next s + else + let (ast, nextPos) ← parseAtom s i + alternatives := match alternatives with + | [] => [[ast]] + | head :: tail => (ast :: head) :: tail + i := nextPos + + -- Check for expected end character. + if let some ec := endChar then + if i.get? s != some ec then + throw (.patternError s!"Expected '{ec}'" s i) + i := i.next s + + -- Build result: concatenate each alternative, then union them. + let concatAlts := alternatives.reverse.filterMap fun alt => + match alt.reverse with + | [] => -- Empty regex. + some (.empty) + | [single] => some single + | head :: tail => some (tail.foldl RegexAST.concat head) + + match concatAlts with + | [] => pure (.empty, i) + | [single] => pure (single, i) + | head :: tail => pure (tail.foldl RegexAST.union head, i) +end + +/-- info: Except.ok (Strata.Python.RegexAST.range 'A' 'z', { byteIdx := 5 }) -/ +#guard_msgs in +#eval parseCharClass "[A-z]" ⟨0⟩ + +-- Test code: Print done +#print "Done!"