cmd/sb: consolidate format flag logic in one place

Use a custom flag to encapsulate the logic around selecting the format
in one place.

Change-Id: Ie488f0439496dbcdcde71c3c46c9ba548b4f29c6
diff --git a/cmd/sb/internal/writer/writer.go b/cmd/sb/internal/writer/writer.go
index cd8ae6c..71c66b6 100644
--- a/cmd/sb/internal/writer/writer.go
+++ b/cmd/sb/internal/writer/writer.go
@@ -26,8 +26,20 @@
 	Right
 )
 
-// WriteTable formats the results as ASCII tables.
-func WriteTable(out io.Writer, columnNames []string, rs syncbase.ResultStream) error {
+type FormattingWriter interface {
+	Write(columnNames []string, rs syncbase.ResultStream) error
+}
+
+type tableWriter struct {
+	w io.Writer
+}
+
+func NewTableWriter(w io.Writer) FormattingWriter {
+	return &tableWriter{w}
+}
+
+// Write formats the results as ASCII tables.
+func (t *tableWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
 	// Buffer the results so we can compute the column widths.
 	columnWidths := make([]int, len(columnNames))
 	for i, cName := range columnNames {
@@ -62,27 +74,27 @@
 		return rs.Err()
 	}
 
-	writeBorder(out, columnWidths)
+	writeBorder(t.w, columnWidths)
 	sep := "| "
 	for i, cName := range columnNames {
-		io.WriteString(out, fmt.Sprintf("%s%*s", sep, columnWidths[i], cName))
+		io.WriteString(t.w, fmt.Sprintf("%s%*s", sep, columnWidths[i], cName))
 		sep = " | "
 	}
-	io.WriteString(out, " |\n")
-	writeBorder(out, columnWidths)
+	io.WriteString(t.w, " |\n")
+	writeBorder(t.w, columnWidths)
 	for _, result := range results {
 		sep = "| "
 		for i, column := range result {
 			if justification[i] == Right {
-				io.WriteString(out, fmt.Sprintf("%s%*s", sep, columnWidths[i], column))
+				io.WriteString(t.w, fmt.Sprintf("%s%*s", sep, columnWidths[i], column))
 			} else {
-				io.WriteString(out, fmt.Sprintf("%s%-*s", sep, columnWidths[i], column))
+				io.WriteString(t.w, fmt.Sprintf("%s%-*s", sep, columnWidths[i], column))
 			}
 			sep = " | "
 		}
-		io.WriteString(out, " |\n")
+		io.WriteString(t.w, " |\n")
 	}
-	writeBorder(out, columnWidths)
+	writeBorder(t.w, columnWidths)
 	return nil
 }
 
@@ -107,15 +119,24 @@
 	}
 }
 
-// WriteCSV formats the results as CSV as specified by https://tools.ietf.org/html/rfc4180.
-func WriteCSV(out io.Writer, columnNames []string, rs syncbase.ResultStream, delimiter string) error {
+type csvWriter struct {
+	w         io.Writer
+	delimiter string
+}
+
+func NewCSVWriter(w io.Writer, delimiter string) FormattingWriter {
+	return &csvWriter{w, delimiter}
+}
+
+// Write formats the results as CSV as specified by https://tools.ietf.org/html/rfc4180.
+func (c *csvWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
 	delim := ""
 	for _, cName := range columnNames {
-		str := doubleQuoteForCSV(cName, delimiter)
-		io.WriteString(out, fmt.Sprintf("%s%s", delim, str))
-		delim = delimiter
+		str := doubleQuoteForCSV(cName, c.delimiter)
+		io.WriteString(c.w, fmt.Sprintf("%s%s", delim, str))
+		delim = c.delimiter
 	}
-	io.WriteString(out, "\n")
+	io.WriteString(c.w, "\n")
 	for rs.Advance() {
 		delim := ""
 		for i, n := 0, rs.ResultCount(); i != n; i++ {
@@ -123,11 +144,11 @@
 			if err := rs.Result(i, &column); err != nil {
 				return err
 			}
-			str := doubleQuoteForCSV(toStringRaw(column, false), delimiter)
-			io.WriteString(out, fmt.Sprintf("%s%s", delim, str))
-			delim = delimiter
+			str := doubleQuoteForCSV(toStringRaw(column, false), c.delimiter)
+			io.WriteString(c.w, fmt.Sprintf("%s%s", delim, str))
+			delim = c.delimiter
 		}
-		io.WriteString(out, "\n")
+		io.WriteString(c.w, "\n")
 	}
 	return rs.Err()
 }
@@ -148,9 +169,17 @@
 	return str
 }
 
-// WriteJson formats the result as a JSON array of arrays (rows) of values.
-func WriteJson(out io.Writer, columnNames []string, rs syncbase.ResultStream) error {
-	io.WriteString(out, "[")
+type jsonWriter struct {
+	w io.Writer
+}
+
+func NewJSONWriter(w io.Writer) FormattingWriter {
+	return &jsonWriter{w}
+}
+
+// Write formats the result as a JSON array of arrays (rows) of values.
+func (j *jsonWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
+	io.WriteString(j.w, "[")
 	jsonColNames := make([][]byte, len(columnNames))
 	for i, cName := range columnNames {
 		jsonCName, err := json.Marshal(cName)
@@ -161,7 +190,7 @@
 	}
 	bOpen := "{"
 	for rs.Advance() {
-		io.WriteString(out, bOpen)
+		io.WriteString(j.w, bOpen)
 		linestart := "\n  "
 		for i, n := 0, rs.ResultCount(); i != n; i++ {
 			var column interface{}
@@ -169,13 +198,13 @@
 				return err
 			}
 			str := toJson(column)
-			io.WriteString(out, fmt.Sprintf("%s%s: %s", linestart, jsonColNames[i], str))
+			io.WriteString(j.w, fmt.Sprintf("%s%s: %s", linestart, jsonColNames[i], str))
 			linestart = ",\n  "
 		}
-		io.WriteString(out, "\n}")
+		io.WriteString(j.w, "\n}")
 		bOpen = ", {"
 	}
-	io.WriteString(out, "]\n")
+	io.WriteString(j.w, "]\n")
 	return rs.Err()
 }
 
diff --git a/cmd/sb/internal/writer/writer_test.go b/cmd/sb/internal/writer/writer_test.go
index fed9851..982e6ce 100644
--- a/cmd/sb/internal/writer/writer_test.go
+++ b/cmd/sb/internal/writer/writer_test.go
@@ -278,7 +278,7 @@
 	}
 	for _, test := range tests {
 		var b bytes.Buffer
-		if err := writer.WriteTable(&b, test.columns, newResultStream(test.rows)); err != nil {
+		if err := writer.NewTableWriter(&b).Write(test.columns, newResultStream(test.rows)); err != nil {
 			t.Errorf("Unexpected error: %v", err)
 			continue
 		}
@@ -392,7 +392,7 @@
 	}
 	for _, test := range tests {
 		var b bytes.Buffer
-		if err := writer.WriteCSV(&b, test.columns, newResultStream(test.rows), test.delimiter); err != nil {
+		if err := writer.NewCSVWriter(&b, test.delimiter).Write(test.columns, newResultStream(test.rows)); err != nil {
 			t.Errorf("Unexpected error: %v", err)
 			continue
 		}
@@ -558,7 +558,7 @@
 	}
 	for _, test := range tests {
 		var b bytes.Buffer
-		if err := writer.WriteJson(&b, test.columns, newResultStream(test.rows)); err != nil {
+		if err := writer.NewJSONWriter(&b).Write(test.columns, newResultStream(test.rows)); err != nil {
 			t.Errorf("Unexpected error: %v", err)
 			continue
 		}
diff --git a/cmd/sb/shell.go b/cmd/sb/shell.go
index 43709ce..afa412a 100644
--- a/cmd/sb/shell.go
+++ b/cmd/sb/shell.go
@@ -40,15 +40,45 @@
 `,
 }
 
+type formatFlag string
+
+func (f *formatFlag) Set(s string) error {
+	for _, v := range []string{"table", "csv", "json"} {
+		if s == v {
+			*f = formatFlag(s)
+			return nil
+		}
+	}
+	return fmt.Errorf("unsupported -format %q", s)
+}
+
+func (f *formatFlag) String() string {
+	return string(*f)
+}
+
+func (f formatFlag) NewWriter(w io.Writer) writer.FormattingWriter {
+	switch f {
+	case "table":
+		return writer.NewTableWriter(w)
+	case "csv":
+		return writer.NewCSVWriter(w, flagCSVDelimiter)
+	case "json":
+		return writer.NewJSONWriter(w)
+	default:
+		panic("unexpected format:" + f)
+	}
+	return nil
+}
+
 var (
-	flagFormat              string
+	flagFormat              formatFlag = "table"
 	flagCSVDelimiter        string
 	flagCreateIfNotExists   bool
 	flagMakeDemoCollections bool
 )
 
 func init() {
-	cmdSbShell.Flags.StringVar(&flagFormat, "format", "table", "Output format. 'table': human-readable table; 'csv': comma-separated values, use -csv-delimiter to control the delimiter; 'json': JSON objects.")
+	cmdSbShell.Flags.Var(&flagFormat, "format", "Output format. 'table': human-readable table; 'csv': comma-separated values, use -csv-delimiter to control the delimiter; 'json': JSON objects.")
 	cmdSbShell.Flags.StringVar(&flagCSVDelimiter, "csv-delimiter", ",", "Delimiter to use when printing data as CSV (e.g. \"\t\", \",\")")
 	cmdSbShell.Flags.BoolVar(&flagCreateIfNotExists, "create-missing", false, "Create the database if it does not exist.")
 	cmdSbShell.Flags.BoolVar(&flagMakeDemoCollections, "make-demo", false, "(Re)create demo collections in the database.")
@@ -304,26 +334,10 @@
 }
 
 func queryExec(ctx *context.T, w io.Writer, d syncbase.Database, q string) error {
-	if columnNames, rs, err := d.Exec(ctx, q); err != nil {
+	columnNames, rs, err := d.Exec(ctx, q)
+	if err != nil {
 		off, msg := splitError(err)
 		return fmt.Errorf("\n%s\n%s^\n%d: %s", q, strings.Repeat(" ", int(off)), off+1, msg)
-	} else {
-		switch flagFormat {
-		case "table":
-			if err := writer.WriteTable(w, columnNames, rs); err != nil {
-				return err
-			}
-		case "csv":
-			if err := writer.WriteCSV(w, columnNames, rs, flagCSVDelimiter); err != nil {
-				return err
-			}
-		case "json":
-			if err := writer.WriteJson(w, columnNames, rs); err != nil {
-				return err
-			}
-		default:
-			panic(fmt.Sprintf("invalid format flag value: %v", flagFormat))
-		}
 	}
-	return nil
+	return flagFormat.NewWriter(w).Write(columnNames, rs)
 }