Skip to content

Commit

Permalink
Add auth token (#19)
Browse files Browse the repository at this point in the history
* add auth token

add CLI argument and a little parsing. This allows using private models from the huggingface hub.

* better parsing

* linter fix

Co-authored-by: pommedeterresautee <pommedeterresautee@msn.com>
  • Loading branch information
sam-writer and pommedeterresautee authored Dec 8, 2021
1 parent 98c6903 commit e1b2f38
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/transformer_deploy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def main():
description="optimize and deploy transformers", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-m", "--model", required=True, help="path to model or URL to Hugging Face Hub")
parser.add_argument(
"--auth-token",
default=None,
help=(
"HuggingFace Hub auth token. Set to `None` (default) for public models. "
"For private models, use `True` to use local cached token, or a string of your HF API token"
),
)
parser.add_argument(
"-b",
"--batch-size",
Expand Down Expand Up @@ -86,17 +94,26 @@ def main():

torch.manual_seed(args.seed)

if isinstance(args.auth_token, str) and args.auth_token.lower() in ["true", "t"]:
auth_token = True
elif isinstance(args.auth_token, str):
auth_token = args.auth_token
else:
auth_token = None

Path(args.output).mkdir(parents=True, exist_ok=True)
onnx_model_path = os.path.join(args.output, "model-original.onnx")
onnx_optim_fp16_path = os.path.join(args.output, "model.onnx")
tensorrt_path = os.path.join(args.output, "model.plan")

assert torch.cuda.is_available(), "CUDA is not available. Please check your CUDA installation"
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.model, use_auth_token=auth_token)
input_names: List[str] = tokenizer.model_input_names
logging.info(f"axis: {input_names}")
include_token_ids = "token_type_ids" in input_names
model_pytorch: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(args.model)
model_pytorch: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.model, use_auth_token=auth_token
)
model_pytorch.cuda()
model_pytorch.eval()

Expand Down

0 comments on commit e1b2f38

Please sign in to comment.