diff --git a/modules/cli/src/main/scala/tfr/Cli.scala b/modules/cli/src/main/scala/tfr/Cli.scala index f4c2212..80ad653 100644 --- a/modules/cli/src/main/scala/tfr/Cli.scala +++ b/modules/cli/src/main/scala/tfr/Cli.scala @@ -34,16 +34,29 @@ object Cli { given ioContextShift as ContextShift[IO] = IO.contextShift(scala.concurrent.ExecutionContext.Implicits.global) + object Options { + enum RecordType(value: String) { + case Example extends RecordType("example") + case PredictionLog extends RecordType("prediction_log") + } + + given recordValueConverter as ValueConverter[RecordType] = + singleArgConverter[RecordType] { s => + RecordType.valueOf(s.split("_").fold("")(_ + _.capitalize)) + } + } + final class Options(arguments: Seq[String]) extends ScallopConf(arguments) { + import Options.{given _, _} printedName = "tfr" banner("""Usage: tfr [options] |TensorFlow TFRecord reader CLI tool |Options: |""".stripMargin) - val record: ScallopOption[String] = - opt[String]( - default = Some("example"), + val record: ScallopOption[RecordType] = + opt[RecordType]( + default = Some(RecordType.Example), descr = "Record type to be read { example | prediction_log }" ) val checkCrc32 = opt[Boolean]( @@ -55,7 +68,7 @@ object Cli { default = Some(false), descr = "Output examples as flat JSON objects" ) - val files =trailArg[List[String]]( + val files = trailArg[List[String]]( required = false, descr = "files? | STDIN", default = Some(List.empty) @@ -65,19 +78,18 @@ object Cli { def main(args: Array[String]): Unit = { val options = Options(ArraySeq.unsafeWrapArray(args)) - println(options.files()) val resources = options.files() match case Nil => Resources.stdin[IO] :: Nil case l => l.iterator.map(Resources.file[IO]).toList options.record() match - case "example" => + case Options.RecordType.Example => given exampleEncoder as Encoder[Example] = if options.flat() then flat.exampleEncoder else tfr.instances.example.exampleEncoder run[Example](options, resources) - case "prediction_log" => + case Options.RecordType.PredictionLog => given predictionLogEncoder as Encoder[PredictionLog] = tfr.instances.prediction.predictionLogEncoder