fix checkpoint, same device bug (#427)

This commit is contained in:
Jintao
2023-07-29 00:06:27 +08:00
committed by GitHub
parent 972298813b
commit 312b63fe06
8 changed files with 48 additions and 54 deletions

View File

@@ -7,7 +7,7 @@ import sys
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Counter, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
@@ -152,10 +152,9 @@ def print_example(example: Dict[str, Any], tokenizer) -> None:
print(f'[INPUT_IDS] {input_ids}')
print(f'[INPUT] {tokenizer.decode(input_ids)}')
print()
n_mask = Counter(labels)[-100]
print(f'[LABLES_IDS] {labels}')
print(
f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
)
print(f'[LABLES] <-100 * {n_mask}>{tokenizer.decode(labels[n_mask:])}')
def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, Any]:
@@ -198,10 +197,10 @@ def print_model_info(model: Module, name: Optional[str] = None) -> None:
logger.info(''.join(s))
def show_freeze_layers(model: Module, max_lines: int = 20) -> None:
def show_freeze_layers(model: Module, max_lines: Optional[int] = 20) -> None:
named_p = list(model.named_parameters())
for i, (n, p) in enumerate(named_p):
if i >= max_lines:
if max_lines is not None and i >= max_lines:
logger.info('...')
break
logger.info(f'{n}: requires_grad={p.requires_grad}')