| from .data import * |
| from .model import * |
|
|
| |
| train_corpus = "webvid_cc3m" |
| train_file = "${available_corpus[${train_corpus}]}" |
| test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"]) |
| test_types = ["msrvtt_1k_test"] |
| num_workers = 6 |
|
|
| stop_key = None |
|
|
| |
| num_frames = 4 |
| num_frames_test = 4 |
| batch_size = 64 |
| max_txt_l = 32 |
|
|
| inputs = dict( |
| image_res=224, |
| video_input=dict( |
| num_frames="${num_frames}", |
| sample_type="rand", |
| num_frames_test="${num_frames_test}", |
| sample_type_test="middle", |
| random_aug=False, |
| ), |
| max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), |
| batch_size=dict(image="${batch_size}", video="${batch_size}"), |
| batch_size_test=dict(image="${batch_size}", video="${batch_size}"), |
| ) |
|
|
| |
| vision_enc = "beit" |
| text_enc = "bert" |
| model = dict( |
| vision_encoder="${VisionEncoders[${vision_enc}]}", |
| text_encoder="${TextEncoders[${text_enc}]}", |
| temporal_modeling=dict( |
| num_frames="${num_frames}", |
| temporal_model_block="timesformer", |
| temporal_model_position="last", |
| temporal_model_config=dict(input_dim="${model.vision_encoder.d_model}"), |
| use_temporal_position_embedding=True, |
| ), |
| vit_add_ln=True, |
| multimodal=dict(enable=True), |
| embed_dim=256, |
| temp=0.07, |
| ) |
|
|
| criterion = dict( |
| loss_weight=dict(vtc=1.0, mlm=1.0, vtm=1.0, mvm=0.0), |
| vtm_hard_neg=True, |
| mlm_masking_prob=0.5, |
| ) |
|
|
| optimizer = dict( |
| opt="adamW", |
| lr=1e-4, |
| opt_betas=[0.9, 0.999], |
| weight_decay=0.02, |
| max_grad_norm=-1, |
| |
| different_lr=dict(enable=False, module_names=[], lr=1e-3), |
| ) |
|
|
| scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=1) |
|
|
| evaluate = False |
| deep_fusion = False |
| evaluation = dict( |
| eval_frame_ensemble="concat", |
| eval_x_only=False, |
| k_test=128, |
| eval_offload=True, |
| ) |
|
|
| fp16 = True |
| gradient_checkpointing = True |
|
|
| |
| wandb = dict( |
| enable=True, |
| entity="likunchang", |
| project="vindlu", |
| ) |
| dist_url = "env://" |
| device = "cuda" |
| mode = "pt" |
|
|
| |
| output_dir = None |
| resume = False |
| debug = False |
| log_freq = 100 |
| seed = 42 |
|
|
| save_latest = True |
| auto_resume = True |
| pretrained_path = "" |
|
|