Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForVision2Seq, AutoTokenizer | |
| from transformers.utils import ModelOutput | |
| def use_default(value, default): | |
| """Utility: return value if not None, else default.""" | |
| return value if value is not None else default | |
| # Prompt templates for different models and tasks | |
| PROMPT_TEMPLATE_ENCODE = ( | |
| "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " | |
| "quantity, text, spatial relationships of the objects and background:<|eot_id|>" | |
| "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" | |
| ) | |
| PROMPT_TEMPLATE_ENCODE_V2 = ( | |
| "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, " | |
| "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" | |
| "<|im_start|>user\n{}<|im_end|>" | |
| ) | |
| NEGATIVE_PROMPT = ( | |
| "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, " | |
| "bad hands, bad teeth, bad eyes, bad limbs, distortion" | |
| ) | |
| PROMPT_TEMPLATE = { | |
| "dit-llm-encode": { | |
| "template": PROMPT_TEMPLATE_ENCODE, | |
| "crop_start": 36, | |
| }, | |
| "dit-llm-encode-v2": { | |
| "template": PROMPT_TEMPLATE_ENCODE_V2, | |
| "crop_start": 34, | |
| }, | |
| } | |
| def load_text_encoder( | |
| text_encoder_type, | |
| text_encoder_precision=None, | |
| text_encoder_path=None, | |
| infer_mode="encoder", | |
| logger=None, | |
| device=None | |
| ): | |
| """ | |
| Load a text encoder model from pretrained weights. | |
| Args: | |
| text_encoder_type (str): Type of text encoder. | |
| text_encoder_precision (str, optional): Precision for model weights. | |
| text_encoder_path (str, optional): Path to pretrained weights. | |
| infer_mode (str): "encoder" or "decoder". | |
| logger (logging.Logger, optional): Logger for info. | |
| device (torch.device, optional): Device to move model to. | |
| Returns: | |
| model (nn.Module): Loaded text encoder. | |
| model_path (str): Path to model. | |
| """ | |
| if logger is not None: | |
| logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}") | |
| if text_encoder_type == 'llm': | |
| text_encoder = AutoModelForVision2Seq.from_pretrained( | |
| text_encoder_path, | |
| torch_dtype="auto" | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") | |
| text_encoder.requires_grad_(False) | |
| if logger is not None: | |
| logger.info(f"Text encoder to dtype: {text_encoder.dtype}") | |
| if device is not None: | |
| text_encoder = text_encoder.to(device) | |
| return text_encoder, text_encoder_path | |
| def load_tokenizer( | |
| tokenizer_type, | |
| tokenizer_path=None, | |
| padding_side="right", | |
| logger=None | |
| ): | |
| """ | |
| Load a tokenizer from pretrained weights. | |
| Args: | |
| tokenizer_type (str): Type of tokenizer. | |
| tokenizer_path (str, optional): Path to pretrained tokenizer. | |
| padding_side (str): Padding side for tokenizer. | |
| logger (logging.Logger, optional): Logger for info. | |
| Returns: | |
| tokenizer: Loaded tokenizer. | |
| tokenizer_path (str): Path to tokenizer. | |
| """ | |
| if logger is not None: | |
| logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") | |
| if tokenizer_type == "llm": | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, use_fast=False, padding_side=padding_side, trust_remote_code=True) | |
| else: | |
| raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") | |
| return tokenizer, tokenizer_path | |
| class TextEncoderModelOutput(ModelOutput): | |
| """ | |
| Output for text encoder models. | |
| Args: | |
| hidden_state (torch.FloatTensor): Output hidden states of the last layer. | |
| attention_mask (torch.LongTensor, optional): Attention mask for valid tokens. | |
| hidden_states_list (tuple(torch.FloatTensor), optional): All hidden states if requested. | |
| text_outputs (list, optional): Decoded texts if requested. | |
| """ | |
| hidden_state: torch.FloatTensor = None | |
| attention_mask: Optional[torch.LongTensor] = None | |
| hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None | |
| text_outputs: Optional[list] = None | |
| class TextEncoder(nn.Module): | |
| """ | |
| TextEncoder wraps a pretrained text encoder and tokenizer for flexible text encoding. | |
| Args: | |
| text_encoder_type (str): Type of text encoder. | |
| max_length (int): Maximum sequence length. | |
| text_encoder_precision (str, optional): Precision for model weights. | |
| text_encoder_path (str, optional): Path to pretrained weights. | |
| tokenizer_type (str, optional): Type of tokenizer. | |
| tokenizer_path (str, optional): Path to pretrained tokenizer. | |
| output_key (str, optional): Output key for model output. | |
| use_attention_mask (bool): Whether to use attention mask. | |
| infer_mode (str): "encoder" or "decoder". | |
| input_max_length (int, optional): Max input length. | |
| prompt_template (dict, optional): Prompt template for image. | |
| prompt_template_video (dict, optional): Prompt template for video. | |
| hidden_state_skip_layer (int, optional): Skip layers from last for hidden state. | |
| apply_final_norm (bool): Whether to apply final layer norm. | |
| reproduce (bool): Deterministic output if True. | |
| logger (logging.Logger, optional): Logger for info. | |
| device (torch.device, optional): Device to move model to. | |
| """ | |
| def __init__( | |
| self, | |
| text_encoder_type: str, | |
| max_length: int, | |
| text_encoder_precision: Optional[str] = None, | |
| text_encoder_path: Optional[str] = None, | |
| tokenizer_type: Optional[str] = None, | |
| tokenizer_path: Optional[str] = None, | |
| output_key: Optional[str] = None, | |
| use_attention_mask: bool = True, | |
| infer_mode: str = "encoder", | |
| input_max_length: Optional[int] = None, | |
| prompt_template: Optional[dict] = None, | |
| prompt_template_video: Optional[dict] = None, | |
| hidden_state_skip_layer: Optional[int] = None, | |
| apply_final_norm: bool = False, | |
| reproduce: bool = False, | |
| logger=None, | |
| device=None, | |
| ): | |
| super().__init__() | |
| self.text_encoder_type = text_encoder_type | |
| self.max_length = max_length | |
| self.precision = text_encoder_precision | |
| self.model_path = text_encoder_path | |
| self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type | |
| self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path | |
| self.use_attention_mask = use_attention_mask | |
| self.input_max_length = input_max_length if input_max_length is not None else max_length | |
| self.prompt_template = dict(prompt_template) if prompt_template is not None else None | |
| self.prompt_template_video = dict(prompt_template_video) if prompt_template_video is not None else None | |
| self.hidden_state_skip_layer = hidden_state_skip_layer | |
| self.apply_final_norm = apply_final_norm | |
| self.infer_mode = infer_mode | |
| self.reproduce = reproduce | |
| self.logger = logger | |
| self.use_template = self.prompt_template is not None | |
| if self.use_template: | |
| assert isinstance(self.prompt_template, dict) and "template" in self.prompt_template, ( | |
| f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" | |
| ) | |
| if self.prompt_template_video is not None: | |
| assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, ( | |
| f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" | |
| ) | |
| assert '{}' in str(self.prompt_template["template"]), ( | |
| "`prompt_template['template']` must contain a placeholder `{}` for the input text, " | |
| f"got {self.prompt_template['template']}" | |
| ) | |
| if infer_mode == "decoder": | |
| assert text_encoder_type in ["llava-llama-3-8b"], ( | |
| f"Unsupported text encoder type for infer_mode='decoder': {text_encoder_type}" | |
| ) | |
| assert self.prompt_template is not None and hidden_state_skip_layer is not None, ( | |
| f"`prompt_template` and `hidden_state_skip_layer` must be provided for infer_mode='decoder', " | |
| f"got prompt_template={self.prompt_template}, hidden_state_skip_layer={self.hidden_state_skip_layer}" | |
| ) | |
| if "t5" in text_encoder_type: | |
| self.output_key = output_key or "last_hidden_state" | |
| elif "clip" in text_encoder_type: | |
| self.output_key = output_key or "pooler_output" | |
| elif any(x in text_encoder_type for x in ["llm"]): | |
| self.output_key = output_key or ("last_hidden_state" if infer_mode == "encoder" else None) | |
| else: | |
| raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") | |
| self.model, self.model_path = load_text_encoder( | |
| text_encoder_type=self.text_encoder_type, | |
| text_encoder_precision=self.precision, | |
| text_encoder_path=self.model_path, | |
| infer_mode=self.infer_mode, | |
| logger=self.logger, | |
| device=device | |
| ) | |
| self.dtype = self.model.dtype | |
| self.device = self.model.device | |
| padding_side = "right" if self.infer_mode == "encoder" else "left" | |
| self.tokenizer, self.tokenizer_path = load_tokenizer( | |
| tokenizer_type=self.tokenizer_type, | |
| tokenizer_path=self.tokenizer_path, | |
| padding_side=padding_side, | |
| logger=self.logger | |
| ) | |
| def __repr__(self): | |
| return f"{self.text_encoder_type} ({self.precision} - {self.model_path})" | |
| def apply_text_to_template(text, template, prevent_empty_text=True): | |
| """ | |
| Apply text to a prompt template. | |
| Args: | |
| text (str): Input text. | |
| template (str or list): Template string or list of chat conversation. | |
| prevent_empty_text (bool): If True, prevent empty user text by adding a space. | |
| Returns: | |
| str or list: Text with template applied. | |
| """ | |
| if isinstance(template, str): | |
| return template.format(text) | |
| elif isinstance(template, list): | |
| conversation = deepcopy(template) | |
| for message in conversation: | |
| if '{}' in message.get("content", ""): | |
| filled_text = message["content"].format(text) | |
| if prevent_empty_text and len(filled_text) == 0: | |
| filled_text = ' ' | |
| message["content"] = filled_text | |
| break # Only one placeholder per conversation | |
| return conversation | |
| else: | |
| raise TypeError(f"Unsupported template type: {type(template)}") | |
| def text2tokens(self, text, data_type='image'): | |
| """ | |
| Tokenize the input text, optionally applying a prompt template. | |
| Args: | |
| text (str or list): Input text. | |
| data_type (str): 'image' or 'video'. | |
| Returns: | |
| dict: Tokenized input. | |
| """ | |
| tokenize_input_type = 'str' | |
| if self.use_template: | |
| if data_type == 'image': | |
| prompt_template = self.prompt_template["template"] | |
| elif data_type == 'video': | |
| prompt_template = self.prompt_template_video["template"] | |
| else: | |
| raise ValueError(f"Unsupported data type: {data_type}") | |
| if isinstance(text, (list, tuple)): | |
| text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text] | |
| if isinstance(text[0], list): | |
| tokenize_input_type = 'list' | |
| elif isinstance(text, str): | |
| text = self.apply_text_to_template(text, prompt_template) | |
| if isinstance(text, list): | |
| tokenize_input_type = 'list' | |
| else: | |
| raise TypeError(f"Unsupported text type: {type(text)}") | |
| kwargs = dict(truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") | |
| if tokenize_input_type == 'str': | |
| return self.tokenizer( | |
| text, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_attention_mask=True, | |
| **kwargs, | |
| ) | |
| elif tokenize_input_type == 'list': | |
| return self.tokenizer.apply_chat_template( | |
| text, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") | |
| def encode( | |
| self, | |
| batch_encoding, | |
| use_attention_mask=None, | |
| output_hidden_states=False, | |
| do_sample=None, | |
| hidden_state_skip_layer=None, | |
| return_texts=False, | |
| data_type='image', | |
| device=None | |
| ): | |
| """ | |
| Encode tokenized input using the text encoder. | |
| Args: | |
| batch_encoding (dict): Batch encoding from tokenizer. | |
| use_attention_mask (bool, optional): Whether to use attention mask. | |
| output_hidden_states (bool): Whether to output all hidden states. | |
| do_sample (bool, optional): Whether to sample from the model (for decoder-only LLMs). | |
| hidden_state_skip_layer (int, optional): Number of layers to skip from last for hidden state. | |
| return_texts (bool): Whether to return decoded texts. | |
| data_type (str): 'image' or 'video'. | |
| device (torch.device, optional): Device to use. | |
| Returns: | |
| TextEncoderModelOutput: Encoded output. | |
| """ | |
| use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) | |
| hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) | |
| do_sample = use_default(do_sample, not self.reproduce) | |
| if self.infer_mode == "encoder": | |
| attention_mask = batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None | |
| if 'Gemma2' in self.text_encoder_type: | |
| input_ids = batch_encoding["input_ids"].to(self.model.device) | |
| _, inputs_embeds, labels, attention_mask = self.model.merge_multimodal( | |
| text_input_ids=input_ids, | |
| text_attention_masks=attention_mask, | |
| text_labels=None, | |
| pixel_values=[None] | |
| ) | |
| outputs = self.model.llm(inputs_embeds=inputs_embeds, labels=labels, attention_mask=attention_mask) | |
| else: | |
| outputs = self.model( | |
| input_ids=batch_encoding["input_ids"].to(self.model.device), | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, | |
| ) | |
| if hidden_state_skip_layer is not None: | |
| last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] | |
| # Apply final norm for intermediate layers if requested | |
| if hidden_state_skip_layer > 0 and self.apply_final_norm: | |
| last_hidden_state = self.model.final_layer_norm(last_hidden_state) | |
| else: | |
| last_hidden_state = outputs[self.output_key] | |
| # Remove hidden states of instruction tokens, only keep prompt tokens. | |
| if self.use_template: | |
| if data_type == 'image': | |
| crop_start = self.prompt_template.get("crop_start", -1) | |
| elif data_type == 'video': | |
| crop_start = self.prompt_template_video.get("crop_start", -1) | |
| else: | |
| raise ValueError(f"Unsupported data type: {data_type}") | |
| if crop_start > 0: | |
| last_hidden_state = last_hidden_state[:, crop_start:] | |
| attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None | |
| if output_hidden_states: | |
| return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states) | |
| return TextEncoderModelOutput(last_hidden_state, attention_mask) | |
| elif self.infer_mode == "decoder": | |
| # Remove leading padding tokens | |
| input_max_valid_tokens = batch_encoding["attention_mask"].sum(dim=1).max().item() | |
| if input_max_valid_tokens < batch_encoding["attention_mask"].shape[1]: | |
| batch_encoding = { | |
| "input_ids": batch_encoding["input_ids"][:, -input_max_valid_tokens:], | |
| "attention_mask": batch_encoding["attention_mask"][:, -input_max_valid_tokens:], | |
| } | |
| # Generate text from the model. | |
| outputs = self.model.generate( | |
| input_ids=batch_encoding["input_ids"].to(self.model.device), | |
| attention_mask=batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None, | |
| max_new_tokens=self.max_length, | |
| do_sample=do_sample, | |
| return_dict_in_generate=True, | |
| output_hidden_states=True, | |
| stop_strings='<|eot_id|>', tokenizer=self.tokenizer, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Concatenate hidden states from all generated tokens. | |
| hidden_states = torch.cat([ | |
| per_token_hidden_states[-(hidden_state_skip_layer + 1)] | |
| for per_token_hidden_states in outputs.hidden_states[1:] | |
| ], dim=1) | |
| if self.apply_final_norm: | |
| hidden_states = self.model.final_layer_norm(hidden_states) | |
| # Make sequence mask from output sequences | |
| output_max_valid_tokens = hidden_states.shape[1] | |
| attention_mask = (outputs.sequences[:, -output_max_valid_tokens - 1:-1] != self.tokenizer.eos_token_id).long() | |
| if return_texts: | |
| text_outputs = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) | |
| return TextEncoderModelOutput(hidden_states, attention_mask, None, text_outputs) | |
| else: | |
| return TextEncoderModelOutput(hidden_states, attention_mask) | |
| else: | |
| raise ValueError(f"Unsupported text encoder infer mode: {self.infer_mode}") | |
| def forward( | |
| self, | |
| text, | |
| use_attention_mask=None, | |
| output_hidden_states=False, | |
| do_sample=False, | |
| hidden_state_skip_layer=None, | |
| return_texts=False | |
| ): | |
| """ | |
| Forward pass: encode text to hidden states. | |
| Args: | |
| text (str or list): Input text. | |
| use_attention_mask (bool, optional): Whether to use attention mask. | |
| output_hidden_states (bool): Whether to output all hidden states. | |
| do_sample (bool): Whether to sample from the model (for decoder-only LLMs). | |
| hidden_state_skip_layer (int, optional): Number of layers to skip from last for hidden state. | |
| return_texts (bool): Whether to return decoded texts. | |
| Returns: | |
| TextEncoderModelOutput: Encoded output. | |
| """ | |
| batch_encoding = self.text2tokens(text) | |
| return self.encode( | |
| batch_encoding, | |
| use_attention_mask=use_attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| do_sample=do_sample, | |
| hidden_state_skip_layer=hidden_state_skip_layer, | |
| return_texts=return_texts | |
| ) | |