In this post, I mostly wanted to deeply understand the recent GSPO algorithm developed by the amazing Qwen team and see if it holds up. I also wanted to do a sort of “You Could Have Invented GSPO” after reading Gwern’s transformers edition. For further reading, I’d recommend Nathan’s RLHF Book.

When we query a language model, we first define some “stopping” token(s) (usually something like <|eot_id|>) and do a single forward pass on our prompt. Then, we sample tokens one-by-one until we hit this end condition.

We can define our model as \(\pi_{\theta}\) and a prompt (pulled from some dataset) as \(x \sim \mathcal{D}\). We can also define the likelihood of sampling some completion or sequence of tokens \(y\) given a prompt \(x\) as

\[\pi_\theta(y|x) = \prod_{t=1}^{|y|}\pi_\theta(y_t|x,y_{< t})\]

While dealing with RL, it is useful to think about our model in terms of before and after we apply some training step. Let’s zoom into the lowest level, where we imagine sampling a single token \(t\) both before and after taking one gradient step.

\[w_t(\theta) = \frac{\pi_\theta(y_t|x,y_{< t})}{\pi_{\theta_{old}}(y_t|x,y_{< t})}\] \[\frac{1}{G} \sum_{i=1}^G \left(\min\left(w_{i,t}A_i, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon\right)A_i\right) - \beta\mathbb{D}_{KL}(\pi_\theta||\pi_{ref})\right)\] \[A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \cdots, r_G\})}{\text{std}(\{r_1, r_2, \cdots, r_G\})}\] \[\text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon\right)A_i\]

GLM Version (TODO):

\[\frac{1}{G} \sum_{i=1}^G \left(r(x, y_i) - \bar{r}(x)\right), \bar{r}(x) = \frac{1}{g} \sum_{i=1}^g r(x, y_i)\]