To enhance scalability and ease of integration for new models, we create the VLM Module class. As shown in the figure above, The current GRPO Trainer primarily handles abstract operations, such as "placing the question into the chat template" and "converting the image and prompt into input_ids". The actual implementation is delegated to the VLM Module, while the GRPO Trainer is responsible solely for calling the exposed function interfaces of the VLM Module.
To add a new model, you need to implement the following functions in the VLM Module:
Return the identifier of the model, such as "internvl", "qwen".
Return the model class of the model that is used to initialize in the GRPO Trainer. For "qwen", the model class is Qwen2_5_VLForConditionalGeneration or Qwen2VLForConditionalGeneration, and for "internvl", the model class is InternVLChatModel.
This function is called after the model and processor are initialized. You can do some post-processing here. Taking "internvl" as an example, we need to record the conv_template and num_image_token for later use, and set the img_context_token_id for the model.
Return whether the model accepts input_embedding as input while not input_ids when calling generate method.
Return the processing class of the model. For most models, AutoProcessor is typically used.
Return the keywords of the vision modules of the model. This is used to freeze the vision modules in the GRPO Trainer.
Besides input_ids and attention_mask, the model also accepts some distinct custom multimodal inputs for different VLMs when calling forward method, such as pixel_values and image_thw for "qwen", and pixel_values and image_flags for "internvl".
There may be some parameters in the custom multimodal inputs that are not used in the generate method, such as image_flags for "internvl". You need to return them in the get_non_generate_params function.
Some models may have some specific parameters for the processing_class, such as max_pixels and min_pixels for "qwen", and max_anyres_num for "internvl". You need to return them in the get_custom_processing_keywords function.
This function is used to place the prompt into the chat template. Different models may have different processing methods, so you need to implement this function according to the model.
This function is used to process the image and prompt into the format that the model accepts. The returned value should be a dict with the following keys: input_ids, attention_mask, and the custom multimodal inputs.
