| import math |
| from typing import List, Optional, Tuple, Union, Any |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import ( |
| GenerationMixin, |
| PreTrainedModel, |
| PreTrainedTokenizer |
| ) |
|
|
| try: |
| from transformers import Qwen3ForCausalLM |
| except ImportError: |
| print('Please upgrade transformers to version 4.51.0 or higher') |
|
|
| try: |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( |
| Qwen2VLImageProcessor, |
| ) |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import PatchMerger |
| except ImportError: |
| print('Please upgrade transformers to version 4.46.3 or higher') |
|
|
| from .configuration_points_gui import POINTSGUIConfig |
|
|
| try: |
| from wepoints.models import Qwen2VisionTransformerForNavitPOINTS |
| except ImportError: |
| print('Please install WePOINTS, and refer to https://github.com/WePOINTS/WePOINTS') |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast |
|
|
|
|
| class POINTSGUIModel(PreTrainedModel, GenerationMixin): |
| config_class = POINTSGUIConfig |
| _no_split_modules = [] |
| _supports_flash_attn_2 = True |
| supports_gradient_checkpointing = True |
| """Chat model for POINTSv1.5. |
| |
| Args: |
| config (POINTSChatConfigV15): The model config. |
| """ |
|
|
| def __init__(self, config: POINTSGUIConfig, **kwargs) -> None: |
| super().__init__(config) |
| config.llm_config._attn_implementation = "flash_attention_2" |
| config._attn_implementation_autoset = False |
| self.llm = Qwen3ForCausalLM(config.llm_config) |
| self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS._from_config( |
| config.vision_config, attn_implementation="flash_attention_2" |
| ) |
| self.vision_projector = PatchMerger(config.llm_config.hidden_size, |
| context_dim=1280).to(torch.bfloat16) |
| |
| def process_images(self, images: torch.Tensor, |
| image_grid_thws: List[list]) -> torch.Tensor: |
| """Obtain image features from the vision encoder. |
| |
| Args: |
| images (torch.Tensor): The input images. |
| image_grid_thws (List[list]): The grid thresholds for the images. |
| |
| Returns: |
| torch.Tensor: The image features. |
| """ |
| image_features = self.vision_encoder(images, grid_thw=image_grid_thws) |
| image_features = self.vision_projector(image_features) |
| return image_features |
| |
| def construct_prompt(self, messages: List[dict], |
| image_processor: Qwen2VLImageProcessor) -> Tuple[str, List[Image.Image], List[list]]: |
| """Construct the prompt for the chat model. |
| |
| Args: |
| messages (List[dict]): The input messages. |
| |
| Returns: |
| Tuple[str, List[Image.Image], List[list]]: |
| The prompt, images, and image grid shape. |
| """ |
| images = [] |
| image_grid_thws = [] |
| reconstructed_messages = [] |
| for message in messages: |
| role = message['role'] |
| content_from_role = '' |
| for item in message['content']: |
| if item['type'] == 'text': |
| content_from_role += item['text'] |
| elif item['type'] == 'image': |
| image_path = item['image'] |
| max_pixels = item['max_pixels'] if 'max_pixels' in item else None |
| image = Image.open(image_path).convert('RGB') |
| if max_pixels is not None: |
| |
| width, height = image.size |
| cur_image_pixels = width * height |
| if cur_image_pixels > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| new_width = math.floor(width / beta) |
| new_height = math.floor(height / beta) |
| image = image.resize((new_width, new_height)) |
| image_data = image_processor(images=image) |
| pixel_values = image_data['pixel_values'] |
| image_grid_thw = image_data['image_grid_thw'] |
| images.extend(pixel_values) |
| image_grid_thws.append(image_grid_thw) |
| seq_len = int(image_grid_thw[0][1] * image_grid_thw[0][2] / 4) |
| content_from_role += '<|vision_start|>' + '<|image_pad|>' * seq_len + '<|vision_end|>' + '\n' |
| reconstructed_messages.append({ |
| 'role': role, |
| 'content': content_from_role |
| }) |
| prompt = self.apply_chat_template(reconstructed_messages) |
| return prompt, images, image_grid_thws |
| |
| def apply_chat_template(self, messages: List[dict]) -> str: |
| """Apply the chat template to the input messages. |
| |
| Args: |
| messages (List[dict]): The input messages. |
| |
| Returns: |
| str: The prompt. |
| """ |
| role_prefix_mapping = { |
| 'user': '<|im_start|>user\n', |
| 'assistant': '<|im_start|>assistant\n', |
| 'system': '<|im_start|>system\n' |
| } |
| role = 'user' |
| prompt = '' |
| for message in messages: |
| role = message['role'] |
| content = message['content'] |
| prompt += role_prefix_mapping[role] + content + '<|im_end|>\n' |
| if role == 'user': |
| prompt += '<|im_start|>assistant\n' |
| return prompt |
|
|
| @torch.no_grad() |
| def chat(self, |
| messages: List[dict], |
| tokenizer: PreTrainedTokenizer, |
| image_processor: object, |
| generation_config: dict = None) -> str: |
| """Generate a response to the input prompt. |
| |
| Args: |
| messages (List[dict]): The input messages. |
| tokenizer (PreTrainedTokenizer): The tokenizer to use. |
| image_processor (object): The image processor to use. |
| generation_config (dict, optional): The generation config. |
| Defaults to None. |
| Returns: |
| str: The generated response. |
| """ |
| prompt, images, image_grid_thws = self.construct_prompt( |
| messages, image_processor |
| ) |
| images = np.array(images) |
| images = torch.from_numpy(images).to(self.vision_encoder.device).to(self.vision_encoder.dtype) |
| image_grid_thws = np.concatenate(image_grid_thws, axis=0) |
| image_grid_thws = ( |
| torch.from_numpy(image_grid_thws) |
| .cuda() |
| .long() |
| ) |
| image_features = self.vision_encoder(images, grid_thw=image_grid_thws) |
| |
| image_features = self.vision_projector(image_features) |
| model_inputs = tokenizer(prompt, return_tensors='pt') |
| input_ids = model_inputs['input_ids'].to(self.device) |
| attention_mask = model_inputs['attention_mask'].to(self.device) |
| |
| eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| |
| image_token_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") |
| generation_config.update( |
| { |
| 'eos_token_id': eos_token_id, |
| } |
| ) |
| outputs = self.generate( |
| input_ids=input_ids, |
| image_grid_thws=image_grid_thws, |
| attention_mask=attention_mask, |
| image_features=[image_features], |
| image_token_id=image_token_id, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
| return response |
| |
| def _split_input_ids(self, input_ids, special_token): |
| special_pos = input_ids == special_token |
| pos = (special_pos[:-1] != special_pos[1:]).nonzero() + 1 |
| if pos.shape[0] % 2 != 0: |
| pos = torch.cat([torch.tensor([[0]]).to(pos.device), pos]) |
| pos = pos.reshape(-1, 2).tolist() |
| return pos |
|
|
| def generate(self, |
| input_ids: torch.LongTensor, |
| image_grid_thws: torch.LongTensor, |
| attention_mask: torch.LongTensor, |
| image_features: List[torch.Tensor], |
| image_token_id: int, |
| generation_config: Optional[dict] = None, |
| output_hidden_states: Optional[bool] = None, |
| **generate_kwargs) -> torch.LongTensor: |
| input_embeddings = self.llm.model.embed_tokens(input_ids) |
| batch_size = input_ids.shape[0] |
| assert len(image_features) == batch_size |
| for i in range(batch_size): |
| pos = self._split_input_ids(input_ids[i], image_token_id) |
| assert len(pos) == len(image_grid_thws) |
| image_pos = [ |
| int(image_grid_thw[1] * image_grid_thw[2] / 4) |
| for image_grid_thw in image_grid_thws |
| ] |
| image_pos.insert(0, 0) |
| image_pos = np.cumsum(image_pos) |
| for j, (start, end) in enumerate(pos): |
| input_embeddings[i, start:end] = \ |
| image_features[i][image_pos[j]:image_pos[j+1]] |
| outputs = self.llm.generate( |
| inputs_embeds=input_embeddings, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| use_cache=True, |
| **generate_kwargs |
| ) |
| return outputs |
|
|
| def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
| """ |
| Encodes images into continuous embeddings that can be forwarded to the language model. |
| |
| Args: |
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): |
| The tensors corresponding to the input images. |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
| The temporal, height and width of feature shape of each image in LLM. |
| """ |
| pixel_values = pixel_values.type(self.visual.dtype) |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() |
| image_embeds = torch.split(image_embeds, split_sizes) |
| return image_embeds |
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Optional[Any], |
| ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if inputs_embeds is None: |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) |
| if pixel_values is not None: |
| image_embeds = self.process_images(pixel_values, image_grid_thw) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| image_mask = ( |
| (input_ids == self.config.image_token_id) |
| .unsqueeze(-1) |
| .expand_as(inputs_embeds) |
| .to(inputs_embeds.device) |
| ) |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| |
| |
| outputs = self.llm.forward( |
| input_ids=None, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| output_hidden_states=output_hidden_states, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=labels, |
| use_cache=True, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| return Qwen2VLCausalLMOutputWithPast( |
| loss=outputs.loss, |
| logits=outputs.logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions |
| ) |
|
|