Knowledge Distillation

2 minute read

Distillation was introduced by Hinton and Dean in 2015, another masterpiece from Google.

The fundamental idea is that training and inference have different requirements, so model compressoin, which is the only knowlege reference in this paper, by Dr. Rich Caruana.

The implementatoin would let student learn from teacher’s logits, on top of learning from the groundtruth label.

1 Software with Temperature

Alt text

2 Important terms

  • Trajectory: sequence of states and actions Alt text
  • Episode: the trajectory from start state to end state
  • Return: rewards summed over an entire trajectory

  • Online vs Offline RL Online: data is acuqired interactively Offline: use pre-recorded dataset

  • On vs Off- Policy This is only meaningful in the context of online RL. Offline RL always employs an off-policy learing scheme.(the training dataset trains an optimal policy irrespective of the policy used to generate data) Alt text

    Target policy: the policy our agent is aiming to learn.
    Behavior policy: the policy that is being used by the agent to select actions as it interacts with the environment. Alt text

    On Policy: target policy == behavior policy (SARSA) Alt text Off Policy: target policy != behavior policy (Q-Learning) Alt text

    On Policy may choose a safer path(assume that the cliff walking agent will at times jump over the cliff if it travels too close to the edge), and have higher rewards. Off policy may choose the optimal path (greedy approach, assuming agent act optimally in the future and NOT to jump off cliff) Alt text

Reference here

  • Model-based vs. Model-free

    Model-based, has an agent trying to understand its environment and creating a model for it based on its interactions with this environment. preferences take priority over the consequences of the actions

    model-free algorithms seek to learn the consequences of their actions through experience via algorithms such as Policy Gradient, Q-Learning, etc. It carries out an action multiple times and will adjust the policy for optimal rewards, based on the outcomes.

    Reference here

    3 Q-Learning

  • Bellman equation Alt text
    and python implementation is here
    def choose_action(state):
      action=0
      if np.random.uniform(0, 1) < epsilon:
          action = env.action_space.sample()
      else:
          action = np.argmax(Q[state, :])
      return action
    ## In the learning loop
    ## The action here is NOT always greedy step which used in the last round of Q update
    action = choose_action(current_state)
    
    Q_table[current_state, action] = (1-lr) * Q_table[current_state, action] +lr*(reward + gamma*max(Q_table[next_state,:]))
            
    

    Due to $\epsilon$-greedy policy, we may not always use the greedy policy result, so it’s off-policy learning algorithem. But in SARSA, action2, which is used to update Q, is passed into the next loop to get the new state.

  • SARSA
#Function to learn the Q-value
def update(state, state2, reward, action, action2):
    predict = Q[state, action]
    target = reward + gamma * Q[state2, action2]
    Q[state, action] = Q[state, action] + alpha * (target - predict)

    ### In the learning loop
    #Getting the next state. The action1 here is EXACTLY actions2 used to in the last round of Q udpate
    state2, reward, done, info = env.step(action1)

    #Choosing the next action
    action2 = choose_action(state2)
      
    #Learning the Q-value
    update(state1, state2, reward, action1, action2)

    state1 = state2
    action1 = action2

4 DQL

Alt text

  • experience replay The process of gathering interaction data for training the Q network
  • Q Network Keep updating Q network by backpropgate the loss function Alt text In vanilla Q learning, we have two Q values as well Q(s, a) for current state-action and $max_{a}(s_{t+1},t)$ for the best state-action pair for the next state.

  • Knowledge distillation
    using a separate network to produce a training target for another network.

    While student network are being updated at every training iteration, Only update teacher/Target Network every several iterations to avoid creating a “moving target”.

    Alt text

Tags:

Categories:

Updated: