mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
remove requirements to avoid a low version of timm
This commit is contained in:
@@ -4,11 +4,14 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope import get_logger
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def normalize_fn(tensor, mean, std):
|
||||
"""Differentiable version of torchvision.functional.normalize"""
|
||||
@@ -41,10 +44,15 @@ class NormalizeByChannelMeanStd(nn.Module):
|
||||
class EasyRobustModel(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
try:
|
||||
import easyrobust.models
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'You are using `EasyRobustModel`, but this model requires `easyrobust`,'
|
||||
'please install it with command `pip install easyrobust`')
|
||||
raise e
|
||||
from timm.models import create_model
|
||||
from mmcls.datasets import ImageNet
|
||||
import modelscope.models.cv.image_classification.backbones
|
||||
from modelscope.utils.hub import read_config
|
||||
|
||||
super().__init__(model_dir)
|
||||
|
||||
@@ -8,7 +8,6 @@ control_ldm
|
||||
ddpm_guided_diffusion
|
||||
diffusers
|
||||
easydict
|
||||
easyrobust
|
||||
edit_distance
|
||||
face_alignment>=1.3.5
|
||||
fairscale>=0.4.1
|
||||
|
||||
Reference in New Issue
Block a user