|
|
|
|
|
by swsdsailor
1169 days ago
|
|
I got it to work with MPS by having pytorch with mps support and then editing the cli.py file to allow the use of mps: Allow passing in --device="mps":
ie: choices=["cuda", "cpu", "mps"] Set kwargs:
kwargs = {
"torch_dtype": torch.float16
} then adding to("mps") on line 98:
model = AutoModelForCausalLM.from_pretrained(model_name,
low_cpu_mem_usage=True, *kwargs).to('mps') commenting out:
raise ValueError(f"Invalid device: {args.device}") and changing cuda to mps on line 80:
if args.device == "mps": I'm not sure it's working correctly but at least it's a step. It's told me how to catch a duck but it often falls into some "renewable energy" sequence. :D |
|