mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 18:49:23 +01:00
fix checkpoint, same device bug (#427)
This commit is contained in:
@@ -7,7 +7,7 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Counter, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -152,10 +152,9 @@ def print_example(example: Dict[str, Any], tokenizer) -> None:
|
||||
print(f'[INPUT_IDS] {input_ids}')
|
||||
print(f'[INPUT] {tokenizer.decode(input_ids)}')
|
||||
print()
|
||||
n_mask = Counter(labels)[-100]
|
||||
print(f'[LABLES_IDS] {labels}')
|
||||
print(
|
||||
f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
|
||||
)
|
||||
print(f'[LABLES] <-100 * {n_mask}>{tokenizer.decode(labels[n_mask:])}')
|
||||
|
||||
|
||||
def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, Any]:
|
||||
@@ -198,10 +197,10 @@ def print_model_info(model: Module, name: Optional[str] = None) -> None:
|
||||
logger.info(''.join(s))
|
||||
|
||||
|
||||
def show_freeze_layers(model: Module, max_lines: int = 20) -> None:
|
||||
def show_freeze_layers(model: Module, max_lines: Optional[int] = 20) -> None:
|
||||
named_p = list(model.named_parameters())
|
||||
for i, (n, p) in enumerate(named_p):
|
||||
if i >= max_lines:
|
||||
if max_lines is not None and i >= max_lines:
|
||||
logger.info('...')
|
||||
break
|
||||
logger.info(f'{n}: requires_grad={p.requires_grad}')
|
||||
|
||||
Reference in New Issue
Block a user