ci : switch from pyright to ty (#20826)
* type fixes * switch to ty * tweak rules * tweak more rules * more tweaks * final tweak * use common import-not-found rule
This commit is contained in:
@@ -199,10 +199,13 @@ class LoraTorchTensor:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.permute:
|
||||
assert len(args)
|
||||
return type(args[0]).permute(*args, **kwargs)
|
||||
elif func is torch.reshape:
|
||||
assert len(args)
|
||||
return type(args[0]).reshape(*args, **kwargs)
|
||||
elif func is torch.stack:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
@@ -211,6 +214,7 @@ class LoraTorchTensor:
|
||||
torch.stack([b._lora_B for b in args[0]], dim),
|
||||
)
|
||||
elif func is torch.cat:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
@@ -362,7 +366,7 @@ if __name__ == '__main__':
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
class LoraModel(model_class):
|
||||
class LoraModel(model_class): # ty: ignore[unsupported-base]
|
||||
model_arch = model_class.model_arch
|
||||
|
||||
lora_alpha: float
|
||||
|
||||
Reference in New Issue
Block a user