Theory NotebookMath for LLMs

Proof Techniques

Mathematical Foundations / Proof Techniques

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Proof Techniques — From Intuition to Iron-Clad Reasoning

A proof is a finite sequence of logically valid steps that transforms hypotheses into conclusions. Without proof techniques, theoretical AI is inaccessible — every convergence guarantee, every generalisation bound, every correctness argument rests on the methods in this notebook.

This notebook is the interactive companion to notes.md. It demonstrates every proof technique with runnable Python code that verifies mathematical statements computationally, making the abstract concrete.

SectionTopicWhat You'll Build
1Intuition & LandscapeProof structure parser, statement analyser, technique selector
2Direct ProofDirect proof verifier, continuity checker, convex hull demo
3Proof by ConstructionBijection builder, ReLU approximation constructor
4Proof by ContrapositiveContrapositive generator, Lipschitz-robustness demo
5Proof by ContradictionIrrationality verifier, infinitely many primes demo
6Proof by CasesCase exhaustion engine, ReLU gradient case analysis
7Mathematical InductionSum formula prover, induction chain visualiser
8Strong InductionPrime factorisation prover, Fibonacci bound demo
9Structural InductionTree property prover, formula length verifier
10Probabilistic MethodRandom graph existence demo, expectation argument
11Counting ArgumentsDouble counting verifier, pigeonhole demo
12Epsilon-Delta & AnalyticContinuity prover, convergence rate demo, fixed point iteration
13ML Theory Proof PatternsUnion bound, concentration inequalities, PAC learning
14Common MistakesMistake detector, invalid proof identifier
15Exercises8 comprehensive proof exercises with full solutions
16Why This MattersMaster reference, 2026 AI perspective

Prerequisites: Python, NumPy. Basic logic and set theory. Style: Every proof is first stated mathematically, then verified computationally with raw implementations.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import math
import itertools
from fractions import Fraction
from typing import List, Tuple, Dict, Optional, Callable
import random

np.random.seed(42)
random.seed(42)
np.set_printoptions(precision=8, suppress=True)

print('Proof Techniques — Interactive Theory Notebook')
print('=' * 55)
print(f'NumPy {np.__version__}')
print(f'Python math, itertools, fractions loaded')
print(f'All implementations are raw — no proof libraries')

1. Intuition

1.1 What Are Proof Techniques?

A proof technique is a general strategy for establishing that a mathematical statement is true beyond any possible doubt.

  • Unlike empirical evidence (which can always be overturned by a new counterexample) or intuition (which is frequently wrong), a mathematical proof is an absolute guarantee — valid for all cases, for all time
  • Proof techniques are the algorithms of mathematical reasoning: given a goal, which technique do you apply? Given a structure, which argument do you use?
  • Mastering proof techniques means mastering the ability to move from "I believe this is true" to "I can prove this is true and explain exactly why"

For AI: proofs establish correctness of algorithms, validity of bounds, convergence of optimisers, and soundness of theoretical guarantees. Without proof techniques, theoretical ML is inaccessible.

1.2 Why Proof Techniques Matter for AI

AI DomainProof Technique UsedWhat It Establishes
Convergence proofsTelescoping + inductionGradient descent converges to stationary point
Generalisation boundsUnion bound + concentrationPAC learning, VC dimension, Rademacher complexity
Algorithm correctnessDirect proof, inductionBPE terminates, softmax is well-defined
Complexity theoryReduction (contradiction)Neural net verification is NP-hard
LLM reasoningAll techniquesEvaluating whether LLM "proofs" are valid
Formal verificationMachine-checked proofsLean, Coq, Isabelle verify AI system properties

1.3 The Landscape of Proof Techniques

PROOF TECHNIQUES
├── Direct Methods
│   ├── Direct proof (assume hypothesis → derive conclusion)
│   ├── Proof by construction (exhibit the object explicitly)
│   └── Proof by exhaustion (check all cases)
├── Indirect Methods
│   ├── Proof by contrapositive (prove ¬Q → ¬P instead of P → Q)
│   ├── Proof by contradiction (assume ¬P → derive ⊥)
│   └── Proof by cases (partition into exhaustive sub-cases)
├── Inductive Methods
│   ├── Mathematical induction (base + step)
│   ├── Strong induction (use all previous cases)
│   └── Structural induction (induct on recursive structure)
├── Probabilistic Methods
│   ├── Probabilistic method (show Pr[property] > 0)
│   ├── Expectation argument (E[X] implies existence)
│   └── Lovász Local Lemma (avoid rare bad events simultaneously)
├── Algebraic Methods
│   ├── Counting arguments (double counting, bijection)
│   ├── Pigeonhole principle (n+1 items in n bins)
│   └── Inclusion-exclusion (count unions precisely)
└── Analytic Methods
    ├── ε-δ arguments (limits, continuity, convergence)
    ├── Compactness arguments (Heine-Borel, sequential)
    └── Fixed point arguments (Banach, Brouwer, Kakutani)

Code cell 5

# ══════════════════════════════════════════════════════════════════
# 1.1–1.3  Proof Technique Landscape — Interactive Classifier
# ══════════════════════════════════════════════════════════════════

# Build a complete taxonomy of proof techniques with metadata
PROOF_TECHNIQUES = {
    'Direct Proof': {
        'category': 'Direct Methods',
        'strategy': 'Assume P, derive Q through logical steps',
        'when_to_use': 'Hypothesis gives concrete objects; conclusion has clear definition',
        'ai_example': 'Attention output lies in convex hull of values',
        'strength': 'Most natural; try first',
        'weakness': 'May get stuck when conclusion is hard to reach directly'
    },
    'Proof by Construction': {
        'category': 'Direct Methods',
        'strategy': 'Exhibit specific object satisfying property',
        'when_to_use': 'Need to prove existence (∃x P(x))',
        'ai_example': 'Construct neural net approximating target function',
        'strength': 'Provides algorithm; most informative existence proof',
        'weakness': 'Finding the construction can be very hard'
    },
    'Proof by Contrapositive': {
        'category': 'Indirect Methods',
        'strategy': 'Prove ¬Q → ¬P (equivalent to P → Q)',
        'when_to_use': '¬Q gives more useful info than P; direct proof stalls',
        'ai_example': 'If robust to adversarial examples then Lipschitz',
        'strength': 'Logically equivalent; opens new proof paths',
        'weakness': 'Must correctly negate both P and Q'
    },
    'Proof by Contradiction': {
        'category': 'Indirect Methods',
        'strategy': 'Assume ¬P, derive contradiction ⊥',
        'when_to_use': 'Impossibility results; ¬P has powerful consequences',
        'ai_example': '√2 is irrational; infinitely many primes',
        'strength': 'Very flexible; works for existence and impossibility',
        'weakness': 'Non-constructive; does not build the object'
    },
    'Proof by Cases': {
        'category': 'Indirect Methods',
        'strategy': 'Partition into exhaustive cases; prove each',
        'when_to_use': 'Natural case split (signs, parity, piecewise functions)',
        'ai_example': 'ReLU subgradient at x=0; piecewise linear networks',
        'strength': 'Handles piecewise behaviour naturally',
        'weakness': 'Number of cases can explode'
    },
    'Mathematical Induction': {
        'category': 'Inductive Methods',
        'strategy': 'Base case + (P(k) → P(k+1))',
        'when_to_use': 'Universal statement ∀n P(n) over natural numbers',
        'ai_example': 'Sum formulas; network depth properties',
        'strength': 'Proves infinitely many cases via finite argument',
        'weakness': 'Must identify right P(n); base case often forgotten'
    },
    'Strong Induction': {
        'category': 'Inductive Methods',
        'strategy': 'Assume P(0)...P(k) all hold; prove P(k+1)',
        'when_to_use': 'P(k+1) depends on multiple previous values',
        'ai_example': 'Prime factorisation; recursive algorithm analysis',
        'strength': 'Stronger hypothesis available in inductive step',
        'weakness': 'May need multiple base cases'
    },
    'Structural Induction': {
        'category': 'Inductive Methods',
        'strategy': 'Base case for atoms; step for composite structures',
        'when_to_use': 'Recursively defined structures (trees, formulas, lists)',
        'ai_example': 'Parse tree properties; recursive tokenisation correctness',
        'strength': 'Matches recursive definition of the structure',
        'weakness': 'Must correctly identify base and recursive cases'
    },
    'Probabilistic Method': {
        'category': 'Probabilistic Methods',
        'strategy': 'Show Pr[property] > 0; conclude object exists',
        'when_to_use': 'Existence proof where construction is unknown',
        'ai_example': 'Random matrices satisfy RIP; JL embeddings exist',
        'strength': 'Proves existence of "good" objects in large spaces',
        'weakness': 'Non-constructive; does not tell you WHICH object'
    },
    'Counting Arguments': {
        'category': 'Algebraic Methods',
        'strategy': 'Count same set two ways; equate counts',
        'when_to_use': 'Combinatorial identities; graph properties',
        'ai_example': 'Sum of degrees = 2|E|; dataset overlap analysis',
        'strength': 'Elegant; reveals hidden structure',
        'weakness': 'Requires finding the right thing to count'
    },
    'Epsilon-Delta': {
        'category': 'Analytic Methods',
        'strategy': 'For every ε>0, find δ>0 such that ...',
        'when_to_use': 'Limits, continuity, convergence',
        'ai_example': 'Gradient descent convergence rate; loss continuity',
        'strength': 'Rigorous treatment of infinite processes',
        'weakness': 'Finding the right δ (or N) can be tricky'
    },
    'Fixed Point Theorems': {
        'category': 'Analytic Methods',
        'strategy': 'Show map is contraction / continuous on compact set',
        'when_to_use': 'Iterative algorithms; equilibrium existence',
        'ai_example': 'Value iteration converges (Bellman is contraction)',
        'strength': 'Guarantees convergence AND uniqueness (Banach)',
        'weakness': 'Contraction constant must be < 1; may be hard to verify'
    }
}

print('PROOF TECHNIQUE TAXONOMY')
print('=' * 75)

# Group by category
categories = {}
for name, info in PROOF_TECHNIQUES.items():
    cat = info['category']
    if cat not in categories:
        categories[cat] = []
    categories[cat].append(name)

for cat, techniques in categories.items():
    print(f'\n┌─ {cat} ({len(techniques)} techniques)')
    for i, tech in enumerate(techniques):
        info = PROOF_TECHNIQUES[tech]
        connector = '└──' if i == len(techniques) - 1 else '├──'
        print(f'│  {connector} {tech}')
        print(f'│      Strategy: {info["strategy"]}')
        print(f'│      AI use:   {info["ai_example"]}')

# Technique selector — given a goal, which technique to try
print('\n' + '=' * 75)
print('TECHNIQUE SELECTOR — Which proof method to try first?')
print('=' * 75)

selection_rules = [
    ('Goal: Prove P → Q',          'Direct Proof',          'Try direct first; switch if stuck'),
    ('Goal: Prove P → Q (stuck)',   'Proof by Contrapositive', 'Negate conclusion; often unlocks path'),
    ('Goal: Prove ∃x P(x)',        'Proof by Construction',  'Build the object explicitly'),
    ('Goal: Prove impossibility',   'Proof by Contradiction', 'Assume possible → derive ⊥'),
    ('Goal: ∀n ∈ ℕ P(n)',          'Mathematical Induction', 'Base case + step'),
    ('Goal: P(n) uses P(j<n)',      'Strong Induction',      'Assume all prior cases'),
    ('Goal: Property of tree/list', 'Structural Induction',  'Match recursive definition'),
    ('Goal: Object exists (large space)', 'Probabilistic Method', 'Random choice has Pr > 0'),
    ('Goal: Combinatorial identity', 'Counting Arguments',   'Count same set two ways'),
    ('Goal: Limit or convergence',  'Epsilon-Delta',         'Given ε, find δ or N'),
    ('Goal: Piecewise behaviour',   'Proof by Cases',        'Partition and handle each'),
    ('Goal: Iteration converges',   'Fixed Point Theorems',  'Show contraction property'),
]

for goal, technique, hint in selection_rules:
    print(f'  {goal:<35}{technique:<25} ({hint})')

print(f'\nTotal techniques catalogued: {len(PROOF_TECHNIQUES)}')

1.4 The Structure of a Mathematical Statement

Every theorem has the form: "If [hypotheses], then [conclusion]"

H1H2Hn    CH_1 \wedge H_2 \wedge \ldots \wedge H_n \implies C
  • Hypotheses: conditions assumed to hold; the "given" part
  • Conclusion: what must be shown to follow; the "prove" part
  • A proof is a finite sequence of logically valid steps from hypotheses to conclusion
  • Each step follows from: axioms, definitions, previously proved results, or hypotheses
  • Key skill: identifying what you are allowed to assume and what you must derive

1.5 Reading and Writing Proofs

Reading a proof: identify hypotheses → identify conclusion → trace each step → check each logical inference → verify completeness

Writing a proof: state what you are proving → state your strategy → execute strategy → end with explicit conclusion

Common proof structure markers:

MarkerPurpose
"Assume…" / "Suppose…"Introduce hypothesis or assumption for indirect argument
"Let…"Introduce a variable or object
"Since…" / "Because…" / "By…"Justify a step
"Therefore…" / "Thus…" / "Hence…"Conclude a step
"We have shown…" / "This completes the proof"Explicit closure
"QED" / "□" / "∎"End of proof

1.6 Historical Timeline

YearMathematicianContribution
~300 BCEEuclidElements: first systematic proofs; 467 propositions from 5 axioms
~250 BCEArchimedesMethod of exhaustion; proto-integration
9th–13th c.Arab mathematiciansAlgebraic proofs; al-Khwarizmi; completing the square
1670sNewton, LeibnizCalculus; proofs informal by modern standards
1821CauchyRigorous ε-δ definitions of limits; modern analysis begins
1860sWeierstrassFormal ε-δ proofs; eliminated intuitive gaps
1874CantorDiagonal argument; ℝ is uncountable
1900HilbertProgramme to formalise all mathematics
1931GödelIncompleteness theorems; limits of formal systems
1976Appel & HakenFour-colour theorem; computer-assisted proof
1995WilesFermat's Last Theorem; 129-page proof
2000s+Coq, Lean, IsabelleMachine-verified proofs
2024AlphaProof (DeepMind)LLM + Lean; silver medal at IMO 2024

Code cell 7

# ══════════════════════════════════════════════════════════════════
# 1.4–1.6  Statement Analyser & Proof Structure Parser
# ══════════════════════════════════════════════════════════════════

def analyse_statement(statement: str, hypotheses: list, conclusion: str) -> dict:
    """Analyse the logical structure of a mathematical statement."""
    result = {
        'statement': statement,
        'form': f'({" ∧ ".join(f"H{i+1}" for i in range(len(hypotheses)))}) → C',
        'hypotheses': hypotheses,
        'conclusion': conclusion,
        'negated_conclusion': f'¬({conclusion})',
        'contrapositive': f'¬({conclusion}) → ¬({" ∧ ".join(hypotheses)})',
        'converse': f'({conclusion}) → ({" ∧ ".join(hypotheses)})',
    }
    return result

# Demonstrate with real mathematical statements
print('STATEMENT STRUCTURE ANALYSIS')
print('=' * 70)

statements = [
    {
        'statement': 'If n² is even, then n is even',
        'hypotheses': ['n² is even'],
        'conclusion': 'n is even'
    },
    {
        'statement': 'If f and g are continuous at a, then f+g is continuous at a',
        'hypotheses': ['f is continuous at a', 'g is continuous at a'],
        'conclusion': 'f+g is continuous at a'
    },
    {
        'statement': 'If f_θ is robust to adversarial examples, then f_θ is Lipschitz',
        'hypotheses': ['f_θ is robust to adversarial examples'],
        'conclusion': 'f_θ is Lipschitz'
    },
    {
        'statement': 'If m and n are even integers, then mn is even',
        'hypotheses': ['m is even', 'n is even'],
        'conclusion': 'mn is even'
    },
]

for s in statements:
    analysis = analyse_statement(**s)
    print(f'\nStatement: {analysis["statement"]}')
    print(f'  Logical form:     {analysis["form"]}')
    for i, h in enumerate(analysis['hypotheses']):
        print(f'    H{i+1}: {h}')
    print(f'    C:  {analysis["conclusion"]}')
    print(f'  Contrapositive:   {analysis["contrapositive"]}')
    print(f'  Converse:         {analysis["converse"]}')
    print(f'  ⚠ Converse ≠ Original (converse is NOT logically equivalent)')

# Historical timeline verification
print('\n' + '=' * 70)
print('PROOF TECHNIQUE HISTORICAL TIMELINE')
print('=' * 70)

timeline = [
    (-300,  'Euclid',        'Axiomatic method; Elements; systematic deduction'),
    (-250,  'Archimedes',    'Method of exhaustion; proto-integration'),
    (820,   'al-Khwarizmi',  'Algebraic proofs; completing the square'),
    (1670,  'Newton/Leibniz','Calculus; informal proofs by modern standards'),
    (1821,  'Cauchy',        'Rigorous ε-δ definitions; modern analysis'),
    (1860,  'Weierstrass',   'Formal ε-δ proofs; eliminated intuitive gaps'),
    (1874,  'Cantor',        'Diagonal argument; ℝ uncountable; set theory'),
    (1900,  'Hilbert',       'Formalisation programme; 23 open problems'),
    (1901,  'Russell',       'Set-theoretic paradox; naive proof inconsistency'),
    (1931,  'Gödel',         'Incompleteness theorems; limits of formal proofs'),
    (1976,  'Appel & Haken', 'Four-colour theorem; computer-assisted proof'),
    (1995,  'Wiles',         "Fermat's Last Theorem; 129-page proof"),
    (2005,  'Lean/Coq',      'Machine-verified proof assistants mature'),
    (2024,  'AlphaProof',    'LLM + Lean achieves IMO silver medal level'),
]

for year, person, contribution in timeline:
    year_str = f'{year} BCE' if year < 0 else f'{year} CE'
    bar_len = max(1, (year + 400) // 80)  # rough visual scale
    print(f'  {year_str:>9}{"█" * min(bar_len, 35):35} {person}: {contribution}')

# Proof reading exercise — identify the structure
print('\n' + '=' * 70)
print('PROOF READING EXERCISE — Identify the markers')
print('=' * 70)

sample_proof = [
    ('Assume',     'n is even.',                          'hypothesis'),
    ('Let',        'n = 2k for some integer k.',          'definition'),
    ('Then',       'n² = (2k)² = 4k² = 2(2k²).',        'algebraic step'),
    ('Since',      '2k² is an integer,',                  'justification'),
    ('Therefore',  'n² is of the form 2m.',               'intermediate conclusion'),
    ('Hence',      'n² is even.',                         'final conclusion'),
    ('□',          '',                                     'end of proof'),
]

print('\nSample proof: "If n is even, then n² is even"\n')
for marker, content, role in sample_proof:
    print(f'  [{role:>24}]  {marker} {content}')

print('\n✓ Every step is justified; every marker serves a purpose')
print('✓ Proof structure: hypothesis → definition → algebra → conclusion')

2. Direct Proof

2.1 The Strategy

To prove P    QP \implies Q:

  1. Assume PP is true
  2. Through a sequence of logical steps, each justified by definitions, axioms, or previously proved results, derive QQ
  3. Conclude QQ follows from PP

Most natural proof technique — try it first for any implication. Works best when the hypothesis gives you something concrete to work with and the conclusion has a clear definition to aim for.

2.2 Template

Proof:
  Assume [P].
  [Step 1: apply definition/theorem to get intermediate result]
  [Step 2: algebraic or logical manipulation]
  ...
  [Final step: arrive at Q]
  Therefore [Q]. □

2.3 Worked Example — Even Times Even Is Even

Theorem: If mm and nn are both even integers, then mnmn is even.

Proof:

  • Assume mm and nn are both even
  • By definition of even: jZ:m=2j\exists j \in \mathbb{Z}: m = 2j and kZ:n=2k\exists k \in \mathbb{Z}: n = 2k
  • Then mn=(2j)(2k)=4jk=2(2jk)mn = (2j)(2k) = 4jk = 2(2jk)
  • Since 2jkZ2jk \in \mathbb{Z}, mnmn is of the form 2×(integer)2 \times (\text{integer})
  • Therefore mnmn is even \square

Code cell 9

# ══════════════════════════════════════════════════════════════════
# 2.1–2.3  Direct Proof — Even × Even = Even (Computational Verification)
# ══════════════════════════════════════════════════════════════════

def is_even(n: int) -> bool:
    """Check if integer is even by definition: n = 2k for some integer k."""
    return n % 2 == 0

def direct_proof_even_times_even(m: int, n: int) -> dict:
    """
    Verify the direct proof that even × even = even.
    
    Proof structure:
      H1: m is even  →  m = 2j for some j ∈ ℤ
      H2: n is even  →  n = 2k for some k ∈ ℤ
      Then: mn = (2j)(2k) = 4jk = 2(2jk)
      Since 2jk ∈ ℤ, mn is even  □
    """
    # Step 1: Verify hypotheses
    assert is_even(m), f'Hypothesis violated: {m} is not even'
    assert is_even(n), f'Hypothesis violated: {n} is not even'
    
    # Step 2: Extract witnesses (the j and k from the definition)
    j = m // 2  # m = 2j
    k = n // 2  # n = 2k
    
    # Step 3: Compute product and verify structure
    product = m * n
    product_via_formula = 4 * j * k  # (2j)(2k) = 4jk
    product_even_form = 2 * (2 * j * k)  # 4jk = 2(2jk)
    
    # Step 4: Verify conclusion
    witness = 2 * j * k  # The integer such that mn = 2 × witness
    
    return {
        'm': m, 'n': n, 'j': j, 'k': k,
        'product': product,
        'product_via_formula': product_via_formula,
        'even_form': product_even_form,
        'witness': witness,
        'conclusion_holds': is_even(product)
    }

print('DIRECT PROOF: Even × Even = Even')
print('=' * 65)
print(f'{"m":>6} {"n":>6}{"j":>4} {"k":>4}{"mn":>8} = 2×{"(2jk)":>8} │ Even?')
print('─' * 65)

test_pairs = [(2, 4), (6, 8), (10, 12), (100, 200), (-4, 6), (0, 42), (14, 22)]
all_pass = True
for m, n in test_pairs:
    result = direct_proof_even_times_even(m, n)
    status = '✓' if result['conclusion_holds'] else '✗'
    print(f'{m:>6} {n:>6}{result["j"]:>4} {result["k"]:>4} │ '
          f'{result["product"]:>8} = 2×{result["witness"]:>8}{status}')
    all_pass = all_pass and result['conclusion_holds']

print(f'\n{"✓ All cases verified" if all_pass else "✗ Some cases failed"}')

# Exhaustive verification for small range
print('\n--- Exhaustive verification for even integers in [-20, 20] ---')
count = 0
failures = 0
for m in range(-20, 21, 2):  # even numbers only
    for n in range(-20, 21, 2):
        count += 1
        if not is_even(m * n):
            failures += 1
            print(f'  COUNTEREXAMPLE: {m} × {n} = {m*n} is not even!')

print(f'  Tested {count} even×even products: {failures} failures')
print(f'  ✓ Empirical verification supports the proof')
print(f'  ⚠ But this does NOT prove it — the proof above does (for ALL integers)')

2.4 When Direct Proof Gets Stuck — n² Even ⟹ n Even

Theorem: If n2n^2 is even then nn is even.

Attempted direct proof: assume n2n^2 is even; n2=2kn^2 = 2k for some kk; n=2kn = \sqrt{2k}stuck — square root of an even number need not obviously be an integer that's even.

Lesson: when direct proof stalls, switch to contrapositive (Section 4) or contradiction (Section 5).

2.5 Worked Example — Sum of Continuous Functions is Continuous

Theorem: If ff and gg are continuous at aa, then f+gf + g is continuous at aa.

Proof (ε-δ):

  • Let ε>0\varepsilon > 0; need δ>0\delta > 0 such that xa<δ    (f+g)(x)(f+g)(a)<ε|x-a| < \delta \implies |(f+g)(x) - (f+g)(a)| < \varepsilon
  • By continuity of ff: δ1>0:xa<δ1    f(x)f(a)<ε/2\exists \delta_1 > 0: |x-a| < \delta_1 \implies |f(x)-f(a)| < \varepsilon/2
  • By continuity of gg: δ2>0:xa<δ2    g(x)g(a)<ε/2\exists \delta_2 > 0: |x-a| < \delta_2 \implies |g(x)-g(a)| < \varepsilon/2
  • Let δ=min(δ1,δ2)\delta = \min(\delta_1, \delta_2)
  • Then xa<δ|x-a| < \delta implies:
(f+g)(x)(f+g)(a)=f(x)+g(x)f(a)g(a)f(x)f(a)+g(x)g(a)<ε2+ε2=ε|(f+g)(x)-(f+g)(a)| = |f(x)+g(x)-f(a)-g(a)| \leq |f(x)-f(a)| + |g(x)-g(a)| < \frac{\varepsilon}{2} + \frac{\varepsilon}{2} = \varepsilon

Therefore f+gf+g is continuous at aa. \square

2.6 Direct Proof in AI Contexts

Attention output lies in convex hull of values:

  • αij0\alpha_{ij} \geq 0 for all jj (softmax outputs non-negative)
  • jαij=1\sum_j \alpha_{ij} = 1 (softmax outputs sum to 1)
  • Oi=jαijVjO_i = \sum_j \alpha_{ij} V_j (weighted sum)
  • By definition of convex hull: Oiconv({Vj})O_i \in \text{conv}(\{V_j\}) \square

Cross-entropy loss is convex in logits:

  • L(z)=zy+logvexp(zv)L(\mathbf{z}) = -z_y + \log \sum_v \exp(z_v)
  • zy-z_y is linear (hence convex); logexp(zv)\log \sum \exp(z_v) is log-sum-exp (convex)
  • Sum of convex functions is convex; therefore LL is convex \square

Code cell 11

# ══════════════════════════════════════════════════════════════════
# 2.4–2.6  Direct Proof in AI — Continuity & Convex Hull Verification
# ══════════════════════════════════════════════════════════════════

# --- 2.4: Show that direct proof of "n² even → n even" gets stuck ---
print('2.4  DIRECT PROOF LIMITATION: n² even → n even')
print('=' * 65)
print('Attempting direct proof...')
print('  Assume n² is even: n² = 2k for some integer k')
print('  Then n = √(2k) ... but √(2k) is not obviously even or integer!')
print('  ✗ STUCK — direct proof does not work here')
print('  → Solution: use contrapositive (Section 4): "n odd → n² odd"')

# Empirical evidence that the theorem IS true
print('\nEmpirical check (not a proof!):')
stuck_count = 0
for n in range(-50, 51):
    if is_even(n * n) and not is_even(n):
        print(f'  Counterexample found: n={n}, n²={n*n}')
        stuck_count += 1
print(f'  Checked n ∈ [-50, 50]: {stuck_count} counterexamples')
print(f'  Theorem appears true, but direct proof fails → need different technique')

# --- 2.5: Sum of continuous functions is continuous (numerical demo) ---
print('\n' + '=' * 65)
print('2.5  SUM OF CONTINUOUS FUNCTIONS IS CONTINUOUS')
print('=' * 65)

def f(x):
    """f(x) = sin(x) — continuous everywhere."""
    return np.sin(x)

def g(x):
    """g(x) = x² — continuous everywhere."""
    return x ** 2

def verify_continuity_at_point(func, a, epsilon, func_name='f'):
    """
    Find δ such that |x-a| < δ → |func(x) - func(a)| < ε
    Returns the δ and verification status.
    """
    # Try decreasing δ values until we find one that works
    for delta_exp in range(1, 20):
        delta = 10.0 ** (-delta_exp)
        # Test many points in (a-δ, a+δ)
        test_points = np.linspace(a - delta, a + delta, 1000)
        deviations = np.abs(func(test_points) - func(a))
        max_deviation = np.max(deviations)
        if max_deviation < epsilon:
            return delta, max_deviation, True
    return None, None, False

a = 1.0  # point of continuity

print(f'\nVerify at a = {a}:')
print(f'  f(x) = sin(x),  g(x) = x²,  (f+g)(x) = sin(x) + x²')
print(f'{"ε":>10}{"δ_f":>12} {"δ_g":>12} {"δ=min":>12}{"max|Δ(f+g)|":>14} {"< ε?":>6}')
print('─' * 75)

for epsilon in [1.0, 0.1, 0.01, 0.001, 0.0001]:
    delta_f, dev_f, ok_f = verify_continuity_at_point(f, a, epsilon / 2, 'f')
    delta_g, dev_g, ok_g = verify_continuity_at_point(g, a, epsilon / 2, 'g')
    
    if ok_f and ok_g:
        delta = min(delta_f, delta_g)
        # Verify f+g with this delta
        test_x = np.linspace(a - delta, a + delta, 1000)
        fg_dev = np.max(np.abs((f(test_x) + g(test_x)) - (f(a) + g(a))))
        status = '✓' if fg_dev < epsilon else '✗'
        print(f'{epsilon:>10.4f}{delta_f:>12.2e} {delta_g:>12.2e} {delta:>12.2e}{fg_dev:>14.2e} {status:>6}')

print('\n✓ Triangle inequality: |Δ(f+g)| ≤ |Δf| + |Δg| < ε/2 + ε/2 = ε')

# --- 2.6: Attention output lies in convex hull of values ---
print('\n' + '=' * 65)
print('2.6  DIRECT PROOF: Attention Output ∈ Convex Hull of Values')
print('=' * 65)

# Build attention mechanism from scratch
seq_len = 5
d_model = 4

# Random value vectors
V = np.random.randn(seq_len, d_model)

# Random attention weights (must be non-negative, sum to 1)
raw_scores = np.random.randn(seq_len)
alpha = np.exp(raw_scores) / np.sum(np.exp(raw_scores))  # softmax

print(f'\nValue vectors V (shape {V.shape}):')
for i in range(seq_len):
    print(f'  V[{i}] = [{", ".join(f"{v:.3f}" for v in V[i])}]')

print(f'\nAttention weights α (softmax output):')
print(f'  α = [{", ".join(f"{a:.4f}" for a in alpha)}]')
print(f'  Property 1: all αᵢ ≥ 0? {np.all(alpha >= 0)} ✓')
print(f'  Property 2: Σαᵢ = {np.sum(alpha):.10f} ≈ 1? {np.isclose(np.sum(alpha), 1.0)} ✓')

# Compute attention output
O = np.zeros(d_model)
for j in range(seq_len):
    O += alpha[j] * V[j]

print(f'\nAttention output O = Σⱼ αⱼ Vⱼ:')
print(f'  O = [{", ".join(f"{o:.3f}" for o in O)}]')

# Verify O is in convex hull: O should be componentwise between min and max of V
print(f'\nConvex hull verification (componentwise bounds):')
for d in range(d_model):
    v_min = np.min(V[:, d])
    v_max = np.max(V[:, d])
    in_hull = v_min <= O[d] <= v_max
    print(f'  dim {d}: V_min={v_min:>7.3f}, O={O[d]:>7.3f}, V_max={v_max:>7.3f}  '
          f'{"✓ in range" if in_hull else "✗ OUT OF RANGE"}')

print('\n✓ By definition: non-negative weights summing to 1 → convex combination')
print('✓ Attention output is ALWAYS in convex hull of value vectors')

# --- Cross-entropy convexity verification ---
print('\n' + '=' * 65)
print('2.6  DIRECT PROOF: Cross-Entropy Loss is Convex')
print('=' * 65)

def cross_entropy_loss(z, y):
    """L(z) = -z_y + log Σ exp(z_v)"""
    return -z[y] + np.log(np.sum(np.exp(z - np.max(z)))) + np.max(z)  # numerically stable

# Verify convexity: L(λz₁ + (1-λ)z₂) ≤ λL(z₁) + (1-λ)L(z₂) for all λ ∈ [0,1]
n_classes = 5
y = 2  # true class
z1 = np.random.randn(n_classes)
z2 = np.random.randn(n_classes)

print(f'\nConvexity check: L(λz₁ + (1-λ)z₂) ≤ λL(z₁) + (1-λ)L(z₂)?')
print(f'  z₁ = [{", ".join(f"{v:.2f}" for v in z1)}]')
print(f'  z₂ = [{", ".join(f"{v:.2f}" for v in z2)}]')
print(f'  True class y = {y}')
print(f'{"λ":>5}{"L(mix)":>10} {"≤":>3} {"λL(z₁)+(1-λ)L(z₂)":>22} │ Convex?')
print('─' * 55)

convex_holds = True
for lam in np.linspace(0, 1, 11):
    z_mix = lam * z1 + (1 - lam) * z2
    L_mix = cross_entropy_loss(z_mix, y)
    L_bound = lam * cross_entropy_loss(z1, y) + (1 - lam) * cross_entropy_loss(z2, y)
    ok = L_mix <= L_bound + 1e-10  # small numerical tolerance
    convex_holds = convex_holds and ok
    print(f'{lam:>5.2f}{L_mix:>10.4f} {"≤":>3} {L_bound:>22.4f}{"✓" if ok else "✗"}')

print(f'\n{"✓ Cross-entropy IS convex in logits" if convex_holds else "✗ Convexity violated!"}')

3. Proof by Construction

3.1 The Strategy

To prove xP(x)\exists x\, P(x): exhibit a specific xx and verify P(x)P(x) holds.

  • Constructive proof: provides an algorithm; the proof itself tells you how to find the object
  • Non-constructive proof: proves existence without exhibiting the object (see Section 5)
  • Constructive proofs are stronger: they tell you not just that something exists, but how to find it

3.2 Template

Proof:
  [Describe or define the object x explicitly]
  We claim x satisfies P.
  [Verify each required property of x]
  Therefore ∃x satisfying P. □

3.3 Worked Example — Constructing a Bijection ℤ → ℕ

Theorem: Z=N|\mathbb{Z}| = |\mathbb{N}| (integers and natural numbers have the same cardinality).

Proof by construction: exhibit explicit bijection f:NZf: \mathbb{N} \to \mathbb{Z}:

  • f(0)=0f(0) = 0
  • f(2k1)=kf(2k-1) = k for k1k \geq 1 (odd naturals → positive integers)
  • f(2k)=kf(2k) = -k for k1k \geq 1 (even positive naturals → negative integers)
  • Verify: injective (different naturals map to different integers) and surjective (every integer is hit)
  • Therefore bijection exists; Z=N|\mathbb{Z}| = |\mathbb{N}| \square

3.4 Worked Example — Neural Network Approximation

Theorem: For any target function f(x)=1[x>0.5]f(x) = \mathbb{1}[x > 0.5], there exists a 1-hidden-layer ReLU network approximating ff to within ε=0.01\varepsilon = 0.01 on [0,1][0,1].

Proof by construction:

  • Let h(x)=ReLU(MxM/2)/Mh(x) = \text{ReLU}(Mx - M/2) / M for large MM (steep ramp at x=0.5x = 0.5)
  • As MM \to \infty: h(x)1[x>0.5]h(x) \to \mathbb{1}[x > 0.5] uniformly except at x=0.5x = 0.5
  • For M=100M = 100: maxh(x)f(x)<0.01\max|h(x) - f(x)| < 0.01 for x[0,0.49][0.51,1]x \in [0, 0.49] \cup [0.51, 1]
  • Explicit construction: W1=[M],  b1=[M/2],  W2=[1/M],  b2=0W_1 = [M],\; b_1 = [-M/2],\; W_2 = [1/M],\; b_2 = 0
  • This constructs the network with desired approximation property \square

Code cell 13

# ══════════════════════════════════════════════════════════════════
# 3.1–3.4  Proof by Construction — Bijection & Neural Net Approximator
# ══════════════════════════════════════════════════════════════════

# --- 3.3: Constructive bijection ℤ → ℕ ---
print('3.3  CONSTRUCTIVE PROOF: Bijection f: ℕ → ℤ')
print('=' * 65)

def nat_to_int(n: int) -> int:
    """Explicit bijection f: ℕ → ℤ constructed in the proof."""
    if n == 0:
        return 0
    elif n % 2 == 1:  # odd: f(2k-1) = k
        k = (n + 1) // 2
        return k
    else:             # even > 0: f(2k) = -k
        k = n // 2
        return -k

def int_to_nat(z: int) -> int:
    """Inverse bijection f⁻¹: ℤ → ℕ (needed to prove surjectivity)."""
    if z == 0:
        return 0
    elif z > 0:
        return 2 * z - 1  # f(2k-1) = k → f⁻¹(k) = 2k-1
    else:
        return -2 * z      # f(2k) = -k → f⁻¹(-k) = 2k

# Show the mapping
print(f'{"n (ℕ)":>8}{"f(n) (ℤ)":>10}     {"z (ℤ)":>8}{"f⁻¹(z) (ℕ)":>12}')
print('─' * 55)
for n in range(13):
    z = nat_to_int(n)
    # Also show reverse
    z_test = n - 6  # test integers from -6 to 6
    n_back = int_to_nat(z_test)
    print(f'{n:>8}{z:>10}     {z_test:>8}{n_back:>12}')

# Verify bijectivity
print('\nVerifying injectivity (no collisions):')
N = 100
images = set()
injective = True
for n in range(N):
    z = nat_to_int(n)
    if z in images:
        print(f'  ✗ Collision: f({n}) = {z} already seen!')
        injective = False
    images.add(z)
print(f'  Mapped ℕ[0..{N-1}] to {len(images)} distinct integers: '
      f'{"✓ injective" if injective else "✗ NOT injective"}')

print('\nVerifying surjectivity (every integer is hit):')
surjective = True
for z in range(-50, 51):
    n = int_to_nat(z)
    if nat_to_int(n) != z:
        print(f'  ✗ f(f⁻¹({z})) = {nat_to_int(n)}{z}')
        surjective = False
print(f'  All integers in [-50, 50] have preimage: '
      f'{"✓ surjective" if surjective else "✗ NOT surjective"}')

print(f'\n✓ f is {"bijective" if injective and surjective else "NOT bijective"} → |ℤ| = |ℕ|')

# --- 3.4: Neural network approximation by construction ---
print('\n' + '=' * 65)
print('3.4  CONSTRUCTIVE PROOF: 1-Layer ReLU Network Approximates Step Function')
print('=' * 65)

def relu(x):
    """ReLU activation: max(0, x)"""
    return np.maximum(0, x)

def step_function(x):
    """Target: f(x) = 1[x > 0.5]"""
    return (x > 0.5).astype(float)

def constructed_network(x, M):
    """
    Constructive proof: h(x) = ReLU(Mx - M/2) / M
    
    Network architecture:
      W1 = [M],  b1 = [-M/2],  W2 = [1/M],  b2 = 0
      h(x) = W2 · ReLU(W1 · x + b1) + b2
    """
    W1 = M
    b1 = -M / 2
    W2 = 1.0 / M
    b2 = 0.0
    hidden = relu(W1 * x + b1)
    output = W2 * hidden + b2
    return output

x = np.linspace(0, 1, 1001)
target = step_function(x)

print(f'\nTarget: f(x) = 1[x > 0.5]')
print(f'Network: h(x) = ReLU(Mx - M/2) / M')
print(f'\n{"M":>6}{"max error":>10} {"max err (excl ±0.01)":>22}{"≤ 0.01?":>8}')
print('─' * 55)

for M in [1, 5, 10, 50, 100, 500, 1000]:
    h = constructed_network(x, M)
    error = np.abs(h - target)
    max_err = np.max(error)
    
    # Error excluding the transition region [0.49, 0.51]
    mask = (x <= 0.49) | (x >= 0.51)
    max_err_excl = np.max(error[mask]) if np.any(mask) else 0.0
    
    status = '✓' if max_err_excl <= 0.01 else '✗'
    print(f'{M:>6}{max_err:>10.6f} {max_err_excl:>22.6f}{status:>8}')

print(f'\n✓ For M=100: max error < 0.01 outside transition region')
print(f'  Explicit construction: W₁=[100], b₁=[-50], W₂=[0.01], b₂=0')
print(f'  This IS the proof — we built the network, verified the bound')

# --- 3.5–3.6: Constructive vs Non-Constructive ---
print('\n' + '=' * 65)
print('3.5–3.6  CONSTRUCTIVE vs NON-CONSTRUCTIVE EXISTENCE')
print('=' * 65)

print('''
┌──────────────────┬──────────────────────────────────────────────────┐
│ Constructive     │ "Here is the object; here is why it works"      │
│                  │ → Provides algorithm; enables implementation    │
│                  │ → Example: bijection f(n) above; network h(x)  │
├──────────────────┼──────────────────────────────────────────────────┤
│ Non-Constructive │ "Such an object must exist (but I won't show    │
│                  │   you which one)"                               │
│                  │ → Often shorter; may be only available proof    │
│                  │ → Example: IVT proves ∃c: f(c)=0 without       │
│                  │   constructing c                                │
├──────────────────┼──────────────────────────────────────────────────┤
│ AI Relevance     │ Constructive = algorithm you can implement      │
│                  │ Non-constructive = bound you know holds         │
│                  │ Both are valid proofs; constructive is stronger │
└──────────────────┴──────────────────────────────────────────────────┘
''')

# Demonstrate with Intermediate Value Theorem
def f_ivt(x):
    return x**3 - x - 1  # continuous; f(1) = -1 < 0, f(2) = 5 > 0

print('Non-constructive existence (IVT): x³ - x - 1 = 0 has root in [1,2]')
print(f'  f(1) = {f_ivt(1)} < 0')
print(f'  f(2) = {f_ivt(2)} > 0')
print(f'  By IVT: ∃c ∈ (1,2): f(c) = 0  ← existence proved, c not constructed')
print(f'  (Bisection can FIND it: c ≈ {1.3247179572:.10f}, but IVT doesn\'t give this)')

4. Proof by Contrapositive

4.1 The Strategy

Logical equivalence: (P    Q)(¬Q    ¬P)(P \implies Q) \equiv (\neg Q \implies \neg P)

To prove P    QP \implies Q, equivalently prove ¬Q    ¬P\neg Q \implies \neg P.

When to use: when the negation of QQ gives more useful information than PP itself.

Critical distinction from contradiction: contrapositive proves ¬Q    ¬P\neg Q \implies \neg P; contradiction assumes ¬P\neg P and derives any false statement \bot.

4.2 Template

Proof (by contrapositive):
  We prove the contrapositive: assume ¬Q.
  [Derive ¬P through logical steps]
  Therefore ¬P.
  By contrapositive equivalence, P → Q. □

4.3 Worked Example — If n² is Even then n is Even

Theorem: For nZn \in \mathbb{Z}, if n2n^2 is even then nn is even.

Contrapositive: if nn is odd then n2n^2 is odd.

Proof (by contrapositive):

  • Assume nn is odd: kZ:n=2k+1\exists k \in \mathbb{Z}: n = 2k + 1
  • Then n2=(2k+1)2=4k2+4k+1=2(2k2+2k)+1n^2 = (2k+1)^2 = 4k^2 + 4k + 1 = 2(2k^2 + 2k) + 1
  • Since 2k2+2kZ2k^2 + 2k \in \mathbb{Z}, n2n^2 is of the form 2m+12m+1; therefore n2n^2 is odd
  • By contrapositive: if n2n^2 is even then nn is even \square

4.4–4.5 Contrapositive in AI Contexts

Theorem: If a neural network fθf_\theta is not Lipschitz, then it is not robust to adversarial examples.

Contrapositive (easier to prove): If fθf_\theta is robust to adversarial examples (all perturbations δδ0\|\delta\| \leq \delta_0 change output by <ε< \varepsilon), then fθf_\theta is Lipschitz with constant L=ε/δ0L = \varepsilon / \delta_0.

Theorem: If L(θ)0\nabla L(\theta) \neq 0, then θ\theta is not a local minimum.

Contrapositive: If θ\theta is a local minimum, then L(θ)=0\nabla L(\theta) = 0 (necessary condition for optimality).

4.6 Recognising When to Use Contrapositive

SignExample
Conclusion QQ involves a negation"x is not…", "cannot be…"
Hypothesis PP involves existenceNegating QQ gives concrete object
Direct proof gets stuckReformulating as ¬Q    ¬P\neg Q \implies \neg P opens new paths
Test: write ¬P\neg P and ¬Q\neg QIf ¬Q    ¬P\neg Q \implies \neg P seems more natural, use contrapositive

Code cell 15

# ══════════════════════════════════════════════════════════════════
# 4.1–4.6  Proof by Contrapositive — Complete Verification
# ══════════════════════════════════════════════════════════════════

# --- 4.3: If n² is even then n is even (via contrapositive) ---
print('4.3  CONTRAPOSITIVE PROOF: n² even → n even')
print('=' * 65)
print('Original:       P → Q  :  n² even → n even')
print('Contrapositive:  ¬Q → ¬P:  n odd → n² odd')
print()

# Prove the contrapositive computationally
print('Verifying contrapositive: if n is odd, then n² is odd')
print(f'{"n":>6} {"n=2k+1":>10} {"n²":>10} {"n²=2m+1":>12} {"n² odd?":>8}')
print('─' * 50)

all_ok = True
for n in range(-11, 12, 2):  # odd numbers
    k = (n - 1) // 2 if n > 0 else (n - 1) // 2
    n_sq = n * n
    m = (n_sq - 1) // 2
    odd = n_sq % 2 == 1
    all_ok = all_ok and odd
    if abs(n) <= 9:
        print(f'{n:>6} {"= 2("+str(k)+")+1":>10} {n_sq:>10} {"= 2("+str(m)+")+1":>12} {"✓" if odd else "✗":>8}')

print(f'\nAll odd n in [-11,11]: n² odd? {"✓ yes" if all_ok else "✗ no"}')
print(f'Contrapositive verified → original theorem holds')

# Algebraic verification
print(f'\nAlgebraic proof trace:')
print(f'  Let n = 2k + 1 (n is odd)')
print(f'  n² = (2k+1)² = 4k² + 4k + 1 = 2(2k² + 2k) + 1')
for k in range(5):
    n = 2 * k + 1
    n_sq = n * n
    witness = 2 * k * k + 2 * k
    print(f'    k={k}: n={n}, n²={n_sq}, 2({witness})+1 = {2*witness+1} ✓')

# --- 4.4-4.5: Contrapositive in AI — Lipschitz & Robustness ---
print('\n' + '=' * 65)
print('4.4  CONTRAPOSITIVE IN AI: Robustness → Lipschitz')
print('=' * 65)

def random_network_1d(x, weights, biases):
    """Simple 1-hidden-layer network."""
    h = np.maximum(0, weights[0] * x + biases[0])  # ReLU
    return float((weights[1] * h + biases[1]).item())

# Lipschitz network (bounded weights)
W_lip = [np.array([2.0]), np.array([1.0])]
b_lip = [np.array([0.0]), np.array([0.0])]

# Non-Lipschitz-like behaviour (very large weights)
W_big = [np.array([1000.0]), np.array([1000.0])]
b_big = [np.array([0.0]), np.array([0.0])]

print(f'\nContrapositive: if f_θ is robust → f_θ is Lipschitz')
print(f'Equivalently:   if f_θ not Lipschitz → f_θ not robust')
print()

delta_0 = 0.01  # perturbation budget
x0 = 0.5

for name, W, b in [('Small weights (L≈2)', W_lip, b_lip), 
                     ('Large weights (L≈10⁶)', W_big, b_big)]:
    f_x0 = random_network_1d(x0, W, b)
    f_perturbed = random_network_1d(x0 + delta_0, W, b)
    output_change = abs(f_perturbed - f_x0)
    
    # Estimate Lipschitz constant
    xs = np.linspace(0, 1, 10000)
    ys = np.array([random_network_1d(x, W, b) for x in xs])
    diffs = np.abs(np.diff(ys)) / np.abs(np.diff(xs))
    L_est = np.max(diffs)
    
    robust = output_change < 0.1  # robustness threshold
    print(f'{name}:')
    print(f'  |f(x+δ) - f(x)| = {output_change:.6f}  (δ={delta_0})')
    print(f'  Estimated Lipschitz L ≈ {L_est:.1f}')
    print(f'  Robust? {"✓ yes" if robust else "✗ no"}  |  Lipschitz bounded? {"✓" if L_est < 100 else "✗"}')
    print()

# --- 4.6: Contrapositive generator ---
print('=' * 65)
print('4.6  CONTRAPOSITIVE GENERATOR')
print('=' * 65)

statements = [
    ('n² is even', 'n is even', 'n is odd', 'n² is odd'),
    ('f is differentiable at a', 'f is continuous at a', 
     'f is not continuous at a', 'f is not differentiable at a'),
    ('∇L(θ) ≠ 0', 'θ is not a local minimum',
     'θ is a local minimum', '∇L(θ) = 0'),
    ('AB = 0 and A is invertible', 'B = 0',
     'B ≠ 0', 'AB ≠ 0 or A is not invertible'),
]

print(f'\n{"Original P → Q":>45}{"Contrapositive ¬Q → ¬P":>45}')
print('─' * 95)
for P, Q, neg_Q, neg_P in statements:
    print(f'{"If " + P + " then " + Q:>45}{"If " + neg_Q + " then " + neg_P:>45}')

print(f'\n✓ Original and contrapositive are LOGICALLY EQUIVALENT')
print(f'✗ Converse (Q → P) is NOT equivalent — common mistake!')

5. Proof by Contradiction

5.1 The Strategy

To prove PP: assume ¬P\neg P; derive a contradiction (a statement known to be false, or a statement CC and its negation ¬C\neg C simultaneously).

Valid by the law of excluded middle: P¬PP \vee \neg P; if ¬P\neg P leads to \bot, then PP must hold.

When to use: when PP has the form "X does not exist" or "X is impossible"; when no positive construction is available; when the assumption ¬P\neg P has powerful consequences.

5.2 Template

Proof (by contradiction):
  Assume for contradiction that ¬P.
  [Derive consequences of ¬P]
  [Arrive at a statement C and its negation ¬C]
  This is a contradiction.
  Therefore P must hold. □

5.3 Worked Example — √2 is Irrational

Theorem: 2Q\sqrt{2} \notin \mathbb{Q}

Proof (by contradiction):

  • Assume for contradiction 2Q\sqrt{2} \in \mathbb{Q}
  • Then 2=p/q\sqrt{2} = p/q for some p,qZp, q \in \mathbb{Z}, q0q \neq 0, with gcd(p,q)=1\gcd(p,q) = 1
  • Then 2=p2/q22 = p^2/q^2; so p2=2q2p^2 = 2q^2
  • p2p^2 is even; by Section 4's result, pp is even; so p=2kp = 2k
  • Then (2k)2=2q2(2k)^2 = 2q^2; 4k2=2q24k^2 = 2q^2; q2=2k2q^2 = 2k^2
  • q2q^2 is even; by same result, qq is even
  • But pp and qq both even contradicts gcd(p,q)=1\gcd(p,q) = 1
  • Contradiction; therefore 2Q\sqrt{2} \notin \mathbb{Q} \square

5.4 Worked Example — Infinitely Many Primes (Euclid)

Theorem: There are infinitely many primes.

Proof (by contradiction):

  • Assume for contradiction: only finitely many primes p1,p2,,pnp_1, p_2, \ldots, p_n
  • Construct N=p1×p2××pn+1N = p_1 \times p_2 \times \cdots \times p_n + 1
  • N>1N > 1; every integer >1> 1 has a prime factor; let pp be a prime factor of NN
  • pp must be one of p1,,pnp_1, \ldots, p_n (these are all primes by assumption)
  • But pNp \mid N and pp1p2pnp \mid p_1 p_2 \cdots p_n; so p(Np1p2pn)=1p \mid (N - p_1 p_2 \cdots p_n) = 1
  • No prime divides 1; contradiction
  • Therefore there are infinitely many primes \square

5.5 Worked Example — No Largest Real Number

Theorem: There is no largest real number.

Proof (by contradiction):

  • Assume MR:Mx\exists M \in \mathbb{R}: M \geq x for all xRx \in \mathbb{R}
  • Consider M+1RM + 1 \in \mathbb{R}; M+1>MM + 1 > M; contradicts MM being largest
  • Contradiction; therefore no largest real number exists \square

Code cell 17

# ══════════════════════════════════════════════════════════════════
# 5.1–5.5  Proof by Contradiction — √2 Irrational + Infinitely Many Primes
# ══════════════════════════════════════════════════════════════════

# --- 5.3: √2 is irrational — trace the contradiction ---
print('5.3  PROOF BY CONTRADICTION: √2 is Irrational')
print('=' * 65)

def trace_sqrt2_contradiction(max_q=100):
    """
    Demonstrate the contradiction in assuming √2 = p/q in lowest terms.
    For each candidate p/q ≈ √2, show that gcd(p,q) ≠ 1 when p²=2q².
    """
    print('If √2 = p/q with gcd(p,q)=1, then p² = 2q²')
    print(f'\n{"p":>6} {"q":>6}{"p²":>8} {"2q²":>8} {"p²=2q²?":>8}{"gcd(p,q)":>8} {"=1?":>4}')
    print('─' * 60)
    
    found = False
    for q in range(1, max_q + 1):
        p_sq = 2 * q * q
        p = int(math.isqrt(p_sq))
        # Check if p² = 2q² exactly
        if p * p == p_sq:
            g = math.gcd(p, q)
            eq = '✓'
            gcd_ok = '✓' if g == 1 else '✗'
            print(f'{p:>6} {q:>6}{p*p:>8} {2*q*q:>8} {eq:>8}{g:>8} {gcd_ok:>4}')
            if g != 1:
                print(f'         └── p={p}=2×{p//2} (even), q={q}=2×{q//2} (even)')
                print(f'             gcd({p},{q})={g} ≠ 1 → CONTRADICTION with lowest terms')
                found = True
    
    if not found:
        print(f'  No integer solution p²=2q² found for q ∈ [1,{max_q}]')
        print(f'  This is expected: √2 is irrational, no such p/q exists!')
    
    return found

trace_sqrt2_contradiction(50)

# Show that √2 cannot be expressed as Fraction
print(f'\nNumerical verification:')
print(f'  √2 = {math.sqrt(2):.15f}')
print(f'  Best rational approximations (continued fraction convergents):')

# Compute convergents of √2 = [1; 2, 2, 2, ...]
p_prev, p_curr = 0, 1
q_prev, q_curr = 1, 0
cf_terms = [1] + [2] * 15  # continued fraction of √2

for i, a in enumerate(cf_terms):
    p_new = a * p_curr + p_prev
    q_new = a * q_curr + q_prev
    p_prev, p_curr = p_curr, p_new
    q_prev, q_curr = q_curr, q_new
    
    if i < 10:
        approx = p_curr / q_curr
        error = abs(approx - math.sqrt(2))
        print(f'    p/q = {p_curr}/{q_curr} = {approx:.12f}, '
              f'error = {error:.2e}, gcd = {math.gcd(p_curr, q_curr)}')

print(f'  → Every fraction is an approximation, never exact')
print(f'  → The proof shows no EXACT rational representation can exist')

# --- 5.4: Infinitely many primes ---
print('\n' + '=' * 65)
print('5.4  PROOF BY CONTRADICTION: Infinitely Many Primes')
print('=' * 65)

def is_prime(n):
    """Check primality by trial division."""
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False
    for i in range(3, int(math.isqrt(n)) + 1, 2):
        if n % i == 0:
            return False
    return True

def euclid_contradiction_demo(assumed_primes):
    """
    Demonstrate Euclid's proof by contradiction.
    Assume these are ALL primes; construct N = product + 1;
    show N has a prime factor NOT in the list.
    """
    product = 1
    for p in assumed_primes:
        product *= p
    N = product + 1
    
    print(f'  Assumed ALL primes: {assumed_primes}')
    print(f'  N = {"×".join(str(p) for p in assumed_primes)} + 1 = {product} + 1 = {N}')
    
    # Find prime factors of N
    factors = []
    temp = N
    for p in range(2, min(N + 1, 10**6)):
        while temp % p == 0:
            factors.append(p)
            temp //= p
        if temp == 1:
            break
    
    new_primes = [f for f in set(factors) if f not in assumed_primes]
    print(f'  Prime factorisation of N: {" × ".join(str(f) for f in factors)}')
    print(f'  New primes found (not in assumed list): {new_primes}')
    print(f'  CONTRADICTION: our list was not complete! □')
    return new_primes

print('\nEuclid\'s argument — trace the contradiction:\n')

# Start with small assumed prime lists
for primes in [[2], [2, 3], [2, 3, 5], [2, 3, 5, 7], [2, 3, 5, 7, 11, 13]]:
    euclid_contradiction_demo(primes)
    print()

# Count primes up to N to show they keep growing
print('Prime counting function π(N) — primes never stop:')
for N in [10, 100, 1000, 10000, 100000]:
    count = sum(1 for n in range(2, N + 1) if is_prime(n))
    print(f'  π({N:>6}) = {count:>5}  (primes up to {N})')

# --- 5.5: No largest real number ---
print('\n' + '=' * 65)
print('5.5  PROOF BY CONTRADICTION: No Largest Real Number')
print('=' * 65)

print('\nAssume M is the largest real number.')
print('Construct M+1:')
for M in [1, 100, 1e10, 1e100, float('inf')]:
    if M == float('inf'):
        print(f'  M = ∞ is not a real number (∞ ∉ ℝ)')
    else:
        print(f'  M = {M:.0e} → M+1 = {M+1:.0e} > M → CONTRADICTION')
print('Therefore no largest real number exists □')

5.6 Non-Constructive Existence via Contradiction

Can prove xP(x)\exists x\, P(x) by contradiction: assume x¬P(x)\forall x\, \neg P(x); derive contradiction. This proves existence without constructing the object.

AI: existence of optimal parameters θ\theta^* minimising loss; proved via compactness (if loss is continuous on compact set); no algorithm for finding θ\theta^* given by the proof.

5.7 Contradiction vs Contrapositive

ContrapositiveContradiction
GoalProve P    QP \implies QProve PP (or P    QP \implies Q)
Assume¬Q\neg Q¬P\neg P
Derive¬P\neg PAny contradiction \bot
ConclusionP    QP \implies Q (logical equivalence)PP (since ¬P\neg P led to \bot)
Best when¬Q    ¬P\neg Q \implies \neg P is natural¬P\neg P has powerful consequences

5.8 Contradiction in AI Contexts

Cross-entropy loss is unbounded below:

  • Assume M:L(z,y)M\exists M: L(\mathbf{z}, y) \geq M for all z\mathbf{z}
  • Set zy=Tz_y = T, all other zv=0z_v = 0: L=T+log(exp(T)+V1)L = -T + \log(\exp(T) + |V|-1) \to -\infty as TT \to \infty
  • For large enough TT: L<ML < M; contradiction
  • Therefore LL is unbounded below \square

Code cell 19

# ══════════════════════════════════════════════════════════════════
# 5.6–5.8  Non-Constructive Existence & Contradiction in AI
# ══════════════════════════════════════════════════════════════════

# --- 5.7: Comparing contrapositive vs contradiction ---
print('5.7  CONTRAPOSITIVE vs CONTRADICTION — Side by Side')
print('=' * 65)

print('''
Example: Prove "if n² is even, then n is even"

METHOD 1 — CONTRAPOSITIVE:
  Goal:     Prove P → Q  (n² even → n even)
  Strategy: Prove ¬Q → ¬P  (n odd → n² odd)
  Assume:   n is odd (n = 2k+1)
  Derive:   n² = (2k+1)² = 2(2k²+2k)+1 is odd = ¬P
  Done:     ¬Q → ¬P proved; by equivalence, P → Q  □

METHOD 2 — CONTRADICTION:
  Goal:     Prove P → Q  (n² even → n even)
  Strategy: Assume P ∧ ¬Q; derive ⊥
  Assume:   n² is even AND n is odd
  Derive:   n odd → n² odd (by algebra above)
            But n² is even (hypothesis)
            n² even AND n² odd → CONTRADICTION  ⊥
  Done:     Assumption led to ⊥; therefore P → Q  □

Both work! Contrapositive is CLEANER for implications.
Contradiction is MORE FLEXIBLE (works for any statement, not just implications).
''')

# --- 5.8: Cross-entropy loss is unbounded below ---
print('5.8  CONTRADICTION IN AI: Cross-Entropy Unbounded Below')
print('=' * 65)

def cross_entropy(z, y):
    """Cross-entropy loss with numerical stability."""
    z_shifted = z - np.max(z)
    log_sum_exp = np.log(np.sum(np.exp(z_shifted))) + np.max(z)
    return -z[y] + log_sum_exp

n_classes = 10
y = 0  # true class

print(f'\nAssume ∃M: L(z, y) ≥ M for all z (i.e., loss has a lower bound)')
print(f'Set z_y = T, all other z_v = 0:')
print(f'L = -T + log(exp(T) + {n_classes - 1})')
print(f'\n{"T":>10}{"L(z,y)":>15}{"< M?":>6}')
print('─' * 40)

M_candidate = -100  # suppose someone claims M = -100 is a lower bound

for T in [1, 10, 100, 1000, 10000]:
    z = np.zeros(n_classes)
    z[y] = T
    L = cross_entropy(z, y)
    below_M = L < M_candidate
    print(f'{T:>10}{L:>15.6f}{"✓ < M" if below_M else "✗ ≥ M":>6}')

print(f'\nAs T → ∞: L → -T + log(exp(T) + {n_classes-1}) → log(1 + {n_classes-1}·exp(-T)) → 0⁻')
print(f'Wait — L approaches 0 from below, it does NOT go to -∞.')
print(f'')
print(f'Correction: Cross-entropy loss L(z,y) = -z_y + log Σ exp(z_v)')
print(f'  L(z,y) = -T + log(e^T + {n_classes-1}) = log(1 + {n_classes-1}·e^(-T))')
print(f'  This approaches 0 from above as T → ∞')
print(f'  The INFIMUM is 0, and it is NOT achieved (no finite z gives L=0)')
print(f'')
print(f'Proof by contradiction that inf is not achieved:')
print(f'  Assume ∃z*: L(z*,y) = 0')
print(f'  Then -z*_y + log Σ exp(z*_v) = 0')
print(f'  → log Σ exp(z*_v) = z*_y')
print(f'  → Σ exp(z*_v) = exp(z*_y)')
print(f'  → exp(z*_y) + Σ_{{v≠y}} exp(z*_v) = exp(z*_y)')
print(f'  → Σ_{{v≠y}} exp(z*_v) = 0')
print(f'  But exp(z*_v) > 0 for all v → Σ > 0 → CONTRADICTION ⊥')
print(f'  Therefore the infimum 0 is never achieved □')

# Verify numerically
print(f'\nNumerical verification — L approaches 0 but never reaches it:')
for T in [1, 10, 100, 1000]:
    z = np.full(n_classes, -T)  # push all non-target logits to -T
    z[y] = 0
    L = cross_entropy(z, y)
    print(f'  z_y=0, others=-{T}: L = {L:.2e} > 0 ✓')

6. Proof by Cases (Case Analysis)

6.1 The Strategy

  • Partition the set of possibilities into exhaustive, mutually exclusive cases
  • Prove the conclusion holds in each case separately
  • Since cases are exhaustive, conclusion holds universally
  • When to use: piecewise functions, absolute values, sign analysis, parity arguments

6.2 Template

Proof (by cases):
  Case 1: [condition C₁]  →  [prove conclusion]
  Case 2: [condition C₂]  →  [prove conclusion]
  ...
  Since C₁, C₂, … are exhaustive, conclusion holds in all cases. □

6.3 Worked Example — |xy| = |x||y|

Theorem: For all x,yRx, y \in \mathbb{R}: xy=xy|xy| = |x||y|.

Proof (by cases on signs):

  • Case 1: x0,y0x \geq 0, y \geq 0: xy0xy \geq 0; xy=xy=xy|xy| = xy = |x||y|
  • Case 2: x0,y<0x \geq 0, y < 0: xy0xy \leq 0; xy=xy|xy| = -xy; xy=x(y)=xy|x||y| = x(-y) = -xy
  • Case 3: x<0,y0x < 0, y \geq 0: xy0xy \leq 0; xy=xy|xy| = -xy; xy=(x)y=xy|x||y| = (-x)y = -xy
  • Case 4: x<0,y<0x < 0, y < 0: xy>0xy > 0; xy=xy|xy| = xy; xy=(x)(y)=xy|x||y| = (-x)(-y) = xy

All four cases exhaustive (cover all sign combinations); result holds \square

6.4 Worked Example — ReLU Subgradient

Theorem: The subgradient of ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x) at x=0x = 0 can be any value in [0,1][0, 1].

A subgradient gg satisfies: ReLU(y)ReLU(0)+g(y0)\text{ReLU}(y) \geq \text{ReLU}(0) + g(y - 0) for all yy.

Proof (by cases):

  • Case 1 (y>0y > 0): ReLU(y)=y\text{ReLU}(y) = y; need ygyy \geq gy → need g1g \leq 1
  • Case 2 (y<0y < 0): ReLU(y)=0\text{ReLU}(y) = 0; need 0gy0 \geq gy; since y<0y < 0, need g0g \geq 0
  • Case 3 (y=0y = 0): 000 \geq 0 trivially ✓

Combining: g[0,1]g \in [0, 1] \square

Code cell 21

# ══════════════════════════════════════════════════════════════════
# 6.1–6.4  Proof by Cases — |xy|=|x||y|, ReLU Subgradient, Parity
# ══════════════════════════════════════════════════════════════════

# --- 6.3: |xy| = |x||y| by case analysis on signs ---
print('6.3  PROOF BY CASES: |xy| = |x|·|y|')
print('=' * 65)

def case_classify(x, y):
    """Determine which sign case (x,y) falls into."""
    if x >= 0 and y >= 0:
        return 'Case 1: x≥0, y≥0'
    elif x >= 0 and y < 0:
        return 'Case 2: x≥0, y<0'
    elif x < 0 and y >= 0:
        return 'Case 3: x<0, y≥0'
    else:
        return 'Case 4: x<0, y<0'

print(f'{"x":>6} {"y":>6}{"Case":>20}{"|xy|":>8} {"=":>2} {"|x|·|y|":>8}{"Equal?":>6}')
print('─' * 65)

test_vals = [3.0, -3.0, 0.0, 1.5, -2.7, 100.0, -0.01]
all_pass = True
for x in test_vals:
    for y in test_vals[:4]:
        lhs = abs(x * y)
        rhs = abs(x) * abs(y)
        case = case_classify(x, y)
        eq = np.isclose(lhs, rhs)
        all_pass = all_pass and eq
        if abs(x) <= 10 and abs(y) <= 10:
            print(f'{x:>6.2f} {y:>6.2f}{case:>20}{lhs:>8.4f} {"=":>2} {rhs:>8.4f}{"✓" if eq else "✗":>6}')

print(f'\n✓ All {len(test_vals) * 4} test cases verified: |xy| = |x|·|y|')
print(f'  4 sign cases are EXHAUSTIVE (cover all real number pairs)')

# --- 6.4: ReLU subgradient at x=0 ---
print('\n' + '=' * 65)
print('6.4  PROOF BY CASES: ReLU Subgradient at x=0')
print('=' * 65)

def relu_val(x):
    return max(0.0, x)

def is_subgradient(g, x0=0.0):
    """Check if g is a subgradient of ReLU at x0.
    Requires: ReLU(y) ≥ ReLU(x0) + g(y - x0) for all y.
    """
    test_ys = np.linspace(-5, 5, 1001)
    for y in test_ys:
        if relu_val(y) < relu_val(x0) + g * (y - x0) - 1e-10:
            return False, y
    return True, None

print(f'\nSubgradient g of ReLU at x=0: ReLU(y) ≥ 0 + g·y for all y')
print(f'\n{"g":>6}{"Valid?":>6}{"Reason":>40}')
print('─' * 60)

for g in [-0.5, -0.1, 0.0, 0.25, 0.5, 0.75, 1.0, 1.1, 1.5]:
    valid, counterexample = is_subgradient(g)
    valid_label = '✓' if valid else '✗'
    if valid:
        reason = '✓ ReLU(y) ≥ g·y for all tested y'
    else:
        reason = f'✗ Fails at y={counterexample:.2f}: ReLU({counterexample:.2f})={relu_val(counterexample):.2f} < {g*counterexample:.2f}'
    print(f'{g:>6.2f}{valid_label:>6}{reason:>40}')

print(f'\nSubgradient set at x=0: [0, 1]')
print(f'  Case 1 (y>0): need g ≤ 1')
print(f'  Case 2 (y<0): need g ≥ 0')
print(f'  Case 3 (y=0): trivially satisfied')
print(f'  Intersection: g ∈ [0, 1] □')

# --- 6.5: Parity argument n(n+1) is even ---
print('\n' + '=' * 65)
print('6.5  PROOF BY CASES: n(n+1) is Even for All Integers')
print('=' * 65)

print(f'\n{"n":>6}{"Parity":>6}{"n(n+1)":>8} {"Even?":>6} │ Proof Trace')
print('─' * 65)

for n in range(-6, 7):
    product = n * (n + 1)
    parity = 'even' if n % 2 == 0 else 'odd'
    even = product % 2 == 0
    
    if n % 2 == 0:
        k = n // 2
        trace = f'n=2({k}), n(n+1) = 2({k})·({n+1}) = 2·{k*(n+1)}'
    else:
        k = (n - 1) // 2
        trace = f'n+1=2({(n+1)//2}), n(n+1) = {n}·2({(n+1)//2}) = 2·{n*((n+1)//2)}'
    
    print(f'{n:>6}{parity:>6}{product:>8} {"✓" if even else "✗":>6}{trace}')

print(f'\n✓ In every case, one of n or n+1 is even → product is even')
print(f'  Case 1 (n even): n = 2k → n(n+1) = 2k(n+1) = 2×[k(n+1)]')
print(f'  Case 2 (n odd):  n+1 = 2m → n(n+1) = n·2m = 2×[nm]')
print(f'  Exhaustive (every integer is even or odd) □')

# --- 6.6-6.7: Case analysis in AI ---
print('\n' + '=' * 65)
print('6.6–6.7  CASE ANALYSIS IN AI: Piecewise Linear Networks')
print('=' * 65)

# A ReLU network divides input space into linear regions
# In each region, the network is affine — prove property for each region

W1 = np.array([[2.0, -1.0], [-1.0, 3.0]])
b1 = np.array([0.5, -1.0])
W2 = np.array([[1.0, -2.0]])
b2 = np.array([0.0])

def relu_network(x):
    """2D input → 1D output ReLU network."""
    h = np.maximum(0, W1 @ x + b1)
    return (W2 @ h + b2)[0]

# Identify activation patterns (which ReLUs are active)
print(f'\nReLU network: x ∈ ℝ² → h = ReLU(W₁x + b₁) → y = W₂h + b₂')
print(f'Each activation pattern defines a LINEAR REGION')
print(f'\nSampling input space to identify linear regions:')

patterns = {}
for _ in range(10000):
    x = np.random.randn(2) * 2
    h_pre = W1 @ x + b1
    pattern = tuple(int(h > 0) for h in h_pre)
    if pattern not in patterns:
        patterns[pattern] = {'count': 0, 'example': x.copy()}
    patterns[pattern]['count'] += 1

print(f'  Found {len(patterns)} distinct activation patterns (linear regions):')
for pattern, info in sorted(patterns.items()):
    x = info['example']
    h_pre = W1 @ x + b1
    h = np.maximum(0, h_pre)
    y = (W2 @ h + b2)[0]
    
    # In this region, network is affine: y = W2 @ diag(pattern) @ W1 @ x + ...
    D = np.diag(pattern)
    W_eff = W2 @ D @ W1
    b_eff = (W2 @ D @ b1 + b2)[0]
    
    print(f'    Pattern {pattern}: {info["count"]:>4} samples | '
          f'Effective affine: y = [{W_eff[0,0]:.1f}, {W_eff[0,1]:.1f}]·x + {b_eff:.1f}')

print(f'\n✓ In each region, network is AFFINE (linear + bias)')
print(f'✓ Properties provable per-region then combined via case analysis')
print(f'✓ This is exactly proof by cases applied to neural networks!')

7. Mathematical Induction

7.1 The Strategy

To prove nN  P(n)\forall n \in \mathbb{N}\; P(n):

  1. Base case: prove P(n0)P(n_0) (smallest relevant nn)
  2. Inductive step: assume P(k)P(k) (inductive hypothesis); prove P(k+1)P(k+1)
  3. Conclusion: by the principle of mathematical induction, P(n)P(n) holds for all nn0n \geq n_0

Logical foundation: well-ordering principle — every non-empty subset of N\mathbb{N} has a minimum. Induction is equivalent to well-ordering.

7.2 Why Induction Works

Proof that induction is valid (by well-ordering):

  • Suppose base case and inductive step both hold; suppose for contradiction n:¬P(n)\exists n: \neg P(n)
  • Let S={nN¬P(n)}S = \{n \in \mathbb{N} \mid \neg P(n)\}; SS non-empty; by well-ordering, SS has minimum mm
  • mn0m \neq n_0 (base case holds); so m>n0m > n_0; then m1<mm-1 < m; m1Sm-1 \notin S; P(m1)P(m-1) holds
  • By inductive step: P(m1)    P(m)P(m-1) \implies P(m); so P(m)P(m) holds; contradicts mSm \in S
  • Contradiction; SS must be empty; P(n)P(n) holds for all nn \square

7.3 Template

Proof (by mathematical induction):
  Base case (n = n₀):
    [Verify P(n₀) directly]

  Inductive step:
    Let k ≥ n₀; assume P(k) (inductive hypothesis).
    [Derive P(k+1) using P(k)]

  By the principle of mathematical induction, P(n) holds for all n ≥ n₀. □

7.4 Worked Example — Sum Formula

Theorem: n1:i=1ni=n(n+1)2\forall n \geq 1: \sum_{i=1}^{n} i = \frac{n(n+1)}{2}

Proof (by induction):

  • Base case (n=1n=1): i=11i=1=1×22\sum_{i=1}^{1} i = 1 = \frac{1 \times 2}{2}
  • Inductive step: assume i=1ki=k(k+1)2\sum_{i=1}^{k} i = \frac{k(k+1)}{2}; prove for k+1k+1:
i=1k+1i=(i=1ki)+(k+1)=k(k+1)2+(k+1)=(k+1)(k+2)2\sum_{i=1}^{k+1} i = \left(\sum_{i=1}^{k} i\right) + (k+1) = \frac{k(k+1)}{2} + (k+1) = \frac{(k+1)(k+2)}{2}

By induction, result holds for all n1n \geq 1 \square

Code cell 23

# ══════════════════════════════════════════════════════════════════
# 7.1–7.4  Mathematical Induction — Sum Formula + Chain Visualiser
# ══════════════════════════════════════════════════════════════════

def induction_proof_sum_formula(n_max=15):
    """
    Prove by induction: Σᵢ₌₁ⁿ i = n(n+1)/2
    
    Base case: n=1 → 1 = 1×2/2 = 1 ✓
    Inductive step: assume Σᵢ₌₁ᵏ i = k(k+1)/2
                    prove Σᵢ₌₁ᵏ⁺¹ i = (k+1)(k+2)/2
    """
    print('7.4  INDUCTION PROOF: Σᵢ₌₁ⁿ i = n(n+1)/2')
    print('=' * 70)
    
    # Base case
    n = 1
    lhs = sum(range(1, n + 1))
    rhs = n * (n + 1) // 2
    print(f'\nBASE CASE (n=1):')
    print(f'  LHS: Σᵢ₌₁¹ i = {lhs}')
    print(f'  RHS: 1×2/2 = {rhs}')
    print(f'  Equal? {"✓ BASE CASE HOLDS" if lhs == rhs else "✗ BASE CASE FAILS"}')
    
    # Inductive step — trace for each k
    print(f'\nINDUCTIVE STEP: Assume P(k), prove P(k+1)')
    print(f'{"k":>4}{"P(k) holds?":>12}{"Σᵢ₌₁ᵏ⁺¹":>10} = {"k(k+1)/2":>10} + {"(k+1)":>6} = {"(k+1)(k+2)/2":>14}{"P(k+1)?":>8}')
    print('─' * 85)
    
    all_ok = True
    for k in range(1, n_max + 1):
        # Verify P(k) — inductive hypothesis
        sum_k = sum(range(1, k + 1))
        formula_k = k * (k + 1) // 2
        pk_holds = sum_k == formula_k
        
        # Prove P(k+1) using P(k)
        sum_k1 = sum_k + (k + 1)  # Using inductive hypothesis
        formula_k1 = (k + 1) * (k + 2) // 2
        pk1_holds = sum_k1 == formula_k1
        
        all_ok = all_ok and pk1_holds
        pk_label = '✓' if pk_holds else '✗'
        pk1_label = '✓' if pk1_holds else '✗'
        print(f'{k:>4}{pk_label:>12}{sum_k1:>10} = {formula_k:>10} + {k+1:>6} = {formula_k1:>14}{pk1_label:>8}')
    
    final_label = '✓ Induction complete' if all_ok else '✗ Induction failed'
    print(f'\n{final_label}: '
          f'P(n) holds for all n ≥ 1 (verified up to n={n_max+1})')
    return all_ok

induction_proof_sum_formula(15)

# --- Induction chain visualisation ---
print('\n' + '=' * 70)
print('7.2  WHY INDUCTION WORKS — The Domino Chain')
print('=' * 70)

print('''
The induction principle creates an infinite chain of implications:

  P(1) ──────────→ P(2) ──────────→ P(3) ──────────→ P(4) ──────────→ ...
  ↑ base case       ↑ step(1→2)      ↑ step(2→3)      ↑ step(3→4)

  Base case grounds the chain:  P(1) is TRUE
  Step propagates truth:        P(k) → P(k+1) for all k
  Together:                     P(n) holds for ALL n ≥ 1

Without base case:  ??? → P(2) → P(3) → ...  (no anchor; chain floats)
Without step:       P(1) ... P(2)?             (no propagation; stuck at base)
''')

# Demonstrate: induction WITHOUT base case can "prove" nonsense
print('⚠ DANGER: Induction without base case "proves" false statements')
print('─' * 70)
print('"Theorem": All positive integers are > 100')
print('  "Inductive step": Assume k > 100. Then k+1 > k > 100. ✓')
print('  But P(1): 1 > 100? NO! Base case fails.')
print('  The inductive step is vacuously true (P(k) is always false for k ≤ 100)')
print('  ✗ Without a valid base case, the "proof" is invalid')

# --- Additional induction examples ---
print('\n' + '=' * 70)
print('ADDITIONAL INDUCTION EXAMPLES')
print('=' * 70)

# Geometric series
print('\n7.5  Geometric Series: Σᵢ₌₀ⁿ rⁱ = (1-rⁿ⁺¹)/(1-r) for r≠1')
print('─' * 70)

r = 0.5
print(f'r = {r}')
print(f'{"n":>4}{"Σᵢ₌₀ⁿ rⁱ (computed)":>22} {"(1-rⁿ⁺¹)/(1-r)":>22}{"Match?":>6}')
print('─' * 60)

for n in range(8):
    lhs = sum(r**i for i in range(n + 1))
    rhs = (1 - r**(n + 1)) / (1 - r)
    match = np.isclose(lhs, rhs)
    print(f'{n:>4}{lhs:>22.10f} {rhs:>22.10f}{"✓" if match else "✗":>6}')

# Power of 2 bound
print(f'\n7.6  Power of 2 Bound: 2ⁿ > n for all n ≥ 1')
print('─' * 70)
print(f'{"n":>4}{"2ⁿ":>12} {"> n?":>6} │ Inductive step: 2ⁿ⁺¹ = 2·2ⁿ > 2n ≥ n+1 (since n≥1)')
print('─' * 60)

for n in range(1, 16):
    power = 2**n
    holds = power > n
    step_detail = f'2·{2**(n-1)} = {power} > {2*(n-1)}{n}' if n > 1 else 'base case'
    print(f'{n:>4}{power:>12} {"✓" if holds else "✗":>6}{step_detail}')

print(f'\n✓ All verified: base case (2¹=2>1) + step (2ⁿ⁺¹ = 2·2ⁿ > 2n ≥ n+1 for n≥1)')

7.7 Common Mistakes in Induction

MistakeWhy It's WrongFix
Forgetting base caseInductive step alone proves nothing; must anchor the chainAlways verify P(n0)P(n_0) explicitly
Circular reasoningUsing P(k+1)P(k+1) in proof of P(k+1)P(k+1)Inductive hypothesis is P(k)P(k), not P(k+1)P(k+1)
Wrong base caseProve P(0)P(0) when theorem requires P(1)P(1)Match base case to theorem's domain
Insufficient base casesSome proofs need P(0)P(0) AND P(1)P(1)Verify step works from smallest values
Not using hypothesisIf proof of P(k+1)P(k+1) never uses P(k)P(k), it's not inductionRewrite as direct proof or verify you need induction

The "all horses are the same colour" fallacy is the classic example: the inductive step fails at n=1n=2n=1 \to n=2 because the overlap argument requires n2n \geq 2, but the base case only establishes n=1n=1.

Code cell 25

# ══════════════════════════════════════════════════════════════════
# 7.7  Common Mistakes in Induction — The Horse Fallacy
# ══════════════════════════════════════════════════════════════════

print('7.7  THE "ALL HORSES ARE THE SAME COLOUR" FALLACY')
print('=' * 70)

print('''
"Theorem" (FALSE): All horses are the same colour.

"Proof" by induction on n (number of horses):
  Base case (n=1): A single horse is trivially the same colour as itself. ✓
  
  Inductive step: Assume any group of k horses are all the same colour.
    Consider a group of k+1 horses: {h₁, h₂, ..., h_k, h_{k+1}}
    - First k horses: {h₁, ..., h_k} → all same colour (by IH)
    - Last k horses:  {h₂, ..., h_{k+1}} → all same colour (by IH)
    - Overlap: {h₂, ..., h_k} are in BOTH groups
    - Therefore h₁ and h_{k+1} are the same colour as the overlap
    - All k+1 horses are the same colour ✓
  
  By induction, all horses are the same colour □

WHERE IS THE BUG?
''')

# Trace the failure at n=1 → n=2
print('TRACING THE BUG:')
print('─' * 70)

for n in range(1, 6):
    k = n
    first_k = list(range(1, k + 1))
    last_k = list(range(2, k + 2))
    overlap = [x for x in first_k if x in last_k]
    
    print(f'  n={k} → n+1={k+1}:')
    print(f'    First {k}: {first_k}')
    print(f'    Last  {k}: {last_k}')
    print(f'    Overlap:  {overlap}')
    
    if len(overlap) == 0:
        print(f'    ✗ OVERLAP IS EMPTY! Cannot conclude h₁ and h_{k+1} share colour')
        print(f'    ✗ THE INDUCTIVE STEP FAILS AT n=1 → n=2')
    else:
        print(f'    ✓ Overlap non-empty — step would work IF base case covered n≥2')

print(f'\n✗ The step requires overlap of size ≥ 1, which needs k ≥ 2')
print(f'  But base case only proves P(1). Gap between P(1) and P(2)!')
print(f'  Lesson: ALWAYS check that the inductive step works from the base case')

# Another common mistake: not using the inductive hypothesis
print('\n' + '=' * 70)
print('MISTAKE: "Induction" proof that doesn\'t use the hypothesis')
print('=' * 70)
print('''
"Prove": n + 1 > n for all n ≥ 1

"Proof by induction":
  Base case: 1 + 1 = 2 > 1 ✓
  Inductive step: (k+1) + 1 = k + 2 > k + 1  ✓

This is CORRECT but it's NOT really induction — the step never uses P(k).
The direct proof works: for any n, n+1 = n + 1 > n since 1 > 0.
If the inductive hypothesis is never used, it's a direct proof in disguise.
''')

# Verify the sum of cubes formula by induction
print('=' * 70)
print('BONUS: Σᵢ₌₁ⁿ i³ = [n(n+1)/2]² (induction)')
print('=' * 70)
print(f'{"n":>4}{"Σi³ (raw sum)":>14} {"[n(n+1)/2]²":>14}{"P(k)→P(k+1)":>30}')
print('─' * 70)

for n in range(1, 13):
    lhs = sum(i**3 for i in range(1, n + 1))
    rhs = (n * (n + 1) // 2) ** 2
    
    # Show the inductive step algebra
    if n >= 2:
        prev_rhs = ((n-1) * n // 2) ** 2
        step = f'{prev_rhs} + {n**3} = {prev_rhs + n**3}'
    else:
        step = 'base case'
    
    print(f'{n:>4}{lhs:>14} {rhs:>14}{step:>30} {"✓" if lhs == rhs else "✗"}')

8. Strong Induction

8.1 The Strategy

  • Ordinary induction: assume P(k)P(k); prove P(k+1)P(k+1)
  • Strong induction: assume P(n0),P(n0+1),,P(k)P(n_0), P(n_0+1), \ldots, P(k) all hold; prove P(k+1)P(k+1)
  • Equivalent in power but more convenient when P(k+1)P(k+1) depends on multiple previous values
  • Same well-ordering justification as ordinary induction

8.2 Template

Proof (by strong induction):
  Base case(s): verify P(n₀), ..., P(n₁) directly.
  Inductive step: Let k ≥ n₁; assume P(j) for all n₀ ≤ j ≤ k.
    [Derive P(k+1) using any of P(n₀), ..., P(k)]
  By strong induction, P(n) holds for all n ≥ n₀. □

8.3 Worked Example — Prime Factorisation

Theorem: Every integer n2n \geq 2 is either prime or a product of primes.

Proof (by strong induction):

  • Base case (n=2n=2): 2 is prime ✓
  • Inductive step: let k2k \geq 2; assume every integer jj with 2jk2 \leq j \leq k is prime or product of primes
    • Case 1: k+1k+1 is prime → done ✓
    • Case 2: k+1k+1 is not prime → a,b\exists a, b with 1<a,b<k+11 < a, b < k+1 and ab=k+1ab = k+1
    • Since 2ak2 \leq a \leq k and 2bk2 \leq b \leq k, by strong IH: aa and bb are products of primes
    • Therefore ab=k+1ab = k+1 is a product of primes ✓
  • By strong induction, result holds for all n2n \geq 2 \square

8.4 Worked Example — Fibonacci Bound

Fibonacci: F1=1,F2=1,Fn=Fn1+Fn2F_1 = 1, F_2 = 1, F_n = F_{n-1} + F_{n-2} for n3n \geq 3.

Theorem: Fn<2nF_n < 2^n for all n1n \geq 1.

Proof (by strong induction):

  • Base cases: F1=1<2F_1 = 1 < 2 ✓; F2=1<4F_2 = 1 < 4
  • Inductive step: assume Fj<2jF_j < 2^j for all jkj \leq k; prove Fk+1<2k+1F_{k+1} < 2^{k+1}:
Fk+1=Fk+Fk1<2k+2k1<2k+2k=2k+1F_{k+1} = F_k + F_{k-1} < 2^k + 2^{k-1} < 2^k + 2^k = 2^{k+1}

By strong induction, Fn<2nF_n < 2^n for all n1n \geq 1 \square

8.5 Strong Induction in AI Analysis

  • Proving properties of recursive algorithms via strong induction on input size
  • BPE tokenisation: tokens increase in length at each merge → induction on merge count
  • Transformer layer analysis: property at layer LL depends on layers 11 through L1L-1 → strong induction on depth
  • Binary tree depth: depth ≤ n1n-1 for nn nodes → induction using both subtrees

Code cell 27

# ══════════════════════════════════════════════════════════════════
# 8.1–8.5  Strong Induction — Prime Factorisation + Fibonacci Bound
# ══════════════════════════════════════════════════════════════════

# --- 8.3: Prime factorisation by strong induction ---
print('8.3  STRONG INDUCTION: Every n ≥ 2 is Prime or Product of Primes')
print('=' * 70)

def prime_factorise(n):
    """Factorise n into primes (the constructive part of the proof)."""
    factors = []
    d = 2
    temp = n
    while d * d <= temp:
        while temp % d == 0:
            factors.append(d)
            temp //= d
        d += 1
    if temp > 1:
        factors.append(temp)
    return factors

print(f'{"n":>4}{"Prime?":>6}{"Factorisation":>30}{"Strong IH used?":>20}')
print('─' * 70)

for n in range(2, 26):
    factors = prime_factorise(n)
    is_p = len(factors) == 1
    factors_str = ' × '.join(str(f) for f in factors)
    
    if is_p:
        ih_used = 'Case 1: prime (done)'
    else:
        # In Case 2, we factor n = a × b, and use IH on a and b
        a = factors[0]
        b = n // a
        ih_used = f'Case 2: P({a})∧P({b})→P({n})'
    
    prime_label = '✓' if is_p else '✗'
    print(f'{n:>4}{prime_label:>6}{factors_str:>30}{ih_used:>20}')

print(f'\n✓ Every number factored into primes')
print(f'  Key: strong IH lets us use P(a) and P(b) for ANY a,b < n')
print(f'  Ordinary induction only gives P(n-1) — not enough here!')

# Show why ordinary induction fails
print('\n  Why ordinary induction is insufficient:')
print('    For n=12: 12 = 3 × 4')
print('    Need P(3) and P(4), but ordinary IH only gives P(11)')
print('    P(11) tells us 11 is prime — useless for factoring 12!')
print('    Strong IH gives P(2), P(3), ..., P(11) — we can use P(3) and P(4)')

# --- 8.4: Fibonacci bound by strong induction ---
print('\n' + '=' * 70)
print('8.4  STRONG INDUCTION: Fₙ < 2ⁿ for all n ≥ 1')
print('=' * 70)

# Build Fibonacci sequence from scratch
fib = {1: 1, 2: 1}
for n in range(3, 31):
    fib[n] = fib[n - 1] + fib[n - 2]

print(f'\n{"n":>4}{"Fₙ":>12} {"2ⁿ":>12} {"Fₙ < 2ⁿ?":>10}{"Proof trace":>40}')
print('─' * 85)

for n in range(1, 21):
    fn = fib[n]
    bound = 2**n
    holds = fn < bound
    
    if n <= 2:
        trace = f'Base case: F_{n}={fn} < {bound}'
    else:
        trace = f'F_{n-1}+F_{n-2} = {fib[n-1]}+{fib[n-2]} < {2**(n-1)}+{2**(n-2)} < {bound}'
    
    print(f'{n:>4}{fn:>12} {bound:>12} {"✓" if holds else "✗":>10}{trace:>40}')

print(f'\n✓ Verified for n=1..20: Fₙ < 2ⁿ')
print(f'  Strong IH critical: F_{n} = F_{n-1} + F_{n-2} uses TWO previous values')

# --- 8.5: Strong induction in AI — Transformer layer properties ---
print('\n' + '=' * 70)
print('8.5  STRONG INDUCTION IN AI: Layer-wise Bound Propagation')
print('=' * 70)

# Demonstrate: if each layer amplifies by at most L, output bounded by L^depth
n_layers = 8
d = 4
x = np.random.randn(d)
x = x / np.linalg.norm(x)  # unit vector input

print(f'\nTransformer-like network: {n_layers} layers, each with Lipschitz constant ≤ L')
print(f'Strong induction: ‖output at layer k‖ ≤ L^k · ‖input‖')
print(f'\nInput ‖x‖ = {np.linalg.norm(x):.4f}')

L_max = 1.5  # max Lipschitz constant per layer
print(f'{"Layer":>6}{"‖h_k‖":>10} {"L^k·‖x‖":>12} {"Bounded?":>8} │ Note')
print('─' * 60)

h = x.copy()
for k in range(1, n_layers + 1):
    # Random linear layer (scaled to have operator norm ≈ L_max)
    W = np.random.randn(d, d)
    W = W / np.linalg.norm(W, ord=2) * (L_max * 0.9)  # ensure ≤ L_max
    h = np.maximum(0, W @ h)  # ReLU(Wx)
    
    actual_norm = np.linalg.norm(h)
    bound = L_max**k * np.linalg.norm(x)
    bounded = actual_norm <= bound + 1e-10
    
    note = 'Base' if k == 1 else f'Uses P(1)..P({k-1})'
    print(f'{k:>6}{actual_norm:>10.4f} {bound:>12.4f} {"✓" if bounded else "✗":>8}{note}')

print(f'\n✓ At each layer, bound holds because we can use ALL previous layer bounds')
print(f'  This is the essence of strong induction in deep network analysis')

9. Structural Induction

9.1 The Strategy

Induction over recursively defined structures: trees, lists, formulas, parse trees.

  • Base case: prove property for atomic/base structures
  • Inductive step: assume property holds for all sub-structures; prove it for the composite
  • Justified because every structure is built from smaller structures in finitely many steps

9.2 Common Structures for Structural Induction

StructureBase CaseRecursive Step
ListsEmpty list []cons(head, tail) where tail is smaller
Binary treesLeafnode(left, value, right)
Propositional formulasAtomic proposition pp¬φ\neg\varphi, φψ\varphi \wedge \psi, φψ\varphi \vee \psi
Natural numbers00succ(n)\text{succ}(n) — ordinary induction!
Token sequencesEmptyprepend(token,sequence)\text{prepend}(\text{token}, \text{sequence})

9.3 Worked Example — Full Binary Tree: Leaves = (Nodes+1)/2

Theorem: In a full binary tree TT (every internal node has exactly 2 children): L(T)=N(T)+12L(T) = \frac{N(T) + 1}{2}

Proof (by structural induction):

  • Base (TT = single leaf): N=1N = 1, L=1L = 1; (1+1)/2=1(1+1)/2 = 1
  • Step: T=node(T,Tr)T = \text{node}(T_\ell, T_r)
    • By IH: L(T)=(N(T)+1)/2L(T_\ell) = (N(T_\ell)+1)/2 and L(Tr)=(N(Tr)+1)/2L(T_r) = (N(T_r)+1)/2
    • N(T)=N(T)+N(Tr)+1N(T) = N(T_\ell) + N(T_r) + 1
    • L(T)=L(T)+L(Tr)=N(T)+12+N(Tr)+12=N(T)+N(Tr)+22=N(T)+12L(T) = L(T_\ell) + L(T_r) = \frac{N(T_\ell)+1}{2} + \frac{N(T_r)+1}{2} = \frac{N(T_\ell)+N(T_r)+2}{2} = \frac{N(T)+1}{2}

By structural induction, result holds for all full binary trees \square

9.4 Worked Example — Formula Length

Define length of propositional formula:

  • len(p)=1\text{len}(p) = 1 for atomic pp
  • len(¬φ)=len(φ)+1\text{len}(\neg\varphi) = \text{len}(\varphi) + 1
  • len(φψ)=len(φ)+len(ψ)+1\text{len}(\varphi \wedge \psi) = \text{len}(\varphi) + \text{len}(\psi) + 1

Theorem: Every formula has positive length.

Proof (by structural induction):

  • Base: len(p)=1>0\text{len}(p) = 1 > 0
  • Step ¬φ\neg\varphi: by IH len(φ)>0\text{len}(\varphi) > 0; len(¬φ)=len(φ)+1>0\text{len}(\neg\varphi) = \text{len}(\varphi)+1 > 0
  • Step φψ\varphi \wedge \psi: by IH both >0> 0; sum +1>0+ 1 > 0

By structural induction, all formulas have positive length \square

9.5 Structural Induction in AI

  • Proving properties of parse trees in language models
  • Proving recursive tokenisation correctness: base = single character; step = apply merge
  • Proving recursive neural network computations are well-defined by induction on tree depth
  • Chain of thought: reasoning traces have recursive structure; proofs about CoT correctness use structural induction on derivation trees

Code cell 29

# ══════════════════════════════════════════════════════════════════
# 9.1–9.5  Structural Induction — Trees, Formulas, Token Sequences
# ══════════════════════════════════════════════════════════════════

# --- 9.3: Full binary tree — leaves = (nodes + 1) / 2 ---
print('9.3  STRUCTURAL INDUCTION: Full Binary Tree Leaf Count')
print('=' * 70)

class BinaryTree:
    """Full binary tree: either a leaf or a node with exactly 2 children."""
    def __init__(self, left=None, right=None, value=None):
        self.left = left
        self.right = right
        self.value = value
        self.is_leaf = (left is None and right is None)
    
    def count_nodes(self):
        if self.is_leaf:
            return 1
        return 1 + self.left.count_nodes() + self.right.count_nodes()
    
    def count_leaves(self):
        if self.is_leaf:
            return 1
        return self.left.count_leaves() + self.right.count_leaves()
    
    def depth(self):
        if self.is_leaf:
            return 0
        return 1 + max(self.left.depth(), self.right.depth())
    
    def verify_leaf_formula(self):
        """Verify L(T) = (N(T) + 1) / 2 by structural induction."""
        N = self.count_nodes()
        L = self.count_leaves()
        formula = (N + 1) / 2
        return L, N, formula, abs(L - formula) < 1e-10

def make_leaf(val='•'):
    return BinaryTree(value=val)

def make_node(left, right, val='○'):
    return BinaryTree(left=left, right=right, value=val)

# Build several full binary trees of increasing size
trees = [
    ('Single leaf', make_leaf()),
    ('1 internal + 2 leaves', 
     make_node(make_leaf(), make_leaf())),
    ('3 internals + 4 leaves',
     make_node(
         make_node(make_leaf(), make_leaf()),
         make_node(make_leaf(), make_leaf())
     )),
    ('7 internals + 8 leaves',
     make_node(
         make_node(
             make_node(make_leaf(), make_leaf()),
             make_node(make_leaf(), make_leaf())
         ),
         make_node(
             make_node(make_leaf(), make_leaf()),
             make_node(make_leaf(), make_leaf())
         )
     )),
    ('Unbalanced (5 internals + 6 leaves)',
     make_node(
         make_leaf(),
         make_node(
             make_leaf(),
             make_node(
                 make_leaf(),
                 make_node(
                     make_leaf(),
                     make_node(make_leaf(), make_leaf())
                 )
             )
         )
     )),
]

print(f'{"Tree":>35}{"N":>4} {"L":>4} {"(N+1)/2":>8}{"L=(N+1)/2?":>11}')
print('─' * 70)

all_ok = True
for name, tree in trees:
    L, N, formula, ok = tree.verify_leaf_formula()
    all_ok = all_ok and ok
    depth = tree.depth()
    print(f'{name:>35}{N:>4} {L:>4} {formula:>8.1f}{"✓" if ok else "✗":>11}  (depth={depth})')

print(f'\n{"✓ All trees satisfy L = (N+1)/2" if all_ok else "✗ Formula violated!"}')
print(f'  Structural induction follows the RECURSIVE definition of the tree')

# --- 9.4: Formula length is always positive ---
print('\n' + '=' * 70)
print('9.4  STRUCTURAL INDUCTION: Formula Length is Positive')
print('=' * 70)

class Formula:
    """Propositional formula: atomic, negation, or binary connective."""
    pass

class Atom(Formula):
    def __init__(self, name):
        self.name = name
    def length(self):
        return 1
    def __str__(self):
        return self.name

class Not(Formula):
    def __init__(self, sub):
        self.sub = sub
    def length(self):
        return self.sub.length() + 1
    def __str__(self):
        return f'¬{self.sub}'

class And(Formula):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def length(self):
        return self.left.length() + self.right.length() + 1
    def __str__(self):
        return f'({self.left}{self.right})'

class Or(Formula):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def length(self):
        return self.left.length() + self.right.length() + 1
    def __str__(self):
        return f'({self.left}{self.right})'

# Build formulas of increasing complexity
p, q, r = Atom('p'), Atom('q'), Atom('r')
formulas = [
    p,                              # atomic
    Not(p),                         # ¬p
    And(p, q),                      # p ∧ q
    Or(Not(p), q),                  # ¬p ∨ q  (material implication)
    And(Or(p, q), Not(r)),          # (p ∨ q) ∧ ¬r
    Not(And(Not(p), Not(q))),       # ¬(¬p ∧ ¬q)  (De Morgan)
    Or(And(p, q), And(Not(p), r)),  # (p∧q) ∨ (¬p∧r)
]

print(f'\n{"Formula":>35}{"Length":>6} {"Positive?":>10}{"Induction type":>20}')
print('─' * 80)

for f in formulas:
    length = f.length()
    positive = length > 0
    
    if isinstance(f, Atom):
        ind_type = 'Base case'
    elif isinstance(f, Not):
        ind_type = f'Step: ¬φ (len={f.sub.length()}+1)'
    else:
        ind_type = f'Step: φ○ψ ({f.left.length()}+{f.right.length()}+1)'
    
    print(f'{str(f):>35}{length:>6} {"✓" if positive else "✗":>10}{ind_type:>20}')

print(f'\n✓ Every formula has positive length (structural induction verified)')

# --- 9.5: Structural induction on token sequences ---
print('\n' + '=' * 70)
print('9.5  STRUCTURAL INDUCTION IN AI: Token Sequence Properties')
print('=' * 70)

# Prove: length of BPE tokenisation ≤ length of character sequence
def bpe_tokenise(text, merges):
    """Simple BPE tokeniser — structural induction on merges."""
    tokens = list(text)  # Base case: character-level
    
    for merge_pair, new_token in merges:
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and tokens[i] == merge_pair[0] and tokens[i+1] == merge_pair[1]:
                new_tokens.append(new_token)
                i += 2  # Merge reduces token count by 1
            else:
                new_tokens.append(tokens[i])
                i += 1
        tokens = new_tokens
    
    return tokens

# Define merge rules
merges = [
    (('h', 'e'), 'he'),
    (('l', 'l'), 'll'),
    (('he', 'll'), 'hell'),
    (('o', ' '), 'o '),
    (('hell', 'o '), 'hello '),
]

text = "hello world"
print(f'\nText: "{text}" (length {len(text)} characters)')
print(f'\nStructural induction on merge applications:')
print(f'  Base: character sequence has length {len(text)}')

tokens = list(text)
for i, (merge_pair, new_token) in enumerate(merges):
    old_len = len(tokens)
    tokens = bpe_tokenise(''.join(tokens) if all(len(t) == 1 for t in tokens) else '', [(merge_pair, new_token)])
    # Actually re-run from scratch with cumulative merges
    tokens = bpe_tokenise(text, merges[:i+1])
    new_len = len(tokens)
    print(f'  Merge {i+1}: {merge_pair} → "{new_token}" │ '
          f'tokens={tokens} │ len: {old_len}{new_len} '
          f'{"(decreased ✓)" if new_len <= old_len else "(INCREASED ✗)"}')

print(f'\n✓ Each merge reduces or maintains token count')
print(f'  By structural induction on merges: |tokenised| ≤ |characters|')

10. The Probabilistic Method

10.1 The Strategy

Prove that an object with desired property exists by showing that a randomly chosen object has the property with positive probability.

Pr[object has property]>0    at least one such object exists\Pr[\text{object has property}] > 0 \implies \text{at least one such object exists}

Introduced by Erdős; one of the most powerful existence proof techniques in combinatorics. Non-constructive: proves existence without exhibiting the object.

10.2 Template

Proof (probabilistic method):
  Define a probability space over candidate objects.
  Compute (or bound) Pr[object satisfies P].
  Show Pr[P] > 0.
  Therefore at least one object satisfying P exists. □

10.3–10.4 Worked Examples

Bipartite subgraph (10.4): Every graph G=(V,E)G = (V, E) has a bipartite subgraph with E/2\geq |E|/2 edges.

Proof: Randomly assign each vertex to LL or RR with probability 1/21/2 each. For each edge (u,v)(u,v): Pr[edge crosses cut]=1/2\Pr[\text{edge crosses cut}] = 1/2. Expected crossing edges =E/2= |E|/2. Since expectation =E/2= |E|/2, some assignment achieves E/2\geq |E|/2. \square

10.5 Lovász Local Lemma

Want to avoid nn "bad events" A1,,AnA_1, \ldots, A_n simultaneously.

  • Each Pr(Ai)p\Pr(A_i) \leq p; each depends on at most dd other events
  • LLL: if ep(d+1)1ep(d+1) \leq 1, then Pr(iAˉi)>0\Pr(\cap_i \bar{A}_i) > 0
  • All bad events can be simultaneously avoided

10.6 First Moment Method

If E[X]=μ\mathbb{E}[X] = \mu, then:

  • Pr(Xμ)>0\Pr(X \geq \mu) > 0 (some outcome achieves at least the mean)
  • Pr(Xμ)>0\Pr(X \leq \mu) > 0 (some outcome achieves at most the mean)

Markov's inequality: Pr(Xt)E[X]t\Pr(X \geq t) \leq \frac{\mathbb{E}[X]}{t} for non-negative XX

10.7 Probabilistic Method in AI Theory

  • Compressed sensing: random Gaussian matrices satisfy RIP with high probability
  • Johnson-Lindenstrauss: random projection preserves distances → good low-dim embedding exists
  • PAC learning: random sample of size nn has low empirical error with high probability
  • Neural tangent kernel: random initialisations enable convergence guarantees

Code cell 31

# ══════════════════════════════════════════════════════════════════
# SECTION 10: THE PROBABILISTIC METHOD
# ══════════════════════════════════════════════════════════════════

import random
import math

random.seed(42)

# ── 10.3 Tournament existence (Ramsey-type) ──────────────────────
# Prove: there exists a tournament on n players where every set of
# k players has someone who beat all others in that set.
# Strategy: random tournament, bound probability of failure

def random_tournament_demo(n, k):
    """
    Demonstrate probabilistic method for tournament existence.
    Random tournament: each pair (i,j), i beats j with prob 1/2.
    Bad event for subset S of size k: no player in S beats all others in S.
    """
    print(f"Tournament existence: n={n} players, k={k}")
    print(f"{'='*55}")

    # Number of k-subsets
    num_subsets = 1
    for i in range(k):
        num_subsets = num_subsets * (n - i) // (i + 1)

    # For a fixed subset S of size k:
    # Pr[player i beats all k-1 others in S] = (1/2)^{k-1}
    # Pr[NO player in S beats all others] = (1 - (1/2)^{k-1})^k
    p_no_king_in_S = (1 - (1/2)**(k-1))**k

    # Union bound: Pr[some bad subset exists] <= C(n,k) * p_no_king
    union_bound = num_subsets * p_no_king_in_S

    print(f"  Number of {k}-subsets: C({n},{k}) = {num_subsets}")
    print(f"  Pr[no king in fixed subset] = (1 - 2^{{-{k-1}}})^{k}")
    print(f"                              = {p_no_king_in_S:.6f}")
    print(f"  Union bound on failure:  {num_subsets} × {p_no_king_in_S:.6f}")
    print(f"                         = {union_bound:.6f}")

    if union_bound < 1:
        print(f"  ✓ Union bound < 1 → good tournament EXISTS")
        print(f"    (Probabilistic method: Pr[success] ≥ {1-union_bound:.6f} > 0)")
    else:
        print(f"  ✗ Union bound ≥ 1 → bound inconclusive for these parameters")

    # Actually construct one by random trial
    print(f"\n  Searching for concrete example by random sampling...")
    for trial in range(1000):
        # Random tournament: beats[i][j] = True if i beats j
        beats = [[False]*n for _ in range(n)]
        for i in range(n):
            for j in range(i+1, n):
                if random.random() < 0.5:
                    beats[i][j] = True
                else:
                    beats[j][i] = True

        # Check all k-subsets (brute force for small n)
        from itertools import combinations
        all_good = True
        for subset in combinations(range(n), k):
            has_king = False
            for player in subset:
                others = [p for p in subset if p != player]
                if all(beats[player][o] for o in others):
                    has_king = True
                    break
            if not has_king:
                all_good = False
                break

        if all_good:
            print(f"  ✓ Found valid tournament on trial {trial+1}")
            return beats
    print(f"  (No example found in 1000 trials)")
    return None

random_tournament_demo(8, 3)

# ── 10.4 Bipartite subgraph: max-cut ────────────────────────────
print(f"\n{'='*60}")
print(f"BIPARTITE SUBGRAPH (First Moment Method)")
print(f"{'='*60}")

def max_cut_probabilistic(adj_list, n):
    """
    Prove: every graph has a cut with >= |E|/2 edges.
    Random partition: each vertex goes to L or R with prob 1/2.
    E[crossing edges] = |E|/2, so some partition achieves >= |E|/2.
    """
    # Count edges
    num_edges = sum(len(neighbors) for neighbors in adj_list) // 2

    print(f"\nGraph: {n} vertices, {num_edges} edges")
    print(f"Claim: ∃ bipartite subgraph with ≥ {num_edges}/2 = {num_edges/2} edges")
    print(f"\nProbabilistic argument:")
    print(f"  For each edge (u,v): Pr[u,v in different parts] = 1/2")
    print(f"  E[crossing edges] = {num_edges} × 1/2 = {num_edges/2}")
    print(f"  ∴ some partition achieves ≥ {num_edges/2} crossing edges")

    # Find one by random sampling
    best_cut = 0
    best_partition = None
    num_trials = 200

    for _ in range(num_trials):
        partition = [random.randint(0, 1) for _ in range(n)]
        crossing = 0
        for u in range(n):
            for v in adj_list[u]:
                if v > u and partition[u] != partition[v]:
                    crossing += 1
        if crossing > best_cut:
            best_cut = crossing
            best_partition = partition[:]

    L_vertices = [i for i in range(n) if best_partition[i] == 0]
    R_vertices = [i for i in range(n) if best_partition[i] == 1]
    print(f"\n  Best partition found ({num_trials} trials):")
    print(f"    L = {L_vertices}")
    print(f"    R = {R_vertices}")
    print(f"    Crossing edges: {best_cut}")
    ratio = best_cut / num_edges if num_edges > 0 else 0
    check = "✓" if best_cut >= num_edges / 2 else "✗"
    print(f"    Ratio: {best_cut}/{num_edges} = {ratio:.3f} {check} (need ≥ 0.5)")

# Example: Petersen-like graph (5-cycle + pentagram)
n = 6
adj = [[] for _ in range(n)]
edges = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0),(0,3),(1,4),(2,5)]
for u, v in edges:
    adj[u].append(v)
    adj[v].append(u)
max_cut_probabilistic(adj, n)

# ── 10.5 Lovász Local Lemma illustration ─────────────────────────
print(f"\n{'='*60}")
print(f"LOVÁSZ LOCAL LEMMA (LLL)")
print(f"{'='*60}")

def lll_check(p, d, n_events):
    """
    Symmetric LLL: if ep(d+1) <= 1, then Pr[avoid all bad events] > 0.
    p = max probability of each bad event
    d = max dependency degree
    """
    e = math.e
    lll_condition = e * p * (d + 1)
    satisfied = lll_condition <= 1

    print(f"\n  n = {n_events} bad events")
    print(f"  p ≤ {p:.6f} (max Pr of each)")
    print(f"  d = {d} (max dependency degree)")
    print(f"  e·p·(d+1) = {e:.4f} × {p:.6f} × {d+1} = {lll_condition:.6f}")
    if satisfied:
        print(f"  ✓ e·p·(d+1) ≤ 1 → can avoid ALL bad events simultaneously")
    else:
        print(f"  ✗ e·p·(d+1) > 1 → LLL condition not met")
    return satisfied

# Example: k-SAT satisfiability
# n variables, m clauses of width k
# Each clause independently false with prob 2^{-k}
# Each clause shares variables with at most d others
print(f"\nApplication: Random k-SAT")
for k in [3, 4, 5, 7]:
    p = 2**(-k)           # Pr[clause unsatisfied by random assignment]
    d = 10 * k            # each clause shares var with ~O(k) others
    print(f"\n  k={k}-SAT, clause failure prob = 2^{-k} = {p:.6f}, d = {d}:")
    lll_check(p, d, 100)

# ── 10.6 First Moment Method: Ramsey bound ──────────────────────
print(f"\n{'='*60}")
print(f"FIRST MOMENT METHOD")
print(f"{'='*60}")

def ramsey_lower_bound(s, max_n=100):
    """
    Prove: R(s,s) > n for some n, using first moment method.
    Random 2-colouring of K_n edges.
    E[monochromatic K_s] = C(n,s) * 2^{1-C(s,2)}
    If E < 1, some colouring has NO monochromatic K_s.
    """
    c_s2 = s * (s - 1) // 2  # C(s,2)

    print(f"\n  Ramsey R({s},{s}) lower bound:")
    print(f"  E[mono K_{s}] = C(n,{s}) · 2^{{1-C({s},2)}} = C(n,{s}) · 2^{{1-{c_s2}}}")

    best_n = 0
    for n in range(s, max_n + 1):
        # C(n, s)
        c_ns = 1
        for i in range(s):
            c_ns = c_ns * (n - i) // (i + 1)
        expected = c_ns * 2**(1 - c_s2)
        if expected < 1:
            best_n = n

    if best_n > 0:
        # Recompute for best_n
        c_ns = 1
        for i in range(s):
            c_ns = c_ns * (best_n - i) // (i + 1)
        expected = c_ns * 2**(1 - c_s2)
        print(f"  For n = {best_n}: E = C({best_n},{s}) · 2^{{1-{c_s2}}} = {expected:.4f} < 1")
        print(f"  ✓ R({s},{s}) > {best_n} (some 2-colouring of K_{best_n} avoids mono K_{s})")
    else:
        print(f"  (bound not useful for n ≤ {max_n})")

for s in [3, 4, 5, 6]:
    ramsey_lower_bound(s, max_n=200)

# ── 10.7 AI connections: Johnson-Lindenstrauss ───────────────────
print(f"\n{'='*60}")
print(f"JL LEMMA: RANDOM PROJECTIONS PRESERVE DISTANCES")
print(f"{'='*60}")

def jl_demo(n_points, d_original, epsilon):
    """
    Johnson-Lindenstrauss: project d-dimensional points to
    k = O(log(n)/ε²) dimensions while preserving pairwise distances.
    Random Gaussian matrix works with high probability.
    """
    k = int(8 * math.log(n_points) / (epsilon**2)) + 1  # target dim

    print(f"\n  n = {n_points} points in R^{d_original}")
    print(f"  ε = {epsilon} (distortion tolerance)")
    print(f"  Target dimension: k = O(log(n)/ε²) = {k}")
    print(f"  Compression ratio: {d_original}/{k} = {d_original/k:.1f}x")

    # Generate random points
    points = []
    for i in range(n_points):
        p = [random.gauss(0, 1) for _ in range(d_original)]
        points.append(p)

    # Random Gaussian projection matrix (k × d)
    proj = []
    for i in range(k):
        row = [random.gauss(0, 1) / math.sqrt(k) for _ in range(d_original)]
        proj.append(row)

    # Project points
    projected = []
    for p in points:
        proj_p = []
        for i in range(k):
            val = sum(proj[i][j] * p[j] for j in range(d_original))
            proj_p.append(val)
        projected.append(proj_p)

    # Check pairwise distance preservation
    max_distortion = 0
    n_pairs = 0
    n_preserved = 0

    for i in range(n_points):
        for j in range(i+1, n_points):
            # Original distance
            d_orig = math.sqrt(sum((points[i][l] - points[j][l])**2
                                   for l in range(d_original)))
            # Projected distance
            d_proj = math.sqrt(sum((projected[i][l] - projected[j][l])**2
                                   for l in range(k)))

            if d_orig > 0:
                ratio = d_proj / d_orig
                distortion = abs(ratio - 1)
                max_distortion = max(max_distortion, distortion)
                n_pairs += 1
                if distortion <= epsilon:
                    n_preserved += 1

    pct = 100 * n_preserved / n_pairs if n_pairs > 0 else 0
    check = "✓" if pct > 90 else "~"
    print(f"\n  Pairwise distance preservation:")
    print(f"    Pairs checked: {n_pairs}")
    print(f"    Pairs within (1±ε): {n_preserved}/{n_pairs} ({pct:.1f}%) {check}")
    print(f"    Max distortion: {max_distortion:.4f}")
    print(f"    Target ε: {epsilon}")
    print(f"  {'✓' if pct > 90 else '~'} Random projection preserves geometry")

jl_demo(n_points=20, d_original=200, epsilon=0.5)

11. Counting Arguments & Combinatorial Proofs

11.1 The Strategy

Prove identities or bounds by counting the same set in two different ways (double counting), or by establishing bijections between sets.

A=B     bijection f:AB|A| = |B| \iff \exists \text{ bijection } f: A \to B

11.2 Double Counting (Handshaking Lemma)

Theorem: In any graph G=(V,E)G = (V,E): vVdeg(v)=2E\sum_{v \in V} \deg(v) = 2|E|

Proof: Count incidences (vertex, edge) where vertex is endpoint of edge.

  • Counting by vertex: each vertex vv contributes deg(v)\deg(v) incidences → total =vdeg(v)= \sum_v \deg(v)
  • Counting by edge: each edge contributes 2 incidences (one per endpoint) → total =2E= 2|E|
  • Both count the same set, so vdeg(v)=2E\sum_v \deg(v) = 2|E| \square

11.3 Vandermonde's Identity

(m+nr)=k=0r(mk)(nrk)\binom{m+n}{r} = \sum_{k=0}^{r} \binom{m}{k}\binom{n}{r-k}

Combinatorial proof: Choose rr items from mm "red" and nn "blue" items. LHS counts directly. RHS: pick kk red and rkr-k blue for each valid kk.

11.4 Pigeonhole Principle

If n+1n+1 objects are placed in nn boxes, at least one box contains 2\geq 2 objects.

A>kB    bB:f1(b)>k|A| > k \cdot |B| \implies \exists b \in B: |f^{-1}(b)| > k

AI application: In a neural network with nn weights quantised to kk levels, if kn<functions to representk^n < |\text{functions to represent}|, some functions share the same weight configuration.

11.5 Inclusion-Exclusion

i=1nAi=iAii<jAiAj+i<j<kAiAjAk\left|\bigcup_{i=1}^n A_i\right| = \sum_{i}|A_i| - \sum_{i<j}|A_i \cap A_j| + \sum_{i<j<k}|A_i \cap A_j \cap A_k| - \cdots

Used to compute exact probabilities when events overlap, e.g., derangement counting.

11.6 AI Application: VC Dimension via Shattering

A hypothesis class H\mathcal{H} shatters a set SS of dd points if for every labelling of SS (all 2d2^d labellings), some hHh \in \mathcal{H} achieves it.

VC-dim(H)=max{d:S,S=d,H shatters S}\text{VC-dim}(\mathcal{H}) = \max\{d : \exists S, |S|=d, \mathcal{H} \text{ shatters } S\}

The Sauer-Shelah lemma (counting argument):

{hS:hH}i=0d(Si)|\{h|_S : h \in \mathcal{H}\}| \leq \sum_{i=0}^{d} \binom{|S|}{i}

Code cell 33

# ══════════════════════════════════════════════════════════════════
# SECTION 11: COUNTING ARGUMENTS & COMBINATORIAL PROOFS
# ══════════════════════════════════════════════════════════════════

# ── 11.2 Double counting: Handshaking lemma ──────────────────────
print("DOUBLE COUNTING: HANDSHAKING LEMMA")
print("="*55)

def verify_handshaking(n, edge_list):
    """Verify sum of degrees = 2|E| by counting incidences two ways."""
    # Build adjacency
    deg = [0] * n
    for u, v in edge_list:
        deg[u] += 1
        deg[v] += 1

    sum_deg = sum(deg)
    twice_edges = 2 * len(edge_list)

    print(f"\n  Graph: {n} vertices, {len(edge_list)} edges")
    print(f"  Degrees: {deg}")
    print(f"  Sum of degrees   = {sum_deg}")
    print(f"  2 × |E|          = {twice_edges}")
    check = "✓" if sum_deg == twice_edges else "✗"
    print(f"  {check} Handshaking lemma verified: Σdeg(v) = 2|E|")

    # Corollary: number of odd-degree vertices is even
    odd_count = sum(1 for d in deg if d % 2 == 1)
    check2 = "✓" if odd_count % 2 == 0 else "✗"
    print(f"  Odd-degree vertices: {odd_count} (must be even) {check2}")

# Test on several graphs
# Triangle
verify_handshaking(3, [(0,1),(1,2),(0,2)])
# Path graph P4
verify_handshaking(4, [(0,1),(1,2),(2,3)])
# Complete K5
k5_edges = [(i,j) for i in range(5) for j in range(i+1,5)]
verify_handshaking(5, k5_edges)

# ── 11.3 Vandermonde's identity ──────────────────────────────────
print(f"\n{'='*55}")
print(f"VANDERMONDE'S IDENTITY: C(m+n, r) = Σ C(m,k)·C(n,r-k)")
print(f"{'='*55}")

def comb(n, k):
    """Compute C(n, k) using raw multiplication."""
    if k < 0 or k > n:
        return 0
    if k == 0 or k == n:
        return 1
    result = 1
    for i in range(min(k, n - k)):
        result = result * (n - i) // (i + 1)
    return result

def verify_vandermonde(m, n, r):
    """Verify C(m+n, r) = sum_{k=0}^{r} C(m,k) * C(n, r-k)."""
    lhs = comb(m + n, r)

    # RHS: sum over k
    terms = []
    rhs = 0
    for k in range(r + 1):
        term = comb(m, k) * comb(n, r - k)
        if term > 0:
            terms.append(f"C({m},{k})·C({n},{r-k})={term}")
        rhs += term

    print(f"\n  C({m}+{n}, {r}) = C({m+n}, {r}) = {lhs}")
    print(f"  Σ C({m},k)·C({n},{r}-k) = {' + '.join(terms[:6])}")
    if len(terms) > 6:
        print(f"    ... ({len(terms)} terms total)")
    print(f"  = {rhs}")
    check = "✓" if lhs == rhs else "✗"
    print(f"  {check} Vandermonde verified")

verify_vandermonde(5, 4, 3)
verify_vandermonde(10, 8, 6)
verify_vandermonde(7, 7, 7)

# ── 11.4 Pigeonhole principle ────────────────────────────────────
print(f"\n{'='*55}")
print(f"PIGEONHOLE PRINCIPLE")
print(f"{'='*55}")

def pigeonhole_demo():
    """Demonstrate pigeonhole: among n+1 numbers mod n, two share a remainder."""
    for n in [5, 10, 20]:
        numbers = [random.randint(1, 1000) for _ in range(n + 1)]
        remainders = [x % n for x in numbers]

        # Find collision
        seen = {}
        collision = None
        for i, r in enumerate(remainders):
            if r in seen:
                collision = (seen[r], i, r)
                break
            seen[r] = i

        print(f"\n  n = {n}: {n+1} numbers, {n} possible remainders mod {n}")
        print(f"  Numbers:    {numbers[:8]}{'...' if len(numbers)>8 else ''}")
        print(f"  Remainders: {remainders[:8]}{'...' if len(remainders)>8 else ''}")
        if collision:
            i, j, r = collision
            print(f"  ✓ Collision: numbers[{i}]={numbers[i]} and numbers[{j}]={numbers[j]}"
                  f" both ≡ {r} (mod {n})")

pigeonhole_demo()

# Generalized pigeonhole: quantisation
print(f"\n  AI Application: Weight Quantisation")
bits = 8
levels = 2**bits
n_weights = 1000
n_possible_configs = levels ** n_weights
print(f"  {n_weights} weights quantised to {bits} bits ({levels} levels)")
print(f"  Possible configurations: {levels}^{n_weights} = 2^{bits*n_weights}")
print(f"  If representing > 2^{bits*n_weights} functions,")
print(f"  ✓ Pigeonhole: some functions MUST share weights")

# ── 11.5 Inclusion-Exclusion: Derangements ──────────────────────
print(f"\n{'='*55}")
print(f"INCLUSION-EXCLUSION: DERANGEMENTS")
print(f"{'='*55}")

def count_derangements(n):
    """
    Count permutations with no fixed points using inclusion-exclusion.
    D_n = n! Σ_{k=0}^{n} (-1)^k / k!
    A_i = {perms fixing i}, |A_i| = (n-1)!, |A_i ∩ A_j| = (n-2)!, etc.
    """
    # factorial
    fact = [1] * (n + 1)
    for i in range(1, n + 1):
        fact[i] = fact[i-1] * i

    # Inclusion-exclusion
    d_n = 0
    print(f"\n  D_{n} via inclusion-exclusion:")
    terms = []
    for k in range(n + 1):
        sign = (-1)**k
        # C(n,k) * (n-k)! = n!/k!
        term = sign * fact[n] // fact[k]
        d_n += term
        terms.append(f"({'+' if sign > 0 else '-'}){fact[n]//fact[k]}")

    print(f"  D_{n} = {' '.join(terms[:8])}")
    print(f"       = {d_n}")

    # Verify by brute force for small n
    if n <= 8:
        from itertools import permutations
        brute = sum(1 for p in permutations(range(n))
                    if all(p[i] != i for i in range(n)))
        check = "✓" if brute == d_n else "✗"
        print(f"  Brute force: {brute} {check}")

    # Ratio to n!
    ratio = d_n / fact[n]
    print(f"  D_{n}/n! = {ratio:.6f} ≈ 1/e = {1/math.e:.6f}")

    return d_n

for n in [3, 5, 8, 10]:
    count_derangements(n)

# ── 11.6 VC dimension and shattering ─────────────────────────────
print(f"\n{'='*55}")
print(f"VC DIMENSION: SHATTERING BY LINEAR CLASSIFIERS")
print(f"{'='*55}")

def check_shattering_2d(points):
    """
    Check if 2D linear classifiers can shatter given points.
    Linear classifier: sign(w1*x1 + w2*x2 + b)
    Try all 2^n labellings and see if each is achievable.
    """
    n = len(points)
    total_labellings = 2**n
    achieved = 0

    print(f"\n  Points: {points}")
    print(f"  Total labellings to achieve: {total_labellings}")

    for labelling_int in range(total_labellings):
        labels = [(labelling_int >> i) & 1 for i in range(n)]
        # Convert to +1/-1
        y = [2*l - 1 for l in labels]

        # Try to find separating line by checking many directions
        found = False
        for angle_deg in range(0, 360, 1):
            angle = angle_deg * math.pi / 180
            w1, w2 = math.cos(angle), math.sin(angle)

            # Project points onto normal direction
            projections = [w1 * p[0] + w2 * p[1] for p in points]

            # Try different thresholds
            sorted_proj = sorted(set(projections))
            thresholds = [sorted_proj[0] - 1]
            for i in range(len(sorted_proj) - 1):
                thresholds.append((sorted_proj[i] + sorted_proj[i+1]) / 2)
            thresholds.append(sorted_proj[-1] + 1)

            for b in thresholds:
                predictions = [1 if (w1*p[0] + w2*p[1] + b) >= 0 else -1
                              for p in points]
                if predictions == y:
                    found = True
                    break
            if found:
                break

        if found:
            achieved += 1

    shattered = achieved == total_labellings
    check = "✓" if shattered else "✗"
    print(f"  Achievable labellings: {achieved}/{total_labellings}")
    print(f"  {check} {'Shattered' if shattered else 'NOT shattered'} by linear classifiers")
    return shattered

# VC-dim of linear classifiers in R² is 3
print(f"\n  Testing VC-dim of linear classifiers in R²:")

# 3 non-collinear points: should shatter
print(f"\n  --- 3 non-collinear points ---")
check_shattering_2d([(0,0), (1,0), (0,1)])

# 4 points: should NOT shatter
print(f"\n  --- 4 points (XOR configuration) ---")
check_shattering_2d([(0,0), (1,0), (0,1), (1,1)])

# Sauer-Shelah bound
print(f"\n  Sauer-Shelah Lemma:")
print(f"  For VC-dim d, max dichotomies on m points ≤ Σ_{{i=0}}^{{d}} C(m,i)")
d_vc = 3
for m in [3, 5, 10, 20, 50]:
    bound = sum(comb(m, i) for i in range(d_vc + 1))
    print(f"    m={m:3d}, d={d_vc}: bound = {bound:8d} vs 2^m = {2**m:15d}"
          f"  {'polynomial' if bound < 2**m else 'all'}")

12. Epsilon-Delta & Analytic Proofs

12.1 The ε-δ Framework

The language of analysis — make "closeness" and "limit" rigorous:

limxaf(x)=L    ε>0,δ>0:0<xa<δ    f(x)L<ε\lim_{x \to a} f(x) = L \iff \forall \varepsilon > 0, \exists \delta > 0: 0 < |x - a| < \delta \implies |f(x) - L| < \varepsilon

Key idea: The adversary picks ε\varepsilon (how close to LL), you must find δ\delta (how close xx must be to aa).

12.2 Continuity of f(x)=x2f(x) = x^2

Claim: f(x)=x2f(x) = x^2 is continuous at x=ax = a, i.e., limxax2=a2\lim_{x \to a} x^2 = a^2.

Proof: Let ε>0\varepsilon > 0. We need x2a2<ε|x^2 - a^2| < \varepsilon whenever xa<δ|x - a| < \delta.

x2a2=xax+axa(xa+2a)δ(δ+2a)|x^2 - a^2| = |x - a||x + a| \leq |x-a|(|x-a| + 2|a|) \leq \delta(\delta + 2|a|)

Choose δ=min(1,ε1+2a)\delta = \min\left(1, \frac{\varepsilon}{1 + 2|a|}\right).

Then x2a2δ(δ+2a)δ(1+2a)ε|x^2 - a^2| \leq \delta(\delta + 2|a|) \leq \delta(1 + 2|a|) \leq \varepsilon. \square

12.3 Sequence Convergence

anL    ε>0,NN:nN    anL<εa_n \to L \iff \forall \varepsilon > 0, \exists N \in \mathbb{N}: n \geq N \implies |a_n - L| < \varepsilon

Example: an=1/n0a_n = 1/n \to 0. Given ε>0\varepsilon > 0, choose N=1/ε+1N = \lceil 1/\varepsilon \rceil + 1. Then nN    1/n1/N<εn \geq N \implies 1/n \leq 1/N < \varepsilon. \square

12.4 Gradient Descent Convergence

For LL-smooth convex ff, gradient descent with η=1/L\eta = 1/L satisfies:

f(xT)f(x)Lx0x22Tf(x_T) - f(x^*) \leq \frac{L\|x_0 - x^*\|^2}{2T}

Proof sketch: At each step, smoothness gives f(xt+1)f(xt)12Lf(xt)2f(x_{t+1}) \leq f(x_t) - \frac{1}{2L}\|\nabla f(x_t)\|^2, and convexity gives f(xt)f(x)f(xt)(xtx)f(x_t) - f(x^*) \leq \nabla f(x_t)^\top(x_t - x^*). Telescope the bound over TT steps.

12.5 Compactness Arguments

A set SRnS \subset \mathbb{R}^n is compact if it is closed and bounded.

Extreme value theorem: A continuous function on a compact set attains its max and min.

Bolzano-Weierstrass: Every bounded sequence has a convergent subsequence.

AI application: Compact parameter spaces guarantee that optimisation achieves its optimum.

12.6 Fixed Point Theorems

Banach (contraction mapping): If T:XXT: X \to X is a contraction (TxTycxy,c<1\|Tx - Ty\| \leq c\|x-y\|, c<1), then TT has a unique fixed point, and iteration xn+1=T(xn)x_{n+1} = T(x_n) converges to it.

Brouwer: Every continuous map f:BnBnf: B^n \to B^n from a compact convex set to itself has a fixed point.

12.7 AI Application: Convergence Guarantees

  • SGD convergence: requires bounding E[xtx2]\mathbb{E}[\|x_t - x^*\|^2] via ε-δ-style analysis
  • GAN equilibrium: Nash equilibrium existence via Brouwer/Kakutani fixed point
  • Self-attention as contraction: if Lipschitz constant <1< 1, iterated attention converges

Code cell 35

# ══════════════════════════════════════════════════════════════════
# SECTION 12: EPSILON-DELTA & ANALYTIC PROOFS
# ══════════════════════════════════════════════════════════════════

import math

# ── 12.2 Continuity of x² via ε-δ ───────────────────────────────
print("ε-δ PROOF: CONTINUITY OF f(x) = x²")
print("="*55)

def verify_x_squared_continuity(a, epsilons):
    """
    For f(x)=x², verify that our δ formula works.
    δ = min(1, ε/(1+2|a|)) guarantees |x²-a²| < ε when |x-a| < δ.
    """
    print(f"\n  f(x) = x² at a = {a}")
    print(f"  {'ε':>12s} {'δ':>12s} {'|x²-a²| max':>14s} {'< ε?':>6s}")
    print(f"  {'-'*48}")

    for eps in epsilons:
        delta = min(1.0, eps / (1 + 2*abs(a)))

        # Sample many x in (a-δ, a+δ) and find max |x²-a²|
        max_diff = 0
        n_samples = 10000
        for i in range(n_samples):
            # x = a + t where |t| < δ
            t = delta * (2 * i / (n_samples - 1) - 1) * 0.999  # stay strictly inside
            x = a + t
            diff = abs(x**2 - a**2)
            if diff > max_diff:
                max_diff = diff

        check = "✓" if max_diff < eps else "✗"
        print(f"  {eps:12.6f} {delta:12.6f} {max_diff:14.6f} {check:>6s}")

    # Show the δ formula derivation
    print(f"\n  Derivation:")
    print(f"  |x²-a²| = |x-a|·|x+a| ≤ |x-a|·(|x-a|+2|a|)")
    print(f"  If δ ≤ 1: ≤ δ(1+2|a|)")
    print(f"  Set δ = ε/(1+2|a|): |x²-a²| ≤ ε  ✓")

verify_x_squared_continuity(3.0, [1.0, 0.1, 0.01, 0.001, 0.0001])
verify_x_squared_continuity(0.0, [1.0, 0.1, 0.01, 0.001])

# ── 12.3 Sequence convergence: 1/n → 0 ──────────────────────────
print(f"\n{'='*55}")
print(f"SEQUENCE CONVERGENCE: 1/n → 0")
print(f"{'='*55}")

def verify_sequence_convergence():
    """
    Prove a_n = 1/n → 0:
    Given ε, find N = ceil(1/ε) + 1, then n ≥ N ⟹ |1/n| < ε.
    """
    print(f"\n  Claim: a_n = 1/n → 0 as n → ∞")
    print(f"  {'ε':>10s} {'N=⌈1/ε⌉+1':>12s} {'a_N = 1/N':>12s} {'< ε?':>6s}")
    print(f"  {'-'*44}")

    for eps in [1.0, 0.5, 0.1, 0.01, 0.001, 0.0001]:
        N = math.ceil(1/eps) + 1
        a_N = 1 / N
        check = "✓" if a_N < eps else "✗"
        print(f"  {eps:10.4f} {N:12d} {a_N:12.6f} {check:>6s}")

    # More interesting: (2n+1)/(3n+5) → 2/3
    print(f"\n  Claim: b_n = (2n+1)/(3n+5) → 2/3")
    L = 2/3
    print(f"  {'ε':>10s} {'N':>6s} {'b_N':>12s} {'|b_N - 2/3|':>14s} {'< ε?':>6s}")
    print(f"  {'-'*52}")

    for eps in [0.1, 0.01, 0.001, 0.0001]:
        # |b_n - 2/3| = |(2n+1)/(3n+5) - 2/3| = |3(2n+1) - 2(3n+5)| / |3(3n+5)|
        #             = |6n+3-6n-10| / |9n+15| = 7/(9n+15)
        # Need 7/(9n+15) < ε, so n > (7/ε - 15)/9
        N = int((7/eps - 15) / 9) + 2
        if N < 1:
            N = 1
        b_N = (2*N + 1) / (3*N + 5)
        diff = abs(b_N - L)
        check = "✓" if diff < eps else "✗"
        print(f"  {eps:10.4f} {N:6d} {b_N:12.8f} {diff:14.8f} {check:>6s}")

verify_sequence_convergence()

# ── 12.4 Gradient descent convergence ────────────────────────────
print(f"\n{'='*55}")
print(f"GRADIENT DESCENT CONVERGENCE (L-smooth convex)")
print(f"{'='*55}")

def gd_convergence_demo():
    """
    Run GD on f(x) = (L/2)x² (L-smooth, convex) and verify:
    f(x_T) - f(x*) ≤ L||x_0 - x*||²/(2T)
    """
    L = 2.0                  # smoothness constant
    eta = 1.0 / L            # step size = 1/L
    x0 = 10.0                # initial point
    x_star = 0.0             # optimum
    f_star = 0.0             # f(x*)

    print(f"\n  f(x) = {L/2:.1f}x² (L={L:.1f}-smooth convex)")
    print(f"  η = 1/L = {eta:.2f}, x₀ = {x0}")
    print(f"  Bound: f(x_T) - f* ≤ L||x₀-x*||²/(2T)")
    print(f"         = {L}·{x0}²/(2T) = {L*x0**2:.1f}/T")
    print(f"\n  {'T':>6s} {'f(x_T)':>14s} {'f(x_T)-f*':>14s} {'Bound':>14s} {'OK?':>5s}")
    print(f"  {'-'*55}")

    x = x0
    for T in [1, 2, 5, 10, 20, 50, 100, 500]:
        # Reset and run T steps
        x = x0
        for _ in range(T):
            grad = L * x            # ∇f = Lx
            x = x - eta * grad      # x_{t+1} = x_t - η∇f(x_t)

        f_xT = (L/2) * x**2
        gap = f_xT - f_star
        bound = L * (x0 - x_star)**2 / (2 * T)
        check = "✓" if gap <= bound * 1.001 else "✗"  # tiny tolerance
        print(f"  {T:6d} {f_xT:14.8f} {gap:14.8f} {bound:14.4f} {check:>5s}")

    # Actual convergence rate for quadratic: x_T = x_0 * (1-ηL)^T = x_0 * 0^T = 0
    # ... except exactly: f(x_T) = (L/2)(x0(1-1))^T... linear convergence is faster
    print(f"\n  Note: for quadratic, actual convergence is geometric:")
    print(f"  x_T = x₀(1 - ηL)^T = {x0}·(1-{eta}·{L})^T = {x0}·0^T = 0 after T=1")
    print(f"  The bound L||x₀||²/(2T) = O(1/T) is tight for worst-case non-quadratic")

gd_convergence_demo()

# ── 12.6 Banach contraction: fixed point iteration ──────────────
print(f"\n{'='*55}")
print(f"BANACH CONTRACTION MAPPING THEOREM")
print(f"{'='*55}")

def contraction_fixed_point():
    """
    T(x) = cos(x) is a contraction on [0, 1] (|T'(x)| = |sin(x)| ≤ sin(1) ≈ 0.84).
    Banach guarantees unique fixed point and convergence of iteration.
    """
    print(f"\n  T(x) = cos(x)")
    print(f"  |T'(x)| = |sin(x)| ≤ sin(1) ≈ {math.sin(1):.4f} < 1 → contraction")
    print(f"\n  Fixed point iteration: x_{'{n+1}'} = cos(x_n)")

    x = 0.0  # starting point
    print(f"\n  {'n':>4s} {'x_n':>16s} {'|x_n - x_{n-1}|':>18s} {'Ratio':>10s}")
    print(f"  {'-'*52}")

    prev_diff = None
    for n in range(20):
        x_new = math.cos(x)
        diff = abs(x_new - x)

        ratio_str = ""
        if prev_diff is not None and prev_diff > 0:
            ratio = diff / prev_diff
            ratio_str = f"{ratio:.6f}"

        if n < 12 or n >= 18:
            print(f"  {n:4d} {x:16.12f} {diff:18.12f} {ratio_str:>10s}")
        elif n == 12:
            print(f"  {'...':>4s}")

        prev_diff = diff
        x = x_new

    # Verify it's a fixed point
    residual = abs(x - math.cos(x))
    print(f"\n  Fixed point: x* ≈ {x:.12f}")
    print(f"  |x* - cos(x*)| = {residual:.2e}")
    print(f"  ✓ Converged (contraction ratio ≈ sin(1) = {math.sin(1):.4f})")

    # AI analogy: iterative inference
    print(f"\n  AI Analogy: Iterative decoding / equilibrium models")
    print(f"  If inference step T is contractive (Lip < 1),")
    print(f"  then repeated application T(T(...T(x)...)) → fixed point")

contraction_fixed_point()

13. Proof Patterns in ML Theory

13.1 Union Bound (Boole's Inequality)

Pr(i=1nAi)i=1nPr(Ai)\Pr\left(\bigcup_{i=1}^n A_i\right) \leq \sum_{i=1}^n \Pr(A_i)

The simplest but most widely used tool in ML theory:

  • Generalisation bounds: bound probability that any hypothesis deviates
  • Multiple testing: control family-wise error rate
  • Convergence: bound probability that any coordinate fails

13.2 Concentration Inequalities

InequalityConditionsBound
MarkovX0X \geq 0Pr(Xt)E[X]t\Pr(X \geq t) \leq \frac{\mathbb{E}[X]}{t}
ChebyshevVar(X)<\text{Var}(X) < \inftyPr(Xμt)Var(X)t2\Pr(\|X - \mu\| \geq t) \leq \frac{\text{Var}(X)}{t^2}
HoeffdingXi[ai,bi]X_i \in [a_i, b_i] indepPr(Xˉμt)2exp ⁣(2n2t2(biai)2)\Pr(\|\bar{X} - \mu\| \geq t) \leq 2\exp\!\left(\frac{-2n^2t^2}{\sum(b_i-a_i)^2}\right)
Bernsteinbounded variancetighter than Hoeffding when variance is small

13.3 PAC Learning Framework

A concept class C\mathcal{C} is PAC-learnable if \exists algorithm AA such that \forall distribution DD, ε,δ>0\forall \varepsilon, \delta > 0:

n1ε(lnH+ln1δ)    Pr[err(h)ε]1δn \geq \frac{1}{\varepsilon}\left(\ln|\mathcal{H}| + \ln\frac{1}{\delta}\right) \implies \Pr[\text{err}(h) \leq \varepsilon] \geq 1 - \delta

Proof strategy: Fix one hypothesis → Hoeffding bound → union bound over H|\mathcal{H}| hypotheses.

13.4 Rademacher Complexity

Rn(F)=Eσ[supfF1ni=1nσif(xi)]\mathfrak{R}_n(\mathcal{F}) = \mathbb{E}_\sigma\left[\sup_{f \in \mathcal{F}} \frac{1}{n}\sum_{i=1}^n \sigma_i f(x_i)\right]

Measures how well F\mathcal{F} can fit random noise. Key bound:

E[supfFRn(f)R(f)]2Rn(F)\mathbb{E}[\sup_{f \in \mathcal{F}} |R_n(f) - R(f)|] \leq 2\mathfrak{R}_n(\mathcal{F})

13.5 Information-Theoretic Arguments

  • Fano's inequality: lower bounds on error via mutual information
  • Data processing inequality: I(X;Z)I(X;Y)I(X;Z) \leq I(X;Y) when XYZX \to Y \to Z
  • Minimax lower bounds: use KL-divergence to bound hardness of estimation

13.6 Interchange of Limits

Many ML proofs require justifying:

  • ddθf(x,θ)dx=fθdx\frac{d}{d\theta}\int f(x,\theta)dx = \int \frac{\partial f}{\partial\theta}dx (Leibniz rule)
  • limnE[Xn]=E[limXn]\lim_{n\to\infty} \mathbb{E}[X_n] = \mathbb{E}[\lim X_n] (dominated convergence)
  • θEpθ[f]=Epθ[fθlogpθ]\nabla_\theta \mathbb{E}_{p_\theta}[f] = \mathbb{E}_{p_\theta}[f \nabla_\theta \log p_\theta] (REINFORCE / score function)

13.7 Reduction Proofs

Prove problem AA is hard by reducing known-hard problem BB to AA:

B is hard+(BpA)    A is hardB \text{ is hard} + (B \leq_p A) \implies A \text{ is hard}

Used in: computational hardness of learning, NP-hardness of network training

Code cell 37

# ══════════════════════════════════════════════════════════════════
# SECTION 13: PROOF PATTERNS IN ML THEORY
# ══════════════════════════════════════════════════════════════════

import math
import random

random.seed(42)

# ── 13.1 Union Bound ────────────────────────────────────────────
print("UNION BOUND IN ACTION")
print("="*55)

def union_bound_demo():
    """
    Demonstrate: if Pr(bad_i) ≤ p for each of n events,
    then Pr(any bad) ≤ n·p by union bound.
    Verify empirically that actual failure rate ≤ union bound.
    """
    print(f"\n  Scenario: n classifiers, each fails with prob p independently")
    print(f"  Union bound: Pr(any fails) ≤ n·p")
    print(f"  Exact: Pr(any fails) = 1 - (1-p)^n")
    print(f"\n  {'n':>5s} {'p':>8s} {'Union np':>10s} {'Exact':>10s} {'Simulated':>10s}")
    print(f"  {'-'*47}")

    n_trials = 50000
    for n, p in [(5, 0.01), (10, 0.01), (50, 0.01), (100, 0.005), (10, 0.1)]:
        union = min(n * p, 1.0)
        exact = 1 - (1 - p)**n

        # Simulate
        failures = 0
        for _ in range(n_trials):
            any_fail = any(random.random() < p for _ in range(n))
            if any_fail:
                failures += 1
        simulated = failures / n_trials

        check = "✓" if simulated <= union + 0.02 else "~"
        print(f"  {n:5d} {p:8.3f} {union:10.4f} {exact:10.4f} {simulated:10.4f} {check}")

    print(f"\n  ✓ Union bound always ≥ exact probability (loose but universal)")

union_bound_demo()

# ── 13.2 Concentration inequalities ─────────────────────────────
print(f"\n{'='*55}")
print(f"CONCENTRATION INEQUALITIES: PROGRESSIVELY TIGHTER")
print(f"{'='*55}")

def concentration_comparison():
    """
    Compare Markov, Chebyshev, Hoeffding bounds for sample mean.
    X_i ~ Bernoulli(p), n samples, bound Pr(|mean - p| ≥ t).
    """
    p_true = 0.5
    n = 100
    print(f"\n  X_i ~ Bernoulli({p_true}), n = {n}")
    print(f"  μ = {p_true}, σ² = p(1-p) = {p_true*(1-p_true)}")
    print(f"  Bounding Pr(|X̄ - μ| ≥ t)")
    print(f"\n  {'t':>6s} {'Markov':>10s} {'Chebyshev':>10s} {'Hoeffding':>10s} {'Simulated':>10s}")
    print(f"  {'-'*50}")

    n_trials = 100000
    for t in [0.05, 0.1, 0.15, 0.2, 0.3]:
        # Markov: Pr(|X̄-μ| ≥ t) ... Markov applies to non-negative
        # Use Pr(X̄ ≥ μ+t) ≤ E[X̄]/(μ+t) = μ/(μ+t) -- one-sided only
        # For two-sided, use Chebyshev instead; Markov is one-sided
        markov_one_sided = min(p_true / (p_true + t), 1.0)

        # Chebyshev: Pr(|X̄-μ| ≥ t) ≤ Var(X̄)/t² = p(1-p)/(n·t²)
        var_mean = p_true * (1 - p_true) / n
        chebyshev = min(var_mean / t**2, 1.0)

        # Hoeffding: Pr(|X̄-μ| ≥ t) ≤ 2·exp(-2n·t²)  [since X_i ∈ [0,1]]
        hoeffding = min(2 * math.exp(-2 * n * t**2), 1.0)

        # Simulate
        violations = 0
        for _ in range(n_trials):
            sample_mean = sum(1 for _ in range(n) if random.random() < p_true) / n
            if abs(sample_mean - p_true) >= t:
                violations += 1
        simulated = violations / n_trials

        print(f"  {t:6.2f} {markov_one_sided:10.4f} {chebyshev:10.4f}"
              f" {hoeffding:10.6f} {simulated:10.4f}")

    print(f"\n  ✓ Hoeffding ≤ Chebyshev ≤ Markov (progressively tighter)")
    print(f"  ✓ All are valid upper bounds on simulated probability")

concentration_comparison()

# ── 13.3 PAC learning: finite hypothesis class ──────────────────
print(f"\n{'='*55}")
print(f"PAC LEARNING: SAMPLE COMPLEXITY")
print(f"{'='*55}")

def pac_sample_complexity():
    """
    For finite hypothesis class |H|:
    n ≥ (1/ε)(ln|H| + ln(1/δ)) guarantees PAC learning.
    """
    print(f"\n  Sample complexity: n ≥ (1/ε)(ln|H| + ln(1/δ))")
    print(f"\n  {'|H|':>10s} {'ε':>6s} {'δ':>6s} {'n needed':>10s}")
    print(f"  {'-'*36}")

    configs = [
        (10,       0.1, 0.05),
        (100,      0.1, 0.05),
        (1000,     0.1, 0.05),
        (1000000,  0.1, 0.05),
        (1000000,  0.01, 0.01),
        (10**9,    0.01, 0.01),
    ]

    for H_size, eps, delta in configs:
        n_needed = math.ceil((1/eps) * (math.log(H_size) + math.log(1/delta)))
        print(f"  {H_size:10d} {eps:6.2f} {delta:6.2f} {n_needed:10d}")

    # Demonstrate: ERM on finite H with enough samples
    print(f"\n  Simulation: learn threshold classifier on [0,1]")
    print(f"  H = {{h_t: predict 1 if x > t}} with 100 discrete thresholds")

    H_size = 100
    thresholds = [i / H_size for i in range(H_size)]
    true_t = 0.37  # true threshold

    for eps in [0.1, 0.05, 0.02]:
        delta = 0.05
        n = int((1/eps) * (math.log(H_size) + math.log(1/delta))) + 1

        # Generate training data
        n_success = 0
        n_reps = 500
        for _ in range(n_reps):
            xs = [random.random() for _ in range(n)]
            ys = [1 if x > true_t else 0 for x in xs]

            # ERM: find best threshold
            best_t = 0
            best_err = n + 1
            for t in thresholds:
                err = sum(1 for x, y in zip(xs, ys) if (1 if x > t else 0) != y)
                if err < best_err:
                    best_err = err
                    best_t = t

            # True error of best_t
            # Error = |best_t - true_t| (for uniform distribution on [0,1])
            true_err = abs(best_t - true_t)
            if true_err <= eps:
                n_success += 1

        pct = 100 * n_success / n_reps
        check = "✓" if pct >= (1 - delta)*100 - 2 else "~"
        print(f"    ε={eps:.2f}, n={n:4d}: success rate {pct:.1f}%"
              f" (need ≥ {(1-delta)*100:.0f}%) {check}")

pac_sample_complexity()

# ── 13.4 Rademacher complexity ───────────────────────────────────
print(f"\n{'='*55}")
print(f"RADEMACHER COMPLEXITY")
print(f"{'='*55}")

def empirical_rademacher(n_points, d, n_rademacher=500):
    """
    Estimate empirical Rademacher complexity of linear classifiers
    in R^d on n random points.
    R_n(F) = E_σ[sup_{||w||≤1} (1/n)|Σ σ_i <w, x_i>|]
           = E_σ[(1/n) ||Σ σ_i x_i||]  (for unit-ball linear class)
    """
    print(f"\n  Linear classifiers in R^{d}, n={n_points} points")

    # Generate random data points
    X = [[random.gauss(0, 1) for _ in range(d)] for _ in range(n_points)]

    total = 0.0
    for _ in range(n_rademacher):
        # Random Rademacher variables σ_i ∈ {-1, +1}
        sigma = [random.choice([-1, 1]) for _ in range(n_points)]

        # Compute Σ σ_i x_i (a d-dimensional vector)
        sum_vec = [0.0] * d
        for i in range(n_points):
            for j in range(d):
                sum_vec[j] += sigma[i] * X[i][j]

        # sup over ||w|| ≤ 1 of <w, sum_vec> = ||sum_vec||
        norm = math.sqrt(sum(v**2 for v in sum_vec))
        total += norm / n_points

    rademacher = total / n_rademacher

    # Theoretical bound for linear class: R_n ~ sqrt(d/n)
    theoretical = math.sqrt(d / n_points)

    print(f"  Empirical Rademacher: {rademacher:.4f}")
    print(f"  Theoretical O(√(d/n)): {theoretical:.4f}")
    print(f"  Ratio: {rademacher/theoretical:.3f}")
    return rademacher

print(f"\n  Rademacher vs dimension and sample size:")
for d in [5, 20, 100]:
    for n in [50, 200]:
        empirical_rademacher(n, d)

# ── 13.6 Score function / REINFORCE ──────────────────────────────
print(f"\n{'='*55}")
print(f"SCORE FUNCTION ESTIMATOR (REINFORCE)")
print(f"{'='*55}")

def reinforce_demo():
    """
    Verify the identity: ∇_θ E_{p_θ}[f(x)] = E_{p_θ}[f(x) ∇_θ log p_θ(x)]
    For Gaussian p_θ = N(θ, 1), f(x) = x².
    ∇_θ E[X²] where X ~ N(θ,1): E[X²] = θ²+1, so ∇_θ = 2θ.
    Score function: ∇_θ log p_θ(x) = x - θ.
    REINFORCE estimate: (1/N) Σ f(x_i)(x_i - θ).
    """
    theta = 3.0
    n_samples = 50000

    # Analytical gradient
    # E[X²] = θ² + 1 → d/dθ = 2θ
    analytical_grad = 2 * theta

    # REINFORCE estimate: E[X² · (X - θ)] where X ~ N(θ, 1)
    samples = [random.gauss(theta, 1.0) for _ in range(n_samples)]
    reinforce_est = sum(x**2 * (x - theta) for x in samples) / n_samples

    # Finite difference check
    h = 0.0001
    samples_plus = [random.gauss(theta + h, 1.0) for _ in range(n_samples)]
    samples_minus = [random.gauss(theta - h, 1.0) for _ in range(n_samples)]
    fd_grad = (sum(x**2 for x in samples_plus)/n_samples
               - sum(x**2 for x in samples_minus)/n_samples) / (2*h)

    print(f"\n  ∇_θ E_{{N(θ,1)}}[X²] at θ = {theta}")
    print(f"  Analytical:    2θ = {analytical_grad:.4f}")
    print(f"  REINFORCE:     {reinforce_est:.4f}")
    print(f"  Finite diff:   {fd_grad:.4f}")

    err_reinforce = abs(reinforce_est - analytical_grad)
    err_fd = abs(fd_grad - analytical_grad)
    print(f"\n  Error (REINFORCE): {err_reinforce:.4f}")
    print(f"  Error (finite diff): {err_fd:.4f}")
    print(f"  ✓ Score function identity verified numerically")
    print(f"\n  This identity underlies: REINFORCE, policy gradients,")
    print(f"  variational inference (ELBO gradient), wake-sleep algorithm")

reinforce_demo()

14. Common Proof Mistakes & Logical Fallacies

Catalogue of Errors

#MistakeExampleWhy It Fails
1Affirming the consequent"If rain → wet. Wet → rain"PQP \Rightarrow Q does not mean QPQ \Rightarrow P
2Circular reasoningAssuming what you're provingNo new information gained
3Wrong directionProving QPQ \Rightarrow P instead of PQP \Rightarrow QImplication is not symmetric
4Division by zero"Let a=ba=b, then a2b2=a(ab)a^2-b^2=a(a-b), divide by (ab)(a-b)"ab=0a-b=0, cannot divide
5Induction: wrong baseProving step but not base caseChain has no starting link
6Induction: wrong step"All horses same colour" — n=1n=1 to n=2n=2 failsOverlap assumption breaks
7Proof by example"Works for n=1,2,3n=1,2,3 therefore n\forall n"Examples \neq proof
8Assuming existence"Let xx be the largest prime..."No largest prime exists
9Confusing \forall\exists order"δε\exists \delta\, \forall \varepsilon" vs "εδ\forall \varepsilon\, \exists \delta"Quantifier order matters
10Ignoring edge casesForgetting n=0n=0, empty set, x=0x=0Proof must cover ALL cases

Code cell 39

# ══════════════════════════════════════════════════════════════════
# SECTION 14: COMMON PROOF MISTAKES & FALLACIES
# ══════════════════════════════════════════════════════════════════

# ── Interactive mistake detector ─────────────────────────────────
print("PROOF MISTAKE DETECTOR")
print("="*55)

COMMON_MISTAKES = {
    "affirming_consequent": {
        "name": "Affirming the Consequent",
        "bad":  "If P→Q and Q is true, conclude P is true",
        "fix":  "P→Q and Q gives NO information about P. Need Q→P (converse) separately.",
        "example": "If training → low loss. Low loss → must be training? NO (could be memorising)"
    },
    "circular_reasoning": {
        "name": "Circular Reasoning (Begging the Question)",
        "bad":  "Assume A to prove B, then use B to justify A",
        "fix":  "Ensure each step follows from PREVIOUSLY established facts",
        "example": "Assume SGD converges, use convergence to prove step size OK, use step size to prove convergence"
    },
    "wrong_direction": {
        "name": "Proving Wrong Direction",
        "bad":  "Asked to prove P→Q, accidentally prove Q→P",
        "fix":  "Clearly state hypothesis and conclusion BEFORE starting",
        "example": "Need: 'differentiable → continuous'. Wrong: 'continuous → differentiable' (false!)"
    },
    "division_by_zero": {
        "name": "Division by Zero / Invalid Operation",
        "bad":  "Divide both sides by expression that could be zero",
        "fix":  "Always verify denominator ≠ 0 before dividing",
        "example": "a=b → a²=ab → a²-b²=ab-b² → (a-b)(a+b)=(a-b)b → a+b=b → 2b=b → 2=1 ???"
    },
    "proof_by_example": {
        "name": "Proof by Example (Incomplete Induction)",
        "bad":  "Check n=1,2,...,10 and conclude ∀n",
        "fix":  "Examples can DISPROVE (counterexample) but never PROVE universal statements",
        "example": "f(n) = n²+n+41 is prime for n=0..39, but f(40) = 40²+40+41 = 41² is not prime"
    },
    "quantifier_swap": {
        "name": "Swapping Quantifier Order",
        "bad":  "∀x∃y P(x,y) is NOT the same as ∃y∀x P(x,y)",
        "fix":  "∀∃ = 'for each x, some y works (possibly different)'. ∃∀ = 'one y works for ALL x'",
        "example": "∀ε∃δ (continuity) vs ∃δ∀ε (nonsensical: one δ for all ε?)"
    }
}

for key, mistake in COMMON_MISTAKES.items():
    print(f"\n  ❌ {mistake['name']}")
    print(f"     Bad:  {mistake['bad']}")
    print(f"     Fix:  {mistake['fix']}")
    print(f"     AI:   {mistake['example']}")

# ── Division by zero "proof" that 1 = 2 ─────────────────────────
print(f"\n{'='*55}")
print(f"CLASSIC FALLACY: 'PROOF' THAT 1 = 2")
print(f"{'='*55}")

def trace_division_by_zero_fallacy():
    """Walk through the classic 1=2 'proof' and identify the error."""
    a, b = 5, 5  # a = b

    print(f"\n  Let a = b = {a}")
    step = 1

    val_lhs = a**2
    val_rhs = a * b
    print(f"  Step {step}: a² = ab               → {val_lhs} = {val_rhs} ✓")
    step += 1

    val_lhs = a**2 - b**2
    val_rhs = a*b - b**2
    print(f"  Step {step}: a²-b² = ab-b²          → {val_lhs} = {val_rhs} ✓")
    step += 1

    factor_lhs = f"(a-b)(a+b)"
    factor_rhs = f"(a-b)·b"
    print(f"  Step {step}: {factor_lhs} = {factor_rhs}{(a-b)*(a+b)} = {(a-b)*b} ✓")
    step += 1

    print(f"  Step {step}: Divide both sides by (a-b) = {a-b}")
    print(f"         ⚠️  THIS IS THE ERROR: a-b = {a}-{b} = 0!")
    print(f"         Cannot divide by zero!")
    step += 1

    print(f"\n  If we (illegally) proceed:")
    print(f"  Step {step}: a+b = b  → {a}+{b} = {b}{a+b} = {b}  →  2b = b  →  2 = 1 ???")
    print(f"\n  ✓ Error identified: division by (a-b) = 0 in Step 4")

trace_division_by_zero_fallacy()

# ── Proof by example failure ─────────────────────────────────────
print(f"\n{'='*55}")
print(f"PROOF BY EXAMPLE FAILURE: n²+n+41")
print(f"{'='*55}")

def is_prime(n):
    if n < 2:
        return False
    if n < 4:
        return True
    if n % 2 == 0 or n % 3 == 0:
        return False
    i = 5
    while i * i <= n:
        if n % i == 0 or n % (i + 2) == 0:
            return False
        i += 6
    return True

print(f"\n  f(n) = n² + n + 41")
print(f"  Euler noticed f(0)..f(39) are ALL prime:")
print(f"\n  {'n':>4s} {'f(n)':>8s} {'Prime?':>8s}")
print(f"  {'-'*24}")

all_prime_so_far = True
for n in range(45):
    fn = n*n + n + 41
    prime = is_prime(fn)
    if not prime:
        all_prime_so_far = False
    marker = "✓" if prime else "✗ ← COMPOSITE!"
    if n <= 5 or n >= 38:
        print(f"  {n:4d} {fn:8d} {marker:>20s}")
    elif n == 6:
        print(f"  {'...':>4s}")

print(f"\n  f(40) = 40² + 40 + 41 = 1681 = 41²")
print(f"  ✓ 40 examples of primality do NOT constitute a proof!")
print(f"  Lesson: always need rigorous argument, not just examples")

# ── Quantifier order matters ─────────────────────────────────────
print(f"\n{'='*55}")
print(f"QUANTIFIER ORDER: ∀∃ vs ∃∀")
print(f"{'='*55}")

def quantifier_demo():
    """Show that ∀x∃y P(x,y) and ∃y∀x P(x,y) are very different."""
    # P(x,y) := "y > x"
    domain = list(range(1, 6))

    print(f"\n  P(x,y) = 'y > x', domain = {domain}")

    # ∀x ∃y: y>x — for each x, find some y > x
    print(f"\n  ∀x ∃y (y > x): for each x, find a y that beats it")
    forall_exists = True
    for x in domain:
        witnesses = [y for y in domain if y > x]
        if witnesses:
            print(f"    x={x}: y={witnesses[0]} works (y>x ✓)")
        else:
            print(f"    x={x}: NO y found! ✗")
            forall_exists = False

    # ∃y ∀x: y>x — find ONE y that beats ALL x
    print(f"\n  ∃y ∀x (y > x): find one y that beats ALL x")
    exists_forall = False
    for y in domain:
        all_beaten = all(y > x for x in domain)
        status = "✓ works for all!" if all_beaten else "✗ fails"
        if y >= max(domain) - 1:
            print(f"    y={y}: {status}")
        if all_beaten:
            exists_forall = True
            break

    if not exists_forall:
        print(f"    No single y > all x (max x = {max(domain)}, max y = {max(domain)})")

    print(f"\n  ∀x∃y(y>x): {forall_exists}")
    print(f"  ∃y∀x(y>x): {exists_forall}")
    print(f"  ✓ Same predicate, different truth values — quantifier order matters!")

quantifier_demo()

15. Practice Exercises

Exercise 1 — Direct Proof

Prove: For all integers nn, if nn is odd, then n2n^2 is odd.

Hint: Write n=2k+1n = 2k + 1 and expand.

Exercise 2 — Proof by Contrapositive

Prove: If n2n^2 is even, then nn is even.

Hint: Contrapositive is "if nn is odd, then n2n^2 is odd" (Exercise 1!).

Exercise 3 — Proof by Contradiction

Prove: log23\log_2 3 is irrational.

Hint: Assume log23=p/q\log_2 3 = p/q, then 2p/q=32^{p/q} = 3, so 2p=3q2^p = 3^q. Why is this impossible?

Exercise 4 — Proof by Induction

Prove: i=1ni2=n(n+1)(2n+1)6\sum_{i=1}^n i^2 = \frac{n(n+1)(2n+1)}{6} for all n1n \geq 1.

Exercise 5 — Proof by Cases

Prove: max(a,b)=a+b+ab2\max(a, b) = \frac{a + b + |a - b|}{2} for all real a,ba, b.

Exercise 6 — Counting Argument

Show that in any group of 6 people, there are either 3 mutual friends or 3 mutual strangers.

Hint: Pigeonhole on edges from one vertex in K6K_6.

Exercise 7 — Construction

Construct a continuous function f:[0,1]Rf: [0,1] \to \mathbb{R} that is nowhere differentiable at rational points.

Hint: Use f(x)=n=0Nancos(bnπx)f(x) = \sum_{n=0}^{N} a^n \cos(b^n \pi x) (Weierstrass-type).

Exercise 8 — ε-δ

Prove: limx2(3x+1)=7\lim_{x \to 2} (3x + 1) = 7 using the ε-δ definition.

Code cell 41

# ══════════════════════════════════════════════════════════════════
# SECTION 15: PRACTICE EXERCISES — SOLUTIONS & VERIFICATION
# ══════════════════════════════════════════════════════════════════

import math
import random

# ── Exercise 1: Direct proof — odd n → odd n² ───────────────────
print("EXERCISE 1: odd n → odd n²  (Direct Proof)")
print("="*55)

def exercise1_odd_squared():
    """
    Proof: Let n be odd. Then n = 2k+1 for some integer k.
    n² = (2k+1)² = 4k²+4k+1 = 2(2k²+2k)+1, which is odd. □
    """
    print(f"\n  Proof: n = 2k+1 → n² = 4k²+4k+1 = 2(2k²+2k)+1")
    print(f"\n  Verification:")
    all_pass = True
    for n in range(-15, 16):
        if n % 2 == 1:  # n is odd
            n_sq = n * n
            k = (n - 1) // 2
            reconstructed = 2 * (2*k*k + 2*k) + 1
            ok = (n_sq % 2 == 1) and (n_sq == reconstructed)
            if not ok:
                all_pass = False
            if abs(n) <= 7:
                print(f"    n={n:3d}, k={k:2d}: n²={n_sq:4d} = 2·{2*k*k+2*k}+1 ✓")
    print(f"  {'✓' if all_pass else '✗'} All odd n in [-15,15] verified")

exercise1_odd_squared()

# ── Exercise 2: Contrapositive — n² even → n even ───────────────
print(f"\n{'='*55}")
print(f"EXERCISE 2: n² even → n even  (Contrapositive)")
print(f"{'='*55}")

def exercise2_contrapositive():
    """
    Prove by contrapositive: n odd → n² odd (which is Exercise 1!).
    Contrapositive of "n² even → n even" is "n odd → n² odd".
    """
    print(f"\n  Contrapositive: 'n odd → n² odd'")
    print(f"  This is EXACTLY Exercise 1!")
    print(f"  Since Exercise 1 is proved, the contrapositive holds. □")
    print(f"\n  Double-check (forward direction):")
    for n in range(20):
        n_sq = n * n
        if n_sq % 2 == 0:
            check = "✓" if n % 2 == 0 else "✗"
            if n <= 10:
                print(f"    n²={n_sq:4d} is even → n={n} is {'even' if n%2==0 else 'odd'} {check}")

exercise2_contrapositive()

# ── Exercise 3: Contradiction — log₂(3) irrational ──────────────
print(f"\n{'='*55}")
print(f"EXERCISE 3: log₂(3) is irrational  (Contradiction)")
print(f"{'='*55}")

def exercise3_log2_3():
    """
    Proof by contradiction:
    Assume log₂(3) = p/q with p,q ∈ Z, q>0.
    Then 2^(p/q) = 3, so 2^p = 3^q.
    But 2^p is even and 3^q is odd. Contradiction! □
    """
    print(f"\n  Assume log₂(3) = p/q (rational)")
    print(f"  Then 2^p = 3^q")
    print(f"  But 2^p is always even (divisible by 2)")
    print(f"  And 3^q is always odd (product of odd numbers)")
    print(f"  Even ≠ Odd → Contradiction! □")

    print(f"\n  Numerical evidence (rational approximations):")
    log2_3 = math.log2(3)
    print(f"  log₂(3) = {log2_3:.15f}")

    # Best rational approximations via continued fraction
    # Convergents of continued fraction for log₂(3) ≈ [1; 1, 1, 2, 2, 3, 1, ...]
    convergents = [(1,1), (2,1), (3,2), (8,5), (19,12), (65,41), (84,53)]
    for p, q in convergents:
        approx = p / q
        error = abs(approx - log2_3)
        # Check 2^p vs 3^q
        two_p = 2**p
        three_q = 3**q
        print(f"    p/q = {p}/{q} = {approx:.8f}, |error| = {error:.2e},"
              f" 2^{p}={two_p}, 3^{q}={three_q}, equal? {'✗ never' if two_p != three_q else '??'}")

exercise3_log2_3()

# ── Exercise 4: Induction — sum of squares ───────────────────────
print(f"\n{'='*55}")
print(f"EXERCISE 4: Σi² = n(n+1)(2n+1)/6  (Induction)")
print(f"{'='*55}")

def exercise4_sum_of_squares():
    """
    Induction proof:
    Base: n=1: 1² = 1 = 1·2·3/6 ✓
    Step: Assume Σ_{i=1}^k i² = k(k+1)(2k+1)/6.
          Σ_{i=1}^{k+1} i² = k(k+1)(2k+1)/6 + (k+1)²
                            = (k+1)[k(2k+1)/6 + (k+1)]
                            = (k+1)[2k²+k+6k+6]/6
                            = (k+1)(2k²+7k+6)/6
                            = (k+1)(k+2)(2k+3)/6
                            = (k+1)((k+1)+1)(2(k+1)+1)/6 ✓
    """
    print(f"\n  Base case: n=1: 1 = 1·2·3/6 = 1 ✓")
    print(f"\n  Inductive step algebra:")
    print(f"  k(k+1)(2k+1)/6 + (k+1)²")
    print(f"  = (k+1)[k(2k+1) + 6(k+1)]/6")
    print(f"  = (k+1)(2k²+7k+6)/6")
    print(f"  = (k+1)(k+2)(2k+3)/6  ✓")

    print(f"\n  Verification:")
    print(f"  {'n':>4s} {'Σi²':>8s} {'n(n+1)(2n+1)/6':>16s} {'Match':>6s}")
    print(f"  {'-'*38}")

    for n in range(1, 16):
        actual_sum = sum(i*i for i in range(1, n+1))
        formula = n * (n + 1) * (2*n + 1) // 6
        check = "✓" if actual_sum == formula else "✗"
        print(f"  {n:4d} {actual_sum:8d} {formula:16d} {check:>6s}")

exercise4_sum_of_squares()

# ── Exercise 5: Cases — max formula ──────────────────────────────
print(f"\n{'='*55}")
print(f"EXERCISE 5: max(a,b) = (a+b+|a-b|)/2  (Cases)")
print(f"{'='*55}")

def exercise5_max_formula():
    """
    Case 1: a ≥ b. Then |a-b| = a-b. RHS = (a+b+a-b)/2 = 2a/2 = a = max(a,b) ✓
    Case 2: a < b. Then |a-b| = b-a. RHS = (a+b+b-a)/2 = 2b/2 = b = max(a,b) ✓
    """
    print(f"\n  Case 1 (a≥b): |a-b|=a-b → (a+b+a-b)/2 = a = max ✓")
    print(f"  Case 2 (a<b): |a-b|=b-a → (a+b+b-a)/2 = b = max ✓")

    print(f"\n  Verification with random pairs:")
    all_ok = True
    for _ in range(12):
        a = round(random.uniform(-10, 10), 2)
        b = round(random.uniform(-10, 10), 2)
        formula_val = (a + b + abs(a - b)) / 2
        max_val = max(a, b)
        ok = abs(formula_val - max_val) < 1e-10
        if not ok:
            all_ok = False
        print(f"    a={a:7.2f}, b={b:7.2f}: max={max_val:7.2f},"
              f" formula={formula_val:7.2f} {'✓' if ok else '✗'}")

    print(f"  {'✓' if all_ok else '✗'} Formula verified for all test cases")

exercise5_max_formula()

# ── Exercise 6: Ramsey R(3,3) ≤ 6 ───────────────────────────────
print(f"\n{'='*55}")
print(f"EXERCISE 6: R(3,3) ≤ 6  (Pigeonhole + Cases)")
print(f"{'='*55}")

def exercise6_ramsey():
    """
    Among 6 people, ∃ 3 mutual friends or 3 mutual strangers.
    Fix person v. By pigeonhole (5 others, 2 colours), v has ≥ 3 of same type.
    WLOG v has ≥ 3 friends {a,b,c}.
    If any pair in {a,b,c} are friends → triangle of friends with v.
    If no pair are friends → a,b,c are mutual strangers. □
    """
    print(f"\n  Proof sketch:")
    print(f"  1. Fix vertex v in K₆. It has 5 edges to others.")
    print(f"  2. Pigeonhole: ≥ ⌈5/2⌉ = 3 edges same colour (WLOG 'friend').")
    print(f"  3. Let v be friends with a, b, c.")
    print(f"  4. If any pair (a,b), (a,c), or (b,c) are friends → friend-triangle ✓")
    print(f"  5. If none are friends → a,b,c are mutual strangers → stranger-triangle ✓")
    print(f"  6. Exhaustive cases → always find monochromatic K₃. □")

    # Brute force verification: try all 2-colourings of K₆
    from itertools import combinations
    n = 6
    edges = list(combinations(range(n), 2))
    num_edges = len(edges)  # C(6,2) = 15
    print(f"\n  Exhaustive check: all 2^{num_edges} = {2**num_edges} colourings of K₆")

    triples = list(combinations(range(n), 3))
    counterexample_found = False

    for colouring_int in range(2**num_edges):
        # Decode colouring: bit i → colour of edge i
        colours = {}
        for i, (u, v) in enumerate(edges):
            colours[(u,v)] = (colouring_int >> i) & 1
            colours[(v,u)] = colours[(u,v)]

        # Check if some triple is monochromatic
        has_mono = False
        for triple in triples:
            a, b, c = triple
            c_ab = colours[(a,b)]
            c_ac = colours[(a,c)]
            c_bc = colours[(b,c)]
            if c_ab == c_ac == c_bc:
                has_mono = True
                break

        if not has_mono:
            counterexample_found = True
            print(f"  ✗ Counterexample found! (this should NEVER happen)")
            break

    if not counterexample_found:
        print(f"  ✓ All {2**num_edges} colourings checked: EVERY one has a mono K₃")
        print(f"  ✓ R(3,3) ≤ 6 verified by exhaustive search")

exercise6_ramsey()

# ── Exercise 8: ε-δ for 3x+1 → 7 as x → 2 ─────────────────────
print(f"\n{'='*55}")
print(f"EXERCISE 8: lim(x→2) 3x+1 = 7  (ε-δ)")
print(f"{'='*55}")

def exercise8_epsilon_delta():
    """
    Need: |3x+1 - 7| < ε when |x-2| < δ.
    |3x+1-7| = |3x-6| = 3|x-2| < 3δ.
    Choose δ = ε/3. Then 3δ = ε. ✓
    """
    print(f"\n  |f(x) - L| = |3x+1 - 7| = 3|x-2|")
    print(f"  Choose δ = ε/3:")
    print(f"  |x-2| < ε/3 → 3|x-2| < ε ✓")

    print(f"\n  {'ε':>10s} {'δ = ε/3':>10s} {'max|f(x)-7|':>14s} {'< ε?':>6s}")
    print(f"  {'-'*44}")

    for eps in [1.0, 0.1, 0.01, 0.001, 0.0001]:
        delta = eps / 3
        # Maximum of |3x+1 - 7| for |x-2| < δ
        max_diff = 0
        for i in range(10000):
            t = delta * (2 * i / 9999 - 1) * 0.999
            x = 2 + t
            diff = abs(3*x + 1 - 7)
            if diff > max_diff:
                max_diff = diff
        check = "✓" if max_diff < eps else "✗"
        print(f"  {eps:10.4f} {delta:10.6f} {max_diff:14.8f} {check:>6s}")

    print(f"\n  ✓ δ = ε/3 works for all ε > 0")

exercise8_epsilon_delta()

16. Why Proof Techniques Matter for AI/ML (2025 Perspective)

The Bridge from Theory to Practice

Proof TechniqueWhere It Appears in Modern AI
Direct proofVerifying gradient correctness, showing convexity of loss functions
ConstructionDesigning approximation architectures (universal approximation), building adversarial examples
Contrapositive"If not robust → not Lipschitz" — diagnosing model failures
ContradictionProving impossibility results (no-free-lunch theorems), irrationality in numerical analysis
Proof by casesAnalysing piecewise-linear networks (ReLU), handling edge cases in algorithms
InductionLayer-wise analysis of deep networks, recurrence relations in RNNs, autoregressive generation
Strong inductionProving properties of recursive architectures (tree-LSTMs), dynamic programming correctness
Structural inductionBPE tokenisation correctness, AST-based code generation, recursive neural networks
Probabilistic methodCompressed sensing (RIP), random initialisation guarantees, JL embeddings
Counting argumentsVC dimension, Rademacher complexity, PAC bounds, model capacity
ε-δ / analyticConvergence of optimisers (SGD, Adam), stability bounds, generalisation gaps
Concentration inequalitiesHigh-probability generalisation bounds, sample complexity, confidence intervals

Key Takeaway

You cannot build reliable AI systems without understanding why they work. Proofs are not academic exercises — they are the foundation of every guarantee we make about convergence, generalisation, safety, and correctness.

References

  1. Velleman, D. J. (2019). How to Prove It: A Structured Approach. Cambridge University Press.
  2. Hammack, R. (2018). Book of Proof. Virginia Commonwealth University (free online).
  3. Alon, N. & Spencer, J. (2016). The Probabilistic Method. Wiley.
  4. Shalev-Shwartz, S. & Ben-David, S. (2014). Understanding Machine Learning. Cambridge.
  5. Boyd, S. & Vandenberghe, L. (2004). Convex Optimization. Cambridge (free online).

Notebook: 06-Proof-Techniques | Part of: Mathematics for AI/ML

Code cell 43

# ══════════════════════════════════════════════════════════════════
# SECTION 16: PROOF TECHNIQUE → AI APPLICATION SUMMARY
# ══════════════════════════════════════════════════════════════════

print("PROOF TECHNIQUES FOR AI/ML: COMPLETE REFERENCE")
print("="*65)

TECHNIQUE_MAP = [
    ("Direct Proof",       "Gradient derivations, loss convexity, forward-pass correctness"),
    ("Construction",       "Universal approximation, weight initialisation, network design"),
    ("Contrapositive",     "Robustness ↔ Lipschitz, failure diagnosis, impossibility"),
    ("Contradiction",      "No-free-lunch, lower bounds, irrationality in numerics"),
    ("Proof by Cases",     "ReLU analysis, piecewise-linear nets, activation regions"),
    ("Induction",          "Layer-wise bounds, recurrence in RNNs, sequence generation"),
    ("Strong Induction",   "Recursive architectures, tree models, DP correctness"),
    ("Structural Induction","Tokenisation, AST code gen, recursive data structures"),
    ("Probabilistic Method","Random init, compressed sensing, JL embeddings"),
    ("Counting / Combinatorial","VC dim, Rademacher, PAC bounds, model capacity"),
    ("Epsilon-Delta",      "SGD convergence, stability, fixed-point iteration"),
    ("Concentration Ineq.","Generalisation bounds, sample complexity, confidence"),
    ("Union Bound",        "Multi-hypothesis testing, uniform convergence"),
    ("REINFORCE / Score",  "Policy gradient, variational inference, ELBO gradient"),
]

print(f"\n  {'Technique':<25s}{'AI/ML Application'}")
print(f"  {'─'*25}{'─'*45}")
for technique, application in TECHNIQUE_MAP:
    print(f"  {technique:<25s}{application}")

print(f"\n{'='*65}")
print(f"  Total sections covered: 16")
print(f"  Proof techniques explored: {len(TECHNIQUE_MAP)}")
print(f"  Key message: Rigorous proof → Reliable AI")
print(f"{'='*65}")
print(f"\n  ✓ Notebook complete: 06-Proof-Techniques theory.ipynb")