From b76ea790a73ce9df0e5b94b813a31455d35845f4 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Thu, 14 Dec 2023 16:14:12 +0800 Subject: [PATCH 01/18] add support for llava --- models/llava/7b_config.json | 54 ++++ preprocess.py | 2 +- pretrain.py | 2 + ...ava_from_huggingface_to_tencentpretrain.py | 77 +++++ scripts/generate_llava_json_deepspeed.py | 299 ++++++++++++++++++ tencentpretrain/embeddings/__init__.py | 5 +- .../embeddings/image_text_embedding.py | 69 ++++ tencentpretrain/models/model.py | 9 + tencentpretrain/opts.py | 4 +- tencentpretrain/trainer.py | 34 +- tencentpretrain/utils/dataloader.py | 63 ++++ tencentpretrain/utils/dataset.py | 116 +++++++ 12 files changed, 728 insertions(+), 6 deletions(-) create mode 100755 models/llava/7b_config.json create mode 100755 scripts/convert_llava_from_huggingface_to_tencentpretrain.py create mode 100755 scripts/generate_llava_json_deepspeed.py create mode 100755 tencentpretrain/embeddings/image_text_embedding.py diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json new file mode 100755 index 00000000..dd44a519 --- /dev/null +++ b/models/llava/7b_config.json @@ -0,0 +1,54 @@ +{ + "embedding": ["image_text"], + "image_text_emb":{ + "vision_encoder":{ + "emb_size": 1024, + "feedforward_size": 4096, + "hidden_size": 1024, + "hidden_act": "gelu_fast", + "heads_num": 16, + "layers_num": 24, + "dropout": 0.1, + "max_seq_length": 577, + "embedding": ["patch", "pos"], + "patch_proj_bias": false, + "remove_embedding_layernorm": false, + "remove_transformer_bias": false, + "rotary_position_embedding": false, + "encoder": "transformer", + "feed_forward": "dense", + "mask": "fully_visible", + "layernorm_positioning": "pre", + "layernorm":"normal" + }, + "projection":{ + "mlp_hidden_size": 4096, + "num_mlp_layer": 2 + }, + "text":{ + "embedding": ["word"] + } + }, + "image_height": 336, + "image_width": 336, + "patch_size": 14, + "remove_embedding_combine_layernorm": true, + "emb_size": 4096, + "feedforward_size": 11008, + "hidden_size": 4096, + "hidden_act": "silu", + "heads_num": 32, + "layers_num": 32, + "dropout": 0.0, + "data_processor": "llava", + "max_seq_length": 2048, + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "rotary_position_embedding": true, + "encoder": "transformer", + "feed_forward": "gated", + "mask": "causal", + "layernorm_positioning": "pre", + "layernorm": "rms", + "target": ["lm"] + } \ No newline at end of file diff --git a/preprocess.py b/preprocess.py index 7f61c54c..1d313a00 100644 --- a/preprocess.py +++ b/preprocess.py @@ -28,7 +28,7 @@ def main(): parser.add_argument("--data_processor", choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm", "gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle", - "llm_pretrain", "llm_sft"], default="bert", + "llm_pretrain", "llm_sft", "llava"], default="bert", help="The data processor of the pretraining model.") parser.add_argument("--docs_buffer_size", type=int, default=100000, help="The buffer size of documents in memory, specific to targets that require negative sampling.") diff --git a/pretrain.py b/pretrain.py index 14c744c3..aa8bbc08 100644 --- a/pretrain.py +++ b/pretrain.py @@ -13,6 +13,8 @@ def main(): help="Path of the preprocessed dataset.") parser.add_argument("--pretrained_model_path", type=str, default=None, help="Path of the pretrained model.") + parser.add_argument("--vit_model_path", type=str, default=None, + help="Path of the Vit pretrained model.") parser.add_argument("--output_model_path", type=str, required=True, help="Path of the output model.") parser.add_argument("--config_path", type=str, default="models/bert/base_config.json", diff --git a/scripts/convert_llava_from_huggingface_to_tencentpretrain.py b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py new file mode 100755 index 00000000..a6c75292 --- /dev/null +++ b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py @@ -0,0 +1,77 @@ +import argparse +import collections +import torch +import os +import json + + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("--input_model_path", type=str, default="models/llava-v1.5-7b/", + help=".") +parser.add_argument("--output_model_path", type=str, default="models/llava-v1.5-7b.bin", + help=".") +parser.add_argument("--type", choices=["7B", "13B", "33B", "65B"], default="7B") + +args = parser.parse_args() + +model_config = {"7B" : [32, 4096, 32], + "13B": [40, 5120, 40], + "33B": [60, 6656, 52], + "65B": [80, 8192, 64] + } + +layers_num, dim, n_heads = model_config[args.type] + +files = os.listdir(args.input_model_path) +model_files = [f for f in files if f[-4:] == ".bin"] +input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files} + +with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + +output_model = collections.OrderedDict() + +def get_weight_from_name(layer_name): + return input_models[weight_map[layer_name]][layer_name] + +def unpermute(w): + return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim) + +output_model["embedding.image_text.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight") + +for i in range(layers_num): + + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = \ + unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.q_proj.weight")) + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = \ + unpermute(get_weight_from_name("model.layers." + str(i) + ".self_attn.k_proj.weight")) + + output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".self_attn.v_proj.weight") + output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".self_attn.o_proj.weight") + + output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".input_layernorm.weight") + + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.gate_proj.weight") + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.up_proj.weight") + output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".mlp.down_proj.weight") + + output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \ + get_weight_from_name("model.layers." + str(i) + ".post_attention_layernorm.weight") + +output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight") +output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight") + +output_model["embedding.image_text.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight") +output_model["embedding.image_text.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias") +output_model["embedding.image_text.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight") +output_model["embedding.image_text.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias") + +torch.save(output_model, args.output_model_path) diff --git a/scripts/generate_llava_json_deepspeed.py b/scripts/generate_llava_json_deepspeed.py new file mode 100755 index 00000000..80ae927f --- /dev/null +++ b/scripts/generate_llava_json_deepspeed.py @@ -0,0 +1,299 @@ +""" + This script provides an exmaple to wrap TencentPretrain for generation. + Given the beginning of a text, language model generates the rest. +""" +import sys +import os +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import transforms +from torchvision.io import read_image +from torchvision.io.image import ImageReadMode +import imghdr +import deepspeed + +tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(tencentpretrain_dir) + +from tencentpretrain.embeddings import * +from tencentpretrain.encoders import * +from tencentpretrain.targets import * +from tencentpretrain.utils.constants import * +from tencentpretrain.utils import * +from tencentpretrain.utils.config import load_hyperparam +from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts +from tencentpretrain.opts import deepspeed_opts +from tencentpretrain.utils.logging import init_logger +from tencentpretrain.model_loader import _load_state_dict_into_model +from tencentpretrain.utils.misc import pooling, ZeroOneNormalize + + +class LLaVaGenerate(nn.Module): + def __init__(self, args): + super(LLaVaGenerate, self).__init__() + self.args = args + self.embedding = Embedding(args) + for embedding_name in args.embedding: + tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) + self.embedding.update(tmp_emb, embedding_name) + + self.encoder = str2encoder[args.encoder](args) + self.pooling_type = args.pooling + + self.target = Target() + self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm") + print("tokenizer vocab nums:", len(args.tokenizer.vocab)) + + self.num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + + def forward(self, src_text, seg_text, src_image, seg_image, image_pos): + """ + Args: + src: [batch_size x seq_length] + tgt: [batch_size] + seg: [batch_size x seq_length] + """ + # Embedding. + src = src_text, src_image, seg_text, seg_image, image_pos + emb = self.embedding(src, None) + seg = torch.cat((seg_image, seg_text), 1) + # encoder + output = self.encoder(emb, seg) + # # Target. + output = self.target.lm.output_layer(output) + return output + + +def top_k_top_p_filtering(logits, top_k, top_p): + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = -float("Inf") + return logits + + +def load_or_initialize_parameters(args, model): + if args.pretrained_model_path is not None: + # Initialize with pretrained model. + args.logger.info("loading model from {0}".format(args.pretrained_model_path)) + keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) + args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) + args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + if args.vit_model_path is not None: + args.logger.info("loading model from {0}".format(args.vit_model_path)) + # model = _load_state_dict_into_model(model, args.vit_model_path, "embedding.image_text.vision_") + keys_info = model.load_state_dict(torch.load(args.vit_model_path, map_location="cpu"), strict=False) + args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) + args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + else: + # Initialize with normal distribution. + for n, p in list(model.named_parameters()): + if "gamma" not in n and "beta" not in n: + p.data.normal_(0, 0.02) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + infer_opts(parser) + + parser.add_argument("--top_k", type=int, default=70) + parser.add_argument("--top_p", type=float, default=0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--vit_model_path", type=str, default=None, + help="Pretrained model of Vit.") + parser.add_argument("--connector_model_path", type=str, default=None, + help="Pretrained model of Connector.") + parser.add_argument("--prompt_template", type=str, choices=["llama2", "vicuna"], + help="give the llm type to choose a prompt", default="llama2") + + tokenizer_opts(parser) + + deepspeed_opts(parser) + + log_opts(parser) + + mp_opts(parser) + + args = parser.parse_args() + + args.target = "lm" + args.batch_size = 1 + + args = load_hyperparam(args) + + args.tokenizer = str2tokenizer[args.tokenizer](args) + + args.logger = init_logger(args) + + args.pretrained_model_path = args.load_model_path + + # Load or initialize parameters. + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = LLaVaGenerate(args) + if args.pretrained_model_path: + model = _load_state_dict_into_model(model, args.pretrained_model_path) + if args.vit_model_path is not None: + model = _load_state_dict_into_model(model, args.vit_model_path, "embedding.image_text.vision_") + else: + model = LLaVaGenerate(args) + # Load or initialize parameters. + load_or_initialize_parameters(args, model) + + deepspeed.init_distributed() + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] + + rank = dist.get_rank() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.eval() + + transform = transforms.Compose([ + transforms.Resize((args.image_height, args.image_width)), + ZeroOneNormalize() + ]) + prompt_template = { + "llama2": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "vicuna": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n"} + num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 dim + seq_text = args.seq_length - num_image_tokens # 576 + outf = open(args.prediction_path, mode="w", encoding="utf-8") + input_f = open(args.test_path, mode="r", encoding="utf-8") + datas = json.load(input_f) + try: + if args.prompt_template == "llama2": + prompt_overall = prompt_template["llama2"] + elif args.prompt_template == "vicuna": + prompt_overall = prompt_template["vicuna"] + except: + args.logger.info("unsupported prompt template!") + NotImplementedError + for line_id, item in enumerate(datas): + try: + id = item["id"] + image_path = "datasets/llava/" + item["image"] + if not os.path.isfile(image_path): + continue + if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': + continue + image = read_image(image_path, ImageReadMode.RGB) + image = image.to(device) + src_image = transform(image) + except: + print("sth wrong with item{}".format(item)) + continue + + prompt_before_image = prompt_overall + " USER:" + ground_truth = [] + text_combine_id = [] + if "conversations" in item: + conversations = item["conversations"] + for i, conv in enumerate(conversations): + # 1 round + if i > 1: + continue + if i == 0: + prompt = conv["value"] + if prompt.endswith(""): + prompt_before_image = prompt_before_image + prompt.replace("","") + prompt_after_image = "\nASSISTANT:" + elif prompt.startswith(""): + prompt_before_image = prompt_before_image + "" + prompt_after_image = prompt.replace("","") + "\nASSISTANT:" + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") + break + text_combine_id = [prompt_before_image_id + prompt_after_image_id] + text_combine_seg = [seg_before_image + seg_after_image] + elif i % 2 == 0: # human + prompt = conv["value"] + prompt_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:") + ) + if text_combine_id: + text_combine_id.append(prompt_id) + text_combine_seg.append(text_combine_seg + [1] * len(prompt_id)) + else: + args.logger.info("no prompt, or prompt too long, jumping") + break + else: # gpt + ground_truth.append(conv["value"]) + else: + prompt = item["instruction"] + prompt_before_image = prompt_before_image + prompt + "" + prompt_after_image = "\nASSISTANT:" + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") + break + text_combine_id = [prompt_before_image_id + prompt_after_image_id] + text_combine_seg = [seg_before_image + seg_after_image] + + image_pos = len(prompt_before_image_id) + + image_tensor = torch.unsqueeze(src_image, 0).half() + image_seg_tensor = torch.ones(1, num_image_tokens).to(device) + image_pos = torch.LongTensor([image_pos]).to(device) + SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) + text_tensor = None + for i, prompt in enumerate(text_combine_id): + if text_tensor is None: + text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([text_combine_seg[i]]).to(device) + else: + text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([text_combine_seg[i]]).to(device)], dim=1) + + while text_tensor.shape[1] + num_image_tokens < args.seq_length: + output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos) + next_token_logits = output[0][-1] / args.temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + + text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1) + # print("next_token:", next_token) + if next_token.cpu().tolist() == SEP_ID: + break + + if rank == 0 and text_tensor is not None: + # outf.write("\t".join(line)+"\n") + tokens = [token_id.item() for token_id in text_tensor[0]] + if args.tokenizer.sp_model is not None: + generated_sentence = args.tokenizer.sp_model.decode(tokens) + else: + generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) + print(item) + print(generated_sentence) + outf.write(generated_sentence + "\n\n") diff --git a/tencentpretrain/embeddings/__init__.py b/tencentpretrain/embeddings/__init__.py index 3eff2e32..40167993 100644 --- a/tencentpretrain/embeddings/__init__.py +++ b/tencentpretrain/embeddings/__init__.py @@ -8,13 +8,14 @@ from tencentpretrain.embeddings.word_patch_embedding import WordPatchEmbedding from tencentpretrain.embeddings.speech_embedding import SpeechEmbedding from tencentpretrain.embeddings.masked_patch_embedding import MaskedPatchEmbedding +from tencentpretrain.embeddings.image_text_embedding import ImageTextEmbedding str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "seg": SegEmbedding, "sinusoidalpos": SinusoidalposEmbedding, "dual": DualEmbedding, "patch": PatchEmbedding, "word_patch": WordPatchEmbedding, "speech": SpeechEmbedding, - "masked_patch": MaskedPatchEmbedding} + "masked_patch": MaskedPatchEmbedding, "image_text": ImageTextEmbedding} __all__ = ["Embedding", "WordEmbedding", "PosEmbedding", "SegEmbedding", "SinusoidalposEmbedding", "DualEmbedding", "PatchEmbedding", "WordPatchEmbedding", "SpeechEmbedding", - "MaskedPatchEmbedding", "str2embedding"] + "MaskedPatchEmbedding", "str2embedding", "ImageTextEmbedding"] diff --git a/tencentpretrain/embeddings/image_text_embedding.py b/tencentpretrain/embeddings/image_text_embedding.py new file mode 100755 index 00000000..b5d378a9 --- /dev/null +++ b/tencentpretrain/embeddings/image_text_embedding.py @@ -0,0 +1,69 @@ +from argparse import Namespace +import torch +import torch.nn as nn +import copy + +from tencentpretrain.embeddings.embedding import Embedding +from tencentpretrain.embeddings.word_embedding import WordEmbedding +from tencentpretrain.embeddings.pos_embedding import PosEmbedding +from tencentpretrain.embeddings.patch_embedding import PatchEmbedding +from tencentpretrain.encoders import str2encoder,TransformerEncoder + +str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "patch": PatchEmbedding} + + +class ImageTextEmbedding(nn.Module): + ''' + an combination of a vision encoder and a text embedding + ''' + def __init__(self, args, vocab_size): + super(ImageTextEmbedding, self).__init__() + # Vit model for vision features + vision_encoder_args = copy.deepcopy(vars(args)) + vision_encoder_args.update(args.image_text_emb["vision_encoder"]) + vision_encoder_args = Namespace(**vision_encoder_args) + self.vision_embedding = Embedding(vision_encoder_args) + for embedding_name in vision_encoder_args.embedding: + tmp_emb = str2embedding[embedding_name](vision_encoder_args, None) + self.vision_embedding.update(tmp_emb, embedding_name) + self.vision_encoder = str2encoder[vision_encoder_args.encoder](vision_encoder_args) + + # map the output of ViT into the same space as the text features + projection_args = copy.deepcopy(vars(args)) + projection_args.update(args.image_text_emb["projection"]) + projection_args = Namespace(**projection_args) + projection_modules = [nn.Linear(vision_encoder_args.emb_size, projection_args.mlp_hidden_size)] + for _ in range(1, projection_args.num_mlp_layer): + projection_modules.append(nn.GELU()) + projection_modules.append(nn.Linear(projection_args.mlp_hidden_size, projection_args.mlp_hidden_size)) + self.projection = nn.Sequential(*projection_modules) + + # text embedding + text_args = copy.deepcopy(vars(args)) + text_args.update(args.image_text_emb["text"]) + text_args = Namespace(**text_args) + self.text_embedding = Embedding(text_args) + for embedding_name in text_args.embedding: + tmp_emb = str2embedding[embedding_name](text_args, len(args.tokenizer.vocab)) + self.text_embedding.update(tmp_emb, embedding_name) + + def forward(self, src, seg=None): + src_text, src_image, seg_text, seg_image, image_pos = src + # seg_text, seg_image, image_pos = seg + # image features + with torch.no_grad(): + emb = self.vision_embedding(src_image, seg_image) + image_emb = self.vision_encoder(emb, seg_image) + image_emb = self.projection(image_emb) + # text embedding + text_emb = self.text_embedding(src_text, seg_text) + # combine text and image + if text_emb.shape[0] == 1: + emb_combine = torch.cat((text_emb[:,:image_pos[0],:], image_emb, text_emb[:,image_pos[0]:,:]), 1) + else: + emb_combine = torch.cat((text_emb[0,:image_pos[0],:], image_emb[0], text_emb[0,image_pos[0]:,:]), 0).unsqueeze(0) + for i in range(1, text_emb.shape[0]): + tmp = torch.cat((text_emb[i,:image_pos[i],:], image_emb[i], text_emb[i,image_pos[i]:,:]), 0).unsqueeze(0) + emb_combine = torch.cat((emb_combine, tmp), 0) + # seg_combine = torch.cat((seg_image, seg_text), 1) + return emb_combine diff --git a/tencentpretrain/models/model.py b/tencentpretrain/models/model.py index f7b30d65..b7566e16 100755 --- a/tencentpretrain/models/model.py +++ b/tencentpretrain/models/model.py @@ -28,6 +28,15 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target): if self.decoder is not None and args.share_embedding: self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight + # add for llava + if "freeze_encoder" in args and args.freeze_encoder: + for name, param in self.encoder.named_parameters(): + param.requires_grad = False + for name, param in self.embedding.named_parameters(): + if "image_text.projection" not in name: + param.requires_grad = False + for name, param in self.target.named_parameters(): + param.requires_grad = False def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None): emb = self.embedding(src, seg) diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 57b03cde..6f1cd6d9 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -52,7 +52,9 @@ def model_opts(parser): help="whether use alibi position embedding.") parser.add_argument("--layer_number_scale", action="store_true", help="whether use layer number scaling.") - + # add for llava + parser.add_argument("--freeze_encoder", action="store_true", + help="whether freeze the encoder parameters.") vision_opts(parser) audio_opts(parser) diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 5ee00ac3..f95c69df 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -1,6 +1,7 @@ import os import json import time +import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel @@ -62,7 +63,9 @@ def init_model(args): for shard_file in shard_filenames: model_for_training = _load_state_dict_into_model(model_for_training, shard_file, "") else: + args.logger.info("loading: {}".format(args.pretrained_model_path)) model_for_training = _load_state_dict_into_model(model_for_training, args.pretrained_model_path, "") + args.logger.info("loaded: {}".format(args.pretrained_model_path)) if args.lora_pretrained_model_path is not None: model_for_training = _load_state_dict_into_model(model_for_training, args.lora_pretrained_model_path, "") elif args.deepspeed and args.use_mp: @@ -92,7 +95,14 @@ def init_model(args): model_for_dataloader = build_vqgan_model(args) else: model_for_dataloader = None - + # add for llava + if args.vit_model_path is not None: + args.logger.info("loading: {}".format(args.vit_model_path)) + # model_for_training = _load_state_dict_into_model(model_for_training, args.vit_model_path, "embedding.image_text.vision_") + keys_info = model_for_training.load_state_dict(torch.load(args.vit_model_path, map_location="cpu"), strict=False) + args.logger.info("loaded: {}".format(args.vit_model_path)) + args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) + args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) return model_for_training, model_for_dataloader @@ -649,12 +659,23 @@ class LlmSftTrainer(LmTrainer): pass +class LlavaTrainer(LmTrainer): + def forward_propagation(self, batch, model): + src_text, src_img, tgt, seg_text, seg_img, seg_tgt, image_pos = batch + seg = torch.cat((seg_img, seg_text), 1) + loss = model((src_text, src_img, seg_text, seg_img, image_pos), tgt, seg, tgt_seg=seg_tgt) + + self.total_loss += loss.item() + loss = loss / self.accumulation_steps + return loss + + str2trainer = {"bert": BertTrainer, "mlm": MlmTrainer, "lm": LmTrainer, "albert": AlbertTrainer, "bilm": BilmTrainer, "cls": ClsTrainer, "mt": MtTrainer, "t5": T5Trainer, "gsg": GsgTrainer, "bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer, "vit": VitTrainer, "vilt": ViltTrainer, "clip": ClipTrainer, "s2t": S2tTrainer, - "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer} + "beit": BeitTrainer, "dalle": DalleTrainer, "llm_sft": LlmSftTrainer, "llava": LlavaTrainer} def worker(local_rank, gpu_ranks, args): @@ -676,6 +697,15 @@ def worker(local_rank, gpu_ranks, args): # Build model. model_for_training, model_for_dataloader = init_model(args) + # add for llava + if global_rank == 0: + args.logger.info("model: {}".format(model_for_training)) + # for name, param in model_for_training.named_parameters(): + # args.logger.info("name: {}, requires_grad: {}".format(name, param.requires_grad)) + # import pdb + # pdb.set_trace() + # add for llava end + # Build optimizer. custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training) diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index b7269518..376e72d1 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -963,3 +963,66 @@ def __iter__(self): yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) + + +class LlavaDataloader(VisionDataloader): + + def __iter__(self): + """ + instances: ((src, tgt), (seg_src, seg_tgt), (src_image, image_pos)) + src, tgt: Tokens of the text sample + seg_src, seg_tgt: Segment of text sample + src_image: Path of the image sample + image_pos: Position of the image in the text sample + + Returns: + src_text: [batch_size x seq_length] + src_image: [batch_size x channel_size x width x hight] + tgt: [batch_size x seq_length] + seg_text: [batch_size x seq_length] + seg_image: [batch_size x (patch_num + 1)] + seg_tgt: [batch_size x seq_length] + image_pos: [batch_size] + + """ + from torchvision.io import read_image + from torchvision.io.image import ImageReadMode + while True: + while self._empty(): + self._fill_buf() + if self.start + self.batch_size >= self.end: + instances = self.buffer[self.start:] + else: + instances = self.buffer[self.start: self.start + self.batch_size] + + self.start += self.batch_size + + src_text = [] + src_image = [] + tgt = [] + seg_text = [] + seg_image = [] + seg_tgt = [] + image_pos = [] + for ins in instances: + ins_src, ins_tgt = ins[0] + ins_seg_src, ins_seg_tgt = ins[1] + ins_src_image, ins_image_pos = ins[2] + + src_text.append(ins_src) + tgt.append(ins_tgt) + seg_text.append(ins_seg_src) + seg_tgt.append(ins_seg_tgt) + image = read_image(ins_src_image, ImageReadMode.RGB) + image = image.cuda(self.local_rank) + src_image.append(self.transform(image)) + seg_image.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) + image_pos.append(ins_image_pos) + + yield torch.LongTensor(src_text), \ + torch.stack(src_image, 0).half(), \ + torch.LongTensor(tgt), \ + torch.LongTensor(seg_text), \ + torch.LongTensor(seg_image), \ + torch.LongTensor(seg_tgt), \ + image_pos diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 9c7a682b..89786e5a 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -969,6 +969,118 @@ def worker(self, proc_id, start, end): dataset_writer.close() +class FileWithTextJsonlDataset(Dataset): + def worker(self, proc_id, start, end): + import json + num_image_tokens = 577 #int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 dim + seq_text = self.seq_length - num_image_tokens # 576 + PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] + prompt_template = { + "llama2": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "vicuna": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n" + } + # self.prompt_template 待添加 + # try: + # if self.prompt_template == "llama2": + # prompt_overall = prompt_template["llama2"] + # elif self.prompt_template == "vicuna": + # prompt_overall = prompt_template["vicuna"] + # except: + # print("unsupported prompt template!") + # NotImplementedError + prompt_overall = prompt_template["llama2"] + print("Worker %d is building dataset ... " % proc_id) + set_seed(self.seed) + dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") + pos = 0 + skip_line = 0 + with open(self.corpus_path, mode="r", encoding="utf-8") as f: + while pos < start: + f.readline() + pos += 1 + while True: + line = json.loads(f.readline()) + pos += 1 + try: + path = line["image"] + if not os.path.isfile(path): + continue + except: + skip_line += 1 + continue + conversations = line["conversations"] + + prompt_before_image = prompt_overall + " USER:" + for i, conv in enumerate(conversations): + if i == 0: + prompt = conv["value"] + if prompt.endswith(""): + prompt_before_image = prompt_before_image + prompt.replace("","") + prompt_after_image = "\nASSISTANT:" + elif prompt.startswith(""): + prompt_before_image = prompt_before_image + "" + prompt_after_image = prompt.replace("","") + "\nASSISTANT:" + prompt_before_image_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + print("promt too long, jump for now") + continue + text_combine_id = prompt_before_image_id + prompt_after_image_id + text_combine_seg = seg_before_image + seg_after_image + tgt_id = [PAD_ID] * (len(text_combine_id) + num_image_tokens - 1) + tgt_seg = [0] * len(tgt_id) + elif i % 2 == 0: # human + prompt = conv["value"] + prompt_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:") + ) + text_combine_id = text_combine_id + prompt_id + text_combine_seg = text_combine_seg + [1] * len(prompt_id) + tgt_id = tgt_id + [PAD_ID] * len(prompt_id) + tgt_seg = tgt_seg + [0] * len(prompt_id) + else: # gpt + answer = conv["value"] + answer_id = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(answer) + [SEP_TOKEN] + ) + text_combine_id = text_combine_id + answer_id + text_combine_seg = text_combine_seg + [1] * len(answer_id) + tgt_id = tgt_id + answer_id + tgt_seg = tgt_seg + [1] * len(answer_id) + + if len(tgt_id) > self.seq_length: + tgt_id = tgt_id[:self.seq_length] + tgt_seg = tgt_seg[:self.seq_length] + pad_num = self.seq_length - len(tgt_id) + tgt_id = tgt_id + [PAD_ID] * pad_num + tgt_seg = tgt_seg + [0] * pad_num + + if len(text_combine_id) > seq_text : + text_combine_id = text_combine_id[:seq_text] + text_combine_seg = text_combine_seg[:seq_text] + pad_num = seq_text - len(text_combine_id) + text_combine_id = text_combine_id + [PAD_ID] * pad_num + text_combine_seg = text_combine_seg + [0] * pad_num + + image_pos = len(prompt_before_image_id) + src = (text_combine_id, tgt_id) + seg = (text_combine_seg, tgt_seg) + image = (path, image_pos) + + pickle.dump((src, seg, image), dataset_writer) + + if pos >= end: + break + + dataset_writer.close() + + class FileWithLabelDataset(Dataset): def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) @@ -1083,3 +1195,7 @@ def worker(self, proc_id, start, end): break dataset_writer.close() + + +class LlavaDataset(FileWithTextJsonlDataset): + pass From 1c83b300baeb9e90c10e734265de569260560e70 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Tue, 19 Dec 2023 15:58:22 +0800 Subject: [PATCH 02/18] fix dataset & name & details in llava --- models/llava/7b_config.json | 2 +- preprocess.py | 7 ++ pretrain.py | 4 +- tencentpretrain/embeddings/__init__.py | 2 +- .../embeddings/image_text_embedding.py | 19 +++-- tencentpretrain/models/model.py | 21 ++--- tencentpretrain/opts.py | 7 +- tencentpretrain/trainer.py | 20 ++--- tencentpretrain/utils/dataloader.py | 11 ++- tencentpretrain/utils/dataset.py | 76 +++++++++---------- 10 files changed, 85 insertions(+), 84 deletions(-) diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json index dd44a519..1ddf3729 100755 --- a/models/llava/7b_config.json +++ b/models/llava/7b_config.json @@ -51,4 +51,4 @@ "layernorm_positioning": "pre", "layernorm": "rms", "target": ["lm"] - } \ No newline at end of file + } diff --git a/preprocess.py b/preprocess.py index 1d313a00..dd7ec4a3 100644 --- a/preprocess.py +++ b/preprocess.py @@ -41,6 +41,13 @@ def main(): "The larger value, the higher probability of using short (truncated) sequence.") parser.add_argument("--full_sentences", action="store_true", help="Full sentences.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") + parser.add_argument("--instruction_template", choices=["sys1", "sys2"], default="sys1", + help="The instruction type for training large language-vision model.") + + # Image options. + parser.add_argument("--image_width", type=int, default=336, help="The width of the input images.") + parser.add_argument("--image_height", type=int, default=336, help="The height of the input images.") + parser.add_argument("--patch_size", type=int, default=14, help="The patch size for the input images.") # Masking options. parser.add_argument("--dynamic_masking", action="store_true", help="Dynamic masking.") diff --git a/pretrain.py b/pretrain.py index aa8bbc08..f0dc180d 100644 --- a/pretrain.py +++ b/pretrain.py @@ -13,8 +13,8 @@ def main(): help="Path of the preprocessed dataset.") parser.add_argument("--pretrained_model_path", type=str, default=None, help="Path of the pretrained model.") - parser.add_argument("--vit_model_path", type=str, default=None, - help="Path of the Vit pretrained model.") + parser.add_argument("--vision_model_path", type=str, default=None, + help="Path of the vision pretrained model.") parser.add_argument("--output_model_path", type=str, required=True, help="Path of the output model.") parser.add_argument("--config_path", type=str, default="models/bert/base_config.json", diff --git a/tencentpretrain/embeddings/__init__.py b/tencentpretrain/embeddings/__init__.py index 40167993..cbf6f979 100644 --- a/tencentpretrain/embeddings/__init__.py +++ b/tencentpretrain/embeddings/__init__.py @@ -18,4 +18,4 @@ __all__ = ["Embedding", "WordEmbedding", "PosEmbedding", "SegEmbedding", "SinusoidalposEmbedding", "DualEmbedding", "PatchEmbedding", "WordPatchEmbedding", "SpeechEmbedding", - "MaskedPatchEmbedding", "str2embedding", "ImageTextEmbedding"] + "MaskedPatchEmbedding", "ImageTextEmbedding", "str2embedding"] diff --git a/tencentpretrain/embeddings/image_text_embedding.py b/tencentpretrain/embeddings/image_text_embedding.py index b5d378a9..7ad1e321 100755 --- a/tencentpretrain/embeddings/image_text_embedding.py +++ b/tencentpretrain/embeddings/image_text_embedding.py @@ -18,7 +18,7 @@ class ImageTextEmbedding(nn.Module): ''' def __init__(self, args, vocab_size): super(ImageTextEmbedding, self).__init__() - # Vit model for vision features + # vision model for vision features vision_encoder_args = copy.deepcopy(vars(args)) vision_encoder_args.update(args.image_text_emb["vision_encoder"]) vision_encoder_args = Namespace(**vision_encoder_args) @@ -28,7 +28,7 @@ def __init__(self, args, vocab_size): self.vision_embedding.update(tmp_emb, embedding_name) self.vision_encoder = str2encoder[vision_encoder_args.encoder](vision_encoder_args) - # map the output of ViT into the same space as the text features + # map the output of vision model into the same space as the text features projection_args = copy.deepcopy(vars(args)) projection_args.update(args.image_text_emb["projection"]) projection_args = Namespace(**projection_args) @@ -49,21 +49,20 @@ def __init__(self, args, vocab_size): def forward(self, src, seg=None): src_text, src_image, seg_text, seg_image, image_pos = src - # seg_text, seg_image, image_pos = seg # image features with torch.no_grad(): - emb = self.vision_embedding(src_image, seg_image) - image_emb = self.vision_encoder(emb, seg_image) + image_emb = self.vision_embedding(src_image, seg_image) + image_emb = self.vision_encoder(image_emb, seg_image) image_emb = self.projection(image_emb) # text embedding text_emb = self.text_embedding(src_text, seg_text) # combine text and image if text_emb.shape[0] == 1: - emb_combine = torch.cat((text_emb[:,:image_pos[0],:], image_emb, text_emb[:,image_pos[0]:,:]), 1) + emb = torch.cat((text_emb[:,:image_pos[0],:], image_emb, text_emb[:,image_pos[0]:,:]), 1) else: - emb_combine = torch.cat((text_emb[0,:image_pos[0],:], image_emb[0], text_emb[0,image_pos[0]:,:]), 0).unsqueeze(0) + emb = torch.cat((text_emb[0,:image_pos[0],:], image_emb[0], text_emb[0,image_pos[0]:,:]), 0).unsqueeze(0) for i in range(1, text_emb.shape[0]): tmp = torch.cat((text_emb[i,:image_pos[i],:], image_emb[i], text_emb[i,image_pos[i]:,:]), 0).unsqueeze(0) - emb_combine = torch.cat((emb_combine, tmp), 0) - # seg_combine = torch.cat((seg_image, seg_text), 1) - return emb_combine + emb = torch.cat((emb, tmp), 0) + + return emb diff --git a/tencentpretrain/models/model.py b/tencentpretrain/models/model.py index b7566e16..c46abe1b 100755 --- a/tencentpretrain/models/model.py +++ b/tencentpretrain/models/model.py @@ -28,15 +28,18 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target): if self.decoder is not None and args.share_embedding: self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight - # add for llava - if "freeze_encoder" in args and args.freeze_encoder: - for name, param in self.encoder.named_parameters(): - param.requires_grad = False - for name, param in self.embedding.named_parameters(): - if "image_text.projection" not in name: - param.requires_grad = False - for name, param in self.target.named_parameters(): - param.requires_grad = False + + if args.freeze_parameters: + name_mapping = { + "embedding": self.embedding, "encoder": self.encoder, "tgt_embedding": self.tgt_embedding, + "decoder": self.decoder, "target": self.target + } + for freeze_name in args.freeze_parameters: + if name_mapping[freeze_name] is None: + continue + for name, param in name_mapping[freeze_name].named_parameters(): + if args.freeze_exclude_by_name == "" or args.freeze_exclude_by_name not in name: + param.requires_grad = False def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None): emb = self.embedding(src, seg) diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 6f1cd6d9..6120ba18 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -52,9 +52,10 @@ def model_opts(parser): help="whether use alibi position embedding.") parser.add_argument("--layer_number_scale", action="store_true", help="whether use layer number scaling.") - # add for llava - parser.add_argument("--freeze_encoder", action="store_true", - help="whether freeze the encoder parameters.") + parser.add_argument("--freeze_parameters", choices=["embedding", "encoder", "tgt_embedding", "decoder", "target"], + default="", nargs='+', help="Which module to be frozen during training.") + parser.add_argument("--freeze_exclude_by_name", type=str, default="", + help="Exclude some modules with the specific string in the name when freezing parameters.") vision_opts(parser) audio_opts(parser) diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index f95c69df..2aeaac7c 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -63,9 +63,7 @@ def init_model(args): for shard_file in shard_filenames: model_for_training = _load_state_dict_into_model(model_for_training, shard_file, "") else: - args.logger.info("loading: {}".format(args.pretrained_model_path)) model_for_training = _load_state_dict_into_model(model_for_training, args.pretrained_model_path, "") - args.logger.info("loaded: {}".format(args.pretrained_model_path)) if args.lora_pretrained_model_path is not None: model_for_training = _load_state_dict_into_model(model_for_training, args.lora_pretrained_model_path, "") elif args.deepspeed and args.use_mp: @@ -95,12 +93,12 @@ def init_model(args): model_for_dataloader = build_vqgan_model(args) else: model_for_dataloader = None - # add for llava - if args.vit_model_path is not None: - args.logger.info("loading: {}".format(args.vit_model_path)) - # model_for_training = _load_state_dict_into_model(model_for_training, args.vit_model_path, "embedding.image_text.vision_") - keys_info = model_for_training.load_state_dict(torch.load(args.vit_model_path, map_location="cpu"), strict=False) - args.logger.info("loaded: {}".format(args.vit_model_path)) + + if args.vision_model_path is not None: + args.logger.info("loading: {}".format(args.vision_model_path)) + # model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_path, "embedding.image_text.vision_") + keys_info = model_for_training.load_state_dict(torch.load(args.vision_model_path, map_location="cpu"), strict=False) + args.logger.info("loaded: {}".format(args.vision_model_path)) args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) return model_for_training, model_for_dataloader @@ -697,14 +695,8 @@ def worker(local_rank, gpu_ranks, args): # Build model. model_for_training, model_for_dataloader = init_model(args) - # add for llava if global_rank == 0: args.logger.info("model: {}".format(model_for_training)) - # for name, param in model_for_training.named_parameters(): - # args.logger.info("name: {}, requires_grad: {}".format(name, param.requires_grad)) - # import pdb - # pdb.set_trace() - # add for llava end # Build optimizer. custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training) diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index 376e72d1..e50a734d 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -971,7 +971,7 @@ def __iter__(self): """ instances: ((src, tgt), (seg_src, seg_tgt), (src_image, image_pos)) src, tgt: Tokens of the text sample - seg_src, seg_tgt: Segment of text sample + seg_src_nums, seg_tgt_nums: Number of the segment information of text sample src_image: Path of the image sample image_pos: Position of the image in the text sample @@ -987,6 +987,7 @@ def __iter__(self): """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode + seg_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1 while True: while self._empty(): self._fill_buf() @@ -1006,17 +1007,21 @@ def __iter__(self): image_pos = [] for ins in instances: ins_src, ins_tgt = ins[0] - ins_seg_src, ins_seg_tgt = ins[1] + ins_seg_nums_src, ins_seg_nums_tgt = ins[1] ins_src_image, ins_image_pos = ins[2] src_text.append(ins_src) tgt.append(ins_tgt) + ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1] + ins_seg_tgt = [] + for i, num in enumerate(ins_seg_nums_tgt): + ins_seg_tgt = ins_seg_tgt + [i % 2] * num seg_text.append(ins_seg_src) seg_tgt.append(ins_seg_tgt) image = read_image(ins_src_image, ImageReadMode.RGB) image = image.cuda(self.local_rank) src_image.append(self.transform(image)) - seg_image.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) + seg_image.append([1] * seg_num) image_pos.append(ins_image_pos) yield torch.LongTensor(src_text), \ diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 89786e5a..395662ef 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -57,6 +57,7 @@ def __init__(self, args, vocab, tokenizer): self.span_max_length = args.span_max_length self.docs_buffer_size = args.docs_buffer_size self.dup_factor = args.dup_factor + self.args = args def build_and_save(self, workers_num): """ @@ -972,23 +973,22 @@ def worker(self, proc_id, start, end): class FileWithTextJsonlDataset(Dataset): def worker(self, proc_id, start, end): import json - num_image_tokens = 577 #int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 dim - seq_text = self.seq_length - num_image_tokens # 576 + num_image_tokens = int(self.args.image_width / self.args.patch_size) * int(self.args.image_height / self.args.patch_size) + 1 + seq_text = self.seq_length - num_image_tokens PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] - prompt_template = { - "llama2": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", - "vicuna": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n" + instruction_template = { + "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n" } - # self.prompt_template 待添加 - # try: - # if self.prompt_template == "llama2": - # prompt_overall = prompt_template["llama2"] - # elif self.prompt_template == "vicuna": - # prompt_overall = prompt_template["vicuna"] - # except: - # print("unsupported prompt template!") - # NotImplementedError - prompt_overall = prompt_template["llama2"] + try: + if self.args.instruction_template == "sys1": + instruction_overall = instruction_template["sys1"] + elif self.args.instruction_template == "sys2": + instruction_overall = instruction_template["sys2"] + except: + print("unsupported instruction template!") + NotImplementedError + print("Worker %d is building dataset ... " % proc_id) set_seed(self.seed) dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") @@ -1010,7 +1010,8 @@ def worker(self, proc_id, start, end): continue conversations = line["conversations"] - prompt_before_image = prompt_overall + " USER:" + prompt_before_image = instruction_overall + " USER:" + text_combine_seg_nums, tgt_seg_nums = [], [] for i, conv in enumerate(conversations): if i == 0: prompt = conv["value"] @@ -1020,60 +1021,53 @@ def worker(self, proc_id, start, end): elif prompt.startswith(""): prompt_before_image = prompt_before_image + "" prompt_after_image = prompt.replace("","") + "\nASSISTANT:" - prompt_before_image_id = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(prompt_before_image) - ) - prompt_after_image_id = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(prompt_after_image) - ) + prompt_before_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_before_image)) + prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) seg_before_image = [1] * len(prompt_before_image_id) seg_after_image = [1] * len(prompt_after_image_id) if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: print("promt too long, jump for now") continue text_combine_id = prompt_before_image_id + prompt_after_image_id - text_combine_seg = seg_before_image + seg_after_image tgt_id = [PAD_ID] * (len(text_combine_id) + num_image_tokens - 1) - tgt_seg = [0] * len(tgt_id) + tgt_seg_nums = [len(tgt_id)] elif i % 2 == 0: # human prompt = conv["value"] - prompt_id = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:") - ) + prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:")) text_combine_id = text_combine_id + prompt_id - text_combine_seg = text_combine_seg + [1] * len(prompt_id) tgt_id = tgt_id + [PAD_ID] * len(prompt_id) - tgt_seg = tgt_seg + [0] * len(prompt_id) + if len(tgt_seg_nums) == 1: + tgt_seg_nums[0] = tgt_seg_nums[0] + len(prompt_id) + else: + tgt_seg_nums = tgt_seg_nums + [len(prompt_id)] else: # gpt answer = conv["value"] - answer_id = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(answer) + [SEP_TOKEN] - ) + answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN]) text_combine_id = text_combine_id + answer_id - text_combine_seg = text_combine_seg + [1] * len(answer_id) tgt_id = tgt_id + answer_id - tgt_seg = tgt_seg + [1] * len(answer_id) + tgt_seg_nums = tgt_seg_nums + [len(answer_id)] if len(tgt_id) > self.seq_length: tgt_id = tgt_id[:self.seq_length] - tgt_seg = tgt_seg[:self.seq_length] pad_num = self.seq_length - len(tgt_id) tgt_id = tgt_id + [PAD_ID] * pad_num - tgt_seg = tgt_seg + [0] * pad_num + while sum(tgt_seg_nums) > self.seq_length: + tgt_seg_nums = tgt_seg_nums[:-1] + pad_num = self.seq_length - sum(tgt_seg_nums) + tgt_seg_nums = tgt_seg_nums + [pad_num] if len(text_combine_id) > seq_text : text_combine_id = text_combine_id[:seq_text] - text_combine_seg = text_combine_seg[:seq_text] + pad_num = seq_text - len(text_combine_id) + text_combine_seg_nums = [len(text_combine_id), pad_num] text_combine_id = text_combine_id + [PAD_ID] * pad_num - text_combine_seg = text_combine_seg + [0] * pad_num image_pos = len(prompt_before_image_id) src = (text_combine_id, tgt_id) - seg = (text_combine_seg, tgt_seg) + seg_nums = (text_combine_seg_nums, tgt_seg_nums) image = (path, image_pos) - - pickle.dump((src, seg, image), dataset_writer) + pickle.dump((src, seg_nums, image), dataset_writer) if pos >= end: break From d1ca0ac04cac7dbac763b521711198a12b0dc1fb Mon Sep 17 00:00:00 2001 From: janinezhao Date: Thu, 21 Dec 2023 15:23:52 +0800 Subject: [PATCH 03/18] dataset support jsonl -> json --- tencentpretrain/utils/dataset.py | 20 +++++++++----------- tencentpretrain/utils/misc.py | 5 +++++ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 395662ef..da25e113 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -970,7 +970,7 @@ def worker(self, proc_id, start, end): dataset_writer.close() -class FileWithTextJsonlDataset(Dataset): +class FileWithTextJsonDataset(Dataset): def worker(self, proc_id, start, end): import json num_image_tokens = int(self.args.image_width / self.args.patch_size) * int(self.args.image_height / self.args.patch_size) + 1 @@ -992,23 +992,21 @@ def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) set_seed(self.seed) dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") - pos = 0 - skip_line = 0 + pos = start + skip_item = 0 with open(self.corpus_path, mode="r", encoding="utf-8") as f: - while pos < start: - f.readline() - pos += 1 + datas = json.load(f) while True: - line = json.loads(f.readline()) + item = datas[pos] pos += 1 try: - path = line["image"] + path = item["image"] if not os.path.isfile(path): continue except: - skip_line += 1 + skip_item += 1 continue - conversations = line["conversations"] + conversations = item["conversations"] prompt_before_image = instruction_overall + " USER:" text_combine_seg_nums, tgt_seg_nums = [], [] @@ -1191,5 +1189,5 @@ def worker(self, proc_id, start, end): dataset_writer.close() -class LlavaDataset(FileWithTextJsonlDataset): +class LlavaDataset(FileWithTextJsonDataset): pass diff --git a/tencentpretrain/utils/misc.py b/tencentpretrain/utils/misc.py index 01545650..9cf56f06 100644 --- a/tencentpretrain/utils/misc.py +++ b/tencentpretrain/utils/misc.py @@ -4,6 +4,11 @@ def count_lines(file_path): lines_num = 0 + if file_path.endswith(".json"): + import json + with open(file_path, 'rb') as f: + data = json.load(f) + return len(data) with open(file_path, 'rb') as f: while True: data = f.read(2 ** 20) From b0ab452f6f42032cf033f27dc8eca8ce67f44b8a Mon Sep 17 00:00:00 2001 From: janinezhao Date: Fri, 22 Dec 2023 15:18:15 +0800 Subject: [PATCH 04/18] fix model loader --- tencentpretrain/model_loader.py | 9 ++++++++- tencentpretrain/trainer.py | 7 ++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tencentpretrain/model_loader.py b/tencentpretrain/model_loader.py index a8e0b175..00b17bd0 100644 --- a/tencentpretrain/model_loader.py +++ b/tencentpretrain/model_loader.py @@ -1,5 +1,6 @@ import os import torch +import collections from tencentpretrain import mpu @@ -18,7 +19,7 @@ def load_model(model, model_path, lora_pretrained_model_path=None): return model -def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""): +def _load_state_dict_into_model(model_to_load, model_path, start_prefix="", missing_prefix=""): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it @@ -53,6 +54,12 @@ def load(module, state_dict, prefix=""): for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") + if missing_prefix != "": + state_dict_withprefix = collections.OrderedDict() + for k in state_dict.keys(): + state_dict_withprefix[missing_prefix + k] = state_dict[k] + del state_dict + state_dict = state_dict_withprefix load(model_to_load, state_dict, prefix=start_prefix) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 2aeaac7c..9a2dc5e0 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -96,11 +96,8 @@ def init_model(args): if args.vision_model_path is not None: args.logger.info("loading: {}".format(args.vision_model_path)) - # model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_path, "embedding.image_text.vision_") - keys_info = model_for_training.load_state_dict(torch.load(args.vision_model_path, map_location="cpu"), strict=False) - args.logger.info("loaded: {}".format(args.vision_model_path)) - args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) - args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_path, missing_prefix=args.vision_model_missing_prefix) + # model_for_training = load_model(model_for_training, args.vision_model_path) return model_for_training, model_for_dataloader From 2f8fcb4029aa11acba4dddab5132d5e5f159dac5 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Mon, 25 Dec 2023 09:56:54 +0800 Subject: [PATCH 05/18] fix model laoder --- ...peed.py => generate_lm_llava_deepspeed.py} | 44 ++++++++----------- tencentpretrain/model_loader.py | 13 ++++-- 2 files changed, 28 insertions(+), 29 deletions(-) rename scripts/{generate_llava_json_deepspeed.py => generate_lm_llava_deepspeed.py} (86%) diff --git a/scripts/generate_llava_json_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py similarity index 86% rename from scripts/generate_llava_json_deepspeed.py rename to scripts/generate_lm_llava_deepspeed.py index 80ae927f..f54ad519 100755 --- a/scripts/generate_llava_json_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -27,7 +27,7 @@ from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts from tencentpretrain.opts import deepspeed_opts from tencentpretrain.utils.logging import init_logger -from tencentpretrain.model_loader import _load_state_dict_into_model +from tencentpretrain.model_loader import _load_state_dict_into_model, load_model from tencentpretrain.utils.misc import pooling, ZeroOneNormalize @@ -96,12 +96,9 @@ def load_or_initialize_parameters(args, model): keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) - if args.vit_model_path is not None: - args.logger.info("loading model from {0}".format(args.vit_model_path)) - # model = _load_state_dict_into_model(model, args.vit_model_path, "embedding.image_text.vision_") - keys_info = model.load_state_dict(torch.load(args.vit_model_path, map_location="cpu"), strict=False) - args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) - args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + if args.vision_model_path is not None: + args.logger.info("loading model from {0}".format(args.vision_model_path)) + model = load_model(model, args.vision_model_path, missing_prefix="embedding.image_text.vision_") else: # Initialize with normal distribution. for n, p in list(model.named_parameters()): @@ -117,12 +114,10 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--vit_model_path", type=str, default=None, - help="Pretrained model of Vit.") - parser.add_argument("--connector_model_path", type=str, default=None, - help="Pretrained model of Connector.") - parser.add_argument("--prompt_template", type=str, choices=["llama2", "vicuna"], - help="give the llm type to choose a prompt", default="llama2") + parser.add_argument("--vision_model_path", type=str, default=None, + help="Pretrained vision model.") + parser.add_argument("--instruction_template", type=str, choices=["sys1", "sys2"], + help="The instruction type for training large language-vision model.", default="sys3") tokenizer_opts(parser) @@ -147,15 +142,15 @@ def load_or_initialize_parameters(args, model): # Load or initialize parameters. if args.enable_zero3: + print("enable_zero3:", args.enable_zero3) with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): model = LLaVaGenerate(args) if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) - if args.vit_model_path is not None: - model = _load_state_dict_into_model(model, args.vit_model_path, "embedding.image_text.vision_") + if args.vision_model_path is not None: + model = _load_state_dict_into_model(model, args.vision_model_path, missing_prefix="embedding.image_text.vision_") else: model = LLaVaGenerate(args) - # Load or initialize parameters. load_or_initialize_parameters(args, model) deepspeed.init_distributed() @@ -170,18 +165,17 @@ def load_or_initialize_parameters(args, model): ZeroOneNormalize() ]) prompt_template = { - "llama2": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", - "vicuna": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n"} - num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 dim - seq_text = args.seq_length - num_image_tokens # 576 + "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n", + "sys3": " You are a helpful language and vision assistant. \n" + } + num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 + 1 dim + seq_text = args.seq_length - num_image_tokens outf = open(args.prediction_path, mode="w", encoding="utf-8") input_f = open(args.test_path, mode="r", encoding="utf-8") datas = json.load(input_f) try: - if args.prompt_template == "llama2": - prompt_overall = prompt_template["llama2"] - elif args.prompt_template == "vicuna": - prompt_overall = prompt_template["vicuna"] + prompt_overall = prompt_template[args.instruction_template] except: args.logger.info("unsupported prompt template!") NotImplementedError @@ -283,12 +277,10 @@ def load_or_initialize_parameters(args, model): text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1) text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1) - # print("next_token:", next_token) if next_token.cpu().tolist() == SEP_ID: break if rank == 0 and text_tensor is not None: - # outf.write("\t".join(line)+"\n") tokens = [token_id.item() for token_id in text_tensor[0]] if args.tokenizer.sp_model is not None: generated_sentence = args.tokenizer.sp_model.decode(tokens) diff --git a/tencentpretrain/model_loader.py b/tencentpretrain/model_loader.py index 00b17bd0..29db8ef0 100644 --- a/tencentpretrain/model_loader.py +++ b/tencentpretrain/model_loader.py @@ -4,16 +4,23 @@ from tencentpretrain import mpu -def load_model(model, model_path, lora_pretrained_model_path=None): +def load_model(model, model_path, lora_pretrained_model_path=None, missing_prefix=""): """ Load model from saved weights. """ + state_dict = torch.load(model_path, map_location="cpu") + if missing_prefix != "": + state_dict_withprefix = collections.OrderedDict() + for k in state_dict.keys(): + state_dict_withprefix[missing_prefix + k] = state_dict[k] + del state_dict + state_dict = state_dict_withprefix if hasattr(model, "module"): - model.module.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) + model.module.load_state_dict(state_dict, strict=False) if lora_pretrained_model_path is not None: model.module.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) else: - model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) + model.load_state_dict(state_dict, strict=False) if lora_pretrained_model_path is not None: model.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) return model From 9a3bf31e919e058de85f1498424ca33fdb253c24 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Tue, 26 Dec 2023 17:18:47 +0800 Subject: [PATCH 06/18] fix image read --- tencentpretrain/utils/dataloader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index e50a734d..e2e9bba5 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -1018,12 +1018,17 @@ def __iter__(self): ins_seg_tgt = ins_seg_tgt + [i % 2] * num seg_text.append(ins_seg_src) seg_tgt.append(ins_seg_tgt) - image = read_image(ins_src_image, ImageReadMode.RGB) + try: + image = read_image(ins_src_image, ImageReadMode.RGB) + except: + print("Something is wrong when reading {}, just skipped!".format(ins_src_image)) + continue image = image.cuda(self.local_rank) src_image.append(self.transform(image)) seg_image.append([1] * seg_num) image_pos.append(ins_image_pos) - + if len(src_image) == 0: + continue yield torch.LongTensor(src_text), \ torch.stack(src_image, 0).half(), \ torch.LongTensor(tgt), \ From c0ea9366b98313e9ed06ddc9ae95a5040675c106 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Wed, 3 Jan 2024 15:45:38 +0800 Subject: [PATCH 07/18] fix vision preprocess --- models/llava/7b_config.json | 17 +- preprocess.py | 9 +- pretrain.py | 6 +- ...ava_from_huggingface_to_tencentpretrain.py | 10 +- scripts/generate_lm_llava_deepspeed.py | 81 ++++---- tencentpretrain/embeddings/__init__.py | 6 +- ...edding.py => vision_language_embedding.py} | 12 +- tencentpretrain/opts.py | 2 +- tencentpretrain/trainer.py | 8 +- tencentpretrain/utils/dataloader.py | 3 + tencentpretrain/utils/dataset.py | 194 ++++++++---------- 11 files changed, 172 insertions(+), 176 deletions(-) rename tencentpretrain/embeddings/{image_text_embedding.py => vision_language_embedding.py} (90%) diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json index 1ddf3729..82622ece 100755 --- a/models/llava/7b_config.json +++ b/models/llava/7b_config.json @@ -1,7 +1,10 @@ { - "embedding": ["image_text"], - "image_text_emb":{ + "embedding": ["vision_language"], + "vision_language_emb":{ "vision_encoder":{ + "image_height": 336, + "image_width": 336, + "patch_size": 14, "emb_size": 1024, "feedforward_size": 4096, "hidden_size": 1024, @@ -19,7 +22,8 @@ "feed_forward": "dense", "mask": "fully_visible", "layernorm_positioning": "pre", - "layernorm":"normal" + "layernorm":"normal", + "has_cls": false }, "projection":{ "mlp_hidden_size": 4096, @@ -29,10 +33,6 @@ "embedding": ["word"] } }, - "image_height": 336, - "image_width": 336, - "patch_size": 14, - "remove_embedding_combine_layernorm": true, "emb_size": 4096, "feedforward_size": 11008, "hidden_size": 4096, @@ -50,5 +50,6 @@ "mask": "causal", "layernorm_positioning": "pre", "layernorm": "rms", + "layernorm_eps": 1e-5, "target": ["lm"] - } + } \ No newline at end of file diff --git a/preprocess.py b/preprocess.py index dd7ec4a3..8c0f1a00 100644 --- a/preprocess.py +++ b/preprocess.py @@ -34,6 +34,8 @@ def main(): help="The buffer size of documents in memory, specific to targets that require negative sampling.") parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.") parser.add_argument("--tgt_seq_length", type=int, default=128, help="Target sequence length of instances.") + parser.add_argument("--vision_seq_length_in_VL", type=int, default=576, + help="Number of image patches in the vision language model(LLaVa).") parser.add_argument("--dup_factor", type=int, default=5, help="Duplicate instances multiple times.") parser.add_argument("--short_seq_prob", type=float, default=0.1, @@ -41,13 +43,6 @@ def main(): "The larger value, the higher probability of using short (truncated) sequence.") parser.add_argument("--full_sentences", action="store_true", help="Full sentences.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") - parser.add_argument("--instruction_template", choices=["sys1", "sys2"], default="sys1", - help="The instruction type for training large language-vision model.") - - # Image options. - parser.add_argument("--image_width", type=int, default=336, help="The width of the input images.") - parser.add_argument("--image_height", type=int, default=336, help="The height of the input images.") - parser.add_argument("--patch_size", type=int, default=14, help="The patch size for the input images.") # Masking options. parser.add_argument("--dynamic_masking", action="store_true", help="Dynamic masking.") diff --git a/pretrain.py b/pretrain.py index f0dc180d..e5cd5cfe 100644 --- a/pretrain.py +++ b/pretrain.py @@ -13,8 +13,8 @@ def main(): help="Path of the preprocessed dataset.") parser.add_argument("--pretrained_model_path", type=str, default=None, help="Path of the pretrained model.") - parser.add_argument("--vision_model_path", type=str, default=None, - help="Path of the vision pretrained model.") + parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, + help="Path of the vision pretrained model in the vision language embedding.") parser.add_argument("--output_model_path", type=str, required=True, help="Path of the output model.") parser.add_argument("--config_path", type=str, default="models/bert/base_config.json", @@ -44,6 +44,8 @@ def main(): # Model options. model_opts(parser) + parser.add_argument("--vision_model_missing_prefix", type=str, required=False, default="embedding.vision_language.vision_", + help="Extra prefix when loading the vision pretrained model as the embedding of the whole model.") # Model parallelism options. mp_opts(parser) diff --git a/scripts/convert_llava_from_huggingface_to_tencentpretrain.py b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py index a6c75292..606a0f9c 100755 --- a/scripts/convert_llava_from_huggingface_to_tencentpretrain.py +++ b/scripts/convert_llava_from_huggingface_to_tencentpretrain.py @@ -39,7 +39,7 @@ def get_weight_from_name(layer_name): def unpermute(w): return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim) -output_model["embedding.image_text.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight") +output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight") for i in range(layers_num): @@ -69,9 +69,9 @@ def unpermute(w): output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight") output_model["target.lm.output_layer.weight"] = get_weight_from_name("lm_head.weight") -output_model["embedding.image_text.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight") -output_model["embedding.image_text.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias") -output_model["embedding.image_text.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight") -output_model["embedding.image_text.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias") +output_model["embedding.vision_language.projection.0.weight"] = get_weight_from_name("model.mm_projector.0.weight") +output_model["embedding.vision_language.projection.0.bias"] = get_weight_from_name("model.mm_projector.0.bias") +output_model["embedding.vision_language.projection.2.weight"] = get_weight_from_name("model.mm_projector.2.weight") +output_model["embedding.vision_language.projection.2.bias"] = get_weight_from_name("model.mm_projector.2.bias") torch.save(output_model, args.output_model_path) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index f54ad519..72d2c83d 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -47,8 +47,6 @@ def __init__(self, args): self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm") print("tokenizer vocab nums:", len(args.tokenizer.vocab)) - self.num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) - def forward(self, src_text, seg_text, src_image, seg_image, image_pos): """ Args: @@ -59,7 +57,7 @@ def forward(self, src_text, seg_text, src_image, seg_image, image_pos): # Embedding. src = src_text, src_image, seg_text, seg_image, image_pos emb = self.embedding(src, None) - seg = torch.cat((seg_image, seg_text), 1) + seg = torch.cat((seg_image[:,1:], seg_text), 1) # encoder output = self.encoder(emb, seg) # # Target. @@ -96,9 +94,9 @@ def load_or_initialize_parameters(args, model): keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) - if args.vision_model_path is not None: - args.logger.info("loading model from {0}".format(args.vision_model_path)) - model = load_model(model, args.vision_model_path, missing_prefix="embedding.image_text.vision_") + if args.vision_model_in_VL_emb_path is not None: + args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) + model = load_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") else: # Initialize with normal distribution. for n, p in list(model.named_parameters()): @@ -114,10 +112,10 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--vision_model_path", type=str, default=None, - help="Pretrained vision model.") - parser.add_argument("--instruction_template", type=str, choices=["sys1", "sys2"], - help="The instruction type for training large language-vision model.", default="sys3") + parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, + help="Path of the vision pretrained model in the vision language embedding.") + parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"], + help="The instruction type for training large language-vision model.", default="sys0") tokenizer_opts(parser) @@ -147,8 +145,8 @@ def load_or_initialize_parameters(args, model): model = LLaVaGenerate(args) if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) - if args.vision_model_path is not None: - model = _load_state_dict_into_model(model, args.vision_model_path, missing_prefix="embedding.image_text.vision_") + if args.vision_model_in_VL_emb_path is not None: + model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") else: model = LLaVaGenerate(args) load_or_initialize_parameters(args, model) @@ -160,16 +158,24 @@ def load_or_initialize_parameters(args, model): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval() + image_height = args.vision_language_emb["vision_encoder"]["image_height"] + image_width = args.vision_language_emb["vision_encoder"]["image_width"] + patch_size = args.vision_language_emb["vision_encoder"]["patch_size"] + transform = transforms.Compose([ - transforms.Resize((args.image_height, args.image_width)), - ZeroOneNormalize() + transforms.Resize(min(image_height, image_width)), + transforms.CenterCrop((image_height, image_width)), + ZeroOneNormalize(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) prompt_template = { + "sys0": "", "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n", - "sys3": " You are a helpful language and vision assistant. \n" + "sys3": " You are a helpful language and vision assistant. \n", + "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n" } - num_image_tokens = int(args.image_width / args.patch_size) * int(args.image_height / args.patch_size) + 1 # 336/14-14 --> 576 + 1 dim + num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 seq_text = args.seq_length - num_image_tokens outf = open(args.prediction_path, mode="w", encoding="utf-8") input_f = open(args.test_path, mode="r", encoding="utf-8") @@ -194,9 +200,9 @@ def load_or_initialize_parameters(args, model): print("sth wrong with item{}".format(item)) continue - prompt_before_image = prompt_overall + " USER:" + prompt_before_image = prompt_overall + " USER: " ground_truth = [] - text_combine_id = [] + prompt_answer_id = [] if "conversations" in item: conversations = item["conversations"] for i, conv in enumerate(conversations): @@ -207,10 +213,14 @@ def load_or_initialize_parameters(args, model): prompt = conv["value"] if prompt.endswith(""): prompt_before_image = prompt_before_image + prompt.replace("","") - prompt_after_image = "\nASSISTANT:" + prompt_after_image = "\nASSISTANT:" elif prompt.startswith(""): prompt_before_image = prompt_before_image + "" - prompt_after_image = prompt.replace("","") + "\nASSISTANT:" + prompt_after_image = prompt.replace("","") + "\nASSISTANT: " + else: + prompt_before_image = prompt_before_image + "" + prompt_after_image = "\n" + prompt + " ASSISTANT: " + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_before_image) ) @@ -222,16 +232,16 @@ def load_or_initialize_parameters(args, model): if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: args.logger.info("promt too long, jump for now") break - text_combine_id = [prompt_before_image_id + prompt_after_image_id] - text_combine_seg = [seg_before_image + seg_after_image] + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] elif i % 2 == 0: # human prompt = conv["value"] prompt_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:") + args.tokenizer.tokenize(" USER: " + prompt + " ASSISTANT: ") ) - if text_combine_id: - text_combine_id.append(prompt_id) - text_combine_seg.append(text_combine_seg + [1] * len(prompt_id)) + if prompt_answer_id: + prompt_answer_id.append(prompt_id) + prompt_answer_seg.append(prompt_answer_seg + [1] * len(prompt_id)) else: args.logger.info("no prompt, or prompt too long, jumping") break @@ -239,8 +249,8 @@ def load_or_initialize_parameters(args, model): ground_truth.append(conv["value"]) else: prompt = item["instruction"] - prompt_before_image = prompt_before_image + prompt + "" - prompt_after_image = "\nASSISTANT:" + prompt_before_image = prompt_before_image + "" + prompt_after_image = "\n" + prompt + "\nASSISTANT: " prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_before_image) ) @@ -252,8 +262,8 @@ def load_or_initialize_parameters(args, model): if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: args.logger.info("promt too long, jump for now") break - text_combine_id = [prompt_before_image_id + prompt_after_image_id] - text_combine_seg = [seg_before_image + seg_after_image] + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] image_pos = len(prompt_before_image_id) @@ -262,14 +272,15 @@ def load_or_initialize_parameters(args, model): image_pos = torch.LongTensor([image_pos]).to(device) SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) text_tensor = None - for i, prompt in enumerate(text_combine_id): + + for i, prompt in enumerate(prompt_answer_id): if text_tensor is None: - text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([text_combine_seg[i]]).to(device) + text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([prompt_answer_seg[i]]).to(device) else: text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1) - text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([text_combine_seg[i]]).to(device)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([prompt_answer_seg[i]]).to(device)], dim=1) - while text_tensor.shape[1] + num_image_tokens < args.seq_length: + while text_tensor.shape[1] + num_image_tokens <= args.seq_length: output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos) next_token_logits = output[0][-1] / args.temperature filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) @@ -288,4 +299,4 @@ def load_or_initialize_parameters(args, model): generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) print(item) print(generated_sentence) - outf.write(generated_sentence + "\n\n") + print(generated_sentence+ "\n\n", file=outf) diff --git a/tencentpretrain/embeddings/__init__.py b/tencentpretrain/embeddings/__init__.py index cbf6f979..a90ae889 100644 --- a/tencentpretrain/embeddings/__init__.py +++ b/tencentpretrain/embeddings/__init__.py @@ -8,14 +8,14 @@ from tencentpretrain.embeddings.word_patch_embedding import WordPatchEmbedding from tencentpretrain.embeddings.speech_embedding import SpeechEmbedding from tencentpretrain.embeddings.masked_patch_embedding import MaskedPatchEmbedding -from tencentpretrain.embeddings.image_text_embedding import ImageTextEmbedding +from tencentpretrain.embeddings.vision_language_embedding import VisionLanguageEmbedding str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "seg": SegEmbedding, "sinusoidalpos": SinusoidalposEmbedding, "dual": DualEmbedding, "patch": PatchEmbedding, "word_patch": WordPatchEmbedding, "speech": SpeechEmbedding, - "masked_patch": MaskedPatchEmbedding, "image_text": ImageTextEmbedding} + "masked_patch": MaskedPatchEmbedding, "vision_language": VisionLanguageEmbedding} __all__ = ["Embedding", "WordEmbedding", "PosEmbedding", "SegEmbedding", "SinusoidalposEmbedding", "DualEmbedding", "PatchEmbedding", "WordPatchEmbedding", "SpeechEmbedding", - "MaskedPatchEmbedding", "ImageTextEmbedding", "str2embedding"] + "MaskedPatchEmbedding", "VisionLanguageEmbedding", "str2embedding"] diff --git a/tencentpretrain/embeddings/image_text_embedding.py b/tencentpretrain/embeddings/vision_language_embedding.py similarity index 90% rename from tencentpretrain/embeddings/image_text_embedding.py rename to tencentpretrain/embeddings/vision_language_embedding.py index 7ad1e321..9b7c6cd1 100755 --- a/tencentpretrain/embeddings/image_text_embedding.py +++ b/tencentpretrain/embeddings/vision_language_embedding.py @@ -12,15 +12,15 @@ str2embedding = {"word": WordEmbedding, "pos": PosEmbedding, "patch": PatchEmbedding} -class ImageTextEmbedding(nn.Module): +class VisionLanguageEmbedding(nn.Module): ''' an combination of a vision encoder and a text embedding ''' def __init__(self, args, vocab_size): - super(ImageTextEmbedding, self).__init__() + super(VisionLanguageEmbedding, self).__init__() # vision model for vision features vision_encoder_args = copy.deepcopy(vars(args)) - vision_encoder_args.update(args.image_text_emb["vision_encoder"]) + vision_encoder_args.update(args.vision_language_emb["vision_encoder"]) vision_encoder_args = Namespace(**vision_encoder_args) self.vision_embedding = Embedding(vision_encoder_args) for embedding_name in vision_encoder_args.embedding: @@ -30,7 +30,7 @@ def __init__(self, args, vocab_size): # map the output of vision model into the same space as the text features projection_args = copy.deepcopy(vars(args)) - projection_args.update(args.image_text_emb["projection"]) + projection_args.update(args.vision_language_emb["projection"]) projection_args = Namespace(**projection_args) projection_modules = [nn.Linear(vision_encoder_args.emb_size, projection_args.mlp_hidden_size)] for _ in range(1, projection_args.num_mlp_layer): @@ -40,7 +40,7 @@ def __init__(self, args, vocab_size): # text embedding text_args = copy.deepcopy(vars(args)) - text_args.update(args.image_text_emb["text"]) + text_args.update(args.vision_language_emb["text"]) text_args = Namespace(**text_args) self.text_embedding = Embedding(text_args) for embedding_name in text_args.embedding: @@ -52,7 +52,7 @@ def forward(self, src, seg=None): # image features with torch.no_grad(): image_emb = self.vision_embedding(src_image, seg_image) - image_emb = self.vision_encoder(image_emb, seg_image) + image_emb = self.vision_encoder(image_emb, seg_image)[:,1:,:] image_emb = self.projection(image_emb) # text embedding text_emb = self.text_embedding(src_text, seg_text) diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 6120ba18..7085bcc7 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -70,7 +70,7 @@ def vision_opts(parser): parser.add_argument("--channels_num", type=int, default=3, help="Channels num.") parser.add_argument("--image_preprocess", type=str, default=["crop", "normalize"], nargs='+', - help="Preprocess and data augmentation methods. Choices: [\"crop\", \"horizontal_flip\", \"normalize\"]. ") + help="Preprocess and data augmentation methods. Choices: [\"crop\", \"center_crop\", \"horizontal_flip\", \"normalize\"]. ") def audio_opts(parser): diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 9a2dc5e0..7301239c 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -94,9 +94,9 @@ def init_model(args): else: model_for_dataloader = None - if args.vision_model_path is not None: - args.logger.info("loading: {}".format(args.vision_model_path)) - model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_path, missing_prefix=args.vision_model_missing_prefix) + if args.vision_model_in_VL_emb_path is not None: + args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path)) + model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path, missing_prefix=args.vision_model_missing_prefix) # model_for_training = load_model(model_for_training, args.vision_model_path) return model_for_training, model_for_dataloader @@ -657,7 +657,7 @@ class LlmSftTrainer(LmTrainer): class LlavaTrainer(LmTrainer): def forward_propagation(self, batch, model): src_text, src_img, tgt, seg_text, seg_img, seg_tgt, image_pos = batch - seg = torch.cat((seg_img, seg_text), 1) + seg = torch.cat((seg_img[:,1:], seg_text), 1) loss = model((src_text, src_img, seg_text, seg_img, image_pos), tgt, seg, tgt_seg=seg_tgt) self.total_loss += loss.item() diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index e2e9bba5..03a2bcf8 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -553,6 +553,9 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca preprocess_pipeline = [] if "corp" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width))) + elif "center_crop" in args.image_preprocess: + preprocess_pipeline.append(transforms.Resize(min(self.image_height, self.image_width))) + preprocess_pipeline.append(transforms.CenterCrop((self.image_height, self.image_width))) if "horizontal_flip" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomHorizontalFlip()) preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width))) diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index da25e113..547e73fa 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -970,109 +970,6 @@ def worker(self, proc_id, start, end): dataset_writer.close() -class FileWithTextJsonDataset(Dataset): - def worker(self, proc_id, start, end): - import json - num_image_tokens = int(self.args.image_width / self.args.patch_size) * int(self.args.image_height / self.args.patch_size) + 1 - seq_text = self.seq_length - num_image_tokens - PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] - instruction_template = { - "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", - "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n" - } - try: - if self.args.instruction_template == "sys1": - instruction_overall = instruction_template["sys1"] - elif self.args.instruction_template == "sys2": - instruction_overall = instruction_template["sys2"] - except: - print("unsupported instruction template!") - NotImplementedError - - print("Worker %d is building dataset ... " % proc_id) - set_seed(self.seed) - dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") - pos = start - skip_item = 0 - with open(self.corpus_path, mode="r", encoding="utf-8") as f: - datas = json.load(f) - while True: - item = datas[pos] - pos += 1 - try: - path = item["image"] - if not os.path.isfile(path): - continue - except: - skip_item += 1 - continue - conversations = item["conversations"] - - prompt_before_image = instruction_overall + " USER:" - text_combine_seg_nums, tgt_seg_nums = [], [] - for i, conv in enumerate(conversations): - if i == 0: - prompt = conv["value"] - if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("","") - prompt_after_image = "\nASSISTANT:" - elif prompt.startswith(""): - prompt_before_image = prompt_before_image + "" - prompt_after_image = prompt.replace("","") + "\nASSISTANT:" - prompt_before_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_before_image)) - prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) - seg_before_image = [1] * len(prompt_before_image_id) - seg_after_image = [1] * len(prompt_after_image_id) - if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: - print("promt too long, jump for now") - continue - text_combine_id = prompt_before_image_id + prompt_after_image_id - tgt_id = [PAD_ID] * (len(text_combine_id) + num_image_tokens - 1) - tgt_seg_nums = [len(tgt_id)] - elif i % 2 == 0: # human - prompt = conv["value"] - prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" USER:" + prompt + "\nASSISTANT:")) - text_combine_id = text_combine_id + prompt_id - tgt_id = tgt_id + [PAD_ID] * len(prompt_id) - if len(tgt_seg_nums) == 1: - tgt_seg_nums[0] = tgt_seg_nums[0] + len(prompt_id) - else: - tgt_seg_nums = tgt_seg_nums + [len(prompt_id)] - else: # gpt - answer = conv["value"] - answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN]) - text_combine_id = text_combine_id + answer_id - tgt_id = tgt_id + answer_id - tgt_seg_nums = tgt_seg_nums + [len(answer_id)] - - if len(tgt_id) > self.seq_length: - tgt_id = tgt_id[:self.seq_length] - pad_num = self.seq_length - len(tgt_id) - tgt_id = tgt_id + [PAD_ID] * pad_num - while sum(tgt_seg_nums) > self.seq_length: - tgt_seg_nums = tgt_seg_nums[:-1] - pad_num = self.seq_length - sum(tgt_seg_nums) - tgt_seg_nums = tgt_seg_nums + [pad_num] - - if len(text_combine_id) > seq_text : - text_combine_id = text_combine_id[:seq_text] - - pad_num = seq_text - len(text_combine_id) - text_combine_seg_nums = [len(text_combine_id), pad_num] - text_combine_id = text_combine_id + [PAD_ID] * pad_num - - image_pos = len(prompt_before_image_id) - src = (text_combine_id, tgt_id) - seg_nums = (text_combine_seg_nums, tgt_seg_nums) - image = (path, image_pos) - pickle.dump((src, seg_nums, image), dataset_writer) - - if pos >= end: - break - - dataset_writer.close() - - class FileWithLabelDataset(Dataset): def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) @@ -1189,5 +1086,92 @@ def worker(self, proc_id, start, end): dataset_writer.close() -class LlavaDataset(FileWithTextJsonDataset): - pass +class LlavaDataset(Dataset): + def worker(self, proc_id, start, end): + import json + num_image_tokens = self.args.vision_seq_length_in_VL + seq_text = self.seq_length - num_image_tokens + PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] + + print("Worker %d is building dataset ... " % proc_id) + set_seed(self.seed) + dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") + pos = start + skip_item = 0 + with open(self.corpus_path, mode="r", encoding="utf-8") as f: + datas = json.load(f) + while True: + item = datas[pos] + pos += 1 + try: + path = item["image"] + if not os.path.isfile(path): + continue + except: + skip_item += 1 + continue + conversations = item["conversations"] + + prompt_before_image = " USER: " + prompt_answer_seg_nums, tgt_seg_nums = [], [] + for i, conv in enumerate(conversations): + if i == 0: + prompt = conv["value"] + if prompt.endswith(""): + prompt_before_image = prompt_before_image + prompt.replace("","") + prompt_after_image = "\nASSISTANT: " + elif prompt.startswith(""): + prompt_before_image = prompt_before_image + "" + prompt_after_image = prompt.replace("","") + "\nASSISTANT: " + prompt_before_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_before_image)) + prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + print("promt too long, jumped") + continue + prompt_answer_id = prompt_before_image_id + prompt_after_image_id + tgt_id = [PAD_ID] * (len(prompt_answer_id) + num_image_tokens - 1) + tgt_seg_nums = [len(tgt_id)] + elif i % 2 == 0: # human + prompt = conv["value"] + prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" USER: " + prompt + "\nASSISTANT: ")) + prompt_answer_id = prompt_answer_id + prompt_id + tgt_id = tgt_id + [PAD_ID] * len(prompt_id) + if len(tgt_seg_nums) == 1: + tgt_seg_nums[0] = tgt_seg_nums[0] + len(prompt_id) + else: + tgt_seg_nums = tgt_seg_nums + [len(prompt_id)] + else: # gpt + answer = conv["value"] + answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN]) + prompt_answer_id = prompt_answer_id + answer_id + tgt_id = tgt_id + answer_id + tgt_seg_nums = tgt_seg_nums + [len(answer_id)] + + if len(tgt_id) > self.seq_length: + tgt_id = tgt_id[:self.seq_length] + pad_num = self.seq_length - len(tgt_id) + tgt_id = tgt_id + [PAD_ID] * pad_num + while sum(tgt_seg_nums) > self.seq_length: + tgt_seg_nums = tgt_seg_nums[:-1] + pad_num = self.seq_length - sum(tgt_seg_nums) + tgt_seg_nums = tgt_seg_nums + [pad_num] + + if len(prompt_answer_id) > seq_text : + prompt_answer_id = prompt_answer_id[:seq_text] + + pad_num = seq_text - len(prompt_answer_id) + prompt_answer_seg_nums = [len(prompt_answer_id), pad_num] + prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num + + image_pos = len(prompt_before_image_id) + src = (prompt_answer_id, tgt_id) + seg_nums = (prompt_answer_seg_nums, tgt_seg_nums) + image = (path, image_pos) + pickle.dump((src, seg_nums, image), dataset_writer) + + if pos >= end: + break + + dataset_writer.close() From 996296e305e52fb4725318bdf14ab67df36017cc Mon Sep 17 00:00:00 2001 From: janinezhao Date: Wed, 3 Jan 2024 15:48:39 +0800 Subject: [PATCH 08/18] fix --- models/llava/7b_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json index 82622ece..5cb5e863 100755 --- a/models/llava/7b_config.json +++ b/models/llava/7b_config.json @@ -52,4 +52,4 @@ "layernorm": "rms", "layernorm_eps": 1e-5, "target": ["lm"] - } \ No newline at end of file + } From 959dc5500b9c967545922a451fe0bf91bc96d5be Mon Sep 17 00:00:00 2001 From: janinezhao Date: Wed, 3 Jan 2024 16:30:30 +0800 Subject: [PATCH 09/18] fix dataset --- tencentpretrain/utils/dataset.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 547e73fa..0566a69a 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -1089,10 +1089,11 @@ def worker(self, proc_id, start, end): class LlavaDataset(Dataset): def worker(self, proc_id, start, end): import json - num_image_tokens = self.args.vision_seq_length_in_VL + num_image_tokens = self.args.vision_seq_length_in_VL # 576 seq_text = self.seq_length - num_image_tokens PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] - + role1, role2 = "USER", "ASSISTANT" + im_start, im_end = "", "" print("Worker %d is building dataset ... " % proc_id) set_seed(self.seed) dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") @@ -1112,17 +1113,17 @@ def worker(self, proc_id, start, end): continue conversations = item["conversations"] - prompt_before_image = " USER: " + prompt_before_image = role1 + ": " prompt_answer_seg_nums, tgt_seg_nums = [], [] for i, conv in enumerate(conversations): if i == 0: prompt = conv["value"] if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("","") - prompt_after_image = "\nASSISTANT: " + prompt_before_image = prompt_before_image + prompt.replace("", im_start) + prompt_after_image = im_end + "\n" + role2 + ": " elif prompt.startswith(""): - prompt_before_image = prompt_before_image + "" - prompt_after_image = prompt.replace("","") + "\nASSISTANT: " + prompt_before_image = prompt_before_image + im_start + prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " prompt_before_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_before_image)) prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) seg_before_image = [1] * len(prompt_before_image_id) @@ -1135,7 +1136,7 @@ def worker(self, proc_id, start, end): tgt_seg_nums = [len(tgt_id)] elif i % 2 == 0: # human prompt = conv["value"] - prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" USER: " + prompt + "\nASSISTANT: ")) + prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(role1 + ": " + prompt + "\n" + role2 + ": ")) prompt_answer_id = prompt_answer_id + prompt_id tgt_id = tgt_id + [PAD_ID] * len(prompt_id) if len(tgt_seg_nums) == 1: From b255e5b35d8ad3f855b96487b14636a559571c18 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Wed, 3 Jan 2024 17:00:38 +0800 Subject: [PATCH 10/18] fix --- scripts/generate_lm_llava_deepspeed.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index 72d2c83d..c7d2e3f6 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -175,6 +175,8 @@ def load_or_initialize_parameters(args, model): "sys3": " You are a helpful language and vision assistant. \n", "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n" } + role1, role2 = "USER", "ASSISTANT" + im_start, im_end = "", "" num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 seq_text = args.seq_length - num_image_tokens outf = open(args.prediction_path, mode="w", encoding="utf-8") @@ -200,7 +202,7 @@ def load_or_initialize_parameters(args, model): print("sth wrong with item{}".format(item)) continue - prompt_before_image = prompt_overall + " USER: " + prompt_before_image = prompt_overall + " " + role1 + ": " ground_truth = [] prompt_answer_id = [] if "conversations" in item: @@ -212,14 +214,14 @@ def load_or_initialize_parameters(args, model): if i == 0: prompt = conv["value"] if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("","") - prompt_after_image = "\nASSISTANT:" + prompt_before_image = prompt_before_image + prompt.replace("", im_start) + prompt_after_image = im_end + "\n" + role2 + ": " elif prompt.startswith(""): - prompt_before_image = prompt_before_image + "" - prompt_after_image = prompt.replace("","") + "\nASSISTANT: " + prompt_before_image = prompt_before_image + im_start + prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " else: - prompt_before_image = prompt_before_image + "" - prompt_after_image = "\n" + prompt + " ASSISTANT: " + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ": " prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_before_image) @@ -237,7 +239,7 @@ def load_or_initialize_parameters(args, model): elif i % 2 == 0: # human prompt = conv["value"] prompt_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(" USER: " + prompt + " ASSISTANT: ") + args.tokenizer.tokenize(role1 + ": " + prompt + " " + role2 + ": ") ) if prompt_answer_id: prompt_answer_id.append(prompt_id) @@ -249,8 +251,8 @@ def load_or_initialize_parameters(args, model): ground_truth.append(conv["value"]) else: prompt = item["instruction"] - prompt_before_image = prompt_before_image + "" - prompt_after_image = "\n" + prompt + "\nASSISTANT: " + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + "\n" + role2 + ": " prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_before_image) ) From 845fef8e9566d37475e46bffe37c4315aa0cd059 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Thu, 4 Jan 2024 19:04:13 +0800 Subject: [PATCH 11/18] fix dataset and dataloader --- pretrain.py | 4 ++-- scripts/convert_model_add_prefix.py | 28 ++++++++++++++++++++++++++ scripts/generate_lm_llava_deepspeed.py | 6 ++---- tencentpretrain/model_loader.py | 22 ++++---------------- tencentpretrain/trainer.py | 3 +-- tencentpretrain/utils/dataloader.py | 20 +++++++++++------- tencentpretrain/utils/dataset.py | 16 +++++++-------- 7 files changed, 57 insertions(+), 42 deletions(-) create mode 100755 scripts/convert_model_add_prefix.py diff --git a/pretrain.py b/pretrain.py index e5cd5cfe..6b56dc08 100644 --- a/pretrain.py +++ b/pretrain.py @@ -37,6 +37,8 @@ def main(): help="Number of prediction labels.") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") + parser.add_argument("--seq_length", type=int, default=128, + help="Sequence length.") # Preprocess options. tokenizer_opts(parser) @@ -44,8 +46,6 @@ def main(): # Model options. model_opts(parser) - parser.add_argument("--vision_model_missing_prefix", type=str, required=False, default="embedding.vision_language.vision_", - help="Extra prefix when loading the vision pretrained model as the embedding of the whole model.") # Model parallelism options. mp_opts(parser) diff --git a/scripts/convert_model_add_prefix.py b/scripts/convert_model_add_prefix.py new file mode 100755 index 00000000..f75ca66e --- /dev/null +++ b/scripts/convert_model_add_prefix.py @@ -0,0 +1,28 @@ +import argparse +import collections +import torch + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", + help=".") + parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", + help=".") + parser.add_argument("--prefix", type=str, default="", help="prefix to add") + + + args = parser.parse_args() + + input_model = torch.load(args.input_model_path, map_location="cpu") + + output_model = collections.OrderedDict() + + for k in input_model.keys(): + output_model[args.prefix + k] = input_model[k] + + torch.save(output_model, args.output_model_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index c7d2e3f6..157b9ac3 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -96,7 +96,7 @@ def load_or_initialize_parameters(args, model): args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) - model = load_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") + model = load_model(model, args.vision_model_in_VL_emb_path) else: # Initialize with normal distribution. for n, p in list(model.named_parameters()): @@ -112,8 +112,6 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, - help="Path of the vision pretrained model in the vision language embedding.") parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"], help="The instruction type for training large language-vision model.", default="sys0") @@ -146,7 +144,7 @@ def load_or_initialize_parameters(args, model): if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) if args.vision_model_in_VL_emb_path is not None: - model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") + model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path) else: model = LLaVaGenerate(args) load_or_initialize_parameters(args, model) diff --git a/tencentpretrain/model_loader.py b/tencentpretrain/model_loader.py index 29db8ef0..a8e0b175 100644 --- a/tencentpretrain/model_loader.py +++ b/tencentpretrain/model_loader.py @@ -1,32 +1,24 @@ import os import torch -import collections from tencentpretrain import mpu -def load_model(model, model_path, lora_pretrained_model_path=None, missing_prefix=""): +def load_model(model, model_path, lora_pretrained_model_path=None): """ Load model from saved weights. """ - state_dict = torch.load(model_path, map_location="cpu") - if missing_prefix != "": - state_dict_withprefix = collections.OrderedDict() - for k in state_dict.keys(): - state_dict_withprefix[missing_prefix + k] = state_dict[k] - del state_dict - state_dict = state_dict_withprefix if hasattr(model, "module"): - model.module.load_state_dict(state_dict, strict=False) + model.module.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) if lora_pretrained_model_path is not None: model.module.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) else: - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) if lora_pretrained_model_path is not None: model.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) return model -def _load_state_dict_into_model(model_to_load, model_path, start_prefix="", missing_prefix=""): +def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it @@ -61,12 +53,6 @@ def load(module, state_dict, prefix=""): for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") - if missing_prefix != "": - state_dict_withprefix = collections.OrderedDict() - for k in state_dict.keys(): - state_dict_withprefix[missing_prefix + k] = state_dict[k] - del state_dict - state_dict = state_dict_withprefix load(model_to_load, state_dict, prefix=start_prefix) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 7301239c..105ce750 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -96,8 +96,7 @@ def init_model(args): if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path)) - model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path, missing_prefix=args.vision_model_missing_prefix) - # model_for_training = load_model(model_for_training, args.vision_model_path) + model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path) return model_for_training, model_for_dataloader diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index 03a2bcf8..8f1a4f55 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -546,6 +546,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca self.patch_size = args.patch_size self.image_height = args.image_height self.image_width = args.image_width + self.args = args from torchvision import transforms from tencentpretrain.utils.misc import ZeroOneNormalize @@ -990,7 +991,9 @@ def __iter__(self): """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode - seg_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1 + + seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + text_seq_length = self.args.seq_length - seg_image_num while True: while self._empty(): self._fill_buf() @@ -1013,14 +1016,17 @@ def __iter__(self): ins_seg_nums_src, ins_seg_nums_tgt = ins[1] ins_src_image, ins_image_pos = ins[2] - src_text.append(ins_src) - tgt.append(ins_tgt) + src_text.append(ins_src[:text_seq_length]) ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1] - ins_seg_tgt = [] + seg_text.append(ins_seg_src[:text_seq_length]) + + ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt + tgt.append(ins_tgt_new[:self.args.seq_length]) + ins_seg_tgt = [0] * seg_image_num for i, num in enumerate(ins_seg_nums_tgt): ins_seg_tgt = ins_seg_tgt + [i % 2] * num - seg_text.append(ins_seg_src) - seg_tgt.append(ins_seg_tgt) + seg_tgt.append(ins_seg_tgt[:self.args.seq_length]) + try: image = read_image(ins_src_image, ImageReadMode.RGB) except: @@ -1028,7 +1034,7 @@ def __iter__(self): continue image = image.cuda(self.local_rank) src_image.append(self.transform(image)) - seg_image.append([1] * seg_num) + seg_image.append([1] * (seg_image_num + 1)) image_pos.append(ins_image_pos) if len(src_image) == 0: continue diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 0566a69a..c2f03f78 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -1089,8 +1089,6 @@ def worker(self, proc_id, start, end): class LlavaDataset(Dataset): def worker(self, proc_id, start, end): import json - num_image_tokens = self.args.vision_seq_length_in_VL # 576 - seq_text = self.seq_length - num_image_tokens PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] role1, role2 = "USER", "ASSISTANT" im_start, im_end = "", "" @@ -1100,9 +1098,9 @@ def worker(self, proc_id, start, end): pos = start skip_item = 0 with open(self.corpus_path, mode="r", encoding="utf-8") as f: - datas = json.load(f) + data = json.load(f) while True: - item = datas[pos] + item = data[pos] pos += 1 try: path = item["image"] @@ -1128,11 +1126,11 @@ def worker(self, proc_id, start, end): prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) seg_before_image = [1] * len(prompt_before_image_id) seg_after_image = [1] * len(prompt_after_image_id) - if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + if len(prompt_before_image_id) + len(prompt_after_image_id) > self.seq_length: print("promt too long, jumped") continue prompt_answer_id = prompt_before_image_id + prompt_after_image_id - tgt_id = [PAD_ID] * (len(prompt_answer_id) + num_image_tokens - 1) + tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1) tgt_seg_nums = [len(tgt_id)] elif i % 2 == 0: # human prompt = conv["value"] @@ -1159,10 +1157,10 @@ def worker(self, proc_id, start, end): pad_num = self.seq_length - sum(tgt_seg_nums) tgt_seg_nums = tgt_seg_nums + [pad_num] - if len(prompt_answer_id) > seq_text : - prompt_answer_id = prompt_answer_id[:seq_text] + if len(prompt_answer_id) > self.seq_length : + prompt_answer_id = prompt_answer_id[:self.seq_length] - pad_num = seq_text - len(prompt_answer_id) + pad_num = self.seq_length - len(prompt_answer_id) prompt_answer_seg_nums = [len(prompt_answer_id), pad_num] prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num From c33ffb0fe4f1ca6cc7e75402b9702b665cfa7067 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Tue, 9 Jan 2024 16:29:53 +0800 Subject: [PATCH 12/18] add pad in image preprocess --- scripts/generate_lm_llava_deepspeed.py | 56 +++++++++++++++++++++----- tencentpretrain/opts.py | 2 +- tencentpretrain/utils/dataloader.py | 46 ++++++++++++++++----- 3 files changed, 83 insertions(+), 21 deletions(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index 157b9ac3..38d240c5 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -104,6 +104,22 @@ def load_or_initialize_parameters(args, model): p.data.normal_(0, 0.02) +def expand2square(img, background_color=(122, 116, 104)): + from PIL import Image + width, height = img.size + if img.mode != "RGB": + img = img.convert("RGB") + if width == height: + return img + elif width > height: + result = Image.new(img.mode, (width, width), background_color) + result.paste(img, (0, (width - height) // 2)) + return result + else: + result = Image.new(img.mode, (height, height), background_color) + result.paste(img, ((height - width) // 2, 0)) + return result + if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -114,7 +130,8 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"], help="The instruction type for training large language-vision model.", default="sys0") - + parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, + help="Path of the vision pretrained model in the vision language embedding.") tokenizer_opts(parser) deepspeed_opts(parser) @@ -123,6 +140,8 @@ def load_or_initialize_parameters(args, model): mp_opts(parser) + vision_opts(parser) + args = parser.parse_args() args.target = "lm" @@ -160,12 +179,20 @@ def load_or_initialize_parameters(args, model): image_width = args.vision_language_emb["vision_encoder"]["image_width"] patch_size = args.vision_language_emb["vision_encoder"]["patch_size"] - transform = transforms.Compose([ - transforms.Resize(min(image_height, image_width)), - transforms.CenterCrop((image_height, image_width)), - ZeroOneNormalize(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ]) + preprocess_pipeline = [] + if "corp" in args.image_preprocess: + preprocess_pipeline.append(transforms.RandomResizedCrop(max(image_height, image_width))) + elif "center_crop" in args.image_preprocess: + preprocess_pipeline.append(transforms.Resize(min(image_height, image_width))) + preprocess_pipeline.append(transforms.CenterCrop((image_height, image_width))) + if "horizontal_flip" in args.image_preprocess: + preprocess_pipeline.append(transforms.RandomHorizontalFlip()) + preprocess_pipeline.append(transforms.Resize((image_height, image_width))) + preprocess_pipeline.append(ZeroOneNormalize()) + if "normalize" in args.image_preprocess: + preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) + transform = transforms.Compose(preprocess_pipeline) + prompt_template = { "sys0": "", "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", @@ -177,7 +204,7 @@ def load_or_initialize_parameters(args, model): im_start, im_end = "", "" num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 seq_text = args.seq_length - num_image_tokens - outf = open(args.prediction_path, mode="w", encoding="utf-8") + outf = open(args.prediction_path, mode="a", encoding="utf-8") input_f = open(args.test_path, mode="r", encoding="utf-8") datas = json.load(input_f) try: @@ -193,7 +220,16 @@ def load_or_initialize_parameters(args, model): continue if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': continue - image = read_image(image_path, ImageReadMode.RGB) + if "pad" in args.image_preprocess: + from PIL import Image + import numpy as np + import torchvision.transforms.functional as transform + image = Image.open(image_path) + image = expand2square(image) + image = torch.from_numpy((np.array(image).transpose(2,0,1))) + else: + image = read_image(image_path, ImageReadMode.RGB) + image = image.to(device) src_image = transform(image) except: @@ -299,4 +335,4 @@ def load_or_initialize_parameters(args, model): generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) print(item) print(generated_sentence) - print(generated_sentence+ "\n\n", file=outf) + print(generated_sentence + "\n\n", file=outf) diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 7085bcc7..88ab02da 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -70,7 +70,7 @@ def vision_opts(parser): parser.add_argument("--channels_num", type=int, default=3, help="Channels num.") parser.add_argument("--image_preprocess", type=str, default=["crop", "normalize"], nargs='+', - help="Preprocess and data augmentation methods. Choices: [\"crop\", \"center_crop\", \"horizontal_flip\", \"normalize\"]. ") + help="Preprocess and data augmentation methods. Choices: [\"crop\" or \"center_crop\" or \"pad\", \"horizontal_flip\", \"normalize\"]. ") def audio_opts(parser): diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index 8f1a4f55..64a57049 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -556,7 +556,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width))) elif "center_crop" in args.image_preprocess: preprocess_pipeline.append(transforms.Resize(min(self.image_height, self.image_width))) - preprocess_pipeline.append(transforms.CenterCrop((self.image_height, self.image_width))) + preprocess_pipeline.append(transforms.CenterCrop((self.image_height, self.image_width))) if "horizontal_flip" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomHorizontalFlip()) preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width))) @@ -565,6 +565,22 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) self.transform = transforms.Compose(preprocess_pipeline) + def expand2square(self, img, background_color=(122, 116, 104)): + from PIL import Image + width, height = img.size + if img.mode != "RGB": + img = img.convert("RGB") + if width == height: + return img + elif width > height: + result = Image.new(img.mode, (width, width), background_color) + result.paste(img, (0, (width - height) // 2)) + return result + else: + result = Image.new(img.mode, (height, height), background_color) + result.paste(img, ((height - width) // 2, 0)) + return result + class VitDataloader(VisionDataloader): def __iter__(self): @@ -1011,11 +1027,30 @@ def __iter__(self): seg_image = [] seg_tgt = [] image_pos = [] + for ins in instances: ins_src, ins_tgt = ins[0] ins_seg_nums_src, ins_seg_nums_tgt = ins[1] ins_src_image, ins_image_pos = ins[2] + try: + if "pad" in self.args.image_preprocess: + from PIL import Image + import numpy as np + import torchvision.transforms.functional as transform + image = Image.open(ins_src_image) + image = self.expand2square(image) + image = torch.from_numpy((np.array(image).transpose(2,0,1))) + else: + image = read_image(ins_src_image, ImageReadMode.RGB) + except: + print("Something is wrong when reading {}, just skipped!".format(ins_src_image)) + continue + image = image.cuda(self.local_rank) + src_image.append(self.transform(image)) + seg_image.append([1] * (seg_image_num + 1)) + image_pos.append(ins_image_pos) + src_text.append(ins_src[:text_seq_length]) ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1] seg_text.append(ins_seg_src[:text_seq_length]) @@ -1027,15 +1062,6 @@ def __iter__(self): ins_seg_tgt = ins_seg_tgt + [i % 2] * num seg_tgt.append(ins_seg_tgt[:self.args.seq_length]) - try: - image = read_image(ins_src_image, ImageReadMode.RGB) - except: - print("Something is wrong when reading {}, just skipped!".format(ins_src_image)) - continue - image = image.cuda(self.local_rank) - src_image.append(self.transform(image)) - seg_image.append([1] * (seg_image_num + 1)) - image_pos.append(ins_image_pos) if len(src_image) == 0: continue yield torch.LongTensor(src_text), \ From 97387511f2429dcc8c32fe91c3c28b012bd49a13 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Thu, 11 Jan 2024 14:45:24 +0800 Subject: [PATCH 13/18] fix seq_length;expand2square --- scripts/generate_lm_llava_deepspeed.py | 18 +----------------- tencentpretrain/utils/dataloader.py | 26 ++++++-------------------- tencentpretrain/utils/misc.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index 38d240c5..f6d5473b 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -28,7 +28,7 @@ from tencentpretrain.opts import deepspeed_opts from tencentpretrain.utils.logging import init_logger from tencentpretrain.model_loader import _load_state_dict_into_model, load_model -from tencentpretrain.utils.misc import pooling, ZeroOneNormalize +from tencentpretrain.utils.misc import pooling, ZeroOneNormalize, expand2square class LLaVaGenerate(nn.Module): @@ -104,22 +104,6 @@ def load_or_initialize_parameters(args, model): p.data.normal_(0, 0.02) -def expand2square(img, background_color=(122, 116, 104)): - from PIL import Image - width, height = img.size - if img.mode != "RGB": - img = img.convert("RGB") - if width == height: - return img - elif width > height: - result = Image.new(img.mode, (width, width), background_color) - result.paste(img, (0, (width - height) // 2)) - return result - else: - result = Image.new(img.mode, (height, height), background_color) - result.paste(img, ((height - width) // 2, 0)) - return result - if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index 64a57049..d1fd48ef 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -565,22 +565,6 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) self.transform = transforms.Compose(preprocess_pipeline) - def expand2square(self, img, background_color=(122, 116, 104)): - from PIL import Image - width, height = img.size - if img.mode != "RGB": - img = img.convert("RGB") - if width == height: - return img - elif width > height: - result = Image.new(img.mode, (width, width), background_color) - result.paste(img, (0, (width - height) // 2)) - return result - else: - result = Image.new(img.mode, (height, height), background_color) - result.paste(img, ((height - width) // 2, 0)) - return result - class VitDataloader(VisionDataloader): def __iter__(self): @@ -1007,9 +991,9 @@ def __iter__(self): """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode + from tencentpretrain.utils.misc import expand2square seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) - text_seq_length = self.args.seq_length - seg_image_num while True: while self._empty(): self._fill_buf() @@ -1032,6 +1016,8 @@ def __iter__(self): ins_src, ins_tgt = ins[0] ins_seg_nums_src, ins_seg_nums_tgt = ins[1] ins_src_image, ins_image_pos = ins[2] + seq_length = len(ins_src) + text_seq_length = seq_length - seg_image_num try: if "pad" in self.args.image_preprocess: @@ -1039,7 +1025,7 @@ def __iter__(self): import numpy as np import torchvision.transforms.functional as transform image = Image.open(ins_src_image) - image = self.expand2square(image) + image = expand2square(image) image = torch.from_numpy((np.array(image).transpose(2,0,1))) else: image = read_image(ins_src_image, ImageReadMode.RGB) @@ -1056,11 +1042,11 @@ def __iter__(self): seg_text.append(ins_seg_src[:text_seq_length]) ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt - tgt.append(ins_tgt_new[:self.args.seq_length]) + tgt.append(ins_tgt_new[:seq_length]) ins_seg_tgt = [0] * seg_image_num for i, num in enumerate(ins_seg_nums_tgt): ins_seg_tgt = ins_seg_tgt + [i % 2] * num - seg_tgt.append(ins_seg_tgt[:self.args.seq_length]) + seg_tgt.append(ins_seg_tgt[:seq_length]) if len(src_image) == 0: continue diff --git a/tencentpretrain/utils/misc.py b/tencentpretrain/utils/misc.py index 9cf56f06..eea8541f 100644 --- a/tencentpretrain/utils/misc.py +++ b/tencentpretrain/utils/misc.py @@ -39,6 +39,24 @@ def pooling(memory_bank, seg, pooling_type): features = memory_bank[:, 0, :] return features + class ZeroOneNormalize(object): def __call__(self, img): return img.float().div(255) + + +def expand2square(img, background_color=(122, 116, 104)): + from PIL import Image + width, height = img.size + if img.mode != "RGB": + img = img.convert("RGB") + if width == height: + return img + elif width > height: + result = Image.new(img.mode, (width, width), background_color) + result.paste(img, (0, (width - height) // 2)) + return result + else: + result = Image.new(img.mode, (height, height), background_color) + result.paste(img, ((height - width) // 2, 0)) + return result From d21baebecfebfd4d468ed653233a3ff5b6fc2e26 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Thu, 11 Jan 2024 19:21:11 +0800 Subject: [PATCH 14/18] fix infer --- scripts/generate_lm_llava_deepspeed.py | 244 ++++++++++++------------- 1 file changed, 121 insertions(+), 123 deletions(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index f6d5473b..1c74442f 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -124,8 +124,6 @@ def load_or_initialize_parameters(args, model): mp_opts(parser) - vision_opts(parser) - args = parser.parse_args() args.target = "lm" @@ -188,7 +186,6 @@ def load_or_initialize_parameters(args, model): im_start, im_end = "", "" num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 seq_text = args.seq_length - num_image_tokens - outf = open(args.prediction_path, mode="a", encoding="utf-8") input_f = open(args.test_path, mode="r", encoding="utf-8") datas = json.load(input_f) try: @@ -196,127 +193,128 @@ def load_or_initialize_parameters(args, model): except: args.logger.info("unsupported prompt template!") NotImplementedError - for line_id, item in enumerate(datas): - try: - id = item["id"] - image_path = "datasets/llava/" + item["image"] - if not os.path.isfile(image_path): - continue - if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': - continue - if "pad" in args.image_preprocess: - from PIL import Image - import numpy as np - import torchvision.transforms.functional as transform - image = Image.open(image_path) - image = expand2square(image) - image = torch.from_numpy((np.array(image).transpose(2,0,1))) - else: - image = read_image(image_path, ImageReadMode.RGB) - - image = image.to(device) - src_image = transform(image) - except: - print("sth wrong with item{}".format(item)) - continue - - prompt_before_image = prompt_overall + " " + role1 + ": " - ground_truth = [] - prompt_answer_id = [] - if "conversations" in item: - conversations = item["conversations"] - for i, conv in enumerate(conversations): - # 1 round - if i > 1: + with open(args.prediction_path, mode="a", encoding="utf-8") as outf: + for line_id, item in enumerate(datas): + try: + id = item["id"] + image_path = "datasets/llava/" + item["image"] + if not os.path.isfile(image_path): continue - if i == 0: - prompt = conv["value"] - if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("", im_start) - prompt_after_image = im_end + "\n" + role2 + ": " - elif prompt.startswith(""): - prompt_before_image = prompt_before_image + im_start - prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " - else: - prompt_before_image = prompt_before_image + im_start - prompt_after_image = im_end + "\n" + prompt + " " + role2 + ": " - - prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(prompt_before_image) - ) - prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(prompt_after_image) - ) - seg_before_image = [1] * len(prompt_before_image_id) - seg_after_image = [1] * len(prompt_after_image_id) - if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: - args.logger.info("promt too long, jump for now") - break - prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] - prompt_answer_seg = [seg_before_image + seg_after_image] - elif i % 2 == 0: # human - prompt = conv["value"] - prompt_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(role1 + ": " + prompt + " " + role2 + ": ") - ) - if prompt_answer_id: - prompt_answer_id.append(prompt_id) - prompt_answer_seg.append(prompt_answer_seg + [1] * len(prompt_id)) - else: - args.logger.info("no prompt, or prompt too long, jumping") - break - else: # gpt - ground_truth.append(conv["value"]) - else: - prompt = item["instruction"] - prompt_before_image = prompt_before_image + im_start - prompt_after_image = im_end + "\n" + prompt + "\n" + role2 + ": " - prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(prompt_before_image) - ) - prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(prompt_after_image) - ) - seg_before_image = [1] * len(prompt_before_image_id) - seg_after_image = [1] * len(prompt_after_image_id) - if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: - args.logger.info("promt too long, jump for now") - break - prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] - prompt_answer_seg = [seg_before_image + seg_after_image] - - image_pos = len(prompt_before_image_id) - - image_tensor = torch.unsqueeze(src_image, 0).half() - image_seg_tensor = torch.ones(1, num_image_tokens).to(device) - image_pos = torch.LongTensor([image_pos]).to(device) - SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) - text_tensor = None - - for i, prompt in enumerate(prompt_answer_id): - if text_tensor is None: - text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([prompt_answer_seg[i]]).to(device) + if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': + continue + if "pad" in args.image_preprocess: + from PIL import Image + import numpy as np + import torchvision.transforms.functional as transform + image = Image.open(image_path) + image = expand2square(image) + image = torch.from_numpy((np.array(image).transpose(2,0,1))) + else: + image = read_image(image_path, ImageReadMode.RGB) + + image = image.to(device) + src_image = transform(image) + except: + print("sth wrong with item{}".format(item)) + continue + + prompt_before_image = prompt_overall + " " + role1 + ": " + ground_truth = [] + prompt_answer_id = [] + if "conversations" in item: + conversations = item["conversations"] + for i, conv in enumerate(conversations): + # 1 round + if i > 1: + continue + if i == 0: + prompt = conv["value"] + if prompt.endswith(""): + prompt_before_image = prompt_before_image + prompt.replace("", im_start) + prompt_after_image = im_end + "\n" + role2 + ": " + elif prompt.startswith(""): + prompt_before_image = prompt_before_image + im_start + prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " + else: + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ": " + + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") + break + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] + elif i % 2 == 0: # human + prompt = conv["value"] + prompt_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(role1 + ": " + prompt + " " + role2 + ": ") + ) + if prompt_answer_id: + prompt_answer_id.append(prompt_id) + prompt_answer_seg.append(prompt_answer_seg + [1] * len(prompt_id)) + else: + args.logger.info("no prompt, or prompt too long, jumping") + break + else: # gpt + ground_truth.append(conv["value"]) else: - text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1) - text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([prompt_answer_seg[i]]).to(device)], dim=1) - - while text_tensor.shape[1] + num_image_tokens <= args.seq_length: - output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos) - next_token_logits = output[0][-1] / args.temperature - filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - - text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1) - text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1) - if next_token.cpu().tolist() == SEP_ID: + prompt = item["instruction"] + prompt_before_image = prompt_before_image + im_start + prompt_after_image = im_end + "\n" + prompt + "\n" + role2 + ": " + prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_before_image) + ) + prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( + args.tokenizer.tokenize(prompt_after_image) + ) + seg_before_image = [1] * len(prompt_before_image_id) + seg_after_image = [1] * len(prompt_after_image_id) + if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + args.logger.info("promt too long, jump for now") break + prompt_answer_id = [prompt_before_image_id + prompt_after_image_id] + prompt_answer_seg = [seg_before_image + seg_after_image] + + image_pos = len(prompt_before_image_id) + + image_tensor = torch.unsqueeze(src_image, 0).half() + image_seg_tensor = torch.ones(1, num_image_tokens).to(device) + image_pos = torch.LongTensor([image_pos]).to(device) + SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) + text_tensor = None + for i, prompt in enumerate(prompt_answer_id): + if text_tensor is None: + text_tensor, text_seg_tensor = torch.LongTensor([prompt]).to(device), torch.LongTensor([prompt_answer_seg[i]]).to(device) + else: + text_tensor = torch.cat([text_tensor, torch.LongTensor([prompt]).to(device)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.LongTensor([prompt_answer_seg[i]]).to(device)], dim=1) + + while text_tensor.shape[1] + num_image_tokens <= args.seq_length: + output = model(text_tensor, text_seg_tensor, image_tensor, image_seg_tensor, image_pos) + next_token_logits = output[0][-1] / args.temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + + text_tensor = torch.cat([text_tensor, next_token.view(1, 1)], dim=1) + text_seg_tensor = torch.cat([text_seg_tensor, torch.tensor([[1]]).to(device)], dim=1) + if next_token.cpu().tolist() == SEP_ID: + break + if rank == 0 and text_tensor is not None: + tokens = [token_id.item() for token_id in text_tensor[0]] + if args.tokenizer.sp_model is not None: + generated_sentence = args.tokenizer.sp_model.decode(tokens) + else: + generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) + print(item) + print(generated_sentence) + print(item + "\n", file=outf) + print(generated_sentence + "\n\n", file=outf) - if rank == 0 and text_tensor is not None: - tokens = [token_id.item() for token_id in text_tensor[0]] - if args.tokenizer.sp_model is not None: - generated_sentence = args.tokenizer.sp_model.decode(tokens) - else: - generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) - print(item) - print(generated_sentence) - print(generated_sentence + "\n\n", file=outf) From 004f109d3a7dcc10a30d33e8e88806b3c5ffa3e1 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Mon, 15 Jan 2024 10:37:40 +0800 Subject: [PATCH 15/18] fix print --- scripts/generate_lm_llava_deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index 1c74442f..8236c82b 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -315,6 +315,6 @@ def load_or_initialize_parameters(args, model): generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) print(item) print(generated_sentence) - print(item + "\n", file=outf) + print("\t".join(item.values()) + "\n", file=outf) print(generated_sentence + "\n\n", file=outf) From b4419c6ccc193b65f3176946f21b4d668aef9fda Mon Sep 17 00:00:00 2001 From: janinezhao Date: Mon, 11 Mar 2024 17:18:08 +0800 Subject: [PATCH 16/18] fix data form and vision features --- models/llava/7b_config.json | 7 +- pretrain.py | 2 - scripts/generate_lm_llava_deepspeed.py | 98 +++++++++++++------ .../embeddings/vision_language_embedding.py | 2 +- tencentpretrain/utils/__init__.py | 17 ++-- tencentpretrain/utils/act_fun.py | 6 ++ tencentpretrain/utils/dataset.py | 42 +++++--- 7 files changed, 116 insertions(+), 58 deletions(-) diff --git a/models/llava/7b_config.json b/models/llava/7b_config.json index 5cb5e863..12d01a02 100755 --- a/models/llava/7b_config.json +++ b/models/llava/7b_config.json @@ -8,10 +8,10 @@ "emb_size": 1024, "feedforward_size": 4096, "hidden_size": 1024, - "hidden_act": "gelu_fast", + "hidden_act": "gelu_quick", "heads_num": 16, "layers_num": 24, - "dropout": 0.1, + "dropout": 0.0, "max_seq_length": 577, "embedding": ["patch", "pos"], "patch_proj_bias": false, @@ -22,7 +22,7 @@ "feed_forward": "dense", "mask": "fully_visible", "layernorm_positioning": "pre", - "layernorm":"normal", + "layernorm": "normal", "has_cls": false }, "projection":{ @@ -50,6 +50,5 @@ "mask": "causal", "layernorm_positioning": "pre", "layernorm": "rms", - "layernorm_eps": 1e-5, "target": ["lm"] } diff --git a/pretrain.py b/pretrain.py index 6b56dc08..8a667160 100644 --- a/pretrain.py +++ b/pretrain.py @@ -37,8 +37,6 @@ def main(): help="Number of prediction labels.") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") - parser.add_argument("--seq_length", type=int, default=128, - help="Sequence length.") # Preprocess options. tokenizer_opts(parser) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index 8236c82b..eaf22b62 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -14,6 +14,9 @@ from torchvision.io.image import ImageReadMode import imghdr import deepspeed +import numpy as np +from PIL import Image +import torchvision.transforms.functional as transform tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) @@ -24,7 +27,7 @@ from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam -from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts +from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts, lora_opts from tencentpretrain.opts import deepspeed_opts from tencentpretrain.utils.logging import init_logger from tencentpretrain.model_loader import _load_state_dict_into_model, load_model @@ -35,11 +38,18 @@ class LLaVaGenerate(nn.Module): def __init__(self, args): super(LLaVaGenerate, self).__init__() self.args = args + lora_params = args.lora_params + if "embedding" in args.use_lora_except: + args.lora_params = None self.embedding = Embedding(args) for embedding_name in args.embedding: tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) self.embedding.update(tmp_emb, embedding_name) + if "encoder" in args.use_lora_except: + args.lora_params = None + else: + args.lora_params = lora_params self.encoder = str2encoder[args.encoder](args) self.pooling_type = args.pooling @@ -94,6 +104,9 @@ def load_or_initialize_parameters(args, model): keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) + if args.lora_pretrained_model_path is not None: + args.logger.info("loading model from {0}".format(args.lora_pretrained_model_path)) + model = load_model(model, args.lora_pretrained_model_path) if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) model = load_model(model, args.vision_model_in_VL_emb_path) @@ -112,7 +125,7 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"], + parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4", "sys5"], help="The instruction type for training large language-vision model.", default="sys0") parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, help="Path of the vision pretrained model in the vision language embedding.") @@ -124,6 +137,8 @@ def load_or_initialize_parameters(args, model): mp_opts(parser) + lora_opts(parser) + args = parser.parse_args() args.target = "lm" @@ -137,6 +152,16 @@ def load_or_initialize_parameters(args, model): args.pretrained_model_path = args.load_model_path + # construct lora dict parameters. + if args.use_lora: + args.lora_params = { + "lora_r": args.lora_r, + "lora_alpha": args.lora_alpha, + "lora_dropout": args.lora_dropout + } + else: + args.lora_params = None + # Load or initialize parameters. if args.enable_zero3: print("enable_zero3:", args.enable_zero3) @@ -144,6 +169,8 @@ def load_or_initialize_parameters(args, model): model = LLaVaGenerate(args) if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) + if args.lora_pretrained_model_path is not None: + model = _load_state_dict_into_model(model, args.lora_pretrained_model_path) if args.vision_model_in_VL_emb_path is not None: model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path) else: @@ -180,10 +207,15 @@ def load_or_initialize_parameters(args, model): "sys1": "<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", "sys2": "<>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<>\n\n", "sys3": " You are a helpful language and vision assistant. \n", - "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n" + "sys4": "[INST]<>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<>\n\n", + "sys5": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" } - role1, role2 = "USER", "ASSISTANT" - im_start, im_end = "", "" + if args.instruction_template == "sys0": + role1, role2 = "### Instruction", "### Output" + else: + role1, role2 = "USER", "ASSISTANT" + + im_start, im_end = " ", "" num_image_tokens = int(image_width / patch_size) * int(image_height / patch_size) + 1 # 336/14-14 --> 576 dim + 1 seq_text = args.seq_length - num_image_tokens input_f = open(args.test_path, mode="r", encoding="utf-8") @@ -193,32 +225,29 @@ def load_or_initialize_parameters(args, model): except: args.logger.info("unsupported prompt template!") NotImplementedError - with open(args.prediction_path, mode="a", encoding="utf-8") as outf: + + with open(args.prediction_path, mode="w", encoding="utf-8") as outf: for line_id, item in enumerate(datas): try: - id = item["id"] - image_path = "datasets/llava/" + item["image"] + if "datasets" not in item["image"]: + image_path = "datasets/llava/" + item["image"] + else: + image_path = item["image"] if not os.path.isfile(image_path): continue if imghdr.what(image_path) != 'jpeg' and imghdr.what(image_path) != 'png': continue + image = Image.open(image_path) if "pad" in args.image_preprocess: - from PIL import Image - import numpy as np - import torchvision.transforms.functional as transform - image = Image.open(image_path) image = expand2square(image) - image = torch.from_numpy((np.array(image).transpose(2,0,1))) - else: - image = read_image(image_path, ImageReadMode.RGB) - + image = torch.from_numpy((np.array(image).transpose(2,0,1))) image = image.to(device) src_image = transform(image) except: print("sth wrong with item{}".format(item)) continue - prompt_before_image = prompt_overall + " " + role1 + ": " + prompt_before_image = prompt_overall + role1 + ": " ground_truth = [] prompt_answer_id = [] if "conversations" in item: @@ -228,19 +257,20 @@ def load_or_initialize_parameters(args, model): if i > 1: continue if i == 0: - prompt = conv["value"] - if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("", im_start) - prompt_after_image = im_end + "\n" + role2 + ": " - elif prompt.startswith(""): - prompt_before_image = prompt_before_image + im_start - prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + if "" in prompt: + before_image, after_image = prompt.split("") + prompt_before_image = prompt_before_image + before_image + im_start + prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":" else: prompt_before_image = prompt_before_image + im_start - prompt_after_image = im_end + "\n" + prompt + " " + role2 + ": " + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(prompt_before_image) + [CLS_TOKEN] + args.tokenizer.tokenize(prompt_before_image) ) prompt_after_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_after_image) @@ -255,7 +285,7 @@ def load_or_initialize_parameters(args, model): elif i % 2 == 0: # human prompt = conv["value"] prompt_id = args.tokenizer.convert_tokens_to_ids( - args.tokenizer.tokenize(role1 + ": " + prompt + " " + role2 + ": ") + args.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":") ) if prompt_answer_id: prompt_answer_id.append(prompt_id) @@ -264,11 +294,15 @@ def load_or_initialize_parameters(args, model): args.logger.info("no prompt, or prompt too long, jumping") break else: # gpt - ground_truth.append(conv["value"]) + if isinstance(conv, str): + answer = conv + else: + answer = conv["value"] + ground_truth.append(answer) else: prompt = item["instruction"] prompt_before_image = prompt_before_image + im_start - prompt_after_image = im_end + "\n" + prompt + "\n" + role2 + ": " + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" prompt_before_image_id = args.tokenizer.convert_tokens_to_ids( args.tokenizer.tokenize(prompt_before_image) ) @@ -285,7 +319,8 @@ def load_or_initialize_parameters(args, model): image_pos = len(prompt_before_image_id) - image_tensor = torch.unsqueeze(src_image, 0).half() + # image_tensor = torch.unsqueeze(src_image, 0).half() + image_tensor = torch.unsqueeze(src_image, 0).bfloat16() image_seg_tensor = torch.ones(1, num_image_tokens).to(device) image_pos = torch.LongTensor([image_pos]).to(device) SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) @@ -315,6 +350,7 @@ def load_or_initialize_parameters(args, model): generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) print(item) print(generated_sentence) - print("\t".join(item.values()) + "\n", file=outf) + print(item, file=outf) + print("\n", file=outf) print(generated_sentence + "\n\n", file=outf) diff --git a/tencentpretrain/embeddings/vision_language_embedding.py b/tencentpretrain/embeddings/vision_language_embedding.py index 9b7c6cd1..d6d9117a 100755 --- a/tencentpretrain/embeddings/vision_language_embedding.py +++ b/tencentpretrain/embeddings/vision_language_embedding.py @@ -52,7 +52,7 @@ def forward(self, src, seg=None): # image features with torch.no_grad(): image_emb = self.vision_embedding(src_image, seg_image) - image_emb = self.vision_encoder(image_emb, seg_image)[:,1:,:] + image_emb = self.vision_encoder(image_emb, seg_image, output_layer=-2)[:,1:,:] image_emb = self.projection(image_emb) # text embedding text_emb = self.text_embedding(src_text, seg_text) diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py index efa947e4..385ddba1 100644 --- a/tencentpretrain/utils/__init__.py +++ b/tencentpretrain/utils/__init__.py @@ -13,15 +13,17 @@ "t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset, "cls": ClsDataset, "prefixlm": PrefixlmDataset, "cls_mlm": ClsMlmDataset, "vit": VitDataset, "vilt": ViltDataset, "clip": ClipDataset, "s2t": S2tDataset, - "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset, "llm_pretrain": LlmPretrainDataset} + "beit":BeitDataset, "dalle": DalleDataset, "llm_sft": LlmSftDataset, + "llm_pretrain": LlmPretrainDataset, "llava": LlavaDataset} str2dataloader = {"bert": BertDataloader, "lm": LmDataloader, "mlm": MlmDataloader, "bilm": BilmDataloader, "albert": AlbertDataloader, "mt": MtDataloader, "t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader, "cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader, - "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader, - "beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader} + "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, + "s2t": S2tDataloader, "beit":BeitDataloader, "dalle": DalleDataloader, + "llm_sft": LlmSftDataloader, "llava":LlavaDataloader} -str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear} +str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear, "gelu_quick": gelu_quick} str2optimizer = {"adamw": AdamW, "adafactor": Adafactor} @@ -38,11 +40,12 @@ "BertDataset", "LmDataset", "MlmDataset", "BilmDataset", "AlbertDataset", "MtDataset", "T5Dataset", "GsgDataset", "BartDataset", "ClsDataset", "PrefixlmDataset", "ClsMlmDataset", - "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "str2dataset", - "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader", + "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "DalleDataset", "LlmSftDataset", "LlavaDataset", + "str2dataset", "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader", "AlbertDataloader", "MtDataloader", "T5Dataloader", "GsgDataloader", "BartDataloader", "ClsDataloader", "PrefixlmDataloader", "ClsMlmDataloader", - "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader", "LlmSftDataloader", "str2dataloader", + "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "DalleDataloader", + "LlmSftDataloader", "LlavaDataloader", "str2dataloader", "gelu", "gelu_fast", "relu", "silu", "linear", "str2act", "AdamW", "Adafactor", "str2optimizer", "get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup", diff --git a/tencentpretrain/utils/act_fun.py b/tencentpretrain/utils/act_fun.py index be306a6b..125607c2 100644 --- a/tencentpretrain/utils/act_fun.py +++ b/tencentpretrain/utils/act_fun.py @@ -10,6 +10,12 @@ def gelu(x): def gelu_fast(x): return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) +def gelu_quick(x): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + return x * torch.sigmoid(1.702 * x) + def relu(x): return F.relu(x) diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index c2f03f78..ea927b2e 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -1091,7 +1091,7 @@ def worker(self, proc_id, start, end): import json PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] role1, role2 = "USER", "ASSISTANT" - im_start, im_end = "", "" + im_start, im_end = " ", "" print("Worker %d is building dataset ... " % proc_id) set_seed(self.seed) dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") @@ -1110,19 +1110,29 @@ def worker(self, proc_id, start, end): skip_item += 1 continue conversations = item["conversations"] - - prompt_before_image = role1 + ": " + if "instruction" in item.keys(): + inst = item["instruction"] + else: + inst = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + if inst: + prompt_before_image = inst + " " + role1 + ":" + else: + prompt_before_image = role1 + ":" prompt_answer_seg_nums, tgt_seg_nums = [], [] for i, conv in enumerate(conversations): if i == 0: - prompt = conv["value"] - if prompt.endswith(""): - prompt_before_image = prompt_before_image + prompt.replace("", im_start) - prompt_after_image = im_end + "\n" + role2 + ": " - elif prompt.startswith(""): + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + if "" in prompt: + before_image, after_image = prompt.split("") + prompt_before_image = prompt_before_image + before_image + im_start + prompt_after_image = im_end + "\n" + after_image + " " + role2 + ":" + else: prompt_before_image = prompt_before_image + im_start - prompt_after_image = prompt.replace("", im_end) + "\n" + role2 + ": " - prompt_before_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_before_image)) + prompt_after_image = im_end + "\n" + prompt + " " + role2 + ":" + prompt_before_image_id = self.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + self.tokenizer.tokenize(prompt_before_image)) prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) seg_before_image = [1] * len(prompt_before_image_id) seg_after_image = [1] * len(prompt_after_image_id) @@ -1133,8 +1143,11 @@ def worker(self, proc_id, start, end): tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1) tgt_seg_nums = [len(tgt_id)] elif i % 2 == 0: # human - prompt = conv["value"] - prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(role1 + ": " + prompt + "\n" + role2 + ": ")) + if isinstance(conv, str): + prompt = conv + else: + prompt = conv["value"] + prompt_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(role1 + ":" + prompt + " " + role2 + ":")) prompt_answer_id = prompt_answer_id + prompt_id tgt_id = tgt_id + [PAD_ID] * len(prompt_id) if len(tgt_seg_nums) == 1: @@ -1142,7 +1155,10 @@ def worker(self, proc_id, start, end): else: tgt_seg_nums = tgt_seg_nums + [len(prompt_id)] else: # gpt - answer = conv["value"] + if isinstance(conv, str): + answer = conv + else: + answer = conv["value"] answer_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer) + [SEP_TOKEN]) prompt_answer_id = prompt_answer_id + answer_id tgt_id = tgt_id + answer_id From 018318f1789de056fc8362fef287668d7ad19816 Mon Sep 17 00:00:00 2001 From: janinezhao Date: Wed, 13 Mar 2024 12:08:35 +0800 Subject: [PATCH 17/18] update convert script and transformer_encoder --- scripts/convert_llm_in_llava.py | 30 +++++++++++++++++++ .../encoders/transformer_encoder.py | 6 +++- 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100755 scripts/convert_llm_in_llava.py diff --git a/scripts/convert_llm_in_llava.py b/scripts/convert_llm_in_llava.py new file mode 100755 index 00000000..652bcb26 --- /dev/null +++ b/scripts/convert_llm_in_llava.py @@ -0,0 +1,30 @@ +import argparse +import collections +import torch + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", + help=".") + parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", + help=".") + + + args = parser.parse_args() + + input_model = torch.load(args.input_model_path, map_location="cpu") + + output_model = collections.OrderedDict() + + for k in input_model.keys(): + if k == "embedding.word.embedding.weight": + output_model["embedding.vision_language.text_embedding.word.embedding.weight"] = input_model[k] + else: + output_model[k] = input_model[k] + + torch.save(output_model, args.output_model_path) + + +if __name__ == "__main__": + main() diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index f3dd6531..309bf5c2 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -63,7 +63,7 @@ def __init__(self, args): self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) - def forward(self, emb, seg): + def forward(self, emb, seg, output_layer=-1): """ Args: emb: [batch_size x seq_length x emb_size] @@ -129,12 +129,16 @@ def custom_forward(*inputs): while l < self.layers_num: inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs) l += self.deepspeed_checkpoint_layers_num + if output_layer != -1 and l == self.layers_num + output_layer: + return inputs[0] else: for i in range(self.layers_num): if self.parameter_sharing: inputs = self.transformer(inputs) else: inputs = self.transformer[i](inputs) + if output_layer != -1 and i == self.layers_num + output_layer: + return inputs[0] hidden = inputs[0] From e99c9d3958024cb78d3eb4c3736c78a86020ba6b Mon Sep 17 00:00:00 2001 From: janinezhao Date: Fri, 12 Apr 2024 14:12:09 +0800 Subject: [PATCH 18/18] fix infer script --- scripts/generate_lm_llava_deepspeed.py | 27 ++------------------------ 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index eaf22b62..f4d960ed 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -27,7 +27,7 @@ from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam -from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts, lora_opts +from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts, mp_opts from tencentpretrain.opts import deepspeed_opts from tencentpretrain.utils.logging import init_logger from tencentpretrain.model_loader import _load_state_dict_into_model, load_model @@ -38,18 +38,11 @@ class LLaVaGenerate(nn.Module): def __init__(self, args): super(LLaVaGenerate, self).__init__() self.args = args - lora_params = args.lora_params - if "embedding" in args.use_lora_except: - args.lora_params = None self.embedding = Embedding(args) for embedding_name in args.embedding: tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) self.embedding.update(tmp_emb, embedding_name) - if "encoder" in args.use_lora_except: - args.lora_params = None - else: - args.lora_params = lora_params self.encoder = str2encoder[args.encoder](args) self.pooling_type = args.pooling @@ -104,9 +97,7 @@ def load_or_initialize_parameters(args, model): keys_info = model.load_state_dict(torch.load(args.pretrained_model_path, map_location="cpu"), strict=False) args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) - if args.lora_pretrained_model_path is not None: - args.logger.info("loading model from {0}".format(args.lora_pretrained_model_path)) - model = load_model(model, args.lora_pretrained_model_path) + if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) model = load_model(model, args.vision_model_in_VL_emb_path) @@ -137,8 +128,6 @@ def load_or_initialize_parameters(args, model): mp_opts(parser) - lora_opts(parser) - args = parser.parse_args() args.target = "lm" @@ -152,16 +141,6 @@ def load_or_initialize_parameters(args, model): args.pretrained_model_path = args.load_model_path - # construct lora dict parameters. - if args.use_lora: - args.lora_params = { - "lora_r": args.lora_r, - "lora_alpha": args.lora_alpha, - "lora_dropout": args.lora_dropout - } - else: - args.lora_params = None - # Load or initialize parameters. if args.enable_zero3: print("enable_zero3:", args.enable_zero3) @@ -169,8 +148,6 @@ def load_or_initialize_parameters(args, model): model = LLaVaGenerate(args) if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) - if args.lora_pretrained_model_path is not None: - model = _load_state_dict_into_model(model, args.lora_pretrained_model_path) if args.vision_model_in_VL_emb_path is not None: model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path) else: