cmd/sb: consolidate format flag logic in one place
Use a custom flag to encapsulate the logic around selecting the format
in one place.
Change-Id: Ie488f0439496dbcdcde71c3c46c9ba548b4f29c6
diff --git a/cmd/sb/internal/writer/writer.go b/cmd/sb/internal/writer/writer.go
index cd8ae6c..71c66b6 100644
--- a/cmd/sb/internal/writer/writer.go
+++ b/cmd/sb/internal/writer/writer.go
@@ -26,8 +26,20 @@
Right
)
-// WriteTable formats the results as ASCII tables.
-func WriteTable(out io.Writer, columnNames []string, rs syncbase.ResultStream) error {
+type FormattingWriter interface {
+ Write(columnNames []string, rs syncbase.ResultStream) error
+}
+
+type tableWriter struct {
+ w io.Writer
+}
+
+func NewTableWriter(w io.Writer) FormattingWriter {
+ return &tableWriter{w}
+}
+
+// Write formats the results as ASCII tables.
+func (t *tableWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
// Buffer the results so we can compute the column widths.
columnWidths := make([]int, len(columnNames))
for i, cName := range columnNames {
@@ -62,27 +74,27 @@
return rs.Err()
}
- writeBorder(out, columnWidths)
+ writeBorder(t.w, columnWidths)
sep := "| "
for i, cName := range columnNames {
- io.WriteString(out, fmt.Sprintf("%s%*s", sep, columnWidths[i], cName))
+ io.WriteString(t.w, fmt.Sprintf("%s%*s", sep, columnWidths[i], cName))
sep = " | "
}
- io.WriteString(out, " |\n")
- writeBorder(out, columnWidths)
+ io.WriteString(t.w, " |\n")
+ writeBorder(t.w, columnWidths)
for _, result := range results {
sep = "| "
for i, column := range result {
if justification[i] == Right {
- io.WriteString(out, fmt.Sprintf("%s%*s", sep, columnWidths[i], column))
+ io.WriteString(t.w, fmt.Sprintf("%s%*s", sep, columnWidths[i], column))
} else {
- io.WriteString(out, fmt.Sprintf("%s%-*s", sep, columnWidths[i], column))
+ io.WriteString(t.w, fmt.Sprintf("%s%-*s", sep, columnWidths[i], column))
}
sep = " | "
}
- io.WriteString(out, " |\n")
+ io.WriteString(t.w, " |\n")
}
- writeBorder(out, columnWidths)
+ writeBorder(t.w, columnWidths)
return nil
}
@@ -107,15 +119,24 @@
}
}
-// WriteCSV formats the results as CSV as specified by https://tools.ietf.org/html/rfc4180.
-func WriteCSV(out io.Writer, columnNames []string, rs syncbase.ResultStream, delimiter string) error {
+type csvWriter struct {
+ w io.Writer
+ delimiter string
+}
+
+func NewCSVWriter(w io.Writer, delimiter string) FormattingWriter {
+ return &csvWriter{w, delimiter}
+}
+
+// Write formats the results as CSV as specified by https://tools.ietf.org/html/rfc4180.
+func (c *csvWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
delim := ""
for _, cName := range columnNames {
- str := doubleQuoteForCSV(cName, delimiter)
- io.WriteString(out, fmt.Sprintf("%s%s", delim, str))
- delim = delimiter
+ str := doubleQuoteForCSV(cName, c.delimiter)
+ io.WriteString(c.w, fmt.Sprintf("%s%s", delim, str))
+ delim = c.delimiter
}
- io.WriteString(out, "\n")
+ io.WriteString(c.w, "\n")
for rs.Advance() {
delim := ""
for i, n := 0, rs.ResultCount(); i != n; i++ {
@@ -123,11 +144,11 @@
if err := rs.Result(i, &column); err != nil {
return err
}
- str := doubleQuoteForCSV(toStringRaw(column, false), delimiter)
- io.WriteString(out, fmt.Sprintf("%s%s", delim, str))
- delim = delimiter
+ str := doubleQuoteForCSV(toStringRaw(column, false), c.delimiter)
+ io.WriteString(c.w, fmt.Sprintf("%s%s", delim, str))
+ delim = c.delimiter
}
- io.WriteString(out, "\n")
+ io.WriteString(c.w, "\n")
}
return rs.Err()
}
@@ -148,9 +169,17 @@
return str
}
-// WriteJson formats the result as a JSON array of arrays (rows) of values.
-func WriteJson(out io.Writer, columnNames []string, rs syncbase.ResultStream) error {
- io.WriteString(out, "[")
+type jsonWriter struct {
+ w io.Writer
+}
+
+func NewJSONWriter(w io.Writer) FormattingWriter {
+ return &jsonWriter{w}
+}
+
+// Write formats the result as a JSON array of arrays (rows) of values.
+func (j *jsonWriter) Write(columnNames []string, rs syncbase.ResultStream) error {
+ io.WriteString(j.w, "[")
jsonColNames := make([][]byte, len(columnNames))
for i, cName := range columnNames {
jsonCName, err := json.Marshal(cName)
@@ -161,7 +190,7 @@
}
bOpen := "{"
for rs.Advance() {
- io.WriteString(out, bOpen)
+ io.WriteString(j.w, bOpen)
linestart := "\n "
for i, n := 0, rs.ResultCount(); i != n; i++ {
var column interface{}
@@ -169,13 +198,13 @@
return err
}
str := toJson(column)
- io.WriteString(out, fmt.Sprintf("%s%s: %s", linestart, jsonColNames[i], str))
+ io.WriteString(j.w, fmt.Sprintf("%s%s: %s", linestart, jsonColNames[i], str))
linestart = ",\n "
}
- io.WriteString(out, "\n}")
+ io.WriteString(j.w, "\n}")
bOpen = ", {"
}
- io.WriteString(out, "]\n")
+ io.WriteString(j.w, "]\n")
return rs.Err()
}
diff --git a/cmd/sb/internal/writer/writer_test.go b/cmd/sb/internal/writer/writer_test.go
index fed9851..982e6ce 100644
--- a/cmd/sb/internal/writer/writer_test.go
+++ b/cmd/sb/internal/writer/writer_test.go
@@ -278,7 +278,7 @@
}
for _, test := range tests {
var b bytes.Buffer
- if err := writer.WriteTable(&b, test.columns, newResultStream(test.rows)); err != nil {
+ if err := writer.NewTableWriter(&b).Write(test.columns, newResultStream(test.rows)); err != nil {
t.Errorf("Unexpected error: %v", err)
continue
}
@@ -392,7 +392,7 @@
}
for _, test := range tests {
var b bytes.Buffer
- if err := writer.WriteCSV(&b, test.columns, newResultStream(test.rows), test.delimiter); err != nil {
+ if err := writer.NewCSVWriter(&b, test.delimiter).Write(test.columns, newResultStream(test.rows)); err != nil {
t.Errorf("Unexpected error: %v", err)
continue
}
@@ -558,7 +558,7 @@
}
for _, test := range tests {
var b bytes.Buffer
- if err := writer.WriteJson(&b, test.columns, newResultStream(test.rows)); err != nil {
+ if err := writer.NewJSONWriter(&b).Write(test.columns, newResultStream(test.rows)); err != nil {
t.Errorf("Unexpected error: %v", err)
continue
}
diff --git a/cmd/sb/shell.go b/cmd/sb/shell.go
index 43709ce..afa412a 100644
--- a/cmd/sb/shell.go
+++ b/cmd/sb/shell.go
@@ -40,15 +40,45 @@
`,
}
+type formatFlag string
+
+func (f *formatFlag) Set(s string) error {
+ for _, v := range []string{"table", "csv", "json"} {
+ if s == v {
+ *f = formatFlag(s)
+ return nil
+ }
+ }
+ return fmt.Errorf("unsupported -format %q", s)
+}
+
+func (f *formatFlag) String() string {
+ return string(*f)
+}
+
+func (f formatFlag) NewWriter(w io.Writer) writer.FormattingWriter {
+ switch f {
+ case "table":
+ return writer.NewTableWriter(w)
+ case "csv":
+ return writer.NewCSVWriter(w, flagCSVDelimiter)
+ case "json":
+ return writer.NewJSONWriter(w)
+ default:
+ panic("unexpected format:" + f)
+ }
+ return nil
+}
+
var (
- flagFormat string
+ flagFormat formatFlag = "table"
flagCSVDelimiter string
flagCreateIfNotExists bool
flagMakeDemoCollections bool
)
func init() {
- cmdSbShell.Flags.StringVar(&flagFormat, "format", "table", "Output format. 'table': human-readable table; 'csv': comma-separated values, use -csv-delimiter to control the delimiter; 'json': JSON objects.")
+ cmdSbShell.Flags.Var(&flagFormat, "format", "Output format. 'table': human-readable table; 'csv': comma-separated values, use -csv-delimiter to control the delimiter; 'json': JSON objects.")
cmdSbShell.Flags.StringVar(&flagCSVDelimiter, "csv-delimiter", ",", "Delimiter to use when printing data as CSV (e.g. \"\t\", \",\")")
cmdSbShell.Flags.BoolVar(&flagCreateIfNotExists, "create-missing", false, "Create the database if it does not exist.")
cmdSbShell.Flags.BoolVar(&flagMakeDemoCollections, "make-demo", false, "(Re)create demo collections in the database.")
@@ -304,26 +334,10 @@
}
func queryExec(ctx *context.T, w io.Writer, d syncbase.Database, q string) error {
- if columnNames, rs, err := d.Exec(ctx, q); err != nil {
+ columnNames, rs, err := d.Exec(ctx, q)
+ if err != nil {
off, msg := splitError(err)
return fmt.Errorf("\n%s\n%s^\n%d: %s", q, strings.Repeat(" ", int(off)), off+1, msg)
- } else {
- switch flagFormat {
- case "table":
- if err := writer.WriteTable(w, columnNames, rs); err != nil {
- return err
- }
- case "csv":
- if err := writer.WriteCSV(w, columnNames, rs, flagCSVDelimiter); err != nil {
- return err
- }
- case "json":
- if err := writer.WriteJson(w, columnNames, rs); err != nil {
- return err
- }
- default:
- panic(fmt.Sprintf("invalid format flag value: %v", flagFormat))
- }
}
- return nil
+ return flagFormat.NewWriter(w).Write(columnNames, rs)
}