Skip to content

Commit

Permalink
Make record type enum
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas committed Aug 19, 2020
1 parent 7266306 commit d512ba3
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions modules/cli/src/main/scala/tfr/Cli.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,29 @@ object Cli {
given ioContextShift as ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.Implicits.global)

object Options {
enum RecordType(value: String) {
case Example extends RecordType("example")
case PredictionLog extends RecordType("prediction_log")
}

given recordValueConverter as ValueConverter[RecordType] =
singleArgConverter[RecordType] { s =>
RecordType.valueOf(s.split("_").fold("")(_ + _.capitalize))
}
}

final class Options(arguments: Seq[String]) extends ScallopConf(arguments) {
import Options.{given _, _}
printedName = "tfr"
banner("""Usage: tfr [options] <files? | STDIN>
|TensorFlow TFRecord reader CLI tool
|Options:
|""".stripMargin)

val record: ScallopOption[String] =
opt[String](
default = Some("example"),
val record: ScallopOption[RecordType] =
opt[RecordType](
default = Some(RecordType.Example),
descr = "Record type to be read { example | prediction_log }"
)
val checkCrc32 = opt[Boolean](
Expand All @@ -55,7 +68,7 @@ object Cli {
default = Some(false),
descr = "Output examples as flat JSON objects"
)
val files =trailArg[List[String]](
val files = trailArg[List[String]](
required = false,
descr = "files? | STDIN",
default = Some(List.empty)
Expand All @@ -65,19 +78,18 @@ object Cli {

def main(args: Array[String]): Unit = {
val options = Options(ArraySeq.unsafeWrapArray(args))
println(options.files())
val resources = options.files() match
case Nil => Resources.stdin[IO] :: Nil
case l => l.iterator.map(Resources.file[IO]).toList

options.record() match
case "example" =>
case Options.RecordType.Example =>
given exampleEncoder as Encoder[Example] =
if options.flat() then flat.exampleEncoder
else tfr.instances.example.exampleEncoder

run[Example](options, resources)
case "prediction_log" =>
case Options.RecordType.PredictionLog =>
given predictionLogEncoder as Encoder[PredictionLog] =
tfr.instances.prediction.predictionLogEncoder

Expand Down

0 comments on commit d512ba3

Please sign in to comment.