diff --git a/README.md b/README.md index 576439da..d1eab395 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,7 @@ switches are most important to you to have implemented next in the new sqlcmd. - `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` username parameter. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. +- Sqlcmd defaults to a horizontal output format (space separated, no borders). To use the new ASCII table format, use the new `--ascii` command line option or set `SQLCMDFORMAT` to `ascii` (`-v SQLCMDFORMAT=ascii`). Note that when using the ASCII table format, individual column widths are determined by the content, but the `SQLCMDCOLWIDTH` variable and the `-w` parameter are still used to control the maximum screen width, determining when columns wrap into separate table segments. The following variables are ignored: `SQLCMDMAXFIXEDTYPEWIDTH`, `SQLCMDMAXVARTYPEWIDTH`, and `SQLCMDHEADERS`. ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 5abc0860..b27ecb0b 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -84,7 +84,8 @@ type SQLCmdArguments struct { TraceFile string ServerNameOverride string // Keep Help at the end of the list - Help bool + Help bool + Ascii bool } func (args *SQLCmdArguments) useEnvVars() bool { @@ -151,6 +152,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { switch { case len(a.InputFile) > 0 && (len(a.Query) > 0 || len(a.InitialQuery) > 0): err = mutuallyExclusiveError("i", `-Q/-q`) + case a.Vertical && a.Ascii: + err = mutuallyExclusiveError("--vertical", "--ascii") case a.UseTrustedConnection && (len(a.UserName) > 0 || len(a.Password) > 0): err = mutuallyExclusiveError("-E", `-U/-P`) case a.UseAad && len(a.AuthenticationMethod) > 0: @@ -465,6 +468,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().BoolVarP(&args.DisableVariableSubstitution, "disable-variable-substitution", "x", false, localizer.Sprintf("Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many %s statements that may contain strings that have the same format as regular variables, such as $(variable_name)", localizer.InsertKeyword)) var variables map[string]string rootCmd.Flags().StringToStringVarP(&args.Variables, "variables", "v", variables, localizer.Sprintf("Creates a sqlcmd scripting variable that can be used in a sqlcmd script. Enclose the value in quotation marks if the value contains spaces. You can specify multiple var=values values. If there are errors in any of the values specified, sqlcmd generates an error message and then exits")) + rootCmd.Flags().IntVarP(&args.PacketSize, "packet-size", "a", 0, localizer.Sprintf("Requests a packet of a different size. This option sets the sqlcmd scripting variable %s. packet_size must be a value between 512 and 32767. The default = 4096. A larger packet size can enhance performance for execution of scripts that have lots of SQL statements between %s commands. You can request a larger packet size. However, if the request is denied, sqlcmd uses the server default for packet size", localizer.PacketSizeVar, localizer.BatchTerminatorGo)) rootCmd.Flags().IntVarP(&args.LoginTimeout, "login-timeOut", "l", -1, localizer.Sprintf("Specifies the number of seconds before a sqlcmd login to the go-mssqldb driver times out when you try to connect to a server. This option sets the sqlcmd scripting variable %s. The default value is 30. 0 means infinite", localizer.LoginTimeOutVar)) rootCmd.Flags().StringVarP(&args.WorkstationName, "workstation-name", "H", "", localizer.Sprintf("This option sets the sqlcmd scripting variable %s. The workstation name is listed in the hostname column of the sys.sysprocesses catalog view and can be returned using the stored procedure sp_who. If this option is not specified, the default is the current computer name. This name can be used to identify different sqlcmd sessions", localizer.WorkstationVar)) @@ -477,6 +481,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { // Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866 //rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true" rootCmd.Flags().BoolVarP(&args.Vertical, "vertical", "", false, localizer.Sprintf("Prints the output in vertical format. This option sets the sqlcmd scripting variable %s to '%s'. The default is false", sqlcmd.SQLCMDFORMAT, "vert")) + rootCmd.Flags().BoolVarP(&args.Ascii, "ascii", "", false, localizer.Sprintf("Prints the output in ASCII table format. This option sets the sqlcmd scripting variable %s to '%s'. The default is false", sqlcmd.SQLCMDFORMAT, "ascii")) + _ = rootCmd.Flags().IntP(errorsToStderr, "r", -1, localizer.Sprintf("%s Redirects error messages with severity >= 11 output to stderr. Pass 1 to to redirect all errors including PRINT.", "-r[0 | 1]")) rootCmd.Flags().IntVar(&args.DriverLoggingLevel, "driver-logging-level", 0, localizer.Sprintf("Level of mssql driver messages to print")) rootCmd.Flags().BoolVarP(&args.ExitOnError, "exit-on-error", "b", false, localizer.Sprintf("Specifies that sqlcmd exits and returns a %s value when an error occurs", localizer.DosErrorLevel)) @@ -713,7 +719,10 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { if a.Vertical { return "vert" } - return "horizontal" + if a.Ascii { + return "ascii" + } + return "" }, } for varname, set := range varmap { @@ -862,7 +871,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } s.Connect = &connectConfig - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior()) + s.Format = sqlcmd.NewSQLCmdDefaultFormatter(vars, args.TrimSpaces, args.getControlCharacterBehavior()) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) if err != nil { diff --git a/internal/sql/mssql.go b/internal/sql/mssql.go index 442e514a..961846cf 100644 --- a/internal/sql/mssql.go +++ b/internal/sql/mssql.go @@ -32,7 +32,7 @@ func (m *mssql) Connect( m.console = nil } m.sqlcmd = sqlcmd.New(m.console, "", v) - m.sqlcmd.Format = sqlcmd.NewSQLCmdDefaultFormatter(false, sqlcmd.ControlIgnore) + m.sqlcmd.Format = sqlcmd.NewSQLCmdDefaultFormatter(v, false, sqlcmd.ControlIgnore) connect := sqlcmd.ConnectSettings{ ServerName: fmt.Sprintf( "%s,%#v", diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 56d509da..76c509a8 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -242,7 +242,7 @@ func TestListCommandUsesColorizer(t *testing.T) { func TestListColorPrintsStyleSamples(t *testing.T) { vars := InitializeVariables(false) s := New(nil, "", vars) - s.Format = NewSQLCmdDefaultFormatter(false, ControlIgnore) + s.Format = NewSQLCmdDefaultFormatter(vars, false, ControlIgnore) // force colorizer on s.colorizer = color.New(true) buf := &memoryBuffer{buf: new(bytes.Buffer)} diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..71bb00e2 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -87,8 +87,12 @@ type sqlCmdFormatterType struct { xml bool } -// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter -func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { +// NewSQLCmdDefaultFormatter returns a Formatter based on the configuration. +// It returns an ASCII formatter if the format is set to "ascii", otherwise it returns a formatter that mimics the original ODBC-based sqlcmd formatter. +func NewSQLCmdDefaultFormatter(vars *Variables, removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { + if vars.Format() == "ascii" { + return NewSQLCmdAsciiFormatter(vars, removeTrailingSpaces, ccb) + } return &sqlCmdFormatterType{ removeTrailingSpaces: removeTrailingSpaces, format: "horizontal", diff --git a/pkg/sqlcmd/format_ascii.go b/pkg/sqlcmd/format_ascii.go new file mode 100644 index 00000000..5f3a5130 --- /dev/null +++ b/pkg/sqlcmd/format_ascii.go @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "database/sql" + "os" + "strings" + "unicode/utf8" + + "github.com/microsoft/go-sqlcmd/internal/color" + "golang.org/x/term" +) + +type asciiFormatter struct { + *sqlCmdFormatterType + rows [][]string + colWidths []int +} + +func NewSQLCmdAsciiFormatter(vars *Variables, removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { + return &asciiFormatter{ + sqlCmdFormatterType: &sqlCmdFormatterType{ + removeTrailingSpaces: removeTrailingSpaces, + format: "ascii", + colorizer: color.New(false), + ccb: ccb, + vars: vars, + }, + } +} + +func (f *asciiFormatter) BeginResultSet(cols []*sql.ColumnType) { + f.sqlCmdFormatterType.BeginResultSet(cols) + f.rows = make([][]string, 0) + f.colWidths = make([]int, len(f.columnDetails)) + for i, c := range f.columnDetails { + f.colWidths[i] = utf8.RuneCountInString(c.col.Name()) + } +} + +func (f *asciiFormatter) AddRow(row *sql.Rows) string { + values, err := f.scanRow(row) + if err != nil { + f.mustWriteErr(err.Error()) + return "" + } + f.rows = append(f.rows, values) + f.rowcount++ + for i, val := range values { + if i < len(f.colWidths) { + l := utf8.RuneCountInString(val) + if l > f.colWidths[i] { + f.colWidths[i] = l + } + } + } + if len(values) > 0 { + return values[0] + } + return "" +} + +func (f *asciiFormatter) EndResultSet() { + if len(f.rows) > 0 || len(f.columnDetails) > 0 { + f.printAsciiTable() + } + f.rows = nil + f.colWidths = nil +} + +func (f *asciiFormatter) printAsciiTable() { + maxWidth := int(f.vars.ScreenWidth()) + if maxWidth <= 0 { + if w, _, err := term.GetSize(int(os.Stdout.Fd())); err == nil { + maxWidth = w - 1 + } else { + maxWidth = 1000000 + } + } + + // Limit column width to maxWidth - 4 (border + padding) + // 1 (left border) + 1 (space) + content + 1 (space) + 1 (right border) = content + 4 + maxColContentWidth := maxWidth - 4 + if maxColContentWidth < 1 { + maxColContentWidth = 1 + } + + for i := range f.colWidths { + if f.colWidths[i] > maxColContentWidth { + f.colWidths[i] = maxColContentWidth + } + } + + totalWidth := 1 + for _, w := range f.colWidths { + totalWidth += w + 3 + } + + if totalWidth <= maxWidth { + f.printTableSegment(f.colWidths, 0, len(f.colWidths)-1) + } else { + startCol := 0 + for startCol < len(f.colWidths) { + currentWidth := 1 + endCol := startCol + for endCol < len(f.colWidths) { + w := f.colWidths[endCol] + 3 + if currentWidth+w > maxWidth { + break + } + currentWidth += w + endCol++ + } + + if endCol == startCol { + endCol++ + } + + f.printTableSegment(f.colWidths, startCol, endCol-1) + startCol = endCol + } + } +} + +func (f *asciiFormatter) printTableSegment(colWidths []int, startCol, endCol int) { + if startCol > endCol { + return + } + + sep := f.vars.ColumnSeparator() + if sep == "" || sep == " " { + sep = "|" + } + + divider := "+" + for i := startCol; i <= endCol; i++ { + divider += strings.Repeat("-", colWidths[i]+2) + "+" + } + f.writeOut(divider+SqlcmdEol, color.TextTypeSeparator) + + header := sep + for i := startCol; i <= endCol; i++ { + name := f.columnDetails[i].col.Name() + header += " " + padRightString(name, colWidths[i]) + " " + sep + } + f.writeOut(header+SqlcmdEol, color.TextTypeHeader) + f.writeOut(divider+SqlcmdEol, color.TextTypeSeparator) + + for _, row := range f.rows { + line := sep + for i := startCol; i <= endCol; i++ { + val := "" + if i < len(row) { + val = row[i] + } + isNumeric := isNumericType(f.columnDetails[i].col.DatabaseTypeName()) + + if isNumeric { + line += " " + padLeftString(val, colWidths[i]) + " " + sep + } else { + line += " " + padRightString(val, colWidths[i]) + " " + sep + } + } + f.writeOut(line+SqlcmdEol, color.TextTypeCell) + } + f.writeOut(divider+SqlcmdEol, color.TextTypeSeparator) +} + +func padRightString(s string, width int) string { + l := utf8.RuneCountInString(s) + if l > width { + r := []rune(s) + if width >= 3 { + return string(r[:width-3]) + "..." + } + return string(r[:width]) + } + return s + strings.Repeat(" ", width-l) +} + +func padLeftString(s string, width int) string { + l := utf8.RuneCountInString(s) + if l > width { + r := []rune(s) + if width >= 3 { + return string(r[:width-3]) + "..." + } + return string(r[:width]) + } + return strings.Repeat(" ", width-l) + s +} + +func isNumericType(typeName string) bool { + switch typeName { + case "TINYINT", "SMALLINT", "INT", "BIGINT", "REAL", "FLOAT", "DECIMAL", "NUMERIC", "MONEY", "SMALLMONEY": + return true + } + return false +} diff --git a/pkg/sqlcmd/format_ascii_test.go b/pkg/sqlcmd/format_ascii_test.go new file mode 100644 index 00000000..c339fead --- /dev/null +++ b/pkg/sqlcmd/format_ascii_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "database/sql" + "reflect" + "testing" + "unsafe" + + "github.com/microsoft/go-sqlcmd/internal/color" + "github.com/stretchr/testify/assert" +) + +func setColumnInfo(c *sql.ColumnType, name string, dbType string) { + v := reflect.ValueOf(c).Elem() + fName := v.FieldByName("name") + if fName.IsValid() { + reflect.NewAt(fName.Type(), unsafe.Pointer(fName.UnsafeAddr())).Elem().SetString(name) + } + fType := v.FieldByName("databaseType") + if fType.IsValid() { + reflect.NewAt(fType.Type(), unsafe.Pointer(fType.UnsafeAddr())).Elem().SetString(dbType) + } +} + +func TestAsciiFormatter(t *testing.T) { + vars := InitializeVariables(false) + vars.Set(SQLCMDFORMAT, "ascii") + + buf := new(bytes.Buffer) + f := &asciiFormatter{ + sqlCmdFormatterType: &sqlCmdFormatterType{ + vars: vars, + out: buf, + colorizer: color.New(false), + format: "ascii", + }, + rows: [][]string{{"1", "test"}}, + colWidths: []int{2, 4}, + } + + // Mock column details + f.columnDetails = make([]columnDetail, 2) + setColumnInfo(&f.columnDetails[0].col, "id", "INT") + setColumnInfo(&f.columnDetails[1].col, "name", "VARCHAR") + + f.printAsciiTable() + + expected := `+----+------+` + SqlcmdEol + + `| id | name |` + SqlcmdEol + + `+----+------+` + SqlcmdEol + + `| 1 | test |` + SqlcmdEol + + `+----+------+` + SqlcmdEol + + assert.Equal(t, expected, buf.String()) +} + +func TestAsciiFormatterWrapping(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + if s.db == nil { + t.Skip("No database connection available") + } + defer func() { + assert.NoError(t, buf.Close()) + }() + + s.vars.Set(SQLCMDFORMAT, "ascii") + s.vars.Set(SQLCMDCOLWIDTH, "20") // Small width to force wrapping + s.Format = NewSQLCmdDefaultFormatter(s.vars, false, ControlIgnore) + + err := runSqlCmd(t, s, []string{"select 1 as id, 'test' as name, '0123456789' as descr", "GO"}) + assert.NoError(t, err, "runSqlCmd returned error") + + expectedPart1 := `+----+------+` + SqlcmdEol + + `| id | name |` + SqlcmdEol + + `+----+------+` + SqlcmdEol + + `| 1 | test |` + SqlcmdEol + + `+----+------+` + SqlcmdEol + + expectedPart2 := `+------------+` + SqlcmdEol + + `| descr |` + SqlcmdEol + + `+------------+` + SqlcmdEol + + `| 0123456789 |` + SqlcmdEol + + `+------------+` + SqlcmdEol + + `(1 row affected)` + SqlcmdEol + + assert.Contains(t, buf.buf.String(), expectedPart1) + assert.Contains(t, buf.buf.String(), expectedPart2) +} + +func TestAsciiFormatterTruncation(t *testing.T) { + vars := InitializeVariables(false) + vars.Set(SQLCMDCOLWIDTH, "20") + + buf := new(bytes.Buffer) + f := &asciiFormatter{ + sqlCmdFormatterType: &sqlCmdFormatterType{ + vars: vars, + out: buf, + colorizer: color.New(false), + format: "ascii", + }, + rows: [][]string{{"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}}, // 50 chars + colWidths: []int{50}, + } + + // Mock column details with empty column type (defaults to non-numeric, empty name) + f.columnDetails = []columnDetail{{}} + + f.printAsciiTable() + + output := buf.String() + + // Expected behavior: + // maxWidth = 20 + // maxColContentWidth = 20 - 4 = 16 + // colWidths[0] should be clamped to 16 + // The value should be truncated to 13 chars + "..." = 16 chars total. + + // Divider: + followed by 16 dashes + 2 dashes (padding) + + + // Total width: 1 + 16 + 2 + 1 = 20 + // Divider line: +------------------+ + + // Header: | | (padded to 16) + // Since name is empty. + + // Value: | AAAAAAAAAAAAA... | (13 A's followed by ...) + + expectedDivider := "+------------------+" + expectedValue := "| AAAAAAAAAAAAA... |" + + assert.Contains(t, output, expectedDivider) + assert.Contains(t, output, expectedValue) + + // Verify it does NOT contain the full string + assert.NotContains(t, output, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +} diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index ade6dd8c..2c325fed 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -635,7 +635,7 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) + s.Format = NewSQLCmdDefaultFormatter(v, true, ControlIgnore) buf := &memoryBuffer{buf: new(bytes.Buffer)} s.SetOutput(buf) err := s.ConnectDb(nil, true) @@ -649,7 +649,7 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) + s.Format = NewSQLCmdDefaultFormatter(v, true, ControlIgnore) file, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") s.SetOutput(file) @@ -667,7 +667,7 @@ func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) + s.Format = NewSQLCmdDefaultFormatter(v, true, ControlIgnore) outfile, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") errfile, err := os.CreateTemp("", "sqlcmderr") diff --git a/pkg/sqlcmd/variables.go b/pkg/sqlcmd/variables.go index aa601627..d4f7fa7f 100644 --- a/pkg/sqlcmd/variables.go +++ b/pkg/sqlcmd/variables.go @@ -179,6 +179,10 @@ func (v Variables) Format() string { switch v[SQLCMDFORMAT] { case "vert", "vertical": return "vertical" + case "ascii": + return "ascii" + case "horiz", "horizontal": + return "horizontal" } return "horizontal" } @@ -246,6 +250,7 @@ func InitializeVariables(fromEnvironment bool) *Variables { SQLCMDUSER: "", SQLCMDUSEAAD: "", SQLCMDCOLORSCHEME: "", + SQLCMDFORMAT: "", } hostname, _ := os.Hostname() variables.Set(SQLCMDWORKSTATION, hostname)