Vendoring (godep). Closes #3.

This commit is contained in:
Matthew Dillon 2015-01-26 12:19:41 -09:00
parent af515e3e3c
commit 9a48e8ef3a
149 changed files with 27820 additions and 0 deletions

58
Godeps/Godeps.json generated Normal file
View file

@ -0,0 +1,58 @@
{
"ImportPath": "github.com/thermokarst/bactdb",
"GoVersion": "go1.4.1",
"Deps": [
{
"ImportPath": "github.com/DavidHuie/gomigrate",
"Rev": "7883dd76040debc099811feadd58425b1495103d"
},
{
"ImportPath": "github.com/codegangsta/cli",
"Comment": "1.2.0-62-gbf4a526",
"Rev": "bf4a526f48af7badd25d2cb02d587e1b01be3b50"
},
{
"ImportPath": "github.com/dgrijalva/jwt-go",
"Comment": "v2.2.0-11-g7c18dce",
"Rev": "7c18dce7b8ff32db34d4b7054d488db656d75cc3"
},
{
"ImportPath": "github.com/google/go-querystring/query",
"Rev": "d8840cbb2baa915f4836edda4750050a2c0b7aea"
},
{
"ImportPath": "github.com/gorilla/context",
"Rev": "215affda49addc4c8ef7e2534915df2c8c35c6cd"
},
{
"ImportPath": "github.com/gorilla/mux",
"Rev": "e444e69cbd2e2e3e0749a2f3c717cec491552bbf"
},
{
"ImportPath": "github.com/gorilla/schema",
"Rev": "ad8849d14e5cdf1f4423b1f0fc53827ebbdd7477"
},
{
"ImportPath": "github.com/jmoiron/modl",
"Rev": "b0cda508e4027c836cf78bd482a3edd0458b6d2f"
},
{
"ImportPath": "github.com/jmoiron/sqlx",
"Comment": "sqlx-v1.0-67-g69738bd",
"Rev": "69738bd209812c5a381786011aff17944a384554"
},
{
"ImportPath": "github.com/lib/pq",
"Comment": "go1.0-cutoff-13-g19eeca3",
"Rev": "19eeca3e30d2577b1761db471ec130810e67f532"
},
{
"ImportPath": "golang.org/x/crypto/bcrypt",
"Rev": "632d287f9f3f54b09809eebbd3cacbcb00b9f2fc"
},
{
"ImportPath": "golang.org/x/crypto/blowfish",
"Rev": "632d287f9f3f54b09809eebbd3cacbcb00b9f2fc"
}
]
}

5
Godeps/Readme generated Normal file
View file

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
Godeps/_workspace/.gitignore generated vendored Normal file
View file

@ -0,0 +1,2 @@
/pkg
/bin

View file

@ -0,0 +1 @@
*.test

View file

@ -0,0 +1,14 @@
language: go
go: 1.3
addons:
postgresql: "9.3"
before_script:
- psql -c 'create database gomigrate;' -U postgres
- mysql -uroot -e "CREATE USER 'gomigrate'@'localhost' IDENTIFIED BY 'password';"
- mysql -uroot -e "GRANT ALL PRIVILEGES ON * . * TO 'gomigrate'@'localhost';"
- mysql -uroot -e "CREATE DATABASE gomigrate;"
- go get github.com/lib/pq
- go get github.com/go-sql-driver/mysql
script:
- DB=pg go test
- DB=mysql go test

View file

@ -0,0 +1,20 @@
Copyright (c) 2014 David Huie
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,72 @@
# gomigrate
[![Build Status](https://travis-ci.org/DavidHuie/gomigrate.svg?branch=master)](https://travis-ci.org/DavidHuie/gomigrate)
A SQL database migration toolkit in Golang.
## Supported databases
- PostgreSQL
- MariaDB
- MySQL
## Usage
First import the package:
```go
import "github.com/DavidHuie/gomigrate"
```
Given a `database/sql` database connection to a PostgreSQL database, `db`,
and a directory to migration files, create a migrator:
```go
migrator := gomigrate.NewMigrator(db, gomigrate.Postgres{}, "./migrations")
```
To migrate the database, run:
```go
err := migrator.Migrate()
```
To rollback the last migration, run:
```go
err := migrator.Rollback()
```
## Migration files
Migration files need to follow a standard format and must be present
in the same directory. Given "up" and "down" steps for a migration,
create a file for each by following this template:
```
{{ id }}_{{ name }}_{{ "up" or "down" }}.sql
```
For a given migration, the `id` and `name` fields must be the same.
The id field is an integer that corresponds to the order in which
the migration should run relative to the other migrations.
### Example
If I'm trying to add a "users" table to the database, I would create
the following two files:
#### 1_add_users_table_up.sql
```
CREATE TABLE users();
```
#### 1_add_users_table_down.sql
```
DROP TABLE users;
```
## Copyright
Copyright (c) 2014 David Huie. See LICENSE.txt for further details.

View file

@ -0,0 +1,70 @@
package gomigrate
type Migratable interface {
SelectMigrationTableSql() string
CreateMigrationTableSql() string
GetMigrationSql() string
MigrationLogInsertSql() string
MigrationLogDeleteSql() string
}
// POSTGRES
type Postgres struct{}
func (p Postgres) SelectMigrationTableSql() string {
return "SELECT tablename FROM pg_catalog.pg_tables WHERE tablename = $1"
}
func (p Postgres) CreateMigrationTableSql() string {
return `CREATE TABLE gomigrate (
id SERIAL PRIMARY KEY,
migration_id BIGINT UNIQUE NOT NULL
)`
}
func (p Postgres) GetMigrationSql() string {
return `SELECT migration_id FROM gomigrate WHERE migration_id = $1`
}
func (p Postgres) MigrationLogInsertSql() string {
return "INSERT INTO gomigrate (migration_id) values ($1)"
}
func (p Postgres) MigrationLogDeleteSql() string {
return "DELETE FROM gomigrate WHERE migration_id = $1"
}
// MYSQL
type Mysql struct{}
func (m Mysql) SelectMigrationTableSql() string {
return "SELECT table_name FROM information_schema.tables WHERE table_name = ?"
}
func (m Mysql) CreateMigrationTableSql() string {
return `CREATE TABLE gomigrate (
id INT NOT NULL AUTO_INCREMENT,
migration_id BIGINT NOT NULL UNIQUE,
PRIMARY KEY (id)
) ENGINE=MyISAM`
}
func (m Mysql) GetMigrationSql() string {
return `SELECT migration_id FROM gomigrate WHERE migration_id = ?`
}
func (m Mysql) MigrationLogInsertSql() string {
return "INSERT INTO gomigrate (migration_id) values (?)"
}
func (m Mysql) MigrationLogDeleteSql() string {
return "DELETE FROM gomigrate WHERE migration_id = ?"
}
// MARIADB
type Mariadb struct {
Mysql
}

View file

@ -0,0 +1,327 @@
// A simple database migrator for PostgreSQL.
package gomigrate
import (
"bytes"
"database/sql"
"errors"
"io/ioutil"
"log"
"path/filepath"
"sort"
)
type migrationType string
const (
migrationTableName = "gomigrate"
upMigration = migrationType("up")
downMigration = migrationType("down")
)
var (
InvalidMigrationFile = errors.New("Invalid migration file")
InvalidMigrationPair = errors.New("Invalid pair of migration files")
InvalidMigrationsPath = errors.New("Invalid migrations path")
InvalidMigrationType = errors.New("Invalid migration type")
NoActiveMigrations = errors.New("No active migrations to rollback")
)
type Migrator struct {
DB *sql.DB
MigrationsPath string
dbAdapter Migratable
migrations map[uint64]*Migration
}
// Returns true if the migration table already exists.
func (m *Migrator) MigrationTableExists() (bool, error) {
row := m.DB.QueryRow(m.dbAdapter.SelectMigrationTableSql(), migrationTableName)
var tableName string
err := row.Scan(&tableName)
if err == sql.ErrNoRows {
log.Print("Migrations table not found")
return false, nil
}
if err != nil {
log.Printf("Error checking for migration table: %v", err)
return false, err
}
log.Print("Migrations table found")
return true, nil
}
// Creates the migrations table if it doesn't exist.
func (m *Migrator) CreateMigrationsTable() error {
_, err := m.DB.Query(m.dbAdapter.CreateMigrationTableSql())
if err != nil {
log.Fatalf("Error creating migrations table: %v", err)
}
log.Printf("Created migrations table: %s", migrationTableName)
return nil
}
// Returns a new migrator.
func NewMigrator(db *sql.DB, adapter Migratable, migrationsPath string) (*Migrator, error) {
// Normalize the migrations path.
path := []byte(migrationsPath)
pathLength := len(path)
if path[pathLength-1] != '/' {
path = append(path, '/')
}
log.Printf("Migrations path: %s", path)
migrator := Migrator{
db,
string(path),
adapter,
make(map[uint64]*Migration),
}
// Create the migrations table if it doesn't exist.
tableExists, err := migrator.MigrationTableExists()
if err != nil {
return nil, err
}
if !tableExists {
if err := migrator.CreateMigrationsTable(); err != nil {
return nil, err
}
}
// Get all metadata from the database.
if err := migrator.fetchMigrations(); err != nil {
return nil, err
}
if err := migrator.getMigrationStatuses(); err != nil {
return nil, err
}
return &migrator, nil
}
// Populates a migrator with a sorted list of migrations from the file system.
func (m *Migrator) fetchMigrations() error {
pathGlob := append([]byte(m.MigrationsPath), []byte("*")...)
matches, err := filepath.Glob(string(pathGlob))
if err != nil {
log.Fatalf("Error while globbing migrations: %v", err)
}
for _, match := range matches {
num, migrationType, name, err := parseMigrationPath(match)
if err != nil {
log.Printf("Invalid migration file found: %s", match)
continue
}
log.Printf("Migration file found: %s", match)
migration, ok := m.migrations[num]
if !ok {
migration = &Migration{Id: num, Name: name, Status: Inactive}
m.migrations[num] = migration
}
if migrationType == upMigration {
migration.UpPath = match
} else {
migration.DownPath = match
}
}
// Validate each migration.
for _, migration := range m.migrations {
if !migration.valid() {
path := migration.UpPath
if path == "" {
path = migration.DownPath
}
log.Printf("Invalid migration pair for path: %s", path)
return InvalidMigrationPair
}
}
log.Printf("Migrations file pairs found: %v", len(m.migrations))
return nil
}
// Queries the migration table to determine the status of each
// migration.
func (m *Migrator) getMigrationStatuses() error {
for _, migration := range m.migrations {
row := m.DB.QueryRow(m.dbAdapter.GetMigrationSql(), migration.Id)
var mid uint64
err := row.Scan(&mid)
if err == sql.ErrNoRows {
continue
}
if err != nil {
log.Printf(
"Error getting migration status for %s: %v",
migration.Name,
err,
)
return err
}
migration.Status = Active
}
return nil
}
// Returns a sorted list of migration ids for a given status. -1 returns
// all migrations.
func (m *Migrator) Migrations(status int) []*Migration {
// Sort all migration ids.
ids := make([]uint64, 0)
for id, _ := range m.migrations {
ids = append(ids, id)
}
sort.Sort(uint64slice(ids))
// Find ids for the given status.
migrations := make([]*Migration, 0)
for _, id := range ids {
migration := m.migrations[id]
if status == -1 || migration.Status == status {
migrations = append(migrations, migration)
}
}
return migrations
}
// Applies a single migration.
func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) error {
var path string
if mType == upMigration {
path = migration.UpPath
} else if mType == downMigration {
path = migration.DownPath
} else {
return InvalidMigrationType
}
log.Printf("Applying migration: %s", path)
sql, err := ioutil.ReadFile(path)
if err != nil {
log.Printf("Error reading migration: %s", path)
return err
}
transaction, err := m.DB.Begin()
if err != nil {
log.Printf("Error opening transaction: %v", err)
return err
}
for n, subMigration := range splitMigrationString(string(sql)) {
if allWhitespace.Match([]byte(subMigration)) {
continue
}
log.Printf("Applying submigration: %v", n+1)
for _, line := range bytes.Split([]byte(subMigration), []byte("\n")) {
log.Printf("MIGRATION: %s", line)
}
// Perform the migration.
result, err := transaction.Exec(string(subMigration))
if err != nil {
log.Printf("Error executing migration: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
}
if rowsAffected, err := result.RowsAffected(); err != nil {
log.Printf("Error getting rows affected: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
} else {
log.Printf("Rows affected: %v", rowsAffected)
}
}
// Log the event.
if mType == upMigration {
_, err = transaction.Exec(
m.dbAdapter.MigrationLogInsertSql(),
migration.Id,
)
} else {
_, err = transaction.Exec(
m.dbAdapter.MigrationLogDeleteSql(),
migration.Id,
)
}
if err != nil {
log.Printf("Error logging migration: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
}
// Commit and update the struct status.
if err := transaction.Commit(); err != nil {
log.Printf("Error commiting transaction: %v", err)
return err
}
if mType == upMigration {
migration.Status = Active
} else {
migration.Status = Inactive
}
return nil
}
// Applies all inactive migrations.
func (m *Migrator) Migrate() error {
for _, migration := range m.Migrations(Inactive) {
if err := m.ApplyMigration(migration, upMigration); err != nil {
return err
}
}
return nil
}
// Rolls back the last migration.
func (m *Migrator) Rollback() error {
return m.RollbackN(1)
}
// Rolls back N migrations.
func (m *Migrator) RollbackN(n int) error {
migrations := m.Migrations(Active)
if len(migrations) == 0 {
return NoActiveMigrations
}
last_migration := len(migrations) - 1 - n
for i := len(migrations) - 1; i != last_migration; i-- {
if err := m.ApplyMigration(migrations[i], downMigration); err != nil {
return err
}
}
return nil
}
// Rolls back all migrations.
func (m *Migrator) RollbackAll() error {
migrations := m.Migrations(Active)
return m.RollbackN(len(migrations))
}

View file

@ -0,0 +1,163 @@
package gomigrate
import (
"database/sql"
"fmt"
"log"
"os"
"testing"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
var (
db *sql.DB
adapter Migratable
)
func GetMigrator(test string) *Migrator {
var suffix string
if os.Getenv("DB") == "pg" {
suffix = "pg"
} else {
suffix = "mysql"
}
path := fmt.Sprintf("test_migrations/%s_%s", test, suffix)
m, err := NewMigrator(db, adapter, path)
if err != nil {
panic(err)
}
return m
}
func TestNewMigrator(t *testing.T) {
m := GetMigrator("test1")
if len(m.migrations) != 1 {
t.Errorf("Invalid number of migrations detected")
}
migration := m.migrations[1]
if migration.Name != "test" {
t.Errorf("Invalid migration name detected: %s", migration.Name)
}
if migration.Id != 1 {
t.Errorf("Invalid migration num detected: %s", migration.Id)
}
if migration.Status != Inactive {
t.Errorf("Invalid migration num detected: %s", migration.Status)
}
cleanup()
}
func TestCreatingMigratorWhenTableExists(t *testing.T) {
// Create the table and populate it with a row.
_, err := db.Exec(adapter.CreateMigrationTableSql())
if err != nil {
t.Error(err)
}
_, err = db.Exec(adapter.MigrationLogInsertSql(), 123)
if err != nil {
t.Error(err)
}
GetMigrator("test1")
// Check that our row is still present.
row := db.QueryRow("select migration_id from gomigrate")
var id uint64
err = row.Scan(&id)
if err != nil {
t.Error(err)
}
if id != 123 {
t.Error("Invalid id found in database")
}
cleanup()
}
func TestMigrationAndRollback(t *testing.T) {
m := GetMigrator("test1")
if err := m.Migrate(); err != nil {
t.Error(err)
}
// Ensure that the migration ran.
row := db.QueryRow(
adapter.SelectMigrationTableSql(),
"test",
)
var tableName string
if err := row.Scan(&tableName); err != nil {
t.Error(err)
}
if tableName != "test" {
t.Errorf("Migration table not created")
}
// Ensure that the migrate status is correct.
row = db.QueryRow(
adapter.GetMigrationSql(),
1,
)
var status int
if err := row.Scan(&status); err != nil {
t.Error(err)
}
if status != Active || m.migrations[1].Status != Active {
t.Error("Invalid status for migration")
}
if err := m.Rollback(); err != nil {
t.Error(err)
}
// Ensure that the down migration ran.
row = db.QueryRow(
adapter.SelectMigrationTableSql(),
"test",
)
err := row.Scan(&tableName)
if err != sql.ErrNoRows {
t.Errorf("Migration table should be deleted")
}
// Ensure that the migration log is missing.
row = db.QueryRow(
adapter.GetMigrationSql(),
1,
)
if err := row.Scan(&status); err != sql.ErrNoRows {
t.Error(err)
}
if m.migrations[1].Status != Inactive {
t.Errorf("Invalid status for migration, %v", m.migrations[1].Status)
}
cleanup()
}
func cleanup() {
_, err := db.Exec("drop table gomigrate")
if err != nil {
panic(err)
}
}
func init() {
var err error
if os.Getenv("DB") == "pg" {
log.Print("Using postgres")
adapter = Postgres{}
db, err = sql.Open("postgres", "host=localhost dbname=gomigrate sslmode=disable")
if err != nil {
panic(err)
}
} else {
log.Print("Using mysql")
adapter = Mariadb{}
db, err = sql.Open("mysql", "gomigrate:password@/gomigrate")
}
}

View file

@ -0,0 +1,26 @@
// Holds metadata about a migration.
package gomigrate
// Migration statuses.
const (
Inactive = iota
Active
)
// Holds configuration information for a given migration.
type Migration struct {
DownPath string
Id uint64
Name string
Status int
UpPath string
}
// Performs a basic validation of a migration.
func (m *Migration) valid() bool {
if m.Id != 0 && m.Name != "" && m.UpPath != "" && m.DownPath != "" {
return true
}
return false
}

View file

@ -0,0 +1 @@
DROP TABLE test;

View file

@ -0,0 +1,4 @@
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
PRIMARY KEY (id)
) ENGINE=MyISAM;

View file

@ -0,0 +1 @@
drop table test;

View file

@ -0,0 +1 @@
create table test();

View file

@ -0,0 +1,63 @@
package gomigrate
import (
"path/filepath"
"regexp"
"strconv"
)
var (
upMigrationFile = regexp.MustCompile(`(\d+)_(\w+)_up\.sql`)
downMigrationFile = regexp.MustCompile(`(\d+)_(\w+)_down\.sql`)
subMigrationSplit = regexp.MustCompile(`;\s*`)
allWhitespace = regexp.MustCompile(`^\s*$`)
)
// Returns the migration number, type and base name, so 1, "up", "migration" from "01_migration_up.sql"
func parseMigrationPath(path string) (uint64, migrationType, string, error) {
filebase := filepath.Base(path)
matches := upMigrationFile.FindAllSubmatch([]byte(filebase), -1)
if matches != nil {
return parseMatches(matches, upMigration)
}
matches = downMigrationFile.FindAllSubmatch([]byte(filebase), -1)
if matches != nil {
return parseMatches(matches, downMigration)
}
return 0, "", "", InvalidMigrationFile
}
// Parses matches given by a migration file regex.
func parseMatches(matches [][][]byte, mType migrationType) (uint64, migrationType, string, error) {
num := matches[0][1]
name := matches[0][2]
parsedNum, err := strconv.ParseUint(string(num), 10, 64)
if err != nil {
return 0, "", "", err
}
return parsedNum, mType, string(name), nil
}
// Splits migration sql into different strings separated by a semi-colon.
func splitMigrationString(sql string) []string {
return subMigrationSplit.Split(sql, -1)
}
// This type is used to sort migration ids.
type uint64slice []uint64
func (u uint64slice) Len() int {
return len(u)
}
func (u uint64slice) Less(a, b int) bool {
return u[a] < u[b]
}
func (u uint64slice) Swap(a, b int) {
tempA := u[a]
u[a] = u[b]
u[b] = tempA
}

View file

@ -0,0 +1,6 @@
language: go
go: 1.1
script:
- go vet ./...
- go test -v ./...

View file

@ -0,0 +1,21 @@
Copyright (C) 2013 Jeremy Saenz
All Rights Reserved.
MIT LICENSE
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,298 @@
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
# cli.go
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
You can view the API docs here:
http://godoc.org/github.com/codegangsta/cli
## Overview
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
## Installation
Make sure you have a working Go environment (go 1.1 is *required*). [See the install instructions](http://golang.org/doc/install.html).
To install `cli.go`, simply run:
```
$ go get github.com/codegangsta/cli
```
Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used:
```
export PATH=$PATH:$GOPATH/bin
```
## Getting Started
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
cli.NewApp().Run(os.Args)
}
```
This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "boom"
app.Usage = "make an explosive entrance"
app.Action = func(c *cli.Context) {
println("boom! I say!")
}
app.Run(os.Args)
}
```
Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below.
## Example
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "greet"
app.Usage = "fight the loneliness!"
app.Action = func(c *cli.Context) {
println("Hello friend!")
}
app.Run(os.Args)
}
```
Install our command to the `$GOPATH/bin` directory:
```
$ go install
```
Finally run our new command:
```
$ greet
Hello friend!
```
cli.go also generates some bitchass help text:
```
$ greet help
NAME:
greet - fight the loneliness!
USAGE:
greet [global options] command [command options] [arguments...]
VERSION:
0.0.0
COMMANDS:
help, h Shows a list of commands or help for one command
GLOBAL OPTIONS
--version Shows version information
```
### Arguments
You can lookup arguments by calling the `Args` function on `cli.Context`.
``` go
...
app.Action = func(c *cli.Context) {
println("Hello", c.Args()[0])
}
...
```
### Flags
Setting and querying flags is simple.
``` go
...
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang",
Value: "english",
Usage: "language for the greeting",
},
}
app.Action = func(c *cli.Context) {
name := "someone"
if len(c.Args()) > 0 {
name = c.Args()[0]
}
if c.String("lang") == "spanish" {
println("Hola", name)
} else {
println("Hello", name)
}
}
...
```
#### Alternate Names
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
},
}
```
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
#### Values from the Environment
You can also have the default value set from the environment via `EnvVar`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "APP_LANG",
},
}
```
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG",
},
}
```
### Subcommands
Subcommands can be defined for a more git-like command line app.
```go
...
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
{
Name: "template",
ShortName: "r",
Usage: "options for task templates",
Subcommands: []cli.Command{
{
Name: "add",
Usage: "add a new template",
Action: func(c *cli.Context) {
println("new task template: ", c.Args().First())
},
},
{
Name: "remove",
Usage: "remove an existing template",
Action: func(c *cli.Context) {
println("removed task template: ", c.Args().First())
},
},
},
},
}
...
```
### Bash Completion
You can enable completion commands by setting the `EnableBashCompletion`
flag on the `App` object. By default, this setting will only auto-complete to
show an app's subcommands, but you can write your own completion methods for
the App or its subcommands.
```go
...
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"}
app := cli.NewApp()
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
BashComplete: func(c *cli.Context) {
// This will complete if no args are passed
if len(c.Args()) > 0 {
return
}
for _, t := range tasks {
fmt.Println(t)
}
},
}
}
...
```
#### To Enable
Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while
setting the `PROG` variable to the name of your program:
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
## Contribution Guidelines
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
If you are have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.

279
Godeps/_workspace/src/github.com/codegangsta/cli/app.go generated vendored Normal file
View file

@ -0,0 +1,279 @@
package cli
import (
"fmt"
"io"
"io/ioutil"
"os"
"text/tabwriter"
"text/template"
"time"
)
// App is the main structure of a cli application. It is recomended that
// and app be created with the cli.NewApp() function
type App struct {
// The name of the program. Defaults to os.Args[0]
Name string
// Description of the program.
Usage string
// Version of the program
Version string
// List of commands to execute
Commands []Command
// List of flags to parse
Flags []Flag
// Boolean to enable bash completion commands
EnableBashCompletion bool
// Boolean to hide built-in help command
HideHelp bool
// Boolean to hide built-in version flag
HideVersion bool
// An action to execute when the bash-completion flag is set
BashComplete func(context *Context)
// An action to execute before any subcommands are run, but after the context is ready
// If a non-nil error is returned, no subcommands are run
Before func(context *Context) error
// The action to execute when no subcommands are specified
Action func(context *Context)
// Execute this function if the proper command cannot be found
CommandNotFound func(context *Context, command string)
// Compilation date
Compiled time.Time
// Author
Author string
// Author e-mail
Email string
// Writer writer to write output to
Writer io.Writer
}
// Tries to find out when this binary was compiled.
// Returns the current time if it fails to find it.
func compileTime() time.Time {
info, err := os.Stat(os.Args[0])
if err != nil {
return time.Now()
}
return info.ModTime()
}
// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action.
func NewApp() *App {
return &App{
Name: os.Args[0],
Usage: "A new cli application",
Version: "0.0.0",
BashComplete: DefaultAppComplete,
Action: helpCommand.Action,
Compiled: compileTime(),
Author: "Author",
Email: "unknown@email",
Writer: os.Stdout,
}
}
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
func (a *App) Run(arguments []string) error {
if HelpPrinter == nil {
defer func() {
HelpPrinter = nil
}()
HelpPrinter = func(templ string, data interface{}) {
w := tabwriter.NewWriter(a.Writer, 0, 8, 1, '\t', 0)
t := template.Must(template.New("help").Parse(templ))
err := t.Execute(w, data)
if err != nil {
panic(err)
}
w.Flush()
}
}
// append help to commands
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
//append version/help flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
if !a.HideVersion {
a.appendFlag(VersionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set)
if nerr != nil {
fmt.Fprintln(a.Writer, nerr)
context := NewContext(a, set, set)
ShowAppHelp(context)
fmt.Fprintln(a.Writer)
return nerr
}
context := NewContext(a, set, set)
if err != nil {
fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n")
ShowAppHelp(context)
fmt.Fprintln(a.Writer)
return err
}
if checkCompletions(context) {
return nil
}
if checkHelp(context) {
return nil
}
if checkVersion(context) {
return nil
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
a.Action(context)
return nil
}
// Another entry point to the cli app, takes care of passing arguments and error handling
func (a *App) RunAndExitOnError() {
if err := a.Run(os.Args); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
func (a *App) RunAsSubcommand(ctx *Context) error {
// append help to commands
if len(a.Commands) > 0 {
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
}
// append flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, ctx.globalSet)
if nerr != nil {
fmt.Fprintln(a.Writer, nerr)
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
} else {
ShowCommandHelp(ctx, context.Args().First())
}
fmt.Fprintln(a.Writer)
return nerr
}
if err != nil {
fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n")
ShowSubcommandHelp(context)
return err
}
if checkCompletions(context) {
return nil
}
if len(a.Commands) > 0 {
if checkSubcommandHelp(context) {
return nil
}
} else {
if checkCommandHelp(ctx, context.Args().First()) {
return nil
}
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
if len(a.Commands) > 0 {
a.Action(context)
} else {
a.Action(ctx)
}
return nil
}
// Returns the named command on App. Returns nil if the command does not exist
func (a *App) Command(name string) *Command {
for _, c := range a.Commands {
if c.HasName(name) {
return &c
}
}
return nil
}
func (a *App) hasFlag(flag Flag) bool {
for _, f := range a.Flags {
if flag == f {
return true
}
}
return false
}
func (a *App) appendFlag(flag Flag) {
if !a.hasFlag(flag) {
a.Flags = append(a.Flags, flag)
}
}

View file

@ -0,0 +1,528 @@
package cli_test
import (
"flag"
"fmt"
"os"
"testing"
"github.com/codegangsta/cli"
)
func ExampleApp() {
// set args for examples sake
os.Args = []string{"greet", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Action = func(c *cli.Context) {
fmt.Printf("Hello %v\n", c.String("name"))
}
app.Run(os.Args)
// Output:
// Hello Jeremy
}
func ExampleAppSubcommand() {
// set args for examples sake
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
fmt.Println("Hello,", c.String("name"))
},
},
},
},
}
app.Run(os.Args)
// Output:
// Hello, Jeremy
}
func ExampleAppHelp() {
// set args for examples sake
os.Args = []string{"greet", "h", "describeit"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
},
}
app.Run(os.Args)
// Output:
// NAME:
// describeit - use it to see a description
//
// USAGE:
// command describeit [arguments...]
//
// DESCRIPTION:
// This is how we describe describeit the function
}
func ExampleAppBashComplete() {
// set args for examples sake
os.Args = []string{"greet", "--generate-bash-completion"}
app := cli.NewApp()
app.Name = "greet"
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
}, {
Name: "next",
Usage: "next example",
Description: "more stuff to see when generating bash completion",
Action: func(c *cli.Context) {
fmt.Printf("the next example")
},
},
}
app.Run(os.Args)
// Output:
// describeit
// d
// next
// help
// h
}
func TestApp_Run(t *testing.T) {
s := ""
app := cli.NewApp()
app.Action = func(c *cli.Context) {
s = s + c.Args().First()
}
err := app.Run([]string{"command", "foo"})
expect(t, err, nil)
err = app.Run([]string{"command", "bar"})
expect(t, err, nil)
expect(t, s, "foobar")
}
var commandAppTests = []struct {
name string
expected bool
}{
{"foobar", true},
{"batbaz", true},
{"b", true},
{"f", true},
{"bat", false},
{"nothing", false},
}
func TestApp_Command(t *testing.T) {
app := cli.NewApp()
fooCommand := cli.Command{Name: "foobar", ShortName: "f"}
batCommand := cli.Command{Name: "batbaz", ShortName: "b"}
app.Commands = []cli.Command{
fooCommand,
batCommand,
}
for _, test := range commandAppTests {
expect(t, app.Command(test.name) != nil, test.expected)
}
}
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
var parsedOption, firstArg string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *cli.Context) {
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
expect(t, parsedOption, "my-option")
expect(t, firstArg, "my-arg")
}
func TestApp_CommandWithFlagBeforeTerminator(t *testing.T) {
var parsedOption string
var args []string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *cli.Context) {
parsedOption = c.String("option")
args = c.Args()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option", "--", "--notARealFlag"})
expect(t, parsedOption, "my-option")
expect(t, args[0], "my-arg")
expect(t, args[1], "--")
expect(t, args[2], "--notARealFlag")
}
func TestApp_CommandWithNoFlagBeforeTerminator(t *testing.T) {
var args []string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Action: func(c *cli.Context) {
args = c.Args()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "--", "notAFlagAtAll"})
expect(t, args[0], "my-arg")
expect(t, args[1], "--")
expect(t, args[2], "notAFlagAtAll")
}
func TestApp_Float64Flag(t *testing.T) {
var meters float64
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
}
app.Action = func(c *cli.Context) {
meters = c.Float64("height")
}
app.Run([]string{"", "--height", "1.93"})
expect(t, meters, 1.93)
}
func TestApp_ParseSliceFlags(t *testing.T) {
var parsedOption, firstArg string
var parsedIntSlice []int
var parsedStringSlice []string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "p", Value: &cli.IntSlice{}, Usage: "set one or more ip addr"},
cli.StringSliceFlag{Name: "ip", Value: &cli.StringSlice{}, Usage: "set one or more ports to open"},
},
Action: func(c *cli.Context) {
parsedIntSlice = c.IntSlice("p")
parsedStringSlice = c.StringSlice("ip")
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
IntsEquals := func(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
StrsEquals := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
var expectedIntSlice = []int{22, 80}
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
}
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
}
}
func TestApp_DefaultStdout(t *testing.T) {
app := cli.NewApp()
if app.Writer != os.Stdout {
t.Error("Default output writer not set.")
}
}
type mockWriter struct {
written []byte
}
func (fw *mockWriter) Write(p []byte) (n int, err error) {
if fw.written == nil {
fw.written = p
} else {
fw.written = append(fw.written, p...)
}
return len(p), nil
}
func (fw *mockWriter) GetWritten() (b []byte) {
return fw.written
}
func TestApp_SetStdout(t *testing.T) {
w := &mockWriter{}
app := cli.NewApp()
app.Name = "test"
app.Writer = w
err := app.Run([]string{"help"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if len(w.written) == 0 {
t.Error("App did not write output to desired writer.")
}
}
func TestApp_BeforeFunc(t *testing.T) {
beforeRun, subcommandRun := false, false
beforeError := fmt.Errorf("fail")
var err error
app := cli.NewApp()
app.Before = func(c *cli.Context) error {
beforeRun = true
s := c.String("opt")
if s == "fail" {
return beforeError
}
return nil
}
app.Commands = []cli.Command{
cli.Command{
Name: "sub",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Flags = []cli.Flag{
cli.StringFlag{Name: "opt"},
}
// run with the Before() func succeeding
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
// reset
beforeRun, subcommandRun = false, false
// run with the Before() func failing
err = app.Run([]string{"command", "--opt", "fail", "sub"})
// should be the same error produced by the Before func
if err != beforeError {
t.Errorf("Run error expected, but not received")
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == true {
t.Errorf("Subcommand executed when NOT expected")
}
}
func TestAppNoHelpFlag(t *testing.T) {
oldFlag := cli.HelpFlag
defer func() {
cli.HelpFlag = oldFlag
}()
cli.HelpFlag = cli.BoolFlag{}
app := cli.NewApp()
err := app.Run([]string{"test", "-h"})
if err != flag.ErrHelp {
t.Errorf("expected error about missing help flag, but got: %s (%T)", err, err)
}
}
func TestAppHelpPrinter(t *testing.T) {
oldPrinter := cli.HelpPrinter
defer func() {
cli.HelpPrinter = oldPrinter
}()
var wasCalled = false
cli.HelpPrinter = func(template string, data interface{}) {
wasCalled = true
}
app := cli.NewApp()
app.Run([]string{"-h"})
if wasCalled == false {
t.Errorf("Help printer expected to be called, but was not")
}
}
func TestAppVersionPrinter(t *testing.T) {
oldPrinter := cli.VersionPrinter
defer func() {
cli.VersionPrinter = oldPrinter
}()
var wasCalled = false
cli.VersionPrinter = func(c *cli.Context) {
wasCalled = true
}
app := cli.NewApp()
ctx := cli.NewContext(app, nil, nil)
cli.ShowVersion(ctx)
if wasCalled == false {
t.Errorf("Version printer expected to be called, but was not")
}
}
func TestAppCommandNotFound(t *testing.T) {
beforeRun, subcommandRun := false, false
app := cli.NewApp()
app.CommandNotFound = func(c *cli.Context, command string) {
beforeRun = true
}
app.Commands = []cli.Command{
cli.Command{
Name: "bar",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Run([]string{"command", "foo"})
expect(t, beforeRun, true)
expect(t, subcommandRun, false)
}
func TestGlobalFlagsInSubcommands(t *testing.T) {
subcommandRun := false
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
}
app.Commands = []cli.Command{
cli.Command{
Name: "foo",
Subcommands: []cli.Command{
{
Name: "bar",
Action: func(c *cli.Context) {
if c.GlobalBool("debug") {
subcommandRun = true
}
},
},
},
},
}
app.Run([]string{"command", "-d", "foo", "bar"})
expect(t, subcommandRun, true)
}

View file

@ -0,0 +1,13 @@
#! /bin/bash
_cli_bash_autocomplete() {
local cur prev opts base
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
return 0
}
complete -F _cli_bash_autocomplete $PROG

View file

@ -0,0 +1,5 @@
autoload -U compinit && compinit
autoload -U bashcompinit && bashcompinit
script_dir=$(dirname $0)
source ${script_dir}/bash_autocomplete

View file

@ -0,0 +1,19 @@
// Package cli provides a minimal framework for creating and organizing command line
// Go applications. cli is designed to be easy to understand and write, the most simple
// cli application can be written as follows:
// func main() {
// cli.NewApp().Run(os.Args)
// }
//
// Of course this application does not do much, so let's make this an actual application:
// func main() {
// app := cli.NewApp()
// app.Name = "greet"
// app.Usage = "say a greeting"
// app.Action = func(c *cli.Context) {
// println("Greetings")
// }
//
// app.Run(os.Args)
// }
package cli

View file

@ -0,0 +1,100 @@
package cli_test
import (
"os"
"github.com/codegangsta/cli"
)
func Example() {
app := cli.NewApp()
app.Name = "todo"
app.Usage = "task list on the command line"
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
}
app.Run(os.Args)
}
func ExampleSubcommand() {
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hello, ", c.String("name"))
},
}, {
Name: "spanish",
ShortName: "sp",
Usage: "sends a greeting in spanish",
Flags: []cli.Flag{
cli.StringFlag{
Name: "surname",
Value: "Jones",
Usage: "Surname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hola, ", c.String("surname"))
},
}, {
Name: "french",
ShortName: "fr",
Usage: "sends a greeting in french",
Flags: []cli.Flag{
cli.StringFlag{
Name: "nickname",
Value: "Stevie",
Usage: "Nickname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Bonjour, ", c.String("nickname"))
},
},
},
}, {
Name: "bye",
Usage: "says goodbye",
Action: func(c *cli.Context) {
println("bye")
},
},
}
app.Run(os.Args)
}

View file

@ -0,0 +1,156 @@
package cli
import (
"fmt"
"io/ioutil"
"strings"
)
// Command is a subcommand for a cli.App.
type Command struct {
// The name of the command
Name string
// short name of the command. Typically one character
ShortName string
// A short description of the usage of this command
Usage string
// A longer explanation of how the command works
Description string
// The function to call when checking for bash command completions
BashComplete func(context *Context)
// An action to execute before any sub-subcommands are run, but after the context is ready
// If a non-nil error is returned, no sub-subcommands are run
Before func(context *Context) error
// The function to call when this command is invoked
Action func(context *Context)
// List of child commands
Subcommands []Command
// List of flags to parse
Flags []Flag
// Treat all flags as normal arguments if true
SkipFlagParsing bool
// Boolean to hide built-in help command
HideHelp bool
}
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
func (c Command) Run(ctx *Context) error {
if len(c.Subcommands) > 0 || c.Before != nil {
return c.startApp(ctx)
}
if !c.HideHelp && (HelpFlag != BoolFlag{}) {
// append help to flags
c.Flags = append(
c.Flags,
HelpFlag,
)
}
if ctx.App.EnableBashCompletion {
c.Flags = append(c.Flags, BashCompletionFlag)
}
set := flagSet(c.Name, c.Flags)
set.SetOutput(ioutil.Discard)
firstFlagIndex := -1
terminatorIndex := -1
for index, arg := range ctx.Args() {
if arg == "--" {
terminatorIndex = index
break
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
firstFlagIndex = index
}
}
var err error
if firstFlagIndex > -1 && !c.SkipFlagParsing {
args := ctx.Args()
regularArgs := make([]string, len(args[1:firstFlagIndex]))
copy(regularArgs, args[1:firstFlagIndex])
var flagArgs []string
if terminatorIndex > -1 {
flagArgs = args[firstFlagIndex:terminatorIndex]
regularArgs = append(regularArgs, args[terminatorIndex:]...)
} else {
flagArgs = args[firstFlagIndex:]
}
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
if err != nil {
fmt.Fprint(ctx.App.Writer, "Incorrect Usage.\n\n")
ShowCommandHelp(ctx, c.Name)
fmt.Fprintln(ctx.App.Writer)
return err
}
nerr := normalizeFlags(c.Flags, set)
if nerr != nil {
fmt.Fprintln(ctx.App.Writer, nerr)
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
fmt.Fprintln(ctx.App.Writer)
return nerr
}
context := NewContext(ctx.App, set, ctx.globalSet)
if checkCommandCompletions(context, c.Name) {
return nil
}
if checkCommandHelp(context, c.Name) {
return nil
}
context.Command = c
c.Action(context)
return nil
}
// Returns true if Command.Name or Command.ShortName matches given name
func (c Command) HasName(name string) bool {
return c.Name == name || c.ShortName == name
}
func (c Command) startApp(ctx *Context) error {
app := NewApp()
// set the name and usage
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
if c.Description != "" {
app.Usage = c.Description
} else {
app.Usage = c.Usage
}
// set CommandNotFound
app.CommandNotFound = ctx.App.CommandNotFound
// set the flags and commands
app.Commands = c.Subcommands
app.Flags = c.Flags
app.HideHelp = c.HideHelp
// bash completion
app.EnableBashCompletion = ctx.App.EnableBashCompletion
if c.BashComplete != nil {
app.BashComplete = c.BashComplete
}
// set the actions
app.Before = c.Before
if c.Action != nil {
app.Action = c.Action
} else {
app.Action = helpSubcommand.Action
}
return app.RunAsSubcommand(ctx)
}

View file

@ -0,0 +1,49 @@
package cli_test
import (
"flag"
"testing"
"github.com/codegangsta/cli"
)
func TestCommandDoNotIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah", "-break"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
}
err := command.Run(c)
expect(t, err.Error(), "flag provided but not defined: -break")
}
func TestCommandIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
SkipFlagParsing: true,
}
err := command.Run(c)
expect(t, err, nil)
}

View file

@ -0,0 +1,339 @@
package cli
import (
"errors"
"flag"
"strconv"
"strings"
"time"
)
// Context is a type that is passed through to
// each Handler action in a cli application. Context
// can be used to retrieve context-specific Args and
// parsed command-line options.
type Context struct {
App *App
Command Command
flagSet *flag.FlagSet
globalSet *flag.FlagSet
setFlags map[string]bool
globalSetFlags map[string]bool
}
// Creates a new context. For use in when invoking an App or Command action.
func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context {
return &Context{App: app, flagSet: set, globalSet: globalSet}
}
// Looks up the value of a local int flag, returns 0 if no int flag exists
func (c *Context) Int(name string) int {
return lookupInt(name, c.flagSet)
}
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) Duration(name string) time.Duration {
return lookupDuration(name, c.flagSet)
}
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
func (c *Context) Float64(name string) float64 {
return lookupFloat64(name, c.flagSet)
}
// Looks up the value of a local bool flag, returns false if no bool flag exists
func (c *Context) Bool(name string) bool {
return lookupBool(name, c.flagSet)
}
// Looks up the value of a local boolT flag, returns false if no bool flag exists
func (c *Context) BoolT(name string) bool {
return lookupBoolT(name, c.flagSet)
}
// Looks up the value of a local string flag, returns "" if no string flag exists
func (c *Context) String(name string) string {
return lookupString(name, c.flagSet)
}
// Looks up the value of a local string slice flag, returns nil if no string slice flag exists
func (c *Context) StringSlice(name string) []string {
return lookupStringSlice(name, c.flagSet)
}
// Looks up the value of a local int slice flag, returns nil if no int slice flag exists
func (c *Context) IntSlice(name string) []int {
return lookupIntSlice(name, c.flagSet)
}
// Looks up the value of a local generic flag, returns nil if no generic flag exists
func (c *Context) Generic(name string) interface{} {
return lookupGeneric(name, c.flagSet)
}
// Looks up the value of a global int flag, returns 0 if no int flag exists
func (c *Context) GlobalInt(name string) int {
return lookupInt(name, c.globalSet)
}
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) GlobalDuration(name string) time.Duration {
return lookupDuration(name, c.globalSet)
}
// Looks up the value of a global bool flag, returns false if no bool flag exists
func (c *Context) GlobalBool(name string) bool {
return lookupBool(name, c.globalSet)
}
// Looks up the value of a global string flag, returns "" if no string flag exists
func (c *Context) GlobalString(name string) string {
return lookupString(name, c.globalSet)
}
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
func (c *Context) GlobalStringSlice(name string) []string {
return lookupStringSlice(name, c.globalSet)
}
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
func (c *Context) GlobalIntSlice(name string) []int {
return lookupIntSlice(name, c.globalSet)
}
// Looks up the value of a global generic flag, returns nil if no generic flag exists
func (c *Context) GlobalGeneric(name string) interface{} {
return lookupGeneric(name, c.globalSet)
}
// Determines if the flag was actually set
func (c *Context) IsSet(name string) bool {
if c.setFlags == nil {
c.setFlags = make(map[string]bool)
c.flagSet.Visit(func(f *flag.Flag) {
c.setFlags[f.Name] = true
})
}
return c.setFlags[name] == true
}
// Determines if the global flag was actually set
func (c *Context) GlobalIsSet(name string) bool {
if c.globalSetFlags == nil {
c.globalSetFlags = make(map[string]bool)
c.globalSet.Visit(func(f *flag.Flag) {
c.globalSetFlags[f.Name] = true
})
}
return c.globalSetFlags[name] == true
}
// Returns a slice of flag names used in this context.
func (c *Context) FlagNames() (names []string) {
for _, flag := range c.Command.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" {
continue
}
names = append(names, name)
}
return
}
// Returns a slice of global flag names used by the app.
func (c *Context) GlobalFlagNames() (names []string) {
for _, flag := range c.App.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" || name == "version" {
continue
}
names = append(names, name)
}
return
}
type Args []string
// Returns the command line arguments associated with the context.
func (c *Context) Args() Args {
args := Args(c.flagSet.Args())
return args
}
// Returns the nth argument, or else a blank string
func (a Args) Get(n int) string {
if len(a) > n {
return a[n]
}
return ""
}
// Returns the first argument, or else a blank string
func (a Args) First() string {
return a.Get(0)
}
// Return the rest of the arguments (not the first one)
// or else an empty string slice
func (a Args) Tail() []string {
if len(a) >= 2 {
return []string(a)[1:]
}
return []string{}
}
// Checks if there are any arguments present
func (a Args) Present() bool {
return len(a) != 0
}
// Swaps arguments at the given indexes
func (a Args) Swap(from, to int) error {
if from >= len(a) || to >= len(a) {
return errors.New("index out of range")
}
a[from], a[to] = a[to], a[from]
return nil
}
func lookupInt(name string, set *flag.FlagSet) int {
f := set.Lookup(name)
if f != nil {
val, err := strconv.Atoi(f.Value.String())
if err != nil {
return 0
}
return val
}
return 0
}
func lookupDuration(name string, set *flag.FlagSet) time.Duration {
f := set.Lookup(name)
if f != nil {
val, err := time.ParseDuration(f.Value.String())
if err == nil {
return val
}
}
return 0
}
func lookupFloat64(name string, set *flag.FlagSet) float64 {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseFloat(f.Value.String(), 64)
if err != nil {
return 0
}
return val
}
return 0
}
func lookupString(name string, set *flag.FlagSet) string {
f := set.Lookup(name)
if f != nil {
return f.Value.String()
}
return ""
}
func lookupStringSlice(name string, set *flag.FlagSet) []string {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*StringSlice)).Value()
}
return nil
}
func lookupIntSlice(name string, set *flag.FlagSet) []int {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*IntSlice)).Value()
}
return nil
}
func lookupGeneric(name string, set *flag.FlagSet) interface{} {
f := set.Lookup(name)
if f != nil {
return f.Value
}
return nil
}
func lookupBool(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return false
}
return val
}
return false
}
func lookupBoolT(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return true
}
return val
}
return false
}
func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
switch ff.Value.(type) {
case *StringSlice:
default:
set.Set(name, ff.Value.String())
}
}
func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
parts := strings.Split(f.getName(), ",")
if len(parts) == 1 {
continue
}
var ff *flag.Flag
for _, name := range parts {
name = strings.Trim(name, " ")
if visited[name] {
if ff != nil {
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
}
ff = set.Lookup(name)
}
}
if ff == nil {
continue
}
for _, name := range parts {
name = strings.Trim(name, " ")
if !visited[name] {
copyFlag(name, ff, set)
}
}
}
return nil
}

View file

@ -0,0 +1,99 @@
package cli_test
import (
"flag"
"testing"
"time"
"github.com/codegangsta/cli"
)
func TestNewContext(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Int("myflag", 42, "doc")
command := cli.Command{Name: "mycommand"}
c := cli.NewContext(nil, set, globalSet)
c.Command = command
expect(t, c.Int("myflag"), 12)
expect(t, c.GlobalInt("myflag"), 42)
expect(t, c.Command.Name, "mycommand")
}
func TestContext_Int(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Int("myflag"), 12)
}
func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
}
func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.String("myflag"), "hello world")
}
func TestContext_Bool(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Bool("myflag"), false)
}
func TestContext_BoolT(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", true, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.BoolT("myflag"), true)
}
func TestContext_Args(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
set.Parse([]string{"--myflag", "bat", "baz"})
expect(t, len(c.Args()), 2)
expect(t, c.Bool("myflag"), true)
}
func TestContext_IsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
c := cli.NewContext(nil, set, globalSet)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.IsSet("myflag"), true)
expect(t, c.IsSet("otherflag"), false)
expect(t, c.IsSet("bogusflag"), false)
expect(t, c.IsSet("myflagGlobal"), false)
}
func TestContext_GlobalIsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
globalSet.Bool("myflagGlobalUnset", true, "doc")
c := cli.NewContext(nil, set, globalSet)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.GlobalIsSet("myflag"), false)
expect(t, c.GlobalIsSet("otherflag"), false)
expect(t, c.GlobalIsSet("bogusflag"), false)
expect(t, c.GlobalIsSet("myflagGlobal"), true)
expect(t, c.GlobalIsSet("myflagGlobalUnset"), false)
expect(t, c.GlobalIsSet("bogusGlobal"), false)
}

View file

@ -0,0 +1,454 @@
package cli
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
"time"
)
// This flag enables bash-completion for all commands and subcommands
var BashCompletionFlag = BoolFlag{
Name: "generate-bash-completion",
}
// This flag prints the version for the application
var VersionFlag = BoolFlag{
Name: "version, v",
Usage: "print the version",
}
// This flag prints the help for all commands and subcommands
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
// unless HideHelp is set to true)
var HelpFlag = BoolFlag{
Name: "help, h",
Usage: "show help",
}
// Flag is a common interface related to parsing flags in cli.
// For more advanced flag parsing techniques, it is recomended that
// this interface be implemented.
type Flag interface {
fmt.Stringer
// Apply Flag settings to the given flag set
Apply(*flag.FlagSet)
getName() string
}
func flagSet(name string, flags []Flag) *flag.FlagSet {
set := flag.NewFlagSet(name, flag.ContinueOnError)
for _, f := range flags {
f.Apply(set)
}
return set
}
func eachName(longName string, fn func(string)) {
parts := strings.Split(longName, ",")
for _, name := range parts {
name = strings.Trim(name, " ")
fn(name)
}
}
// Generic is a generic parseable type identified by a specific flag
type Generic interface {
Set(value string) error
String() string
}
// GenericFlag is the flag type for types implementing Generic
type GenericFlag struct {
Name string
Value Generic
Usage string
EnvVar string
}
// String returns the string representation of the generic flag to display the
// help text to the user (uses the String() method of the generic flag to show
// the value)
func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage))
}
// Apply takes the flagset and calls Set on the generic flag with the value
// provided by the user for parsing by the flag
func (f GenericFlag) Apply(set *flag.FlagSet) {
val := f.Value
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
val.Set(envVal)
break
}
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f GenericFlag) getName() string {
return f.Name
}
type StringSlice []string
func (f *StringSlice) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *StringSlice) String() string {
return fmt.Sprintf("%s", *f)
}
func (f *StringSlice) Value() []string {
return *f
}
type StringSliceFlag struct {
Name string
Value *StringSlice
Usage string
EnvVar string
}
func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &StringSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
newVal.Set(s)
}
f.Value = newVal
break
}
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f StringSliceFlag) getName() string {
return f.Name
}
type IntSlice []int
func (f *IntSlice) Set(value string) error {
tmp, err := strconv.Atoi(value)
if err != nil {
return err
} else {
*f = append(*f, tmp)
}
return nil
}
func (f *IntSlice) String() string {
return fmt.Sprintf("%d", *f)
}
func (f *IntSlice) Value() []int {
return *f
}
type IntSliceFlag struct {
Name string
Value *IntSlice
Usage string
EnvVar string
}
func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &IntSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
err := newVal.Set(s)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
}
f.Value = newVal
break
}
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f IntSliceFlag) getName() string {
return f.Name
}
type BoolFlag struct {
Name string
Usage string
EnvVar string
}
func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
func (f BoolFlag) Apply(set *flag.FlagSet) {
val := false
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
}
break
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolFlag) getName() string {
return f.Name
}
type BoolTFlag struct {
Name string
Usage string
EnvVar string
}
func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
func (f BoolTFlag) Apply(set *flag.FlagSet) {
val := true
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolTFlag) getName() string {
return f.Name
}
type StringFlag struct {
Name string
Value string
Usage string
EnvVar string
}
func (f StringFlag) String() string {
var fmtString string
fmtString = "%s %v\t%v"
if len(f.Value) > 0 {
fmtString = "%s \"%v\"\t%v"
} else {
fmtString = "%s %v\t%v"
}
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
}
func (f StringFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
f.Value = envVal
break
}
}
}
eachName(f.Name, func(name string) {
set.String(name, f.Value, f.Usage)
})
}
func (f StringFlag) getName() string {
return f.Name
}
type IntFlag struct {
Name string
Value int
Usage string
EnvVar string
}
func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f IntFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValInt, err := strconv.ParseInt(envVal, 0, 64)
if err == nil {
f.Value = int(envValInt)
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Int(name, f.Value, f.Usage)
})
}
func (f IntFlag) getName() string {
return f.Name
}
type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
}
func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f DurationFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Duration(name, f.Value, f.Usage)
})
}
func (f DurationFlag) getName() string {
return f.Name
}
type Float64Flag struct {
Name string
Value float64
Usage string
EnvVar string
}
func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
func (f Float64Flag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValFloat, err := strconv.ParseFloat(envVal, 10)
if err == nil {
f.Value = float64(envValFloat)
}
}
}
}
eachName(f.Name, func(name string) {
set.Float64(name, f.Value, f.Usage)
})
}
func (f Float64Flag) getName() string {
return f.Name
}
func prefixFor(name string) (prefix string) {
if len(name) == 1 {
prefix = "-"
} else {
prefix = "--"
}
return
}
func prefixedNames(fullName string) (prefixed string) {
parts := strings.Split(fullName, ",")
for i, name := range parts {
name = strings.Trim(name, " ")
prefixed += prefixFor(name) + name
if i < len(parts)-1 {
prefixed += ", "
}
}
return
}
func withEnvHint(envVar, str string) string {
envText := ""
if envVar != "" {
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $"))
}
return str + envText
}

View file

@ -0,0 +1,742 @@
package cli_test
import (
"fmt"
"os"
"reflect"
"strings"
"testing"
"github.com/codegangsta/cli"
)
var boolFlagTests = []struct {
name string
expected string
}{
{"help", "--help\t"},
{"h", "-h\t"},
}
func TestBoolFlagHelpOutput(t *testing.T) {
for _, test := range boolFlagTests {
flag := cli.BoolFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
var stringFlagTests = []struct {
name string
value string
expected string
}{
{"help", "", "--help \t"},
{"h", "", "-h \t"},
{"h", "", "-h \t"},
{"test", "Something", "--test \"Something\"\t"},
}
func TestStringFlagHelpOutput(t *testing.T) {
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_FOO", "derp")
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_FOO]") {
t.Errorf("%s does not end with [$APP_FOO]", output)
}
}
}
var stringSliceFlagTests = []struct {
name string
value *cli.StringSlice
expected string
}{
{"help", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "--help [--help option --help option]\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h [-h option -h option]\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h [-h option -h option]\t"},
{"test", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("Something")
return s
}(), "--test [--test option --test option]\t"},
}
func TestStringSliceFlagHelpOutput(t *testing.T) {
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_QWWX", "11,4")
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_QWWX]") {
t.Errorf("%q does not end with [$APP_QWWX]", output)
}
}
}
var intFlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestIntFlagHelpOutput(t *testing.T) {
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAR", "2")
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var durationFlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestDurationFlagHelpOutput(t *testing.T) {
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAR", "2h3m6s")
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var intSliceFlagTests = []struct {
name string
value *cli.IntSlice
expected string
}{
{"help", &cli.IntSlice{}, "--help [--help option --help option]\t"},
{"h", &cli.IntSlice{}, "-h [-h option -h option]\t"},
{"h", &cli.IntSlice{}, "-h [-h option -h option]\t"},
{"test", func() *cli.IntSlice {
i := &cli.IntSlice{}
i.Set("9")
return i
}(), "--test [--test option --test option]\t"},
}
func TestIntSliceFlagHelpOutput(t *testing.T) {
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_SMURF", "42,3")
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_SMURF]") {
t.Errorf("%q does not end with [$APP_SMURF]", output)
}
}
}
var float64FlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestFloat64FlagHelpOutput(t *testing.T) {
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAZ", "99.4")
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAZ]") {
t.Errorf("%s does not end with [$APP_BAZ]", output)
}
}
}
var genericFlagTests = []struct {
name string
value cli.Generic
expected string
}{
{"test", &Parser{"abc", "def"}, "--test \"abc,def\"\ttest flag"},
{"t", &Parser{"abc", "def"}, "-t \"abc,def\"\ttest flag"},
}
func TestGenericFlagHelpOutput(t *testing.T) {
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name, Value: test.value, Usage: "test flag"}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_ZAP", "3")
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_ZAP]") {
t.Errorf("%s does not end with [$APP_ZAP]", output)
}
}
}
func TestParseMultiString(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.String("serve") != "10" {
t.Errorf("main name not set")
}
if ctx.String("s") != "10" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10"})
}
func TestParseMultiStringFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_COUNT", "20")
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
},
Action: func(ctx *cli.Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_COUNT", "20")
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "count, c", EnvVar: "COMPAT_COUNT,APP_COUNT"},
},
Action: func(ctx *cli.Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "serve, s", Value: &cli.StringSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiStringSliceFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSliceFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiInt(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("serve") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("s") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10"})
}
func TestParseMultiIntFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "serve, s", Value: &cli.IntSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiIntSliceFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiIntSliceFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiFloat64(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("serve") != 10.2 {
t.Errorf("main name not set")
}
if ctx.Float64("s") != 10.2 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10.2"})
}
func TestParseMultiFloat64FromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiFloat64FromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBool(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("serve") != true {
t.Errorf("main name not set")
}
if ctx.Bool("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "1")
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "1")
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolT(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("serve") != true {
t.Errorf("main name not set")
}
if ctx.BoolT("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolTFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "0")
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolTFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "0")
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
type Parser [2]string
func (p *Parser) Set(value string) error {
parts := strings.Split(value, ",")
if len(parts) != 2 {
return fmt.Errorf("invalid format")
}
(*p)[0] = parts[0]
(*p)[1] = parts[1]
return nil
}
func (p *Parser) String() string {
return fmt.Sprintf("%s,%s", p[0], p[1])
}
func TestParseGeneric(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10,20"})
}
func TestParseGenericFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_SERVE", "20,30")
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseGenericFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_FOO", "99,2000")
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "foos", Value: &Parser{}, EnvVar: "COMPAT_FOO,APP_FOO"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("foos"), &Parser{"99", "2000"}) {
t.Errorf("value not set from env")
}
},
}
a.Run([]string{"run"})
}

View file

@ -0,0 +1,211 @@
package cli
import "fmt"
// The text template for the Default help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...]
VERSION:
{{.Version}}{{if or .Author .Email}}
AUTHOR:{{if .Author}}
{{.Author}}{{if .Email}} - <{{.Email}}>{{end}}{{else}}
{{.Email}}{{end}}{{end}}
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
GLOBAL OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
`
// The text template for the command help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var CommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
DESCRIPTION:
{{.Description}}{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{ end }}
`
// The text template for the subcommand help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var SubcommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
`
var helpCommand = Command{
Name: "help",
ShortName: "h",
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowAppHelp(c)
}
},
}
var helpSubcommand = Command{
Name: "help",
ShortName: "h",
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowSubcommandHelp(c)
}
},
}
// Prints help for the App
type helpPrinter func(templ string, data interface{})
var HelpPrinter helpPrinter = nil
// Prints version for the App
var VersionPrinter = printVersion
func ShowAppHelp(c *Context) {
HelpPrinter(AppHelpTemplate, c.App)
}
// Prints the list of subcommands as the default app completion method
func DefaultAppComplete(c *Context) {
for _, command := range c.App.Commands {
fmt.Fprintln(c.App.Writer, command.Name)
if command.ShortName != "" {
fmt.Fprintln(c.App.Writer, command.ShortName)
}
}
}
// Prints help for the given command
func ShowCommandHelp(c *Context, command string) {
for _, c := range c.App.Commands {
if c.HasName(command) {
HelpPrinter(CommandHelpTemplate, c)
return
}
}
if c.App.CommandNotFound != nil {
c.App.CommandNotFound(c, command)
} else {
fmt.Fprintf(c.App.Writer, "No help topic for '%v'\n", command)
}
}
// Prints help for the given subcommand
func ShowSubcommandHelp(c *Context) {
ShowCommandHelp(c, c.Command.Name)
}
// Prints the version number of the App
func ShowVersion(c *Context) {
VersionPrinter(c)
}
func printVersion(c *Context) {
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version)
}
// Prints the lists of commands within a given context
func ShowCompletions(c *Context) {
a := c.App
if a != nil && a.BashComplete != nil {
a.BashComplete(c)
}
}
// Prints the custom completions for a given command
func ShowCommandCompletions(ctx *Context, command string) {
c := ctx.App.Command(command)
if c != nil && c.BashComplete != nil {
c.BashComplete(ctx)
}
}
func checkVersion(c *Context) bool {
if c.GlobalBool("version") {
ShowVersion(c)
return true
}
return false
}
func checkHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowAppHelp(c)
return true
}
return false
}
func checkCommandHelp(c *Context, name string) bool {
if c.Bool("h") || c.Bool("help") {
ShowCommandHelp(c, name)
return true
}
return false
}
func checkSubcommandHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowSubcommandHelp(c)
return true
}
return false
}
func checkCompletions(c *Context) bool {
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
ShowCompletions(c)
return true
}
return false
}
func checkCommandCompletions(c *Context, name string) bool {
if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
ShowCommandCompletions(c, name)
return true
}
return false
}

View file

@ -0,0 +1,19 @@
package cli_test
import (
"reflect"
"testing"
)
/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func refute(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}

View file

@ -0,0 +1,4 @@
.DS_Store
bin

View file

@ -0,0 +1,8 @@
Copyright (c) 2012 Dave Grijalva
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,55 @@
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-jones-json-web-token.html)
**NOTICE:** We recently introduced a breaking change in the API. Please refer to [VERSION_HISTORY.md](VERSION_HISTORY.md) for details.
## What the heck is a JWT?
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
## What's in the box?
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are RSA256 and HMAC SHA256, though hooks are present for adding your own.
## Parse and Verify
Parsing and verifying tokens is pretty straight forward. You pass in the token and a function for looking up the key. This is done as a callback since you may need to parse the token to find out what signing method and key was used.
```go
token, err := jwt.Parse(myToken, func(token *jwt.Token) (interface{}, error) {
return myLookupKey(token.Header["kid"])
})
if err == nil && token.Valid {
deliverGoodness("!")
} else {
deliverUtterRejection(":(")
}
```
## Create a token
```go
// Create the token
token := jwt.New(SigningMethodHS256)
// Set some claims
token.Claims["foo"] = "bar"
token.Claims["exp"] = time.Now().Add(time.Hour * 72).Unix()
// Sign and get the complete encoded token as a string
tokenString, err := token.SignedString(mySigningKey)
```
## Project Status & Versioning
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
## More
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. For a more http centric example, see [this gist](https://gist.github.com/cryptix/45c33ecf0ae54828e63b).

View file

@ -0,0 +1,54 @@
## `jwt-go` Version History
#### 2.2.0
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
#### 2.1.0
Backwards compatible API change that was missed in 2.0.0.
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
#### 2.0.0
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
* **Compatibility Breaking Changes**
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
* `KeyFunc` now returns `interface{}` instead of `[]byte`
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodHS256`
* Added public package global `SigningMethodHS384`
* Added public package global `SigningMethodHS512`
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodRS256`
* Added public package global `SigningMethodRS384`
* Added public package global `SigningMethodRS512`
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
* Refactored the RSA implementation to be easier to read
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
#### 1.0.2
* Fixed bug in parsing public keys from certificates
* Added more tests around the parsing of keys for RS256
* Code refactoring in RS256 implementation. No functional changes
#### 1.0.1
* Fixed panic if RS256 signing method was passed an invalid key
#### 1.0.0
* First versioned release
* API stabilized
* Supports creating, signing, parsing, and validating JWT tokens
* Supports RS256 and HS256 signing methods

View file

@ -0,0 +1,186 @@
// A useful example app. You can use this to debug your tokens on the command line.
// This is also a great place to look at how you might use this library.
//
// Example usage:
// The following will create and sign a token, then verify it and output the original claims.
// echo {\"foo\":\"bar\"} | bin/jwt -key test/sample_key -alg RS256 -sign - | bin/jwt -key test/sample_key.pub -verify -
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
"github.com/dgrijalva/jwt-go"
)
var (
// Options
flagAlg = flag.String("alg", "", "signing algorithm identifier")
flagKey = flag.String("key", "", "path to key file or '-' to read from stdin")
flagCompact = flag.Bool("compact", false, "output compact JSON")
flagDebug = flag.Bool("debug", false, "print out all kinds of debug data")
// Modes - exactly one of these is required
flagSign = flag.String("sign", "", "path to claims object to sign or '-' to read from stdin")
flagVerify = flag.String("verify", "", "path to JWT token to verify or '-' to read from stdin")
)
func main() {
// Usage message if you ask for -help or if you mess up inputs.
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
fmt.Fprintf(os.Stderr, " One of the following flags is required: sign, verify\n")
flag.PrintDefaults()
}
// Parse command line options
flag.Parse()
// Do the thing. If something goes wrong, print error to stderr
// and exit with a non-zero status code
if err := start(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
// Figure out which thing to do and then do that
func start() error {
if *flagSign != "" {
return signToken()
} else if *flagVerify != "" {
return verifyToken()
} else {
flag.Usage()
return fmt.Errorf("None of the required flags are present. What do you want me to do?")
}
}
// Helper func: Read input from specified file or stdin
func loadData(p string) ([]byte, error) {
if p == "" {
return nil, fmt.Errorf("No path specified")
}
var rdr io.Reader
if p == "-" {
rdr = os.Stdin
} else {
if f, err := os.Open(p); err == nil {
rdr = f
defer f.Close()
} else {
return nil, err
}
}
return ioutil.ReadAll(rdr)
}
// Print a json object in accordance with the prophecy (or the command line options)
func printJSON(j interface{}) error {
var out []byte
var err error
if *flagCompact == false {
out, err = json.MarshalIndent(j, "", " ")
} else {
out, err = json.Marshal(j)
}
if err == nil {
fmt.Println(string(out))
}
return err
}
// Verify a token and output the claims. This is a great example
// of how to verify and view a token.
func verifyToken() error {
// get the token
tokData, err := loadData(*flagVerify)
if err != nil {
return fmt.Errorf("Couldn't read token: %v", err)
}
// trim possible whitespace from token
tokData = regexp.MustCompile(`\s*$`).ReplaceAll(tokData, []byte{})
if *flagDebug {
fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData))
}
// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
return loadData(*flagKey)
})
// Print some debug data
if *flagDebug && token != nil {
fmt.Fprintf(os.Stderr, "Header:\n%v\n", token.Header)
fmt.Fprintf(os.Stderr, "Claims:\n%v\n", token.Claims)
}
// Print an error if we can't parse for some reason
if err != nil {
return fmt.Errorf("Couldn't parse token: %v", err)
}
// Is token invalid?
if !token.Valid {
return fmt.Errorf("Token is invalid")
}
// Print the token details
if err := printJSON(token.Claims); err != nil {
return fmt.Errorf("Failed to output claims: %v", err)
}
return nil
}
// Create, sign, and output a token. This is a great, simple example of
// how to use this library to create and sign a token.
func signToken() error {
// get the token data from command line arguments
tokData, err := loadData(*flagSign)
if err != nil {
return fmt.Errorf("Couldn't read token: %v", err)
} else if *flagDebug {
fmt.Fprintf(os.Stderr, "Token: %v bytes", len(tokData))
}
// parse the JSON of the claims
var claims map[string]interface{}
if err := json.Unmarshal(tokData, &claims); err != nil {
return fmt.Errorf("Couldn't parse claims JSON: %v", err)
}
// get the key
keyData, err := loadData(*flagKey)
if err != nil {
return fmt.Errorf("Couldn't read key: %v", err)
}
// get the signing alg
alg := jwt.GetSigningMethod(*flagAlg)
if alg == nil {
return fmt.Errorf("Couldn't find signing method: %v", *flagAlg)
}
// create a new token
token := jwt.New(alg)
token.Claims = claims
if out, err := token.SignedString(keyData); err == nil {
fmt.Println(out)
} else {
return fmt.Errorf("Error signing token: %v", err)
}
return nil
}

View file

@ -0,0 +1,4 @@
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
//
// See README.md for more info.
package jwt

View file

@ -0,0 +1,52 @@
package jwt_test
import (
"fmt"
"github.com/dgrijalva/jwt-go"
"time"
)
func ExampleParse(myToken string, myLookupKey func(interface{}) (interface{}, error)) {
token, err := jwt.Parse(myToken, func(token *jwt.Token) (interface{}, error) {
return myLookupKey(token.Header["kid"])
})
if err == nil && token.Valid {
fmt.Println("Your token is valid. I like your style.")
} else {
fmt.Println("This token is terrible! I cannot accept this.")
}
}
func ExampleNew(mySigningKey string) (string, error) {
// Create the token
token := jwt.New(jwt.SigningMethodHS256)
// Set some claims
token.Claims["foo"] = "bar"
token.Claims["exp"] = time.Now().Add(time.Hour * 72).Unix()
// Sign and get the complete encoded token as a string
tokenString, err := token.SignedString(mySigningKey)
return tokenString, err
}
func ExampleParse_errorChecking(myToken string, myLookupKey func(interface{}) (interface{}, error)) {
token, err := jwt.Parse(myToken, func(token *jwt.Token) (interface{}, error) {
return myLookupKey(token.Header["kid"])
})
if token.Valid {
fmt.Println("You look nice today")
} else if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
fmt.Println("That's not even a token")
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
// Token is either expired or not active yet
fmt.Println("Timing is everything")
} else {
fmt.Println("Couldn't handle this token:", err)
}
} else {
fmt.Println("Couldn't handle this token:", err)
}
}

View file

@ -0,0 +1,82 @@
package jwt
import (
"crypto"
"crypto/hmac"
"errors"
)
// Implements the HMAC-SHA family of signing methods signing methods
type SigningMethodHMAC struct {
Name string
Hash crypto.Hash
}
// Specific instances for HS256 and company
var (
SigningMethodHS256 *SigningMethodHMAC
SigningMethodHS384 *SigningMethodHMAC
SigningMethodHS512 *SigningMethodHMAC
ErrSignatureInvalid = errors.New("signature is invalid")
)
func init() {
// HS256
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
return SigningMethodHS256
})
// HS384
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
return SigningMethodHS384
})
// HS512
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
return SigningMethodHS512
})
}
func (m *SigningMethodHMAC) Alg() string {
return m.Name
}
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
if keyBytes, ok := key.([]byte); ok {
var sig []byte
var err error
if sig, err = DecodeSegment(signature); err == nil {
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
err = ErrSignatureInvalid
}
}
return err
}
return ErrInvalidKey
}
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
return EncodeSegment(hasher.Sum(nil)), nil
}
return "", ErrInvalidKey
}

View file

@ -0,0 +1,79 @@
package jwt_test
import (
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"strings"
"testing"
)
var hmacTestData = []struct {
name string
tokenString string
alg string
claims map[string]interface{}
valid bool
}{
{
"web sample",
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
"HS256",
map[string]interface{}{"iss": "joe", "exp": 1300819380, "http://example.com/is_root": true},
true,
},
{
"HS384",
"eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJleHAiOjEuMzAwODE5MzhlKzA5LCJodHRwOi8vZXhhbXBsZS5jb20vaXNfcm9vdCI6dHJ1ZSwiaXNzIjoiam9lIn0.KWZEuOD5lbBxZ34g7F-SlVLAQ_r5KApWNWlZIIMyQVz5Zs58a7XdNzj5_0EcNoOy",
"HS384",
map[string]interface{}{"iss": "joe", "exp": 1300819380, "http://example.com/is_root": true},
true,
},
{
"HS512",
"eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJleHAiOjEuMzAwODE5MzhlKzA5LCJodHRwOi8vZXhhbXBsZS5jb20vaXNfcm9vdCI6dHJ1ZSwiaXNzIjoiam9lIn0.CN7YijRX6Aw1n2jyI2Id1w90ja-DEMYiWixhYCyHnrZ1VfJRaFQz1bEbjjA5Fn4CLYaUG432dEYmSbS4Saokmw",
"HS512",
map[string]interface{}{"iss": "joe", "exp": 1300819380, "http://example.com/is_root": true},
true,
},
{
"web sample: invalid",
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXo",
"HS256",
map[string]interface{}{"iss": "joe", "exp": 1300819380, "http://example.com/is_root": true},
false,
},
}
// Sample data from http://tools.ietf.org/html/draft-jones-json-web-signature-04#appendix-A.1
var hmacTestKey, _ = ioutil.ReadFile("test/hmacTestKey")
func TestHMACVerify(t *testing.T) {
for _, data := range hmacTestData {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}
}
}
func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
}
}

View file

@ -0,0 +1,237 @@
package jwt
import (
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
)
// TimeFunc provides the current time when parsing token to validate "exp" claim (expiration time).
// You can override it to use another time value. This is useful for testing or if your
// server uses a different time zone than your tokens.
var TimeFunc = time.Now
// Parse methods use this callback function to supply
// the key for verification. The function receives the parsed,
// but unverified Token. This allows you to use propries in the
// Header of the token (such as `kid`) to identify which key to use.
type Keyfunc func(*Token) (interface{}, error)
// Error constants
var (
ErrInvalidKey = errors.New("key is invalid or of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
ErrNoTokenInRequest = errors.New("no token present in request")
)
// A JWT Token. Different fields will be used depending on whether you're
// creating or parsing/verifying a token.
type Token struct {
Raw string // The raw token. Populated when you Parse a token
Method SigningMethod // The signing method used or to be used
Header map[string]interface{} // The first segment of the token
Claims map[string]interface{} // The second segment of the token
Signature string // The third segment of the token. Populated when you Parse a token
Valid bool // Is the token valid? Populated when you Parse/Verify a token
}
// Create a new Token. Takes a signing method
func New(method SigningMethod) *Token {
return &Token{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
},
Claims: make(map[string]interface{}),
Method: method,
}
}
// Get the complete, signed token
func (t *Token) SignedString(key interface{}) (string, error) {
var sig, sstr string
var err error
if sstr, err = t.SigningString(); err != nil {
return "", err
}
if sig, err = t.Method.Sign(sstr, key); err != nil {
return "", err
}
return strings.Join([]string{sstr, sig}, "."), nil
}
// Generate the signing string. This is the
// most expensive part of the whole deal. Unless you
// need this for something special, just go straight for
// the SignedString.
func (t *Token) SigningString() (string, error) {
var err error
parts := make([]string, 2)
for i, _ := range parts {
var source map[string]interface{}
if i == 0 {
source = t.Header
} else {
source = t.Claims
}
var jsonValue []byte
if jsonValue, err = json.Marshal(source); err != nil {
return "", err
}
parts[i] = EncodeSegment(jsonValue)
}
return strings.Join(parts, "."), nil
}
// Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, &ValidationError{err: "token contains an invalid number of segments", Errors: ValidationErrorMalformed}
}
var err error
token := &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(claimBytes, &token.Claims); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, &ValidationError{err: "signing method (alg) is unavailable.", Errors: ValidationErrorUnverifiable}
}
} else {
return token, &ValidationError{err: "signing method (alg) is unspecified.", Errors: ValidationErrorUnverifiable}
}
// Lookup key
var key interface{}
if keyFunc == nil {
// keyFunc was not provided. short circuiting validation
return token, &ValidationError{err: "no Keyfunc was provided.", Errors: ValidationErrorUnverifiable}
}
if key, err = keyFunc(token); err != nil {
// keyFunc returned an error
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorUnverifiable}
}
// Check expiration times
vErr := &ValidationError{}
now := TimeFunc().Unix()
if exp, ok := token.Claims["exp"].(float64); ok {
if now > int64(exp) {
vErr.err = "token is expired"
vErr.Errors |= ValidationErrorExpired
}
}
if nbf, ok := token.Claims["nbf"].(float64); ok {
if now < int64(nbf) {
vErr.err = "token is not valid yet"
vErr.Errors |= ValidationErrorNotValidYet
}
}
// Perform validation
if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err != nil {
vErr.err = err.Error()
vErr.Errors |= ValidationErrorSignatureInvalid
}
if vErr.valid() {
token.Valid = true
return token, nil
}
return token, vErr
}
// The errors that might occur when parsing and validating a token
const (
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
ValidationErrorUnverifiable // Token could not be verified because of signing problems
ValidationErrorSignatureInvalid // Signature validation failed
ValidationErrorExpired // Exp validation failed
ValidationErrorNotValidYet // NBF validation failed
)
// The error from Parse if token is not valid
type ValidationError struct {
err string
Errors uint32 // bitfield. see ValidationError... constants
}
// Validation error is an error type
func (e ValidationError) Error() string {
if e.err == "" {
return "token is invalid"
}
return e.err
}
// No errors
func (e *ValidationError) valid() bool {
if e.Errors > 0 {
return false
}
return true
}
// Try to find the token in an http.Request.
// This method will call ParseMultipartForm if there's no token in the header.
// Currently, it looks in the Authorization header as well as
// looking for an 'access_token' request parameter in req.Form.
func ParseFromRequest(req *http.Request, keyFunc Keyfunc) (token *Token, err error) {
// Look for an Authorization header
if ah := req.Header.Get("Authorization"); ah != "" {
// Should be a bearer token
if len(ah) > 6 && strings.ToUpper(ah[0:6]) == "BEARER" {
return Parse(ah[7:], keyFunc)
}
}
// Look for "access_token" parameter
req.ParseMultipartForm(10e6)
if tokStr := req.Form.Get("access_token"); tokStr != "" {
return Parse(tokStr, keyFunc)
}
return nil, ErrNoTokenInRequest
}
// Encode JWT specific base64url encoding with padding stripped
func EncodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
}
// Decode JWT specific base64url encoding with padding stripped
func DecodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}

View file

@ -0,0 +1,174 @@
package jwt_test
import (
"fmt"
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"net/http"
"reflect"
"testing"
"time"
)
var (
jwtTestDefaultKey []byte
defaultKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return jwtTestDefaultKey, nil }
emptyKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, nil }
errorKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, fmt.Errorf("error loading key") }
nilKeyFunc jwt.Keyfunc = nil
)
var jwtTestData = []struct {
name string
tokenString string
keyfunc jwt.Keyfunc
claims map[string]interface{}
valid bool
errors uint32
}{
{
"basic",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
map[string]interface{}{"foo": "bar"},
true,
0,
},
{
"basic expired",
"", // autogen
defaultKeyFunc,
map[string]interface{}{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false,
jwt.ValidationErrorExpired,
},
{
"basic nbf",
"", // autogen
defaultKeyFunc,
map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false,
jwt.ValidationErrorNotValidYet,
},
{
"expired and nbf",
"", // autogen
defaultKeyFunc,
map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
},
{
"basic invalid",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
map[string]interface{}{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
},
{
"basic nokeyfunc",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
nilKeyFunc,
map[string]interface{}{"foo": "bar"},
false,
jwt.ValidationErrorUnverifiable,
},
{
"basic nokey",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
emptyKeyFunc,
map[string]interface{}{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
},
{
"basic errorkey",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
errorKeyFunc,
map[string]interface{}{"foo": "bar"},
false,
jwt.ValidationErrorUnverifiable,
},
}
func init() {
var e error
if jwtTestDefaultKey, e = ioutil.ReadFile("test/sample_key.pub"); e != nil {
panic(e)
}
}
func makeSample(c map[string]interface{}) string {
key, e := ioutil.ReadFile("test/sample_key")
if e != nil {
panic(e.Error())
}
token := jwt.New(jwt.SigningMethodRS256)
token.Claims = c
s, e := token.SignedString(key)
if e != nil {
panic(e.Error())
}
return s
}
func TestJWT(t *testing.T) {
for _, data := range jwtTestData {
if data.tokenString == "" {
data.tokenString = makeSample(data.claims)
}
token, err := jwt.Parse(data.tokenString, data.keyfunc)
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid token passed validation", data.name)
}
if data.errors != 0 {
if err == nil {
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
} else {
// compare the bitfield part of the error
if err.(*jwt.ValidationError).Errors != data.errors {
t.Errorf("[%v] Errors don't match expectation", data.name)
}
}
}
}
}
func TestParseRequest(t *testing.T) {
// Bearer token request
for _, data := range jwtTestData {
if data.tokenString == "" {
data.tokenString = makeSample(data.claims)
}
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
token, err := jwt.ParseFromRequest(r, data.keyfunc)
if token == nil {
t.Errorf("[%v] Token was not found: %v", data.name, err)
continue
}
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying token: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid token passed validation", data.name)
}
}
}

View file

@ -0,0 +1,114 @@
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// Implements the RSA family of signing methods signing methods
type SigningMethodRSA struct {
Name string
Hash crypto.Hash
}
// Specific instances for RS256 and company
var (
SigningMethodRS256 *SigningMethodRSA
SigningMethodRS384 *SigningMethodRSA
SigningMethodRS512 *SigningMethodRSA
)
func init() {
// RS256
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
return SigningMethodRS256
})
// RS384
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
return SigningMethodRS384
})
// RS512
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
return SigningMethodRS512
})
}
func (m *SigningMethodRSA) Alg() string {
return m.Name
}
// Implements the Verify method from SigningMethod
// For this signing method, must be either a PEM encoded PKCS1 or PKCS8 RSA public key as
// []byte, or an rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
var rsaKey *rsa.PublicKey
switch k := key.(type) {
case []byte:
if rsaKey, err = ParseRSAPublicKeyFromPEM(k); err != nil {
return err
}
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
}
// Implements the Sign method from SigningMethod
// For this signing method, must be either a PEM encoded PKCS1 or PKCS8 RSA private key as
// []byte, or an rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
var err error
var rsaKey *rsa.PrivateKey
switch k := key.(type) {
case []byte:
if rsaKey, err = ParseRSAPrivateKeyFromPEM(k); err != nil {
return "", err
}
case *rsa.PrivateKey:
rsaKey = k
default:
return "", ErrInvalidKey
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
return EncodeSegment(sigBytes), nil
} else {
return "", err
}
}

View file

@ -0,0 +1,144 @@
package jwt_test
import (
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"strings"
"testing"
)
var rsaTestData = []struct {
name string
tokenString string
alg string
claims map[string]interface{}
valid bool
}{
{
"Basic RS256",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
"RS256",
map[string]interface{}{"foo": "bar"},
true,
},
{
"Basic RS384",
"eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIifQ.W-jEzRfBigtCWsinvVVuldiuilzVdU5ty0MvpLaSaqK9PlAWWlDQ1VIQ_qSKzwL5IXaZkvZFJXT3yL3n7OUVu7zCNJzdwznbC8Z-b0z2lYvcklJYi2VOFRcGbJtXUqgjk2oGsiqUMUMOLP70TTefkpsgqDxbRh9CDUfpOJgW-dU7cmgaoswe3wjUAUi6B6G2YEaiuXC0XScQYSYVKIzgKXJV8Zw-7AN_DBUI4GkTpsvQ9fVVjZM9csQiEXhYekyrKu1nu_POpQonGd8yqkIyXPECNmmqH5jH4sFiF67XhD7_JpkvLziBpI-uh86evBUadmHhb9Otqw3uV3NTaXLzJw",
"RS384",
map[string]interface{}{"foo": "bar"},
true,
},
{
"Basic RS512",
"eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIifQ.zBlLlmRrUxx4SJPUbV37Q1joRcI9EW13grnKduK3wtYKmDXbgDpF1cZ6B-2Jsm5RB8REmMiLpGms-EjXhgnyh2TSHE-9W2gA_jvshegLWtwRVDX40ODSkTb7OVuaWgiy9y7llvcknFBTIg-FnVPVpXMmeV_pvwQyhaz1SSwSPrDyxEmksz1hq7YONXhXPpGaNbMMeDTNP_1oj8DZaqTIL9TwV8_1wb2Odt_Fy58Ke2RVFijsOLdnyEAjt2n9Mxihu9i3PhNBkkxa2GbnXBfq3kzvZ_xxGGopLdHhJjcGWXO-NiwI9_tiu14NRv4L2xC0ItD9Yz68v2ZIZEp_DuzwRQ",
"RS512",
map[string]interface{}{"foo": "bar"},
true,
},
{
"basic invalid: foo => bar",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
"RS256",
map[string]interface{}{"foo": "bar"},
false,
},
}
func TestRSAVerify(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key.pub")
for _, data := range rsaTestData {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}
}
}
func TestRSASign(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key")
for _, data := range rsaTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), key)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
}
}
func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key.pub")
parsedKey, err := jwt.ParseRSAPublicKeyFromPEM(key)
if err != nil {
t.Fatal(err)
}
testData := rsaTestData[0]
parts := strings.Split(testData.tokenString, ".")
err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), parts[2], parsedKey)
if err != nil {
t.Errorf("[%v] Error while verifying key: %v", testData.name, err)
}
}
func TestRSAWithPreParsedPrivateKey(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key")
parsedKey, err := jwt.ParseRSAPrivateKeyFromPEM(key)
if err != nil {
t.Fatal(err)
}
testData := rsaTestData[0]
parts := strings.Split(testData.tokenString, ".")
sig, err := jwt.SigningMethodRS256.Sign(strings.Join(parts[0:2], "."), parsedKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", testData.name, err)
}
if sig != parts[2] {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", testData.name, sig, parts[2])
}
}
func TestRSAKeyParsing(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key")
pubKey, _ := ioutil.ReadFile("test/sample_key.pub")
badKey := []byte("All your base are belong to key")
// Test parsePrivateKey
if _, e := jwt.ParseRSAPrivateKeyFromPEM(key); e != nil {
t.Errorf("Failed to parse valid private key: %v", e)
}
if k, e := jwt.ParseRSAPrivateKeyFromPEM(pubKey); e == nil {
t.Errorf("Parsed public key as valid private key: %v", k)
}
if k, e := jwt.ParseRSAPrivateKeyFromPEM(badKey); e == nil {
t.Errorf("Parsed invalid key as valid private key: %v", k)
}
// Test parsePublicKey
if _, e := jwt.ParseRSAPublicKeyFromPEM(pubKey); e != nil {
t.Errorf("Failed to parse valid public key: %v", e)
}
if k, e := jwt.ParseRSAPublicKeyFromPEM(key); e == nil {
t.Errorf("Parsed private key as valid public key: %v", k)
}
if k, e := jwt.ParseRSAPublicKeyFromPEM(badKey); e == nil {
t.Errorf("Parsed invalid key as valid private key: %v", k)
}
}

View file

@ -0,0 +1,68 @@
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
)
// Parse PEM encoded PKCS1 or PKCS8 private key
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}

View file

@ -0,0 +1,24 @@
package jwt
var signingMethods = map[string]func() SigningMethod{}
// Signing method
type SigningMethod interface {
Verify(signingString, signature string, key interface{}) error
Sign(signingString string, key interface{}) (string, error)
Alg() string
}
// Register the "alg" name and a factory function for signing method.
// This is typically done during init() in the method's implementation
func RegisterSigningMethod(alg string, f func() SigningMethod) {
signingMethods[alg] = f
}
// Get a signing method from an "alg" string
func GetSigningMethod(alg string) (method SigningMethod) {
if methodF, ok := signingMethods[alg]; ok {
method = methodF()
}
return
}

View file

@ -0,0 +1 @@
#5K+・シミew{ヲ住ウ(跼Tノ(ゥ┫メP.ソモ燾辻G<>感テwb="=.!r.Oタヘ奎gミ€

View file

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn
SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i
cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC
PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR
ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA
Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3
n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy
MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9
POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE
KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM
IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn
FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY
mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj
FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U
I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs
2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn
/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT
OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86
EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+
hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0
4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb
mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry
eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3
CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+
9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq
-----END RSA PRIVATE KEY-----

View file

@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41
fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7
mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp
HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2
XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b
ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
7wIDAQAB
-----END PUBLIC KEY-----

View file

@ -0,0 +1,274 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package query implements encoding of structs into URL query parameters.
//
// As a simple example:
//
// type Options struct {
// Query string `url:"q"`
// ShowAll bool `url:"all"`
// Page int `url:"page"`
// }
//
// opt := Options{ "foo", true, 2 }
// v, _ := query.Values(opt)
// fmt.Print(v.Encode()) // will output: "q=foo&all=true&page=2"
//
// The exact mapping between Go values and url.Values is described in the
// documentation for the Values() function.
package query
import (
"bytes"
"errors"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
)
var timeType = reflect.TypeOf(time.Time{})
var encoderType = reflect.TypeOf(new(Encoder)).Elem()
// Encoder is an interface implemented by any type that wishes to encode
// itself into URL values in a non-standard way.
type Encoder interface {
EncodeValues(key string, v *url.Values) error
}
// Values returns the url.Values encoding of v.
//
// Values expects to be passed a struct, and traverses it recursively using the
// following encoding rules.
//
// Each exported struct field is encoded as a URL parameter unless
//
// - the field's tag is "-", or
// - the field is empty and its tag specifies the "omitempty" option
//
// The empty values are false, 0, any nil pointer or interface value, any array
// slice, map, or string of length zero, and any time.Time that returns true
// for IsZero().
//
// The URL parameter name defaults to the struct field name but can be
// specified in the struct field's tag value. The "url" key in the struct
// field's tag value is the key name, followed by an optional comma and
// options. For example:
//
// // Field is ignored by this package.
// Field int `url:"-"`
//
// // Field appears as URL parameter "myName".
// Field int `url:"myName"`
//
// // Field appears as URL parameter "myName" and the field is omitted if
// // its value is empty
// Field int `url:"myName,omitempty"`
//
// // Field appears as URL parameter "Field" (the default), but the field
// // is skipped if empty. Note the leading comma.
// Field int `url:",omitempty"`
//
// For encoding individual field values, the following type-dependent rules
// apply:
//
// Boolean values default to encoding as the strings "true" or "false".
// Including the "int" option signals that the field should be encoded as the
// strings "1" or "0".
//
// time.Time values default to encoding as RFC3339 timestamps. Including the
// "unix" option signals that the field should be encoded as a Unix time (see
// time.Unix())
//
// Slice and Array values default to encoding as multiple URL values of the
// same name. Including the "comma" option signals that the field should be
// encoded as a single comma-delimited value. Including the "space" option
// similarly encodes the value as a single space-delimited string.
//
// Anonymous struct fields are usually encoded as if their inner exported
// fields were fields in the outer struct, subject to the standard Go
// visibility rules. An anonymous struct field with a name given in its URL
// tag is treated as having that name, rather than being anonymous.
//
// Non-nil pointer values are encoded as the value pointed to.
//
// All other values are encoded using their default string representation.
//
// Multiple fields that encode to the same URL parameter name will be included
// as multiple URL values of the same name.
func Values(v interface{}) (url.Values, error) {
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if val.IsNil() {
return nil, errors.New("query: Values() expects non-nil value")
}
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("query: Values() expects struct input. Got %v", val.Kind())
}
values := make(url.Values)
err := reflectValue(values, val)
return values, err
}
// reflectValue populates the values parameter from the struct fields in val.
// Embedded structs are followed recursively (using the rules defined in the
// Values function documentation) breadth-first.
func reflectValue(values url.Values, val reflect.Value) error {
var embedded []reflect.Value
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
sf := typ.Field(i)
if sf.PkgPath != "" { // unexported
continue
}
sv := val.Field(i)
tag := sf.Tag.Get("url")
if tag == "-" {
continue
}
name, opts := parseTag(tag)
if name == "" {
if sf.Anonymous && sv.Kind() == reflect.Struct {
// save embedded struct for later processing
embedded = append(embedded, sv)
continue
}
name = sf.Name
}
if opts.Contains("omitempty") && isEmptyValue(sv) {
continue
}
if sv.Type().Implements(encoderType) {
m := sv.Interface().(Encoder)
if err := m.EncodeValues(name, &values); err != nil {
return err
}
continue
}
if sv.Kind() == reflect.Slice || sv.Kind() == reflect.Array {
var del byte
if opts.Contains("comma") {
del = ','
} else if opts.Contains("space") {
del = ' '
}
if del != 0 {
s := new(bytes.Buffer)
first := true
for i := 0; i < sv.Len(); i++ {
if first {
first = false
} else {
s.WriteByte(del)
}
s.WriteString(valueString(sv.Index(i), opts))
}
values.Add(name, s.String())
} else {
for i := 0; i < sv.Len(); i++ {
values.Add(name, valueString(sv.Index(i), opts))
}
}
continue
}
values.Add(name, valueString(sv, opts))
}
for _, f := range embedded {
if err := reflectValue(values, f); err != nil {
return err
}
}
return nil
}
// valueString returns the string representation of a value.
func valueString(v reflect.Value, opts tagOptions) string {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
return ""
}
v = v.Elem()
}
if v.Kind() == reflect.Bool && opts.Contains("int") {
if v.Bool() {
return "1"
}
return "0"
}
if v.Type() == timeType {
t := v.Interface().(time.Time)
if opts.Contains("unix") {
return strconv.FormatInt(t.Unix(), 10)
}
return t.Format(time.RFC3339)
}
return fmt.Sprint(v.Interface())
}
// isEmptyValue checks if a value should be considered empty for the purposes
// of omitting fields with the "omitempty" option.
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
if v.Type() == timeType {
return v.Interface().(time.Time).IsZero()
}
return false
}
// tagOptions is the string following a comma in a struct field's "url" tag, or
// the empty string. It does not include the leading comma.
type tagOptions []string
// parseTag splits a struct field's url tag into its name and comma-separated
// options.
func parseTag(tag string) (string, tagOptions) {
s := strings.Split(tag, ",")
return s[0], s[1:]
}
// Contains checks whether the tagOptions contains the specified option.
func (o tagOptions) Contains(option string) bool {
for _, s := range o {
if s == option {
return true
}
}
return false
}

View file

@ -0,0 +1,238 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package query
import (
"fmt"
"net/url"
"reflect"
"testing"
"time"
)
func TestValues_types(t *testing.T) {
str := "string"
strPtr := &str
tests := []struct {
in interface{}
want url.Values
}{
{
// basic primitives
struct {
A string
B int
C uint
D float32
E bool
}{},
url.Values{
"A": {""},
"B": {"0"},
"C": {"0"},
"D": {"0"},
"E": {"false"},
},
},
{
// pointers
struct {
A *string
B *int
C **string
}{A: strPtr, C: &strPtr},
url.Values{
"A": {str},
"B": {""},
"C": {str},
},
},
{
// slices and arrays
struct {
A []string
B []string `url:",comma"`
C []string `url:",space"`
D [2]string
E [2]string `url:",comma"`
F [2]string `url:",space"`
G []*string `url:",space"`
H []bool `url:",int,space"`
}{
A: []string{"a", "b"},
B: []string{"a", "b"},
C: []string{"a", "b"},
D: [2]string{"a", "b"},
E: [2]string{"a", "b"},
F: [2]string{"a", "b"},
G: []*string{&str, &str},
H: []bool{true, false},
},
url.Values{
"A": {"a", "b"},
"B": {"a,b"},
"C": {"a b"},
"D": {"a", "b"},
"E": {"a,b"},
"F": {"a b"},
"G": {"string string"},
"H": {"1 0"},
},
},
{
// other types
struct {
A time.Time
B time.Time `url:",unix"`
C bool `url:",int"`
D bool `url:",int"`
}{
A: time.Date(2000, 1, 1, 12, 34, 56, 0, time.UTC),
B: time.Date(2000, 1, 1, 12, 34, 56, 0, time.UTC),
C: true,
D: false,
},
url.Values{
"A": {"2000-01-01T12:34:56Z"},
"B": {"946730096"},
"C": {"1"},
"D": {"0"},
},
},
}
for i, tt := range tests {
v, err := Values(tt.in)
if err != nil {
t.Errorf("%d. Values(%q) returned error: %v", i, tt.in, err)
}
if !reflect.DeepEqual(tt.want, v) {
t.Errorf("%d. Values(%q) returned %v, want %v", i, tt.in, v, tt.want)
}
}
}
func TestValues_omitEmpty(t *testing.T) {
str := ""
s := struct {
a string
A string
B string `url:",omitempty"`
C string `url:"-"`
D string `url:"omitempty"` // actually named omitempty, not an option
E *string `url:",omitempty"`
}{E: &str}
v, err := Values(s)
if err != nil {
t.Errorf("Values(%q) returned error: %v", s, err)
}
want := url.Values{
"A": {""},
"omitempty": {""},
"E": {""}, // E is included because the pointer is not empty, even though the string being pointed to is
}
if !reflect.DeepEqual(want, v) {
t.Errorf("Values(%q) returned %v, want %v", s, v, want)
}
}
type A struct {
B
}
type B struct {
C string
}
type D struct {
B
C string
}
func TestValues_embeddedStructs(t *testing.T) {
tests := []struct {
in interface{}
want url.Values
}{
{
A{B{C: "foo"}},
url.Values{"C": {"foo"}},
},
{
D{B: B{C: "bar"}, C: "foo"},
url.Values{"C": {"foo", "bar"}},
},
}
for i, tt := range tests {
v, err := Values(tt.in)
if err != nil {
t.Errorf("%d. Values(%q) returned error: %v", i, tt.in, err)
}
if !reflect.DeepEqual(tt.want, v) {
t.Errorf("%d. Values(%q) returned %v, want %v", i, tt.in, v, tt.want)
}
}
}
func TestValues_invalidInput(t *testing.T) {
_, err := Values("")
if err == nil {
t.Errorf("expected Values() to return an error on invalid input")
}
}
type EncodedArgs []string
func (m EncodedArgs) EncodeValues(key string, v *url.Values) error {
for i, arg := range m {
v.Set(fmt.Sprintf("%s.%d", key, i), arg)
}
return nil
}
func TestValues_Marshaler(t *testing.T) {
s := struct {
Args EncodedArgs `url:"arg"`
}{[]string{"a", "b", "c"}}
v, err := Values(s)
if err != nil {
t.Errorf("Values(%q) returned error: %v", s, err)
}
want := url.Values{
"arg.0": {"a"},
"arg.1": {"b"},
"arg.2": {"c"},
}
if !reflect.DeepEqual(want, v) {
t.Errorf("Values(%q) returned %v, want %v", s, v, want)
}
}
func TestTagParsing(t *testing.T) {
name, opts := parseTag("field,foobar,foo")
if name != "field" {
t.Fatalf("name = %q, want field", name)
}
for _, tt := range []struct {
opt string
want bool
}{
{"foobar", true},
{"foo", true},
{"bar", false},
{"field", false},
} {
if opts.Contains(tt.opt) != tt.want {
t.Errorf("Contains(%q) = %v", tt.opt, !tt.want)
}
}
}

View file

@ -0,0 +1,9 @@
language: go
go:
- 1.0
- 1.1
- 1.2
- 1.3
- 1.4
- tip

View file

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,7 @@
context
=======
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
gorilla/context is a general purpose registry for global request variables.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context

View file

@ -0,0 +1,143 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"sync"
"time"
)
var (
mutex sync.RWMutex
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
)
// Set stores a value for a given key in a given request.
func Set(r *http.Request, key, val interface{}) {
mutex.Lock()
if data[r] == nil {
data[r] = make(map[interface{}]interface{})
datat[r] = time.Now().Unix()
}
data[r][key] = val
mutex.Unlock()
}
// Get returns a value stored for a given key in a given request.
func Get(r *http.Request, key interface{}) interface{} {
mutex.RLock()
if ctx := data[r]; ctx != nil {
value := ctx[key]
mutex.RUnlock()
return value
}
mutex.RUnlock()
return nil
}
// GetOk returns stored value and presence state like multi-value return of map access.
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
mutex.RLock()
if _, ok := data[r]; ok {
value, ok := data[r][key]
mutex.RUnlock()
return value, ok
}
mutex.RUnlock()
return nil, false
}
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
func GetAll(r *http.Request) map[interface{}]interface{} {
mutex.RLock()
if context, ok := data[r]; ok {
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result
}
mutex.RUnlock()
return nil
}
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
// the request was registered.
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
mutex.RLock()
context, ok := data[r]
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result, ok
}
// Delete removes a value stored for a given key in a given request.
func Delete(r *http.Request, key interface{}) {
mutex.Lock()
if data[r] != nil {
delete(data[r], key)
}
mutex.Unlock()
}
// Clear removes all values stored for a given request.
//
// This is usually called by a handler wrapper to clean up request
// variables at the end of a request lifetime. See ClearHandler().
func Clear(r *http.Request) {
mutex.Lock()
clear(r)
mutex.Unlock()
}
// clear is Clear without the lock.
func clear(r *http.Request) {
delete(data, r)
delete(datat, r)
}
// Purge removes request data stored for longer than maxAge, in seconds.
// It returns the amount of requests removed.
//
// If maxAge <= 0, all request data is removed.
//
// This is only used for sanity check: in case context cleaning was not
// properly set some request data can be kept forever, consuming an increasing
// amount of memory. In case this is detected, Purge() must be called
// periodically until the problem is fixed.
func Purge(maxAge int) int {
mutex.Lock()
count := 0
if maxAge <= 0 {
count = len(data)
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
} else {
min := time.Now().Unix() - int64(maxAge)
for r := range data {
if datat[r] < min {
clear(r)
count++
}
}
}
mutex.Unlock()
return count
}
// ClearHandler wraps an http.Handler and clears request values at the end
// of a request lifetime.
func ClearHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer Clear(r)
h.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,161 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"testing"
)
type keyType int
const (
key1 keyType = iota
key2
)
func TestContext(t *testing.T) {
assertEqual := func(val interface{}, exp interface{}) {
if val != exp {
t.Errorf("Expected %v, got %v.", exp, val)
}
}
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
emptyR, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
// Get()
assertEqual(Get(r, key1), nil)
// Set()
Set(r, key1, "1")
assertEqual(Get(r, key1), "1")
assertEqual(len(data[r]), 1)
Set(r, key2, "2")
assertEqual(Get(r, key2), "2")
assertEqual(len(data[r]), 2)
//GetOk
value, ok := GetOk(r, key1)
assertEqual(value, "1")
assertEqual(ok, true)
value, ok = GetOk(r, "not exists")
assertEqual(value, nil)
assertEqual(ok, false)
Set(r, "nil value", nil)
value, ok = GetOk(r, "nil value")
assertEqual(value, nil)
assertEqual(ok, true)
// GetAll()
values := GetAll(r)
assertEqual(len(values), 3)
// GetAll() for empty request
values = GetAll(emptyR)
if values != nil {
t.Error("GetAll didn't return nil value for invalid request")
}
// GetAllOk()
values, ok = GetAllOk(r)
assertEqual(len(values), 3)
assertEqual(ok, true)
// GetAllOk() for empty request
values, ok = GetAllOk(emptyR)
assertEqual(value, nil)
assertEqual(ok, false)
// Delete()
Delete(r, key1)
assertEqual(Get(r, key1), nil)
assertEqual(len(data[r]), 2)
Delete(r, key2)
assertEqual(Get(r, key2), nil)
assertEqual(len(data[r]), 1)
// Clear()
Clear(r)
assertEqual(len(data), 0)
}
func parallelReader(r *http.Request, key string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Get(r, key)
}
done <- struct{}{}
}
func parallelWriter(r *http.Request, key, value string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Set(r, key, value)
}
done <- struct{}{}
}
func benchmarkMutex(b *testing.B, numReaders, numWriters, iterations int) {
b.StopTimer()
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
done := make(chan struct{})
b.StartTimer()
for i := 0; i < b.N; i++ {
wait := make(chan struct{})
for i := 0; i < numReaders; i++ {
go parallelReader(r, "test", iterations, wait, done)
}
for i := 0; i < numWriters; i++ {
go parallelWriter(r, "test", "123", iterations, wait, done)
}
close(wait)
for i := 0; i < numReaders+numWriters; i++ {
<-done
}
}
}
func BenchmarkMutexSameReadWrite1(b *testing.B) {
benchmarkMutex(b, 1, 1, 32)
}
func BenchmarkMutexSameReadWrite2(b *testing.B) {
benchmarkMutex(b, 2, 2, 32)
}
func BenchmarkMutexSameReadWrite4(b *testing.B) {
benchmarkMutex(b, 4, 4, 32)
}
func BenchmarkMutex1(b *testing.B) {
benchmarkMutex(b, 2, 8, 32)
}
func BenchmarkMutex2(b *testing.B) {
benchmarkMutex(b, 16, 4, 64)
}
func BenchmarkMutex3(b *testing.B) {
benchmarkMutex(b, 1, 2, 128)
}
func BenchmarkMutex4(b *testing.B) {
benchmarkMutex(b, 128, 32, 256)
}
func BenchmarkMutex5(b *testing.B) {
benchmarkMutex(b, 1024, 2048, 64)
}
func BenchmarkMutex6(b *testing.B) {
benchmarkMutex(b, 2048, 1024, 512)
}

View file

@ -0,0 +1,82 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package context stores values shared during a request lifetime.
For example, a router can set variables extracted from the URL and later
application handlers can access those values, or it can be used to store
sessions values to be saved at the end of a request. There are several
others common uses.
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
Here's the basic usage: first define the keys that you will need. The key
type is interface{} so a key can be of any type that supports equality.
Here we define a key using a custom int type to avoid name collisions:
package foo
import (
"github.com/gorilla/context"
)
type key int
const MyKey key = 0
Then set a variable. Variables are bound to an http.Request object, so you
need a request instance to set a value:
context.Set(r, MyKey, "bar")
The application can later access the variable using the same key you provided:
func MyHandler(w http.ResponseWriter, r *http.Request) {
// val is "bar".
val := context.Get(r, foo.MyKey)
// returns ("bar", true)
val, ok := context.GetOk(r, foo.MyKey)
// ...
}
And that's all about the basic usage. We discuss some other ideas below.
Any type can be stored in the context. To enforce a given type, make the key
private and wrap Get() and Set() to accept and return values of a specific
type:
type key int
const mykey key = 0
// GetMyKey returns a value for this package from the request values.
func GetMyKey(r *http.Request) SomeType {
if rv := context.Get(r, mykey); rv != nil {
return rv.(SomeType)
}
return nil
}
// SetMyKey sets a value for this package in the request values.
func SetMyKey(r *http.Request, val SomeType) {
context.Set(r, mykey, val)
}
Variables must be cleared at the end of a request, to remove all values
that were stored. This can be done in an http.Handler, after a request was
served. Just call Clear() passing the request:
context.Clear(r)
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
variables at the end of a request lifetime.
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
so if you are using either of them you don't need to clear the context manually.
*/
package context

View file

@ -0,0 +1,7 @@
language: go
go:
- 1.0
- 1.1
- 1.2
- tip

27
Godeps/_workspace/src/github.com/gorilla/mux/LICENSE generated vendored Normal file
View file

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,7 @@
mux
===
[![Build Status](https://travis-ci.org/gorilla/mux.png?branch=master)](https://travis-ci.org/gorilla/mux)
gorilla/mux is a powerful URL router and dispatcher.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/mux

View file

@ -0,0 +1,21 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"net/http"
"testing"
)
func BenchmarkMux(b *testing.B) {
router := new(Router)
handler := func(w http.ResponseWriter, r *http.Request) {}
router.HandleFunc("/v1/{v1}", handler)
request, _ := http.NewRequest("GET", "/v1/anything", nil)
for i := 0; i < b.N; i++ {
router.ServeHTTP(nil, request)
}
}

199
Godeps/_workspace/src/github.com/gorilla/mux/doc.go generated vendored Normal file
View file

@ -0,0 +1,199 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package gorilla/mux implements a request router and dispatcher.
The name mux stands for "HTTP request multiplexer". Like the standard
http.ServeMux, mux.Router matches incoming requests against a list of
registered routes and calls a handler for the route that matches the URL
or other conditions. The main features are:
* Requests can be matched based on URL host, path, path prefix, schemes,
header and query values, HTTP methods or using custom matchers.
* URL hosts and paths can have variables with an optional regular
expression.
* Registered URLs can be built, or "reversed", which helps maintaining
references to resources.
* Routes can be used as subrouters: nested routes are only tested if the
parent route matches. This is useful to define groups of routes that
share common conditions like a host, a path prefix or other repeated
attributes. As a bonus, this optimizes request matching.
* It implements the http.Handler interface so it is compatible with the
standard http.ServeMux.
Let's start registering a couple of URL paths and handlers:
func main() {
r := mux.NewRouter()
r.HandleFunc("/", HomeHandler)
r.HandleFunc("/products", ProductsHandler)
r.HandleFunc("/articles", ArticlesHandler)
http.Handle("/", r)
}
Here we register three routes mapping URL paths to handlers. This is
equivalent to how http.HandleFunc() works: if an incoming request URL matches
one of the paths, the corresponding handler is called passing
(http.ResponseWriter, *http.Request) as parameters.
Paths can have variables. They are defined using the format {name} or
{name:pattern}. If a regular expression pattern is not defined, the matched
variable will be anything until the next slash. For example:
r := mux.NewRouter()
r.HandleFunc("/products/{key}", ProductHandler)
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
The names are used to create a map of route variables which can be retrieved
calling mux.Vars():
vars := mux.Vars(request)
category := vars["category"]
And this is all you need to know about the basic usage. More advanced options
are explained below.
Routes can also be restricted to a domain or subdomain. Just define a host
pattern to be matched. They can also have variables:
r := mux.NewRouter()
// Only matches if domain is "www.domain.com".
r.Host("www.domain.com")
// Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.domain.com")
There are several other matchers that can be added. To match path prefixes:
r.PathPrefix("/products/")
...or HTTP methods:
r.Methods("GET", "POST")
...or URL schemes:
r.Schemes("https")
...or header values:
r.Headers("X-Requested-With", "XMLHttpRequest")
...or query values:
r.Queries("key", "value")
...or to use a custom matcher function:
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
return r.ProtoMajor == 0
})
...and finally, it is possible to combine several matchers in a single route:
r.HandleFunc("/products", ProductsHandler).
Host("www.domain.com").
Methods("GET").
Schemes("http")
Setting the same matching conditions again and again can be boring, so we have
a way to group several routes that share the same requirements.
We call it "subrouting".
For example, let's say we have several URLs that should only match when the
host is "www.domain.com". Create a route for that host and get a "subrouter"
from it:
r := mux.NewRouter()
s := r.Host("www.domain.com").Subrouter()
Then register routes in the subrouter:
s.HandleFunc("/products/", ProductsHandler)
s.HandleFunc("/products/{key}", ProductHandler)
s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
The three URL paths we registered above will only be tested if the domain is
"www.domain.com", because the subrouter is tested first. This is not
only convenient, but also optimizes request matching. You can create
subrouters combining any attribute matchers accepted by a route.
Subrouters can be used to create domain or path "namespaces": you define
subrouters in a central place and then parts of the app can register its
paths relatively to a given subrouter.
There's one more thing about subroutes. When a subrouter has a path prefix,
the inner routes use it as base for their paths:
r := mux.NewRouter()
s := r.PathPrefix("/products").Subrouter()
// "/products/"
s.HandleFunc("/", ProductsHandler)
// "/products/{key}/"
s.HandleFunc("/{key}/", ProductHandler)
// "/products/{key}/details"
s.HandleFunc("/{key}/details", ProductDetailsHandler)
Now let's see how to build registered URLs.
Routes can be named. All routes that define a name can have their URLs built,
or "reversed". We define a name calling Name() on a route. For example:
r := mux.NewRouter()
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
Name("article")
To build a URL, get the route and call the URL() method, passing a sequence of
key/value pairs for the route variables. For the previous route, we would do:
url, err := r.Get("article").URL("category", "technology", "id", "42")
...and the result will be a url.URL with the following path:
"/articles/technology/42"
This also works for host variables:
r := mux.NewRouter()
r.Host("{subdomain}.domain.com").
Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// url.String() will be "http://news.domain.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
All variables defined in the route are required, and their values must
conform to the corresponding patterns. These requirements guarantee that a
generated URL will always match a registered route -- the only exception is
for explicitly defined "build-only" routes which never match.
There's also a way to build only the URL host or path for a route:
use the methods URLHost() or URLPath() instead. For the previous route,
we would do:
// "http://news.domain.com/"
host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42"
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
And if you use subrouters, host and path defined separately can be built
as well:
r := mux.NewRouter()
s := r.Host("{subdomain}.domain.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler).
Name("article")
// "http://news.domain.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
*/
package mux

353
Godeps/_workspace/src/github.com/gorilla/mux/mux.go generated vendored Normal file
View file

@ -0,0 +1,353 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"fmt"
"net/http"
"path"
"github.com/gorilla/context"
)
// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
}
// Router registers routes to be matched and dispatches a handler.
//
// It implements the http.Handler interface, so it can be registered to serve
// requests:
//
// var router = mux.NewRouter()
//
// func main() {
// http.Handle("/", router)
// }
//
// Or, for Google App Engine, register it in a init() function:
//
// func init() {
// http.Handle("/", router)
// }
//
// This will send all incoming requests to the router.
type Router struct {
// Configurable Handler to be used when no route matches.
NotFoundHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order.
routes []*Route
// Routes by name for URL building.
namedRoutes map[string]*Route
// See Router.StrictSlash(). This defines the flag for new routes.
strictSlash bool
// If true, do not clear the request context after handling the request
KeepContext bool
}
// Match matches registered routes against the request.
func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
for _, route := range r.routes {
if route.Match(req, match) {
return true
}
}
return false
}
// ServeHTTP dispatches the handler registered in the matched route.
//
// When there is a match, the route variables can be retrieved calling
// mux.Vars(request).
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Clean path to canonical form and redirect.
if p := cleanPath(req.URL.Path); p != req.URL.Path {
// Added 3 lines (Philip Schlump) - It was droping the query string and #whatever from query.
// This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue:
// http://code.google.com/p/go/issues/detail?id=5252
url := *req.URL
url.Path = p
p = url.String()
w.Header().Set("Location", p)
w.WriteHeader(http.StatusMovedPermanently)
return
}
var match RouteMatch
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
setVars(req, match.Vars)
setCurrentRoute(req, match.Route)
}
if handler == nil {
handler = r.NotFoundHandler
if handler == nil {
handler = http.NotFoundHandler()
}
}
if !r.KeepContext {
defer context.Clear(req)
}
handler.ServeHTTP(w, req)
}
// Get returns a route registered with the given name.
func (r *Router) Get(name string) *Route {
return r.getNamedRoutes()[name]
}
// GetRoute returns a route registered with the given name. This method
// was renamed to Get() and remains here for backwards compatibility.
func (r *Router) GetRoute(name string) *Route {
return r.getNamedRoutes()[name]
}
// StrictSlash defines the trailing slash behavior for new routes. The initial
// value is false.
//
// When true, if the route path is "/path/", accessing "/path" will redirect
// to the former and vice versa. In other words, your application will always
// see the path as specified in the route.
//
// When false, if the route path is "/path", accessing "/path/" will not match
// this route and vice versa.
//
// Special case: when a route sets a path prefix using the PathPrefix() method,
// strict slash is ignored for that route because the redirect behavior can't
// be determined from a prefix alone. However, any subrouters created from that
// route inherit the original StrictSlash setting.
func (r *Router) StrictSlash(value bool) *Router {
r.strictSlash = value
return r
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
// getNamedRoutes returns the map where named routes are registered.
func (r *Router) getNamedRoutes() map[string]*Route {
if r.namedRoutes == nil {
if r.parent != nil {
r.namedRoutes = r.parent.getNamedRoutes()
} else {
r.namedRoutes = make(map[string]*Route)
}
}
return r.namedRoutes
}
// getRegexpGroup returns regexp definitions from the parent route, if any.
func (r *Router) getRegexpGroup() *routeRegexpGroup {
if r.parent != nil {
return r.parent.getRegexpGroup()
}
return nil
}
// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
// NewRoute registers an empty route.
func (r *Router) NewRoute() *Route {
route := &Route{parent: r, strictSlash: r.strictSlash}
r.routes = append(r.routes, route)
return route
}
// Handle registers a new route with a matcher for the URL path.
// See Route.Path() and Route.Handler().
func (r *Router) Handle(path string, handler http.Handler) *Route {
return r.NewRoute().Path(path).Handler(handler)
}
// HandleFunc registers a new route with a matcher for the URL path.
// See Route.Path() and Route.HandlerFunc().
func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
*http.Request)) *Route {
return r.NewRoute().Path(path).HandlerFunc(f)
}
// Headers registers a new route with a matcher for request header values.
// See Route.Headers().
func (r *Router) Headers(pairs ...string) *Route {
return r.NewRoute().Headers(pairs...)
}
// Host registers a new route with a matcher for the URL host.
// See Route.Host().
func (r *Router) Host(tpl string) *Route {
return r.NewRoute().Host(tpl)
}
// MatcherFunc registers a new route with a custom matcher function.
// See Route.MatcherFunc().
func (r *Router) MatcherFunc(f MatcherFunc) *Route {
return r.NewRoute().MatcherFunc(f)
}
// Methods registers a new route with a matcher for HTTP methods.
// See Route.Methods().
func (r *Router) Methods(methods ...string) *Route {
return r.NewRoute().Methods(methods...)
}
// Path registers a new route with a matcher for the URL path.
// See Route.Path().
func (r *Router) Path(tpl string) *Route {
return r.NewRoute().Path(tpl)
}
// PathPrefix registers a new route with a matcher for the URL path prefix.
// See Route.PathPrefix().
func (r *Router) PathPrefix(tpl string) *Route {
return r.NewRoute().PathPrefix(tpl)
}
// Queries registers a new route with a matcher for URL query values.
// See Route.Queries().
func (r *Router) Queries(pairs ...string) *Route {
return r.NewRoute().Queries(pairs...)
}
// Schemes registers a new route with a matcher for URL schemes.
// See Route.Schemes().
func (r *Router) Schemes(schemes ...string) *Route {
return r.NewRoute().Schemes(schemes...)
}
// ----------------------------------------------------------------------------
// Context
// ----------------------------------------------------------------------------
// RouteMatch stores information about a matched route.
type RouteMatch struct {
Route *Route
Handler http.Handler
Vars map[string]string
}
type contextKey int
const (
varsKey contextKey = iota
routeKey
)
// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := context.Get(r, varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
}
// CurrentRoute returns the matched route for the current request, if any.
func CurrentRoute(r *http.Request) *Route {
if rv := context.Get(r, routeKey); rv != nil {
return rv.(*Route)
}
return nil
}
func setVars(r *http.Request, val interface{}) {
context.Set(r, varsKey, val)
}
func setCurrentRoute(r *http.Request, val interface{}) {
context.Set(r, routeKey, val)
}
// ----------------------------------------------------------------------------
// Helpers
// ----------------------------------------------------------------------------
// cleanPath returns the canonical path for p, eliminating . and .. elements.
// Borrowed from the net/http package.
func cleanPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
// uniqueVars returns an error if two slices contain duplicated strings.
func uniqueVars(s1, s2 []string) error {
for _, v1 := range s1 {
for _, v2 := range s2 {
if v1 == v2 {
return fmt.Errorf("mux: duplicated route variable %q", v2)
}
}
}
return nil
}
// mapFromPairs converts variadic string parameters to a string map.
func mapFromPairs(pairs ...string) (map[string]string, error) {
length := len(pairs)
if length%2 != 0 {
return nil, fmt.Errorf(
"mux: number of parameters must be multiple of 2, got %v", pairs)
}
m := make(map[string]string, length/2)
for i := 0; i < length; i += 2 {
m[pairs[i]] = pairs[i+1]
}
return m, nil
}
// matchInArray returns true if the given string value is in the array.
func matchInArray(arr []string, value string) bool {
for _, v := range arr {
if v == value {
return true
}
}
return false
}
// matchMap returns true if the given key/value pairs exist in a given map.
func matchMap(toCheck map[string]string, toMatch map[string][]string,
canonicalKey bool) bool {
for k, v := range toCheck {
// Check if key exists.
if canonicalKey {
k = http.CanonicalHeaderKey(k)
}
if values := toMatch[k]; values == nil {
return false
} else if v != "" {
// If value was defined as an empty string we only check that the
// key exists. Otherwise we also check for equality.
valueExists := false
for _, value := range values {
if v == value {
valueExists = true
break
}
}
if !valueExists {
return false
}
}
}
return true
}

View file

@ -0,0 +1,943 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"fmt"
"net/http"
"testing"
"github.com/gorilla/context"
)
type routeTest struct {
title string // title of the test
route *Route // the route being tested
request *http.Request // a request to test the route
vars map[string]string // the expected vars of the match
host string // the expected host of the match
path string // the expected path of the match
shouldMatch bool // whether the request is expected to match the route at all
shouldRedirect bool // whether the request should result in a redirect
}
func TestHost(t *testing.T) {
// newRequestHost a new request with a method, url, and host header
newRequestHost := func(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host
return req
}
tests := []routeTest{
{
title: "Host route match",
route: new(Route).Host("aaa.bbb.ccc"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: true,
},
{
title: "Host route, wrong host in request URL",
route: new(Route).Host("aaa.bbb.ccc"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: false,
},
{
title: "Host route with port, match",
route: new(Route).Host("aaa.bbb.ccc:1234"),
request: newRequest("GET", "http://aaa.bbb.ccc:1234/111/222/333"),
vars: map[string]string{},
host: "aaa.bbb.ccc:1234",
path: "",
shouldMatch: true,
},
{
title: "Host route with port, wrong port in request URL",
route: new(Route).Host("aaa.bbb.ccc:1234"),
request: newRequest("GET", "http://aaa.bbb.ccc:9999/111/222/333"),
vars: map[string]string{},
host: "aaa.bbb.ccc:1234",
path: "",
shouldMatch: false,
},
{
title: "Host route, match with host in request header",
route: new(Route).Host("aaa.bbb.ccc"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc"),
vars: map[string]string{},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: true,
},
{
title: "Host route, wrong host in request header",
route: new(Route).Host("aaa.bbb.ccc"),
request: newRequestHost("GET", "/111/222/333", "aaa.222.ccc"),
vars: map[string]string{},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: false,
},
// BUG {new(Route).Host("aaa.bbb.ccc:1234"), newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:1234"), map[string]string{}, "aaa.bbb.ccc:1234", "", true},
{
title: "Host route with port, wrong host in request header",
route: new(Route).Host("aaa.bbb.ccc:1234"),
request: newRequestHost("GET", "/111/222/333", "aaa.bbb.ccc:9999"),
vars: map[string]string{},
host: "aaa.bbb.ccc:1234",
path: "",
shouldMatch: false,
},
{
title: "Host route with pattern, match",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{"v1": "bbb"},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: true,
},
{
title: "Host route with pattern, wrong host in request URL",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{"v1": "bbb"},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: false,
},
{
title: "Host route with multiple patterns, match",
route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: true,
},
{
title: "Host route with multiple patterns, wrong host in request URL",
route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc"},
host: "aaa.bbb.ccc",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestPath(t *testing.T) {
tests := []routeTest{
{
title: "Path route, match",
route: new(Route).Path("/111/222/333"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{},
host: "",
path: "/111/222/333",
shouldMatch: true,
},
{
title: "Path route, match with trailing slash in request and path",
route: new(Route).Path("/111/"),
request: newRequest("GET", "http://localhost/111/"),
vars: map[string]string{},
host: "",
path: "/111/",
shouldMatch: true,
},
{
title: "Path route, do not match with trailing slash in path",
route: new(Route).Path("/111/"),
request: newRequest("GET", "http://localhost/111"),
vars: map[string]string{},
host: "",
path: "/111",
shouldMatch: false,
},
{
title: "Path route, do not match with trailing slash in request",
route: new(Route).Path("/111"),
request: newRequest("GET", "http://localhost/111/"),
vars: map[string]string{},
host: "",
path: "/111/",
shouldMatch: false,
},
{
title: "Path route, wrong path in request in request URL",
route: new(Route).Path("/111/222/333"),
request: newRequest("GET", "http://localhost/1/2/3"),
vars: map[string]string{},
host: "",
path: "/111/222/333",
shouldMatch: false,
},
{
title: "Path route with pattern, match",
route: new(Route).Path("/111/{v1:[0-9]{3}}/333"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{"v1": "222"},
host: "",
path: "/111/222/333",
shouldMatch: true,
},
{
title: "Path route with pattern, URL in request does not match",
route: new(Route).Path("/111/{v1:[0-9]{3}}/333"),
request: newRequest("GET", "http://localhost/111/aaa/333"),
vars: map[string]string{"v1": "222"},
host: "",
path: "/111/222/333",
shouldMatch: false,
},
{
title: "Path route with multiple patterns, match",
route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"},
host: "",
path: "/111/222/333",
shouldMatch: true,
},
{
title: "Path route with multiple patterns, URL in request does not match",
route: new(Route).Path("/{v1:[0-9]{3}}/{v2:[0-9]{3}}/{v3:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/aaa/333"),
vars: map[string]string{"v1": "111", "v2": "222", "v3": "333"},
host: "",
path: "/111/222/333",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestPathPrefix(t *testing.T) {
tests := []routeTest{
{
title: "PathPrefix route, match",
route: new(Route).PathPrefix("/111"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{},
host: "",
path: "/111",
shouldMatch: true,
},
{
title: "PathPrefix route, match substring",
route: new(Route).PathPrefix("/1"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{},
host: "",
path: "/1",
shouldMatch: true,
},
{
title: "PathPrefix route, URL prefix in request does not match",
route: new(Route).PathPrefix("/111"),
request: newRequest("GET", "http://localhost/1/2/3"),
vars: map[string]string{},
host: "",
path: "/111",
shouldMatch: false,
},
{
title: "PathPrefix route with pattern, match",
route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{"v1": "222"},
host: "",
path: "/111/222",
shouldMatch: true,
},
{
title: "PathPrefix route with pattern, URL prefix in request does not match",
route: new(Route).PathPrefix("/111/{v1:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/aaa/333"),
vars: map[string]string{"v1": "222"},
host: "",
path: "/111/222",
shouldMatch: false,
},
{
title: "PathPrefix route with multiple patterns, match",
route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/222/333"),
vars: map[string]string{"v1": "111", "v2": "222"},
host: "",
path: "/111/222",
shouldMatch: true,
},
{
title: "PathPrefix route with multiple patterns, URL prefix in request does not match",
route: new(Route).PathPrefix("/{v1:[0-9]{3}}/{v2:[0-9]{3}}"),
request: newRequest("GET", "http://localhost/111/aaa/333"),
vars: map[string]string{"v1": "111", "v2": "222"},
host: "",
path: "/111/222",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestHostPath(t *testing.T) {
tests := []routeTest{
{
title: "Host and Path route, match",
route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Host and Path route, wrong host in request URL",
route: new(Route).Host("aaa.bbb.ccc").Path("/111/222/333"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
{
title: "Host and Path route with pattern, match",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{"v1": "bbb", "v2": "222"},
host: "aaa.bbb.ccc",
path: "/111/222/333",
shouldMatch: true,
},
{
title: "Host and Path route with pattern, URL in request does not match",
route: new(Route).Host("aaa.{v1:[a-z]{3}}.ccc").Path("/111/{v2:[0-9]{3}}/333"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{"v1": "bbb", "v2": "222"},
host: "aaa.bbb.ccc",
path: "/111/222/333",
shouldMatch: false,
},
{
title: "Host and Path route with multiple patterns, match",
route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"),
request: newRequest("GET", "http://aaa.bbb.ccc/111/222/333"),
vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"},
host: "aaa.bbb.ccc",
path: "/111/222/333",
shouldMatch: true,
},
{
title: "Host and Path route with multiple patterns, URL in request does not match",
route: new(Route).Host("{v1:[a-z]{3}}.{v2:[a-z]{3}}.{v3:[a-z]{3}}").Path("/{v4:[0-9]{3}}/{v5:[0-9]{3}}/{v6:[0-9]{3}}"),
request: newRequest("GET", "http://aaa.222.ccc/111/222/333"),
vars: map[string]string{"v1": "aaa", "v2": "bbb", "v3": "ccc", "v4": "111", "v5": "222", "v6": "333"},
host: "aaa.bbb.ccc",
path: "/111/222/333",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestHeaders(t *testing.T) {
// newRequestHeaders creates a new request with a method, url, and headers
newRequestHeaders := func(method, url string, headers map[string]string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
for k, v := range headers {
req.Header.Add(k, v)
}
return req
}
tests := []routeTest{
{
title: "Headers route, match",
route: new(Route).Headers("foo", "bar", "baz", "ding"),
request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "bar", "baz": "ding"}),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Headers route, bad header values",
route: new(Route).Headers("foo", "bar", "baz", "ding"),
request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "bar", "baz": "dong"}),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestMethods(t *testing.T) {
tests := []routeTest{
{
title: "Methods route, match GET",
route: new(Route).Methods("GET", "POST"),
request: newRequest("GET", "http://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Methods route, match POST",
route: new(Route).Methods("GET", "POST"),
request: newRequest("POST", "http://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Methods route, bad method",
route: new(Route).Methods("GET", "POST"),
request: newRequest("PUT", "http://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestQueries(t *testing.T) {
tests := []routeTest{
{
title: "Queries route, match",
route: new(Route).Queries("foo", "bar", "baz", "ding"),
request: newRequest("GET", "http://localhost?foo=bar&baz=ding"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route, match with a query string",
route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"),
request: newRequest("GET", "http://www.example.com/api?foo=bar&baz=ding"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route, match with a query string out of order",
route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"),
request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route, bad query",
route: new(Route).Queries("foo", "bar", "baz", "ding"),
request: newRequest("GET", "http://localhost?foo=bar&baz=dong"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
{
title: "Queries route with pattern, match",
route: new(Route).Queries("foo", "{v1}"),
request: newRequest("GET", "http://localhost?foo=bar"),
vars: map[string]string{"v1": "bar"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with multiple patterns, match",
route: new(Route).Queries("foo", "{v1}", "baz", "{v2}"),
request: newRequest("GET", "http://localhost?foo=bar&baz=ding"),
vars: map[string]string{"v1": "bar", "v2": "ding"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with regexp pattern, match",
route: new(Route).Queries("foo", "{v1:[0-9]+}"),
request: newRequest("GET", "http://localhost?foo=10"),
vars: map[string]string{"v1": "10"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with regexp pattern, regexp does not match",
route: new(Route).Queries("foo", "{v1:[0-9]+}"),
request: newRequest("GET", "http://localhost?foo=a"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestSchemes(t *testing.T) {
tests := []routeTest{
// Schemes
{
title: "Schemes route, match https",
route: new(Route).Schemes("https", "ftp"),
request: newRequest("GET", "https://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Schemes route, match ftp",
route: new(Route).Schemes("https", "ftp"),
request: newRequest("GET", "ftp://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Schemes route, bad scheme",
route: new(Route).Schemes("https", "ftp"),
request: newRequest("GET", "http://localhost"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestMatcherFunc(t *testing.T) {
m := func(r *http.Request, m *RouteMatch) bool {
if r.URL.Host == "aaa.bbb.ccc" {
return true
}
return false
}
tests := []routeTest{
{
title: "MatchFunc route, match",
route: new(Route).MatcherFunc(m),
request: newRequest("GET", "http://aaa.bbb.ccc"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "MatchFunc route, non-match",
route: new(Route).MatcherFunc(m),
request: newRequest("GET", "http://aaa.222.ccc"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestSubRouter(t *testing.T) {
subrouter1 := new(Route).Host("{v1:[a-z]+}.google.com").Subrouter()
subrouter2 := new(Route).PathPrefix("/foo/{v1}").Subrouter()
tests := []routeTest{
{
route: subrouter1.Path("/{v2:[a-z]+}"),
request: newRequest("GET", "http://aaa.google.com/bbb"),
vars: map[string]string{"v1": "aaa", "v2": "bbb"},
host: "aaa.google.com",
path: "/bbb",
shouldMatch: true,
},
{
route: subrouter1.Path("/{v2:[a-z]+}"),
request: newRequest("GET", "http://111.google.com/111"),
vars: map[string]string{"v1": "aaa", "v2": "bbb"},
host: "aaa.google.com",
path: "/bbb",
shouldMatch: false,
},
{
route: subrouter2.Path("/baz/{v2}"),
request: newRequest("GET", "http://localhost/foo/bar/baz/ding"),
vars: map[string]string{"v1": "bar", "v2": "ding"},
host: "",
path: "/foo/bar/baz/ding",
shouldMatch: true,
},
{
route: subrouter2.Path("/baz/{v2}"),
request: newRequest("GET", "http://localhost/foo/bar"),
vars: map[string]string{"v1": "bar", "v2": "ding"},
host: "",
path: "/foo/bar/baz/ding",
shouldMatch: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
func TestNamedRoutes(t *testing.T) {
r1 := NewRouter()
r1.NewRoute().Name("a")
r1.NewRoute().Name("b")
r1.NewRoute().Name("c")
r2 := r1.NewRoute().Subrouter()
r2.NewRoute().Name("d")
r2.NewRoute().Name("e")
r2.NewRoute().Name("f")
r3 := r2.NewRoute().Subrouter()
r3.NewRoute().Name("g")
r3.NewRoute().Name("h")
r3.NewRoute().Name("i")
if r1.namedRoutes == nil || len(r1.namedRoutes) != 9 {
t.Errorf("Expected 9 named routes, got %v", r1.namedRoutes)
} else if r1.Get("i") == nil {
t.Errorf("Subroute name not registered")
}
}
func TestStrictSlash(t *testing.T) {
r := NewRouter()
r.StrictSlash(true)
tests := []routeTest{
{
title: "Redirect path without slash",
route: r.NewRoute().Path("/111/"),
request: newRequest("GET", "http://localhost/111"),
vars: map[string]string{},
host: "",
path: "/111/",
shouldMatch: true,
shouldRedirect: true,
},
{
title: "Do not redirect path with slash",
route: r.NewRoute().Path("/111/"),
request: newRequest("GET", "http://localhost/111/"),
vars: map[string]string{},
host: "",
path: "/111/",
shouldMatch: true,
shouldRedirect: false,
},
{
title: "Redirect path with slash",
route: r.NewRoute().Path("/111"),
request: newRequest("GET", "http://localhost/111/"),
vars: map[string]string{},
host: "",
path: "/111",
shouldMatch: true,
shouldRedirect: true,
},
{
title: "Do not redirect path without slash",
route: r.NewRoute().Path("/111"),
request: newRequest("GET", "http://localhost/111"),
vars: map[string]string{},
host: "",
path: "/111",
shouldMatch: true,
shouldRedirect: false,
},
{
title: "Propagate StrictSlash to subrouters",
route: r.NewRoute().PathPrefix("/static/").Subrouter().Path("/images/"),
request: newRequest("GET", "http://localhost/static/images"),
vars: map[string]string{},
host: "",
path: "/static/images/",
shouldMatch: true,
shouldRedirect: true,
},
{
title: "Ignore StrictSlash for path prefix",
route: r.NewRoute().PathPrefix("/static/"),
request: newRequest("GET", "http://localhost/static/logo.png"),
vars: map[string]string{},
host: "",
path: "/static/",
shouldMatch: true,
shouldRedirect: false,
},
}
for _, test := range tests {
testRoute(t, test)
}
}
// ----------------------------------------------------------------------------
// Helpers
// ----------------------------------------------------------------------------
func getRouteTemplate(route *Route) string {
host, path := "none", "none"
if route.regexp != nil {
if route.regexp.host != nil {
host = route.regexp.host.template
}
if route.regexp.path != nil {
path = route.regexp.path.template
}
}
return fmt.Sprintf("Host: %v, Path: %v", host, path)
}
func testRoute(t *testing.T, test routeTest) {
request := test.request
route := test.route
vars := test.vars
shouldMatch := test.shouldMatch
host := test.host
path := test.path
url := test.host + test.path
shouldRedirect := test.shouldRedirect
var match RouteMatch
ok := route.Match(request, &match)
if ok != shouldMatch {
msg := "Should match"
if !shouldMatch {
msg = "Should not match"
}
t.Errorf("(%v) %v:\nRoute: %#v\nRequest: %#v\nVars: %v\n", test.title, msg, route, request, vars)
return
}
if shouldMatch {
if test.vars != nil && !stringMapEqual(test.vars, match.Vars) {
t.Errorf("(%v) Vars not equal: expected %v, got %v", test.title, vars, match.Vars)
return
}
if host != "" {
u, _ := test.route.URLHost(mapToPairs(match.Vars)...)
if host != u.Host {
t.Errorf("(%v) URLHost not equal: expected %v, got %v -- %v", test.title, host, u.Host, getRouteTemplate(route))
return
}
}
if path != "" {
u, _ := route.URLPath(mapToPairs(match.Vars)...)
if path != u.Path {
t.Errorf("(%v) URLPath not equal: expected %v, got %v -- %v", test.title, path, u.Path, getRouteTemplate(route))
return
}
}
if url != "" {
u, _ := route.URL(mapToPairs(match.Vars)...)
if url != u.Host+u.Path {
t.Errorf("(%v) URL not equal: expected %v, got %v -- %v", test.title, url, u.Host+u.Path, getRouteTemplate(route))
return
}
}
if shouldRedirect && match.Handler == nil {
t.Errorf("(%v) Did not redirect", test.title)
return
}
if !shouldRedirect && match.Handler != nil {
t.Errorf("(%v) Unexpected redirect", test.title)
return
}
}
}
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}
type TestA301ResponseWriter struct {
hh http.Header
status int
}
func (ho TestA301ResponseWriter) Header() http.Header {
return http.Header(ho.hh)
}
func (ho TestA301ResponseWriter) Write(b []byte) (int, error) {
return 0, nil
}
func (ho TestA301ResponseWriter) WriteHeader(code int) {
ho.status = code
}
func Test301Redirect(t *testing.T) {
m := make(http.Header)
func1 := func(w http.ResponseWriter, r *http.Request) {}
func2 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/api/", func2).Name("func2")
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost//api/?abc=def", nil)
res := TestA301ResponseWriter{
hh: m,
status: 0,
}
r.ServeHTTP(&res, req)
if "http://localhost/api/?abc=def" != res.hh["Location"][0] {
t.Errorf("Should have complete URL with query string")
}
}
// https://plus.google.com/101022900381697718949/posts/eWy6DjFJ6uW
func TestSubrouterHeader(t *testing.T) {
expected := "func1 response"
func1 := func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, expected)
}
func2 := func(http.ResponseWriter, *http.Request) {}
r := NewRouter()
s := r.Headers("SomeSpecialHeader", "").Subrouter()
s.HandleFunc("/", func1).Name("func1")
r.HandleFunc("/", func2).Name("func2")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
req.Header.Add("SomeSpecialHeader", "foo")
match := new(RouteMatch)
matched := r.Match(req, match)
if !matched {
t.Errorf("Should match request")
}
if match.Route.GetName() != "func1" {
t.Errorf("Expecting func1 handler, got %s", match.Route.GetName())
}
resp := NewRecorder()
match.Handler.ServeHTTP(resp, req)
if resp.Body.String() != expected {
t.Errorf("Expecting %q", expected)
}
}
// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
p := make([]string, len(m)*2)
for k, v := range m {
p[i] = k
p[i+1] = v
i += 2
}
return p
}
// stringMapEqual checks the equality of two string maps
func stringMapEqual(m1, m2 map[string]string) bool {
nil1 := m1 == nil
nil2 := m2 == nil
if nil1 != nil2 || len(m1) != len(m2) {
return false
}
for k, v := range m1 {
if v != m2[k] {
return false
}
}
return true
}
// newRequest is a helper function to create a new request with a method and url
func newRequest(method, url string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
return req
}

View file

@ -0,0 +1,714 @@
// Old tests ported to Go1. This is a mess. Want to drop it one day.
// Copyright 2011 Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"bytes"
"net/http"
"testing"
)
// ----------------------------------------------------------------------------
// ResponseRecorder
// ----------------------------------------------------------------------------
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// ResponseRecorder is an implementation of http.ResponseWriter that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
Code int // the HTTP response code from WriteHeader
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
}
// NewRecorder returns an initialized ResponseRecorder.
func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
}
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
return rw.HeaderMap
}
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
if rw.Body != nil {
rw.Body.Write(buf)
}
if rw.Code == 0 {
rw.Code = http.StatusOK
}
return len(buf), nil
}
// WriteHeader sets rw.Code.
func (rw *ResponseRecorder) WriteHeader(code int) {
rw.Code = code
}
// Flush sets rw.Flushed to true.
func (rw *ResponseRecorder) Flush() {
rw.Flushed = true
}
// ----------------------------------------------------------------------------
func TestRouteMatchers(t *testing.T) {
var scheme, host, path, query, method string
var headers map[string]string
var resultVars map[bool]map[string]string
router := NewRouter()
router.NewRoute().Host("{var1}.google.com").
Path("/{var2:[a-z]+}/{var3:[0-9]+}").
Queries("foo", "bar").
Methods("GET").
Schemes("https").
Headers("x-requested-with", "XMLHttpRequest")
router.NewRoute().Host("www.{var4}.com").
PathPrefix("/foo/{var5:[a-z]+}/{var6:[0-9]+}").
Queries("baz", "ding").
Methods("POST").
Schemes("http").
Headers("Content-Type", "application/json")
reset := func() {
// Everything match.
scheme = "https"
host = "www.google.com"
path = "/product/42"
query = "?foo=bar"
method = "GET"
headers = map[string]string{"X-Requested-With": "XMLHttpRequest"}
resultVars = map[bool]map[string]string{
true: {"var1": "www", "var2": "product", "var3": "42"},
false: {},
}
}
reset2 := func() {
// Everything match.
scheme = "http"
host = "www.google.com"
path = "/foo/product/42/path/that/is/ignored"
query = "?baz=ding"
method = "POST"
headers = map[string]string{"Content-Type": "application/json"}
resultVars = map[bool]map[string]string{
true: {"var4": "google", "var5": "product", "var6": "42"},
false: {},
}
}
match := func(shouldMatch bool) {
url := scheme + "://" + host + path + query
request, _ := http.NewRequest(method, url, nil)
for key, value := range headers {
request.Header.Add(key, value)
}
var routeMatch RouteMatch
matched := router.Match(request, &routeMatch)
if matched != shouldMatch {
// Need better messages. :)
if matched {
t.Errorf("Should match.")
} else {
t.Errorf("Should not match.")
}
}
if matched {
currentRoute := routeMatch.Route
if currentRoute == nil {
t.Errorf("Expected a current route.")
}
vars := routeMatch.Vars
expectedVars := resultVars[shouldMatch]
if len(vars) != len(expectedVars) {
t.Errorf("Expected vars: %v Got: %v.", expectedVars, vars)
}
for name, value := range vars {
if expectedVars[name] != value {
t.Errorf("Expected vars: %v Got: %v.", expectedVars, vars)
}
}
}
}
// 1st route --------------------------------------------------------------
// Everything match.
reset()
match(true)
// Scheme doesn't match.
reset()
scheme = "http"
match(false)
// Host doesn't match.
reset()
host = "www.mygoogle.com"
match(false)
// Path doesn't match.
reset()
path = "/product/notdigits"
match(false)
// Query doesn't match.
reset()
query = "?foo=baz"
match(false)
// Method doesn't match.
reset()
method = "POST"
match(false)
// Header doesn't match.
reset()
headers = map[string]string{}
match(false)
// Everything match, again.
reset()
match(true)
// 2nd route --------------------------------------------------------------
// Everything match.
reset2()
match(true)
// Scheme doesn't match.
reset2()
scheme = "https"
match(false)
// Host doesn't match.
reset2()
host = "sub.google.com"
match(false)
// Path doesn't match.
reset2()
path = "/bar/product/42"
match(false)
// Query doesn't match.
reset2()
query = "?foo=baz"
match(false)
// Method doesn't match.
reset2()
method = "GET"
match(false)
// Header doesn't match.
reset2()
headers = map[string]string{}
match(false)
// Everything match, again.
reset2()
match(true)
}
type headerMatcherTest struct {
matcher headerMatcher
headers map[string]string
result bool
}
var headerMatcherTests = []headerMatcherTest{
{
matcher: headerMatcher(map[string]string{"x-requested-with": "XMLHttpRequest"}),
headers: map[string]string{"X-Requested-With": "XMLHttpRequest"},
result: true,
},
{
matcher: headerMatcher(map[string]string{"x-requested-with": ""}),
headers: map[string]string{"X-Requested-With": "anything"},
result: true,
},
{
matcher: headerMatcher(map[string]string{"x-requested-with": "XMLHttpRequest"}),
headers: map[string]string{},
result: false,
},
}
type hostMatcherTest struct {
matcher *Route
url string
vars map[string]string
result bool
}
var hostMatcherTests = []hostMatcherTest{
{
matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}"),
url: "http://abc.def.ghi/",
vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi"},
result: true,
},
{
matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}"),
url: "http://a.b.c/",
vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi"},
result: false,
},
}
type methodMatcherTest struct {
matcher methodMatcher
method string
result bool
}
var methodMatcherTests = []methodMatcherTest{
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
method: "GET",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
method: "POST",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
method: "PUT",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
method: "DELETE",
result: false,
},
}
type pathMatcherTest struct {
matcher *Route
url string
vars map[string]string
result bool
}
var pathMatcherTests = []pathMatcherTest{
{
matcher: NewRouter().NewRoute().Path("/{foo:[0-9][0-9][0-9]}/{bar:[0-9][0-9][0-9]}/{baz:[0-9][0-9][0-9]}"),
url: "http://localhost:8080/123/456/789",
vars: map[string]string{"foo": "123", "bar": "456", "baz": "789"},
result: true,
},
{
matcher: NewRouter().NewRoute().Path("/{foo:[0-9][0-9][0-9]}/{bar:[0-9][0-9][0-9]}/{baz:[0-9][0-9][0-9]}"),
url: "http://localhost:8080/1/2/3",
vars: map[string]string{"foo": "123", "bar": "456", "baz": "789"},
result: false,
},
}
type schemeMatcherTest struct {
matcher schemeMatcher
url string
result bool
}
var schemeMatcherTests = []schemeMatcherTest{
{
matcher: schemeMatcher([]string{"http", "https"}),
url: "http://localhost:8080/",
result: true,
},
{
matcher: schemeMatcher([]string{"http", "https"}),
url: "https://localhost:8080/",
result: true,
},
{
matcher: schemeMatcher([]string{"https"}),
url: "http://localhost:8080/",
result: false,
},
{
matcher: schemeMatcher([]string{"http"}),
url: "https://localhost:8080/",
result: false,
},
}
type urlBuildingTest struct {
route *Route
vars []string
url string
}
var urlBuildingTests = []urlBuildingTest{
{
route: new(Route).Host("foo.domain.com"),
vars: []string{},
url: "http://foo.domain.com",
},
{
route: new(Route).Host("{subdomain}.domain.com"),
vars: []string{"subdomain", "bar"},
url: "http://bar.domain.com",
},
{
route: new(Route).Host("foo.domain.com").Path("/articles"),
vars: []string{},
url: "http://foo.domain.com/articles",
},
{
route: new(Route).Path("/articles"),
vars: []string{},
url: "/articles",
},
{
route: new(Route).Path("/articles/{category}/{id:[0-9]+}"),
vars: []string{"category", "technology", "id", "42"},
url: "/articles/technology/42",
},
{
route: new(Route).Host("{subdomain}.domain.com").Path("/articles/{category}/{id:[0-9]+}"),
vars: []string{"subdomain", "foo", "category", "technology", "id", "42"},
url: "http://foo.domain.com/articles/technology/42",
},
}
func TestHeaderMatcher(t *testing.T) {
for _, v := range headerMatcherTests {
request, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
for key, value := range v.headers {
request.Header.Add(key, value)
}
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, request.Header)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, request.Header)
}
}
}
}
func TestHostMatcher(t *testing.T) {
for _, v := range hostMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
vars := routeMatch.Vars
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, v.url)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, v.url)
}
}
if result {
if len(vars) != len(v.vars) {
t.Errorf("%#v: vars length should be %v, got %v.", v.matcher, len(v.vars), len(vars))
}
for name, value := range vars {
if v.vars[name] != value {
t.Errorf("%#v: expected value %v for key %v, got %v.", v.matcher, v.vars[name], name, value)
}
}
} else {
if len(vars) != 0 {
t.Errorf("%#v: vars length should be 0, got %v.", v.matcher, len(vars))
}
}
}
}
func TestMethodMatcher(t *testing.T) {
for _, v := range methodMatcherTests {
request, _ := http.NewRequest(v.method, "http://localhost:8080/", nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, v.method)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, v.method)
}
}
}
}
func TestPathMatcher(t *testing.T) {
for _, v := range pathMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
vars := routeMatch.Vars
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, v.url)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, v.url)
}
}
if result {
if len(vars) != len(v.vars) {
t.Errorf("%#v: vars length should be %v, got %v.", v.matcher, len(v.vars), len(vars))
}
for name, value := range vars {
if v.vars[name] != value {
t.Errorf("%#v: expected value %v for key %v, got %v.", v.matcher, v.vars[name], name, value)
}
}
} else {
if len(vars) != 0 {
t.Errorf("%#v: vars length should be 0, got %v.", v.matcher, len(vars))
}
}
}
}
func TestSchemeMatcher(t *testing.T) {
for _, v := range schemeMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, v.url)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, v.url)
}
}
}
}
func TestUrlBuilding(t *testing.T) {
for _, v := range urlBuildingTests {
u, _ := v.route.URL(v.vars...)
url := u.String()
if url != v.url {
t.Errorf("expected %v, got %v", v.url, url)
/*
reversePath := ""
reverseHost := ""
if v.route.pathTemplate != nil {
reversePath = v.route.pathTemplate.Reverse
}
if v.route.hostTemplate != nil {
reverseHost = v.route.hostTemplate.Reverse
}
t.Errorf("%#v:\nexpected: %q\ngot: %q\nreverse path: %q\nreverse host: %q", v.route, v.url, url, reversePath, reverseHost)
*/
}
}
ArticleHandler := func(w http.ResponseWriter, r *http.Request) {
}
router := NewRouter()
router.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).Name("article")
url, _ := router.Get("article").URL("category", "technology", "id", "42")
expected := "/articles/technology/42"
if url.String() != expected {
t.Errorf("Expected %v, got %v", expected, url.String())
}
}
func TestMatchedRouteName(t *testing.T) {
routeName := "stock"
router := NewRouter()
route := router.NewRoute().Path("/products/").Name(routeName)
url := "http://www.domain.com/products/"
request, _ := http.NewRequest("GET", url, nil)
var rv RouteMatch
ok := router.Match(request, &rv)
if !ok || rv.Route != route {
t.Errorf("Expected same route, got %+v.", rv.Route)
}
retName := rv.Route.GetName()
if retName != routeName {
t.Errorf("Expected %q, got %q.", routeName, retName)
}
}
func TestSubRouting(t *testing.T) {
// Example from docs.
router := NewRouter()
subrouter := router.NewRoute().Host("www.domain.com").Subrouter()
route := subrouter.NewRoute().Path("/products/").Name("products")
url := "http://www.domain.com/products/"
request, _ := http.NewRequest("GET", url, nil)
var rv RouteMatch
ok := router.Match(request, &rv)
if !ok || rv.Route != route {
t.Errorf("Expected same route, got %+v.", rv.Route)
}
u, _ := router.Get("products").URL()
builtUrl := u.String()
// Yay, subroute aware of the domain when building!
if builtUrl != url {
t.Errorf("Expected %q, got %q.", url, builtUrl)
}
}
func TestVariableNames(t *testing.T) {
route := new(Route).Host("{arg1}.domain.com").Path("/{arg1}/{arg2:[0-9]+}")
if route.err == nil {
t.Errorf("Expected error for duplicated variable names")
}
}
func TestRedirectSlash(t *testing.T) {
var route *Route
var routeMatch RouteMatch
r := NewRouter()
r.StrictSlash(false)
route = r.NewRoute()
if route.strictSlash != false {
t.Errorf("Expected false redirectSlash.")
}
r.StrictSlash(true)
route = r.NewRoute()
if route.strictSlash != true {
t.Errorf("Expected true redirectSlash.")
}
route = new(Route)
route.strictSlash = true
route.Path("/{arg1}/{arg2:[0-9]+}/")
request, _ := http.NewRequest("GET", "http://localhost/foo/123", nil)
routeMatch = RouteMatch{}
_ = route.Match(request, &routeMatch)
vars := routeMatch.Vars
if vars["arg1"] != "foo" {
t.Errorf("Expected foo.")
}
if vars["arg2"] != "123" {
t.Errorf("Expected 123.")
}
rsp := NewRecorder()
routeMatch.Handler.ServeHTTP(rsp, request)
if rsp.HeaderMap.Get("Location") != "http://localhost/foo/123/" {
t.Errorf("Expected redirect header.")
}
route = new(Route)
route.strictSlash = true
route.Path("/{arg1}/{arg2:[0-9]+}")
request, _ = http.NewRequest("GET", "http://localhost/foo/123/", nil)
routeMatch = RouteMatch{}
_ = route.Match(request, &routeMatch)
vars = routeMatch.Vars
if vars["arg1"] != "foo" {
t.Errorf("Expected foo.")
}
if vars["arg2"] != "123" {
t.Errorf("Expected 123.")
}
rsp = NewRecorder()
routeMatch.Handler.ServeHTTP(rsp, request)
if rsp.HeaderMap.Get("Location") != "http://localhost/foo/123" {
t.Errorf("Expected redirect header.")
}
}
// Test for the new regexp library, still not available in stable Go.
func TestNewRegexp(t *testing.T) {
var p *routeRegexp
var matches []string
tests := map[string]map[string][]string{
"/{foo:a{2}}": {
"/a": nil,
"/aa": {"aa"},
"/aaa": nil,
"/aaaa": nil,
},
"/{foo:a{2,}}": {
"/a": nil,
"/aa": {"aa"},
"/aaa": {"aaa"},
"/aaaa": {"aaaa"},
},
"/{foo:a{2,3}}": {
"/a": nil,
"/aa": {"aa"},
"/aaa": {"aaa"},
"/aaaa": nil,
},
"/{foo:[a-z]{3}}/{bar:[a-z]{2}}": {
"/a": nil,
"/ab": nil,
"/abc": nil,
"/abcd": nil,
"/abc/ab": {"abc", "ab"},
"/abc/abc": nil,
"/abcd/ab": nil,
},
`/{foo:\w{3,}}/{bar:\d{2,}}`: {
"/a": nil,
"/ab": nil,
"/abc": nil,
"/abc/1": nil,
"/abc/12": {"abc", "12"},
"/abcd/12": {"abcd", "12"},
"/abcd/123": {"abcd", "123"},
},
}
for pattern, paths := range tests {
p, _ = newRouteRegexp(pattern, false, false, false, false)
for path, result := range paths {
matches = p.regexp.FindStringSubmatch(path)
if result == nil {
if matches != nil {
t.Errorf("%v should not match %v.", pattern, path)
}
} else {
if len(matches) != len(result)+1 {
t.Errorf("Expected %v matches, got %v.", len(result)+1, len(matches))
} else {
for k, v := range result {
if matches[k+1] != v {
t.Errorf("Expected %v, got %v.", v, matches[k+1])
}
}
}
}
}
}
}

276
Godeps/_workspace/src/github.com/gorilla/mux/regexp.go generated vendored Normal file
View file

@ -0,0 +1,276 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"bytes"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
)
// newRouteRegexp parses a route template and returns a routeRegexp,
// used to match a host, a path or a query string.
//
// It will extract named variables, assemble a regexp to be matched, create
// a "reverse" template to build URLs and compile regexps to validate variable
// values used in URL building.
//
// Previously we accepted only Python-like identifiers for variable
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that
// name and pattern can't be empty, and names can't contain a colon.
func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash bool) (*routeRegexp, error) {
// Check if it is well-formed.
idxs, errBraces := braceIndices(tpl)
if errBraces != nil {
return nil, errBraces
}
// Backup the original.
template := tpl
// Now let's parse it.
defaultPattern := "[^/]+"
if matchQuery {
defaultPattern = "[^?&]+"
matchPrefix = true
} else if matchHost {
defaultPattern = "[^.]+"
matchPrefix = false
}
// Only match strict slash if not matching
if matchPrefix || matchHost || matchQuery {
strictSlash = false
}
// Set a flag for strictSlash.
endSlash := false
if strictSlash && strings.HasSuffix(tpl, "/") {
tpl = tpl[:len(tpl)-1]
endSlash = true
}
varsN := make([]string, len(idxs)/2)
varsR := make([]*regexp.Regexp, len(idxs)/2)
pattern := bytes.NewBufferString("")
if !matchQuery {
pattern.WriteByte('^')
}
reverse := bytes.NewBufferString("")
var end int
var err error
for i := 0; i < len(idxs); i += 2 {
// Set all values we are interested in.
raw := tpl[end:idxs[i]]
end = idxs[i+1]
parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2)
name := parts[0]
patt := defaultPattern
if len(parts) == 2 {
patt = parts[1]
}
// Name or pattern can't be empty.
if name == "" || patt == "" {
return nil, fmt.Errorf("mux: missing name or pattern in %q",
tpl[idxs[i]:end])
}
// Build the regexp pattern.
fmt.Fprintf(pattern, "%s(%s)", regexp.QuoteMeta(raw), patt)
// Build the reverse template.
fmt.Fprintf(reverse, "%s%%s", raw)
// Append variable name and compiled pattern.
varsN[i/2] = name
varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
if err != nil {
return nil, err
}
}
// Add the remaining.
raw := tpl[end:]
pattern.WriteString(regexp.QuoteMeta(raw))
if strictSlash {
pattern.WriteString("[/]?")
}
if !matchPrefix {
pattern.WriteByte('$')
}
reverse.WriteString(raw)
if endSlash {
reverse.WriteByte('/')
}
// Compile full regexp.
reg, errCompile := regexp.Compile(pattern.String())
if errCompile != nil {
return nil, errCompile
}
// Done!
return &routeRegexp{
template: template,
matchHost: matchHost,
matchQuery: matchQuery,
strictSlash: strictSlash,
regexp: reg,
reverse: reverse.String(),
varsN: varsN,
varsR: varsR,
}, nil
}
// routeRegexp stores a regexp to match a host or path and information to
// collect and validate route variables.
type routeRegexp struct {
// The unmodified template.
template string
// True for host match, false for path or query string match.
matchHost bool
// True for query string match, false for path and host match.
matchQuery bool
// The strictSlash value defined on the route, but disabled if PathPrefix was used.
strictSlash bool
// Expanded regexp.
regexp *regexp.Regexp
// Reverse template.
reverse string
// Variable names.
varsN []string
// Variable regexps (validators).
varsR []*regexp.Regexp
}
// Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if !r.matchHost {
if r.matchQuery {
return r.regexp.MatchString(req.URL.RawQuery)
} else {
return r.regexp.MatchString(req.URL.Path)
}
}
return r.regexp.MatchString(getHost(req))
}
// url builds a URL part using the given values.
func (r *routeRegexp) url(pairs ...string) (string, error) {
values, err := mapFromPairs(pairs...)
if err != nil {
return "", err
}
urlValues := make([]interface{}, len(r.varsN))
for k, v := range r.varsN {
value, ok := values[v]
if !ok {
return "", fmt.Errorf("mux: missing route variable %q", v)
}
urlValues[k] = value
}
rv := fmt.Sprintf(r.reverse, urlValues...)
if !r.regexp.MatchString(rv) {
// The URL is checked against the full regexp, instead of checking
// individual variables. This is faster but to provide a good error
// message, we check individual regexps if the URL doesn't match.
for k, v := range r.varsN {
if !r.varsR[k].MatchString(values[v]) {
return "", fmt.Errorf(
"mux: variable %q doesn't match, expected %q", values[v],
r.varsR[k].String())
}
}
}
return rv, nil
}
// braceIndices returns the first level curly brace indices from a string.
// It returns an error in case of unbalanced braces.
func braceIndices(s string) ([]int, error) {
var level, idx int
idxs := make([]int, 0)
for i := 0; i < len(s); i++ {
switch s[i] {
case '{':
if level++; level == 1 {
idx = i
}
case '}':
if level--; level == 0 {
idxs = append(idxs, idx, i+1)
} else if level < 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
}
}
if level != 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
return idxs, nil
}
// ----------------------------------------------------------------------------
// routeRegexpGroup
// ----------------------------------------------------------------------------
// routeRegexpGroup groups the route matchers that carry variables.
type routeRegexpGroup struct {
host *routeRegexp
path *routeRegexp
queries []*routeRegexp
}
// setMatch extracts the variables from the URL once a route matches.
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables.
if v.host != nil {
hostVars := v.host.regexp.FindStringSubmatch(getHost(req))
if hostVars != nil {
for k, v := range v.host.varsN {
m.Vars[v] = hostVars[k+1]
}
}
}
// Store path variables.
if v.path != nil {
pathVars := v.path.regexp.FindStringSubmatch(req.URL.Path)
if pathVars != nil {
for k, v := range v.path.varsN {
m.Vars[v] = pathVars[k+1]
}
// Check if we should redirect.
if v.path.strictSlash {
p1 := strings.HasSuffix(req.URL.Path, "/")
p2 := strings.HasSuffix(v.path.template, "/")
if p1 != p2 {
u, _ := url.Parse(req.URL.String())
if p1 {
u.Path = u.Path[:len(u.Path)-1]
} else {
u.Path += "/"
}
m.Handler = http.RedirectHandler(u.String(), 301)
}
}
}
}
// Store query string variables.
rawQuery := req.URL.RawQuery
for _, q := range v.queries {
queryVars := q.regexp.FindStringSubmatch(rawQuery)
if queryVars != nil {
for k, v := range q.varsN {
m.Vars[v] = queryVars[k+1]
}
}
}
}
// getHost tries its best to return the request host.
func getHost(r *http.Request) string {
if r.URL.IsAbs() {
return r.URL.Host
}
host := r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
}

524
Godeps/_workspace/src/github.com/gorilla/mux/route.go generated vendored Normal file
View file

@ -0,0 +1,524 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mux
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
)
// Route stores information to match a request and build URLs.
type Route struct {
// Parent where the route was registered (a Router).
parent parentRoute
// Request handler for the route.
handler http.Handler
// List of matchers.
matchers []matcher
// Manager for the variables from host and path.
regexp *routeRegexpGroup
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, this route never matches: it is only used to build URLs.
buildOnly bool
// The name used to build URLs.
name string
// Error resulted from building a route.
err error
}
// Match matches the route against the request.
func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if r.buildOnly || r.err != nil {
return false
}
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
return false
}
}
// Yay, we have a match. Let's collect some info about it.
if match.Route == nil {
match.Route = r
}
if match.Handler == nil {
match.Handler = r.handler
}
if match.Vars == nil {
match.Vars = make(map[string]string)
}
// Set variables.
if r.regexp != nil {
r.regexp.setMatch(req, match, r)
}
return true
}
// ----------------------------------------------------------------------------
// Route attributes
// ----------------------------------------------------------------------------
// GetError returns an error resulted from building the route, if any.
func (r *Route) GetError() error {
return r.err
}
// BuildOnly sets the route to never match: it is only used to build URLs.
func (r *Route) BuildOnly() *Route {
r.buildOnly = true
return r
}
// Handler --------------------------------------------------------------------
// Handler sets a handler for the route.
func (r *Route) Handler(handler http.Handler) *Route {
if r.err == nil {
r.handler = handler
}
return r
}
// HandlerFunc sets a handler function for the route.
func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route {
return r.Handler(http.HandlerFunc(f))
}
// GetHandler returns the handler for the route, if any.
func (r *Route) GetHandler() http.Handler {
return r.handler
}
// Name -----------------------------------------------------------------------
// Name sets the name for the route, used to build URLs.
// If the name was registered already it will be overwritten.
func (r *Route) Name(name string) *Route {
if r.name != "" {
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
r.name, name)
}
if r.err == nil {
r.name = name
r.getNamedRoutes()[name] = r
}
return r
}
// GetName returns the name for the route, if any.
func (r *Route) GetName() string {
return r.name
}
// ----------------------------------------------------------------------------
// Matchers
// ----------------------------------------------------------------------------
// matcher types try to match a request.
type matcher interface {
Match(*http.Request, *RouteMatch) bool
}
// addMatcher adds a matcher to the route.
func (r *Route) addMatcher(m matcher) *Route {
if r.err == nil {
r.matchers = append(r.matchers, m)
}
return r
}
// addRegexpMatcher adds a host or path matcher and builder to a route.
func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery bool) error {
if r.err != nil {
return r.err
}
r.regexp = r.getRegexpGroup()
if !matchHost && !matchQuery {
if len(tpl) == 0 || tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
}
if r.regexp.path != nil {
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl
}
}
rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash)
if err != nil {
return err
}
for _, q := range r.regexp.queries {
if err = uniqueVars(rr.varsN, q.varsN); err != nil {
return err
}
}
if matchHost {
if r.regexp.path != nil {
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil {
return err
}
}
r.regexp.host = rr
} else {
if r.regexp.host != nil {
if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil {
return err
}
}
if matchQuery {
r.regexp.queries = append(r.regexp.queries, rr)
} else {
r.regexp.path = rr
}
}
r.addMatcher(rr)
return nil
}
// Headers --------------------------------------------------------------------
// headerMatcher matches the request against header values.
type headerMatcher map[string]string
func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchMap(m, r.Header, true)
}
// Headers adds a matcher for request header values.
// It accepts a sequence of key/value pairs to be matched. For example:
//
// r := mux.NewRouter()
// r.Headers("Content-Type", "application/json",
// "X-Requested-With", "XMLHttpRequest")
//
// The above route will only match if both request header values match.
//
// It the value is an empty string, it will match any value if the key is set.
func (r *Route) Headers(pairs ...string) *Route {
if r.err == nil {
var headers map[string]string
headers, r.err = mapFromPairs(pairs...)
return r.addMatcher(headerMatcher(headers))
}
return r
}
// Host -----------------------------------------------------------------------
// Host adds a matcher for the URL host.
// It accepts a template with zero or more URL variables enclosed by {}.
// Variables can define an optional regexp pattern to me matched:
//
// - {name} matches anything until the next dot.
//
// - {name:pattern} matches the given regexp pattern.
//
// For example:
//
// r := mux.NewRouter()
// r.Host("www.domain.com")
// r.Host("{subdomain}.domain.com")
// r.Host("{subdomain:[a-z]+}.domain.com")
//
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Host(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, true, false, false)
return r
}
// MatcherFunc ----------------------------------------------------------------
// MatcherFunc is the function signature used by custom matchers.
type MatcherFunc func(*http.Request, *RouteMatch) bool
func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool {
return m(r, match)
}
// MatcherFunc adds a custom function to be used as request matcher.
func (r *Route) MatcherFunc(f MatcherFunc) *Route {
return r.addMatcher(f)
}
// Methods --------------------------------------------------------------------
// methodMatcher matches the request against HTTP methods.
type methodMatcher []string
func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}
// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) Methods(methods ...string) *Route {
for k, v := range methods {
methods[k] = strings.ToUpper(v)
}
return r.addMatcher(methodMatcher(methods))
}
// Path -----------------------------------------------------------------------
// Path adds a matcher for the URL path.
// It accepts a template with zero or more URL variables enclosed by {}. The
// template must start with a "/".
// Variables can define an optional regexp pattern to me matched:
//
// - {name} matches anything until the next slash.
//
// - {name:pattern} matches the given regexp pattern.
//
// For example:
//
// r := mux.NewRouter()
// r.Path("/products/").Handler(ProductsHandler)
// r.Path("/products/{key}").Handler(ProductsHandler)
// r.Path("/articles/{category}/{id:[0-9]+}").
// Handler(ArticleHandler)
//
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Path(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, false, false, false)
return r
}
// PathPrefix -----------------------------------------------------------------
// PathPrefix adds a matcher for the URL path prefix. This matches if the given
// template is a prefix of the full URL path. See Route.Path() for details on
// the tpl argument.
//
// Note that it does not treat slashes specially ("/foobar/" will be matched by
// the prefix "/foo") so you may want to use a trailing slash here.
//
// Also note that the setting of Router.StrictSlash() has no effect on routes
// with a PathPrefix matcher.
func (r *Route) PathPrefix(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, false, true, false)
return r
}
// Query ----------------------------------------------------------------------
// Queries adds a matcher for URL query values.
// It accepts a sequence of key/value pairs. Values may define variables.
// For example:
//
// r := mux.NewRouter()
// r.Queries("foo", "bar", "id", "{id:[0-9]+}")
//
// The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&id=42.
//
// It the value is an empty string, it will match any value if the key is set.
//
// Variables can define an optional regexp pattern to me matched:
//
// - {name} matches anything until the next slash.
//
// - {name:pattern} matches the given regexp pattern.
func (r *Route) Queries(pairs ...string) *Route {
length := len(pairs)
if length%2 != 0 {
r.err = fmt.Errorf(
"mux: number of parameters must be multiple of 2, got %v", pairs)
return nil
}
for i := 0; i < length; i += 2 {
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], false, true, true); r.err != nil {
return r
}
}
return r
}
// Schemes --------------------------------------------------------------------
// schemeMatcher matches the request against URL schemes.
type schemeMatcher []string
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.URL.Scheme)
}
// Schemes adds a matcher for URL schemes.
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
}
return r.addMatcher(schemeMatcher(schemes))
}
// Subrouter ------------------------------------------------------------------
// Subrouter creates a subrouter for the route.
//
// It will test the inner routes only if the parent route matched. For example:
//
// r := mux.NewRouter()
// s := r.Host("www.domain.com").Subrouter()
// s.HandleFunc("/products/", ProductsHandler)
// s.HandleFunc("/products/{key}", ProductHandler)
// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
//
// Here, the routes registered in the subrouter won't be tested if the host
// doesn't match.
func (r *Route) Subrouter() *Router {
router := &Router{parent: r, strictSlash: r.strictSlash}
r.addMatcher(router)
return router
}
// ----------------------------------------------------------------------------
// URL building
// ----------------------------------------------------------------------------
// URL builds a URL for the route.
//
// It accepts a sequence of key/value pairs for the route variables. For
// example, given this route:
//
// r := mux.NewRouter()
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Name("article")
//
// ...a URL for it can be built using:
//
// url, err := r.Get("article").URL("category", "technology", "id", "42")
//
// ...which will return an url.URL with the following path:
//
// "/articles/technology/42"
//
// This also works for host variables:
//
// r := mux.NewRouter()
// r.Host("{subdomain}.domain.com").
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Name("article")
//
// // url.String() will be "http://news.domain.com/articles/technology/42"
// url, err := r.Get("article").URL("subdomain", "news",
// "category", "technology",
// "id", "42")
//
// All variables defined in the route are required, and their values must
// conform to the corresponding patterns.
func (r *Route) URL(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil {
return nil, errors.New("mux: route doesn't have a host or path")
}
var scheme, host, path string
var err error
if r.regexp.host != nil {
// Set a default scheme.
scheme = "http"
if host, err = r.regexp.host.url(pairs...); err != nil {
return nil, err
}
}
if r.regexp.path != nil {
if path, err = r.regexp.path.url(pairs...); err != nil {
return nil, err
}
}
return &url.URL{
Scheme: scheme,
Host: host,
Path: path,
}, nil
}
// URLHost builds the host part of the URL for a route. See Route.URL().
//
// The route must have a host defined.
func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.host == nil {
return nil, errors.New("mux: route doesn't have a host")
}
host, err := r.regexp.host.url(pairs...)
if err != nil {
return nil, err
}
return &url.URL{
Scheme: "http",
Host: host,
}, nil
}
// URLPath builds the path part of the URL for a route. See Route.URL().
//
// The route must have a path defined.
func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.path == nil {
return nil, errors.New("mux: route doesn't have a path")
}
path, err := r.regexp.path.url(pairs...)
if err != nil {
return nil, err
}
return &url.URL{
Path: path,
}, nil
}
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
// parentRoute allows routes to know about parent host and path definitions.
type parentRoute interface {
getNamedRoutes() map[string]*Route
getRegexpGroup() *routeRegexpGroup
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Route) getNamedRoutes() map[string]*Route {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
return r.parent.getNamedRoutes()
}
// getRegexpGroup returns regexp definitions from this route.
func (r *Route) getRegexpGroup() *routeRegexpGroup {
if r.regexp == nil {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
regexp := r.parent.getRegexpGroup()
if regexp == nil {
r.regexp = new(routeRegexpGroup)
} else {
// Copy.
r.regexp = &routeRegexpGroup{
host: regexp.host,
path: regexp.path,
queries: regexp.queries,
}
}
}
return r.regexp
}

View file

@ -0,0 +1,7 @@
language: go
go:
- 1.0
- 1.1
- 1.2
- tip

View file

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,3 @@
schema
======
[![Build Status](https://travis-ci.org/gorilla/schema.png?branch=master)](https://travis-ci.org/gorilla/schema)

View file

@ -0,0 +1,221 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"errors"
"reflect"
"strconv"
"strings"
"sync"
)
var invalidPath = errors.New("schema: invalid path")
// newCache returns a new cache.
func newCache() *cache {
c := cache{
m: make(map[reflect.Type]*structInfo),
conv: make(map[reflect.Type]Converter),
tag: "schema",
}
for k, v := range converters {
c.conv[k] = v
}
return &c
}
// cache caches meta-data about a struct.
type cache struct {
l sync.RWMutex
m map[reflect.Type]*structInfo
conv map[reflect.Type]Converter
tag string
}
// parsePath parses a path in dotted notation verifying that it is a valid
// path to a struct field.
//
// It returns "path parts" which contain indices to fields to be used by
// reflect.Value.FieldByString(). Multiple parts are required for slices of
// structs.
func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
var struc *structInfo
var field *fieldInfo
var index64 int64
var err error
parts := make([]pathPart, 0)
path := make([]string, 0)
keys := strings.Split(p, ".")
for i := 0; i < len(keys); i++ {
if struc = c.get(t); struc == nil {
return nil, invalidPath
}
if field = struc.get(keys[i]); field == nil {
return nil, invalidPath
}
// Valid field. Append index.
path = append(path, field.name)
if field.ss {
// Parse a special case: slices of structs.
// i+1 must be the slice index, and i+2 must exist.
i++
if i+1 >= len(keys) {
return nil, invalidPath
}
if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil {
return nil, invalidPath
}
parts = append(parts, pathPart{
path: path,
field: field,
index: int(index64),
})
path = make([]string, 0)
// Get the next struct type, dropping ptrs.
if field.typ.Kind() == reflect.Ptr {
t = field.typ.Elem()
} else {
t = field.typ
}
if t.Kind() == reflect.Slice {
t = t.Elem()
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
}
} else if field.typ.Kind() == reflect.Struct {
t = field.typ
} else if field.typ.Kind() == reflect.Ptr && field.typ.Elem().Kind() == reflect.Struct {
t = field.typ.Elem()
}
}
// Add the remaining.
parts = append(parts, pathPart{
path: path,
field: field,
index: -1,
})
return parts, nil
}
// get returns a cached structInfo, creating it if necessary.
func (c *cache) get(t reflect.Type) *structInfo {
c.l.RLock()
info := c.m[t]
c.l.RUnlock()
if info == nil {
info = c.create(t, nil)
c.l.Lock()
c.m[t] = info
c.l.Unlock()
}
return info
}
// creat creates a structInfo with meta-data about a struct.
func (c *cache) create(t reflect.Type, info *structInfo) *structInfo {
if info == nil {
info = &structInfo{fields: []*fieldInfo{}}
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Anonymous {
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if ft.Kind() == reflect.Struct {
c.create(ft, info)
}
}
c.createField(field, info)
}
return info
}
// createField creates a fieldInfo for the given field.
func (c *cache) createField(field reflect.StructField, info *structInfo) {
alias := fieldAlias(field, c.tag)
if alias == "-" {
// Ignore this field.
return
}
// Check if the type is supported and don't cache it if not.
// First let's get the basic type.
isSlice, isStruct := false, false
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if isSlice = ft.Kind() == reflect.Slice; isSlice {
ft = ft.Elem()
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
}
if isStruct = ft.Kind() == reflect.Struct; !isStruct {
if conv := c.conv[ft]; conv == nil {
// Type is not supported.
return
}
}
info.fields = append(info.fields, &fieldInfo{
typ: field.Type,
name: field.Name,
ss: isSlice && isStruct,
alias: alias,
})
}
// ----------------------------------------------------------------------------
type structInfo struct {
fields []*fieldInfo
}
func (i *structInfo) get(alias string) *fieldInfo {
for _, field := range i.fields {
if strings.EqualFold(field.alias, alias) {
return field
}
}
return nil
}
type fieldInfo struct {
typ reflect.Type
name string // field name in the struct.
ss bool // true if this is a slice of structs.
alias string
}
type pathPart struct {
field *fieldInfo
path []string // path to the field: walks structs using field names.
index int // struct index in slices of structs.
}
// ----------------------------------------------------------------------------
// fieldAlias parses a field tag to get a field alias.
func fieldAlias(field reflect.StructField, tagName string) string {
var alias string
if tag := field.Tag.Get(tagName); tag != "" {
// For now tags only support the name but let's folow the
// comma convention from encoding/json and others.
if idx := strings.Index(tag, ","); idx == -1 {
alias = tag
} else {
alias = tag[:idx]
}
}
if alias == "" {
alias = field.Name
}
return alias
}

View file

@ -0,0 +1,145 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"reflect"
"strconv"
)
type Converter func(string) reflect.Value
var (
invalidValue = reflect.Value{}
boolType = reflect.TypeOf(false)
float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))
intType = reflect.TypeOf(int(0))
int8Type = reflect.TypeOf(int8(0))
int16Type = reflect.TypeOf(int16(0))
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))
stringType = reflect.TypeOf("")
uintType = reflect.TypeOf(uint(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
)
// Default converters for basic types.
var converters = map[reflect.Type]Converter{
boolType: convertBool,
float32Type: convertFloat32,
float64Type: convertFloat64,
intType: convertInt,
int8Type: convertInt8,
int16Type: convertInt16,
int32Type: convertInt32,
int64Type: convertInt64,
stringType: convertString,
uintType: convertUint,
uint8Type: convertUint8,
uint16Type: convertUint16,
uint32Type: convertUint32,
uint64Type: convertUint64,
}
func convertBool(value string) reflect.Value {
if value == "on" {
return reflect.ValueOf(true)
} else if v, err := strconv.ParseBool(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertFloat32(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 32); err == nil {
return reflect.ValueOf(float32(v))
}
return invalidValue
}
func convertFloat64(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertInt(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 0); err == nil {
return reflect.ValueOf(int(v))
}
return invalidValue
}
func convertInt8(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 8); err == nil {
return reflect.ValueOf(int8(v))
}
return invalidValue
}
func convertInt16(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 16); err == nil {
return reflect.ValueOf(int16(v))
}
return invalidValue
}
func convertInt32(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 32); err == nil {
return reflect.ValueOf(int32(v))
}
return invalidValue
}
func convertInt64(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}
func convertString(value string) reflect.Value {
return reflect.ValueOf(value)
}
func convertUint(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 0); err == nil {
return reflect.ValueOf(uint(v))
}
return invalidValue
}
func convertUint8(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 8); err == nil {
return reflect.ValueOf(uint8(v))
}
return invalidValue
}
func convertUint16(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 16); err == nil {
return reflect.ValueOf(uint16(v))
}
return invalidValue
}
func convertUint32(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 32); err == nil {
return reflect.ValueOf(uint32(v))
}
return invalidValue
}
func convertUint64(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 64); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

View file

@ -0,0 +1,226 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schema
import (
"errors"
"fmt"
"reflect"
)
// NewDecoder returns a new Decoder.
func NewDecoder() *Decoder {
return &Decoder{cache: newCache()}
}
// Decoder decodes values from a map[string][]string to a struct.
type Decoder struct {
cache *cache
zeroEmpty bool
ignoreUnknownKeys bool
}
// SetAliasTag changes the tag used to locate custome field aliases.
// The default tag is "schema".
func (d *Decoder) SetAliasTag(tag string) {
d.cache.tag = tag
}
// ZeroEmpty controls the behaviour when the decoder encounters empty values
// in a map.
// If z is true and a key in the map has the empty string as a value
// then the corresponding struct field is set to the zero value.
// If z is false then empty strings are ignored.
//
// The default value is false, that is empty values do not change
// the value of the struct field.
func (d *Decoder) ZeroEmpty(z bool) {
d.zeroEmpty = z
}
// IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
// keys in the map.
// If i is true and an unknown field is encountered, it is ignored. This is
// similar to how unknown keys are handled by encoding/json.
// If i is false then Decode will return an error. Note that any valid keys
// will still be decoded in to the target struct.
//
// To preserve backwards compatibility, the default value is false.
func (d *Decoder) IgnoreUnknownKeys(i bool) {
d.ignoreUnknownKeys = i
}
// RegisterConverter registers a converter function for a custom type.
func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
d.cache.conv[reflect.TypeOf(value)] = converterFunc
}
// Decode decodes a map[string][]string to a struct.
//
// The first parameter must be a pointer to a struct.
//
// The second parameter is a map, typically url.Values from an HTTP request.
// Keys are "paths" in dotted notation to the struct fields and nested structs.
//
// See the package documentation for a full explanation of the mechanics.
func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return errors.New("schema: interface must be a pointer to struct")
}
v = v.Elem()
t := v.Type()
errors := MultiError{}
for path, values := range src {
if parts, err := d.cache.parsePath(path, t); err == nil {
if err = d.decode(v, path, parts, values); err != nil {
errors[path] = err
}
} else if !d.ignoreUnknownKeys {
errors[path] = fmt.Errorf("schema: invalid path %q", path)
}
}
if len(errors) > 0 {
return errors
}
return nil
}
// decode fills a struct field using a parsed path.
func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart,
values []string) error {
// Get the field walking the struct fields by index.
for _, name := range parts[0].path {
if v.Type().Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
v = v.FieldByName(name)
}
// Don't even bother for unexported fields.
if !v.CanSet() {
return nil
}
// Dereference if needed.
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
if v.IsNil() {
v.Set(reflect.New(t))
}
v = v.Elem()
}
// Slice of structs. Let's go recursive.
if len(parts) > 1 {
idx := parts[0].index
if v.IsNil() || v.Len() < idx+1 {
value := reflect.MakeSlice(t, idx+1, idx+1)
if v.Len() < idx+1 {
// Resize it.
reflect.Copy(value, v)
}
v.Set(value)
}
return d.decode(v.Index(idx), path, parts[1:], values)
}
// Simple case.
if t.Kind() == reflect.Slice {
var items []reflect.Value
elemT := t.Elem()
isPtrElem := elemT.Kind() == reflect.Ptr
if isPtrElem {
elemT = elemT.Elem()
}
conv := d.cache.conv[elemT]
if conv == nil {
return fmt.Errorf("schema: converter not found for %v", elemT)
}
for key, value := range values {
if value == "" {
if d.zeroEmpty {
items = append(items, reflect.Zero(elemT))
}
} else if item := conv(value); item.IsValid() {
if isPtrElem {
ptr := reflect.New(elemT)
ptr.Elem().Set(item)
item = ptr
}
items = append(items, item)
} else {
// If a single value is invalid should we give up
// or set a zero value?
return ConversionError{path, key}
}
}
value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
v.Set(value)
} else {
val := ""
// Use the last value provided if any values were provided
if len(values) > 0 {
val = values[len(values)-1]
}
if val == "" {
if d.zeroEmpty {
v.Set(reflect.Zero(t))
}
} else if conv := d.cache.conv[t]; conv != nil {
if value := conv(val); value.IsValid() {
v.Set(value)
} else {
return ConversionError{path, -1}
}
} else {
return fmt.Errorf("schema: converter not found for %v", t)
}
}
return nil
}
// Errors ---------------------------------------------------------------------
// ConversionError stores information about a failed conversion.
type ConversionError struct {
Key string // key from the source map.
Index int // index for multi-value fields; -1 for single-value fields.
}
func (e ConversionError) Error() string {
if e.Index < 0 {
return fmt.Sprintf("schema: error converting value for %q", e.Key)
}
return fmt.Sprintf("schema: error converting value for index %d of %q",
e.Index, e.Key)
}
// MultiError stores multiple decoding errors.
//
// Borrowed from the App Engine SDK.
type MultiError map[string]error
func (e MultiError) Error() string {
s := ""
for _, err := range e {
s = err.Error()
break
}
switch len(e) {
case 0:
return "(0 errors)"
case 1:
return s
case 2:
return s + " (and 1 other error)"
}
return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
}

File diff suppressed because it is too large Load diff

123
Godeps/_workspace/src/github.com/gorilla/schema/doc.go generated vendored Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package gorilla/schema fills a struct with form values.
The basic usage is really simple. Given this struct:
type Person struct {
Name string
Phone string
}
...we can fill it passing a map to the Load() function:
values := map[string][]string{
"Name": {"John"},
"Phone": {"999-999-999"},
}
person := new(Person)
decoder := schema.NewDecoder()
decoder.Decode(person, values)
This is just a simple example and it doesn't make a lot of sense to create
the map manually. Typically it will come from a http.Request object and
will be of type url.Values: http.Request.Form or http.Request.MultipartForm:
func MyHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
// Handle error
}
decoder := schema.NewDecoder()
// r.PostForm is a map of our POST form values
err := decoder.Decode(person, r.PostForm)
if err != nil {
// Handle error
}
// Do something with person.Name or person.Phone
}
Note: it is a good idea to set a Decoder instance as a package global,
because it caches meta-data about structs, and a instance can be shared safely:
var decoder = schema.NewDecoder()
To define custom names for fields, use a struct tag "schema". To not populate
certain fields, use a dash for the name and it will be ignored:
type Person struct {
Name string `schema:"name"` // custom name
Phone string `schema:"phone"` // custom name
Admin bool `schema:"-"` // this field is never set
}
The supported field types in the destination struct are:
* bool
* float variants (float32, float64)
* int variants (int, int8, int16, int32, int64)
* string
* uint variants (uint, uint8, uint16, uint32, uint64)
* struct
* a pointer to one of the above types
* a slice or a pointer to a slice of one of the above types
Non-supported types are simply ignored, however custom types can be registered
to be converted.
To fill nested structs, keys must use a dotted notation as the "path" for the
field. So for example, to fill the struct Person below:
type Phone struct {
Label string
Number string
}
type Person struct {
Name string
Phone Phone
}
...the source map must have the keys "Name", "Phone.Label" and "Phone.Number".
This means that an HTML form to fill a Person struct must look like this:
<form>
<input type="text" name="Name">
<input type="text" name="Phone.Label">
<input type="text" name="Phone.Number">
</form>
Single values are filled using the first value for a key from the source map.
Slices are filled using all values for a key from the source map. So to fill
a Person with multiple Phone values, like:
type Person struct {
Name string
Phones []Phone
}
...an HTML form that accepts three Phone values would look like this:
<form>
<input type="text" name="Name">
<input type="text" name="Phones.0.Label">
<input type="text" name="Phones.0.Number">
<input type="text" name="Phones.1.Label">
<input type="text" name="Phones.1.Number">
<input type="text" name="Phones.2.Label">
<input type="text" name="Phones.2.Number">
</form>
Notice that only for slices of structs the slice index is required.
This is needed for disambiguation: if the nested struct also had a slice
field, we could not translate multiple values to it if we did not use an
index for the parent struct.
*/
package schema

View file

@ -0,0 +1,7 @@
_test
_testmain.go
_obj
*~
*.6
6.out
environ

22
Godeps/_workspace/src/github.com/jmoiron/modl/LICENSE generated vendored Normal file
View file

@ -0,0 +1,22 @@
(The MIT License)
Copyright (c) 2012 James Cooper <james@bitmechanic.com>
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

112
Godeps/_workspace/src/github.com/jmoiron/modl/README.md generated vendored Normal file
View file

@ -0,0 +1,112 @@
# Modl: Go database model & mapping
[![Build Status](https://drone.io/github.com/jmoiron/modl/status.png)](https://drone.io/github.com/jmoiron/modl/latest)
[![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/modl)
Modl is a library which provides database modelling and mapping. It is a fork of
James Cooper's wonderful [gorp](http://github.com/coopernurse/gorp).
**Note**. Modl's public facing interface is considered unfinished and open to
change. The current API will not be broken lightly, but additions are likely.
As Gorp's behavior moves on, Modl may adopt some of it or may not.
## Goals
Modl's goal is to clean up bits of gorp's API, add some additional features
like query building helpers and additional control over SQL generation, and
to reuse lower level abstractions provided in [sqlx](http://github.com/jmoiron/sqlx).
The driving philosophies behind modl are:
* If something can be done in `database/sql`, do it that way
* Go is not a good declarative language, do not abuse struct tags
* Expose as much as possible to facilitate extension by other libraries
* Avoid abstractions which provide initial convenience but must eventually be abandoned
* Avoid reflect and cache its results where possible, use where necessary
### Features
* Bind struct models to tables
* CRUD helpers for bound structs
* Create schema from database model (great for testing)
* Pre/post insert/update/delete hooks
* Automatic binding of auto increment PKs after insert
* Delete & Fetch by primary keys (w/ multi-key support)
* Sql trace logging
* Bind arbitrary SQL queries to a struct
* Optional optimistic locking using a version column (for update/deletes)
### Differences from Gorp
Since modl is a gorp fork, here are some of its major behavioral differences:
* Field names are mapped to all-lowercase sql equivalents. Capitalization is
an artifact of visibility, and this was required for compatibility with `sqlx`.
* No more struct/slice returns, pass pointers to methods instead
* Many panics in gorp are errors in modl
* TypeConverters are removed in favor of implementing `sql.Scanner` & `driver.Valuer`
for custom types.
## Testing
To use the `test-all` script, set the following environment variables:
```sh
# mysql DSN, note parseTime arg for time scanning support:
MODL_MYSQL_DSN="username:password@/dbname?parseTime=true"
# postgres DSN, like:
MODL_POSTGRES_DSN="user=username password=pw dbname=dbname sslmode=disable"
# sqlite DSN, which is a path
MODL_SQLITE_DSN="/dev/shm/modltest.db"
# optional, will fail the test if any DBs are skipped (for CI, mostly)
MODL_FAIL_ON_SKIP=true
```
In addition to this, you can create an `environ` file in this directory which
will be sourced and ignored by git. You can continue to use the `MODL_TEST_DSN`
and `MODL_TEST_DIALECT` variables if you want to manually run `go test` or if
you want to run the benchmarks, as described below.
The original README.md follows:
## API Stability
The API of Modl has been quite stable since its conception. Changes to the API
are avoided as much as possible but there is currently no promise of forward or
backward compatibility.
## Supported Databases
Modl relies heavily upon the `database/sql` package, and has a
[Dialect](dialect.go) interface which can be used to smooth over
differences between databases. There is a [list of sql drivers](http://code.google.com/p/go-wiki/wiki/SQLDrivers)
on the Go wiki, most of which Modl should be compatible with. Dialects
are provided for:
* MySQL
* PostgreSQL
* sqlite3
The test suite is continuously run against all of these databases.
## Documentation
API Documentation is available on [godoc](https://godoc.org/github.com/jmoiron/modl).
## Performance ##
Modl performs similar to Gorp. There are benchmarks in `modl_test.go` which
will benchmark native querying w/ `database/sql` and manual Scanning with what
modl does. Modl should perform between 2-3% slower than hand-done SQL.
## Contributors
The original contributors to gorp are:
* [@coopernurse](http://github.com/coopernurse) - James Cooper
* [@robfig](http://github.com/robfig) - Rob Figueiredo
* [@sqs](http://github.com/sqs) - Quinn Slack
* matthias-margush - column aliasing via tags

View file

@ -0,0 +1,17 @@
Current benchmarks:
PASS
BenchmarkNativeCrud 100 22605756 ns/op 5402 B/op 107 allocs/op
BenchmarkModlCrud 100 21107020 ns/op 7329 B/op 153 allocs/op
ok _/mnt/omocha/jmoiron/dev/go/modl 5.975s
PASS
BenchmarkNativeCrud 100 13153419 ns/op 8822 B/op 334 allocs/op
BenchmarkModlCrud 100 11676837 ns/op 10577 B/op 381 allocs/op
ok _/mnt/omocha/jmoiron/dev/go/modl 3.813s
PASS
BenchmarkNativeCrud 5000 281356 ns/op 1455 B/op 51 allocs/op
BenchmarkModlCrud 5000 313549 ns/op 2713 B/op 90 allocs/op
ok _/mnt/omocha/jmoiron/dev/go/modl 3.136s
(In order: MySQL, PostgreSQL, SQLite)

467
Godeps/_workspace/src/github.com/jmoiron/modl/dbmap.go generated vendored Normal file
View file

@ -0,0 +1,467 @@
// Package modl provides a non-declarative database modelling layer to ease
// the use of frequently repeated patterns in database-backed applications
// and centralize database use to ease profiling and reporting.
//
// It is a fork of the wonderful github.com/coopernurse/gorp package, but is
// rewritten to use github.com/jmoiron/sqlx as a base.
//
// Use of this source code is governed by a MIT-style license that can be
// found in the LICENSE file.
//
package modl
import (
"bytes"
"database/sql"
"fmt"
"log"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
)
// TableNameMapper is the function used by AddTable to map struct names to database table names, in analogy
// to sqlx.NameMapper which does the same for struct field name to database column names.
var TableNameMapper = strings.ToLower
// DbMap is the root modl mapping object. Create one of these for each
// database schema you wish to map. Each DbMap contains a list of
// mapped tables.
//
// Example:
//
// dialect := modl.MySQLDialect{"InnoDB", "UTF8"}
// dbmap := &modl.DbMap{Db: db, Dialect: dialect}
//
type DbMap struct {
// Db handle to use with this map
Db *sql.DB
Dbx *sqlx.DB
// Dialect implementation to use with this map
Dialect Dialect
tables []*TableMap
logger *log.Logger
logPrefix string
mapper *reflectx.Mapper
}
// NewDbMap returns a new DbMap using the db connection and dialect.
func NewDbMap(db *sql.DB, dialect Dialect) *DbMap {
return &DbMap{
Db: db,
Dialect: dialect,
Dbx: sqlx.NewDb(db, dialect.DriverName()),
mapper: reflectx.NewMapperFunc("db", sqlx.NameMapper),
}
}
// TraceOn turns on SQL statement logging for this DbMap. After this is
// called, all SQL statements will be sent to the logger. If prefix is
// a non-empty string, it will be written to the front of all logged
// strings, which can aid in filtering log lines.
//
// Use TraceOn if you want to spy on the SQL statements that modl
// generates.
func (m *DbMap) TraceOn(prefix string, logger *log.Logger) {
m.logger = logger
if len(prefix) == 0 {
m.logPrefix = prefix
} else {
m.logPrefix = prefix + " "
}
}
// TraceOff turns off tracing. It is idempotent.
func (m *DbMap) TraceOff() {
m.logger = nil
m.logPrefix = ""
}
// AddTable registers the given interface type with modl. The table name
// will be given the name of the TypeOf(i), lowercased.
//
// This operation is idempotent. If i's type is already mapped, the
// existing *TableMap is returned.
func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap {
Name := ""
if len(name) > 0 {
Name = name[0]
}
t := reflect.TypeOf(i)
// Use sqlx's NameMapper function if no name is supplied
if len(Name) == 0 {
Name = TableNameMapper(t.Name())
}
// check if we have a table for this type already
// if so, update the name and return the existing pointer
for i := range m.tables {
table := m.tables[i]
if table.gotype == t {
table.TableName = Name
return table
}
}
tmap := &TableMap{gotype: t, TableName: Name, dbmap: m, mapper: m.mapper}
tmap.setupHooks(i)
n := t.NumField()
tmap.Columns = make([]*ColumnMap, 0, n)
for i := 0; i < n; i++ {
f := t.Field(i)
columnName := f.Tag.Get("db")
if columnName == "" {
columnName = sqlx.NameMapper(f.Name)
}
cm := &ColumnMap{
ColumnName: columnName,
Transient: columnName == "-",
fieldName: f.Name,
gotype: f.Type,
table: tmap,
}
tmap.Columns = append(tmap.Columns, cm)
if cm.fieldName == "Version" {
tmap.version = tmap.Columns[len(tmap.Columns)-1]
}
}
m.tables = append(m.tables, tmap)
return tmap
}
// AddTableWithName adds a new mapping of the interface to a table name.
func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
return m.AddTable(i, name)
}
// CreateTablesSql returns create table SQL as a map of table names to
// their associated CREATE TABLE statements.
func (m *DbMap) CreateTablesSql() (map[string]string, error) {
return m.createTables(false, false)
}
// CreateTables iterates through TableMaps registered to this DbMap and
// executes "create table" statements against the database for each.
//
// This is particularly useful in unit tests where you want to create
// and destroy the schema automatically.
func (m *DbMap) CreateTables() error {
_, err := m.createTables(false, true)
return err
}
// CreateTablesIfNotExists is similar to CreateTables, but starts
// each statement with "create table if not exists" so that existing
// tables do not raise errors.
func (m *DbMap) CreateTablesIfNotExists() error {
_, err := m.createTables(true, true)
return err
}
func writeColumnSql(sql *bytes.Buffer, col *ColumnMap) {
if len(col.createSql) > 0 {
sql.WriteString(col.createSql)
return
}
sqltype := col.sqltype
if len(sqltype) == 0 {
sqltype = col.table.dbmap.Dialect.ToSqlType(col)
}
sql.WriteString(fmt.Sprintf("%s %s", col.table.dbmap.Dialect.QuoteField(col.ColumnName), sqltype))
if col.isPK {
sql.WriteString(" not null")
if len(col.table.Keys) == 1 {
sql.WriteString(" primary key")
}
}
if col.Unique {
sql.WriteString(" unique")
}
if col.isAutoIncr {
sql.WriteString(" " + col.table.dbmap.Dialect.AutoIncrStr())
}
}
func (m *DbMap) createTables(ifNotExists, exec bool) (map[string]string, error) {
var err error
ret := map[string]string{}
sep := ", "
prefix := ""
if !exec {
sep = ",\n"
prefix = " "
}
for i := range m.tables {
table := m.tables[i]
s := bytes.Buffer{}
s.WriteString("create table ")
if ifNotExists {
s.WriteString("if not exists ")
}
s.WriteString(m.Dialect.QuoteField(table.TableName))
s.WriteString(" (")
if !exec {
s.WriteString("\n")
}
x := 0
for _, col := range table.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(sep)
}
s.WriteString(prefix)
writeColumnSql(&s, col)
x++
}
}
if len(table.Keys) > 1 {
s.WriteString(", primary key (")
for x := range table.Keys {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(m.Dialect.QuoteField(table.Keys[x].ColumnName))
}
s.WriteString(")")
}
s.WriteString(fmt.Sprintf(")%s;", m.Dialect.CreateTableSuffix()))
if exec {
_, err = m.Exec(s.String())
if err != nil {
break
}
} else {
ret[table.TableName] = s.String()
}
}
return ret, err
}
// DropTables iterates through TableMaps registered to this DbMap and
// executes "drop table" statements against the database for each.
func (m *DbMap) DropTables() error {
var err error
for i := range m.tables {
table := m.tables[i]
_, e := m.Exec(fmt.Sprintf("drop table %s;", m.Dialect.QuoteField(table.TableName)))
if e != nil {
err = e
}
}
return err
}
// Insert runs a SQL INSERT statement for each element in list. List
// items must be pointers, because any interface whose TableMap has an
// auto-increment PK will have its insert Id bound to the PK struct field.
//
// Hook functions PreInsert() and/or PostInsert() will be executed
// before/after the INSERT statement if the interface defines them.
func (m *DbMap) Insert(list ...interface{}) error {
return insert(m, m, list...)
}
// Update runs a SQL UPDATE statement for each element in list. List
// items must be pointers.
//
// Hook functions PreUpdate() and/or PostUpdate() will be executed
// before/after the UPDATE statement if the interface defines them.
//
// Returns number of rows updated.
//
// Returns an error if SetKeys has not been called on the TableMap or if
// any interface in the list has not been registered with AddTable.
func (m *DbMap) Update(list ...interface{}) (int64, error) {
return update(m, m, list...)
}
// Delete runs a SQL DELETE statement for each element in list. List
// items must be pointers.
//
// Hook functions PreDelete() and/or PostDelete() will be executed
// before/after the DELETE statement if the interface defines them.
//
// Returns number of rows deleted.
//
// Returns an error if SetKeys has not been called on the TableMap or if
// any interface in the list has not been registered with AddTable.
func (m *DbMap) Delete(list ...interface{}) (int64, error) {
return deletes(m, m, list...)
}
// Get runs a SQL SELECT to fetch a single row from the table based on the
// primary key(s)
//
// dest should be an empty value for the struct to load.
// keys should be the primary key value(s) for the row to load. If
// multiple keys exist on the table, the order should match the column
// order specified in SetKeys() when the table mapping was defined.
//
// Hook function PostGet() will be executed
// after the SELECT statement if the interface defines it.
//
// Returns a pointer to a struct that matches or nil if no row is found.
//
// Returns an error if SetKeys has not been called on the TableMap or
// if any interface in the list has not been registered with AddTable.
func (m *DbMap) Get(dest interface{}, keys ...interface{}) error {
return get(m, m, dest, keys...)
}
// Select runs an arbitrary SQL query, binding the columns in the result
// to fields on the struct specified by dest. args represent the bind
// parameters for the SQL statement.
//
// Column names on the SELECT statement should be aliased to the field names
// on the struct dest. Returns an error if one or more columns in the result
// do not match. It is OK if fields on i are not part of the SQL
// statement.
//
// Hook function PostGet() will be executed
// after the SELECT statement if the interface defines it.
//
// Values are returned in one of two ways:
//
// 1. If dest is a struct or a pointer to a struct, returns a slice of pointers to
// matching rows of type dest.
//
// 2. If dest is a pointer to a slice, the results will be appended to that slice
// and nil returned.
//
// dest does NOT need to be registered with AddTable().
func (m *DbMap) Select(dest interface{}, query string, args ...interface{}) error {
return hookedselect(m, m, dest, query, args...)
}
// SelectOne runs an arbitrary SQL Query, binding the columns in the result to
// fields on the struct specified by dest.
func (m *DbMap) SelectOne(dest interface{}, query string, args ...interface{}) error {
return hookedget(m, m, dest, query, args...)
}
// Exec runs an arbitrary SQL statement. args represent the bind parameters.
// This is equivalent to running Exec() using database/sql.
func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
m.trace(query, args)
//stmt, err := m.Db.Prepare(query)
//if err != nil {
// return nil, err
//}
//fmt.Println("Exec", query, args)
return m.Db.Exec(query, args...)
}
// Begin starts a modl Transaction.
func (m *DbMap) Begin() (*Transaction, error) {
m.trace("begin;")
tx, err := m.Dbx.Beginx()
if err != nil {
return nil, err
}
return &Transaction{m, tx}, nil
}
// FIXME: This is a poor interface. Checking for nils is un-go-like, and this
// function should be TableFor(i interface{}) (*TableMap, error)
// FIXME: rewrite this in terms of sqlx's reflect helpers
// TableFor returns any matching tables for the interface i or nil if not found.
// If i is a slice, then the table is given for the base slice type.
func (m *DbMap) TableFor(i interface{}) *TableMap {
var t reflect.Type
v := reflect.ValueOf(i)
start:
switch v.Kind() {
case reflect.Ptr:
// dereference pointer and try again; we never want to store pointer
// types anywhere, that way we always know how to do lookups
v = v.Elem()
goto start
case reflect.Slice:
// if this is a slice of X's, we're interested in the type of X
t = v.Type().Elem()
default:
t = v.Type()
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return m.TableForType(t)
}
// FIXME: returning a nil pointer is not go-like; return (*TableMap, err) instead.
// TableForType returns any matching tables for the type t or nil if not found.
func (m *DbMap) TableForType(t reflect.Type) *TableMap {
for _, table := range m.tables {
if table.gotype == t {
return table
}
}
return nil
}
// TruncateTables truncates all tables in the DbMap.
func (m *DbMap) TruncateTables() error {
return m.truncateTables(false)
}
// TruncateTablesIdentityRestart truncates all tables in the DbMap and
// resets the identity counter.
func (m *DbMap) TruncateTablesIdentityRestart() error {
return m.truncateTables(true)
}
func (m *DbMap) truncateTables(restartIdentity bool) error {
var err error
var restartClause string
for i := range m.tables {
table := m.tables[i]
if restartIdentity {
restartClause = m.Dialect.RestartIdentityClause(table.TableName)
}
// if the restart clause exists and starts with ';', then assume it's an
// additional query to run after we truncate. This is true with MySQL and
// SQLite, which do not have extra clauses for this during table truncation.
if len(restartClause) > 0 && restartClause[0] == ';' {
_, err = m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(),
m.Dialect.QuoteField(table.TableName)))
if err != nil {
return err
}
_, err = m.Exec(restartClause[1:])
if err != nil {
return err
}
} else {
_, err := m.Exec(fmt.Sprintf("%s %s %s;", m.Dialect.TruncateClause(), m.Dialect.QuoteField(table.TableName), restartClause))
if err != nil {
return err
}
}
}
return nil
}
func (m *DbMap) handle() handle {
return &tracingHandle{h: m.Dbx, d: m}
}
func (m *DbMap) trace(query string, args ...interface{}) {
if m.logger != nil {
m.logger.Printf("%s%s %v", m.logPrefix, query, args)
}
}

View file

@ -0,0 +1,420 @@
package modl
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
)
// Dialect is an interface that encapsulates behaviors that differ across
// SQL databases.
type Dialect interface {
// ToSqlType returns the SQL column type to use when creating a table of the
// given Go Type. maxsize can be used to switch based on size. For example,
// in MySQL []byte could map to BLOB, MEDIUMBLOB, or LONGBLOB depending on the
// maxsize
ToSqlType(col *ColumnMap) string
// AutoIncrStr returns a string to append to primary key column definitions
AutoIncrStr() string
//AutoIncrBindValue returns the default value for an auto increment row in an insert.
AutoIncrBindValue() string
// AutoIncrInsertSuffix returns a string to append to an insert statement which
// will make it return the auto-increment key on insert.
AutoIncrInsertSuffix(col *ColumnMap) string
// CreateTableSuffix returns a string to append to "create table" statement for
// vendor specific table attributes, eg. MySQL engine
CreateTableSuffix() string
// InsertAutoIncr
InsertAutoIncr(e SqlExecutor, insertSql string, params ...interface{}) (int64, error)
// InsertAutIncrAny takes a destination for non-integer auto-incr, like
// uuids which scan to strings, hashes, etc.
InsertAutoIncrAny(e SqlExecutor, insertSql string, dest interface{}, params ...interface{}) error
// BindVar returns the variable string to use when forming SQL statements
// in many dbs it is "?", but Postgres requires '$#'
//
// The supplied int is a zero based index of the bindvar in the statement
BindVar(i int) string
// QuoteField returns a quoted version of the field name.
QuoteField(field string) string
// TruncateClause is a string used to truncate tables.
TruncateClause() string
// RestartIdentityClause returns a string used to reset the identity counter
// when truncating tables. If the string starts with a ';', it is assumed to
// be a separate query and is executed separately.
RestartIdentityClause(table string) string
// DriverName returns the driver name for a dialect.
DriverName() string
}
func standardInsertAutoIncr(e SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
res, err := e.handle().Exec(insertSql, params...)
if err != nil {
return 0, err
}
return res.LastInsertId()
}
func standardAutoIncrAny(e SqlExecutor, insertSql string, dest interface{}, params ...interface{}) error {
rows, err := e.handle().Queryx(insertSql, params...)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return rows.Scan(dest)
}
return fmt.Errorf("No auto-incr value returned for insert: `%s` error: %s", insertSql, rows.Err())
}
// -- sqlite3
// SqliteDialect implements the Dialect interface for Sqlite3.
type SqliteDialect struct {
suffix string
}
// DriverName returns the driverName for sqlite.
func (d SqliteDialect) DriverName() string {
return "sqlite"
}
// ToSqlType maps go types to sqlite types.
func (d SqliteDialect) ToSqlType(col *ColumnMap) string {
switch col.gotype.Kind() {
case reflect.Bool:
return "integer"
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "integer"
case reflect.Float64, reflect.Float32:
return "real"
case reflect.Slice:
if col.gotype.Elem().Kind() == reflect.Uint8 {
return "blob"
}
}
switch col.gotype.Name() {
case "NullableInt64":
return "integer"
case "NullableFloat64":
return "real"
case "NullableBool":
return "integer"
case "NullableBytes":
return "blob"
case "Time", "NullTime":
return "datetime"
}
// sqlite ignores maxsize, so we will do that here too
return fmt.Sprintf("text")
}
// AutoIncrStr returns autoincrement clause for sqlite.
func (d SqliteDialect) AutoIncrStr() string {
return "autoincrement"
}
// AutoIncrBindValue returns the bind value for auto incr fields in sqlite.
func (d SqliteDialect) AutoIncrBindValue() string {
return "null"
}
// AutoIncrInsertSuffix returns the auto incr insert suffix for sqlite.
func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// CreateTableSuffix returns d.suffix
func (d SqliteDialect) CreateTableSuffix() string {
return d.suffix
}
// BindVar returns "?", the simpler of the sqlite bindvars.
func (d SqliteDialect) BindVar(i int) string {
return "?"
}
// InsertAutoIncr runs the standard
func (d SqliteDialect) InsertAutoIncr(e SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(e, insertSql, params...)
}
func (d SqliteDialect) InsertAutoIncrAny(e SqlExecutor, insertSql string, dest interface{}, params ...interface{}) error {
return standardAutoIncrAny(e, insertSql, dest, params...)
}
// QuoteField quotes f with "" for sqlite
func (d SqliteDialect) QuoteField(f string) string {
return `"` + f + `"`
}
// TruncateClause returns the truncate clause for sqlite. There is no TRUNCATE
// statement in sqlite3, but DELETE FROM uses a truncate optimization:
// http://www.sqlite.org/lang_delete.html
func (d SqliteDialect) TruncateClause() string {
return "delete from"
}
// RestartIdentityClause restarts the sqlite_sequence for the provided table.
// It is executed by TruncateTable as a separate query.
func (d SqliteDialect) RestartIdentityClause(table string) string {
return "; DELETE FROM sqlite_sequence WHERE name='" + table + "'"
}
// -- PostgreSQL
// PostgresDialect implements the Dialect interface for PostgreSQL.
type PostgresDialect struct {
suffix string
}
// DriverName returns the driverName for postgres.
func (d PostgresDialect) DriverName() string {
return "postgres"
}
// ToSqlType maps go types to postgres types.
func (d PostgresDialect) ToSqlType(col *ColumnMap) string {
switch col.gotype.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Uint16, reflect.Uint32:
if col.isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if col.isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64, reflect.Float32:
return "real"
case reflect.Slice:
if col.gotype.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}
switch col.gotype.Name() {
case "NullableInt64":
return "bigint"
case "NullableFloat64":
return "double"
case "NullableBool":
return "smallint"
case "NullableBytes":
return "bytea"
case "Time", "Nulltime":
return "timestamp with time zone"
}
maxsize := col.MaxSize
if col.MaxSize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// AutoIncrStr returns empty string, as it's not used in postgres.
func (d PostgresDialect) AutoIncrStr() string {
return ""
}
// AutoIncrBindValue returns 'default' for default auto incr bind values.
func (d PostgresDialect) AutoIncrBindValue() string {
return "default"
}
// AutoIncrInsertSuffix appnds 'returning colname' to a query so that the
// new autoincrement value for the column name is returned by Insert.
func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}
// CreateTableSuffix returns the configured suffix.
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
}
// BindVar returns "$(i+1)"
func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}
// InsertAutoIncr inserts via a query and reads the resultant rows for the new
// auto increment ID, as it's not returned with the result in PostgreSQL.
func (d PostgresDialect) InsertAutoIncr(e SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
rows, err := e.handle().Queryx(insertSql, params...)
if err != nil {
return 0, err
}
defer rows.Close()
if rows.Next() {
var id int64
err := rows.Scan(&id)
return id, err
}
return 0, errors.New("No serial value returned for insert: " + insertSql + ", error: " + rows.Err().Error())
}
func (d PostgresDialect) InsertAutoIncrAny(e SqlExecutor, insertSql string, dest interface{}, params ...interface{}) error {
return standardAutoIncrAny(e, insertSql, dest, params...)
}
// QuoteField quotes f with ""
func (d PostgresDialect) QuoteField(f string) string {
return `"` + sqlx.NameMapper(f) + `"`
}
// TruncateClause returns 'truncate'
func (d PostgresDialect) TruncateClause() string {
return "truncate"
}
// RestartIdentityClause returns 'restart identity', which will restart serial
// sequences for this table at the same time as a truncation is performed.
func (d PostgresDialect) RestartIdentityClause(table string) string {
return "restart identity"
}
// -- MySQL
// MySQLDialect is an implementation of Dialect for MySQL databases.
type MySQLDialect struct {
// Engine is the storage engine to use "InnoDB" vs "MyISAM" for example
Engine string
// Encoding is the character encoding to use for created tables
Encoding string
}
// DriverName returns "mysql", used by the ziutek and the go-sql-driver
// versions of the MySQL driver.
func (d MySQLDialect) DriverName() string {
return "mysql"
}
// ToSqlType maps go types to MySQL types.
func (d MySQLDialect) ToSqlType(col *ColumnMap) string {
switch col.gotype.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Uint16, reflect.Uint32:
return "int"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float64, reflect.Float32:
return "double"
case reflect.Slice:
if col.gotype.Elem().Kind() == reflect.Uint8 {
return "mediumblob"
}
}
switch col.gotype.Name() {
case "NullableInt64":
return "bigint"
case "NullableFloat64":
return "double"
case "NullableBool":
return "tinyint"
case "NullableBytes":
return "mediumblob"
case "Time", "NullTime":
return "datetime"
}
maxsize := col.MaxSize
if col.MaxSize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// AutoIncrStr returns "auto_increment".
func (d MySQLDialect) AutoIncrStr() string {
return "auto_increment"
}
// AutoIncrBindValue returns "null", default for auto increment fields in MySQL.
func (d MySQLDialect) AutoIncrBindValue() string {
return "null"
}
// AutoIncrInsertSuffix returns "".
func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// CreateTableSuffix returns engine=%s charset=%s based on values stored on struct
func (d MySQLDialect) CreateTableSuffix() string {
return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
}
// BindVar returns "?"
func (d MySQLDialect) BindVar(i int) string {
return "?"
}
// InsertAutoIncr runs the standard Insert Exec, which uses LastInsertId to get
// the value of the auto increment column.
func (d MySQLDialect) InsertAutoIncr(e SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(e, insertSql, params...)
}
func (d MySQLDialect) InsertAutoIncrAny(e SqlExecutor, insertSql string, dest interface{}, params ...interface{}) error {
return standardAutoIncrAny(e, insertSql, dest, params...)
}
// QuoteField quotes f using ``.
func (d MySQLDialect) QuoteField(f string) string {
return "`" + f + "`"
}
// FIXME: use sqlx's rebind, which was written after it had been created for modl
// ReBind formats the bindvars in the query string (these are '?') for the dialect.
func ReBind(query string, dialect Dialect) string {
binder := dialect.BindVar(0)
if binder == "?" {
return query
}
for i, j := 0, strings.Index(query, "?"); j >= 0; i++ {
query = strings.Replace(query, "?", dialect.BindVar(i), 1)
j = strings.Index(query, "?")
}
return query
}
// TruncateClause returns 'truncate'.
func (d MySQLDialect) TruncateClause() string {
return "truncate"
}
// RestartIdentityClause alters the table's AUTO_INCREMENT value after truncation,
// as MySQL doesn't have an identity clause for the truncate statement.
func (d MySQLDialect) RestartIdentityClause(table string) string {
return "; alter table " + table + " AUTO_INCREMENT = 1"
}

View file

@ -0,0 +1,51 @@
package modl
import (
"database/sql"
"github.com/jmoiron/sqlx"
)
// a cursor is either an sqlx.Db or an sqlx.Tx
type handle interface {
Select(dest interface{}, query string, args ...interface{}) error
Get(dest interface{}, query string, args ...interface{}) error
Queryx(query string, args ...interface{}) (*sqlx.Rows, error)
QueryRowx(query string, args ...interface{}) *sqlx.Row
Exec(query string, args ...interface{}) (sql.Result, error)
/*
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
*/
}
// an implmentation of handle which traces using dbmap
type tracingHandle struct {
d *DbMap
h handle
}
func (t *tracingHandle) Select(dest interface{}, query string, args ...interface{}) error {
t.d.trace(query, args...)
return t.h.Select(dest, query, args...)
}
func (t *tracingHandle) Get(dest interface{}, query string, args ...interface{}) error {
t.d.trace(query, args...)
return t.h.Get(dest, query, args...)
}
func (t *tracingHandle) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) {
t.d.trace(query, args...)
return t.h.Queryx(query, args...)
}
func (t *tracingHandle) QueryRowx(query string, args ...interface{}) *sqlx.Row {
t.d.trace(query, args...)
return t.h.QueryRowx(query, args...)
}
func (t *tracingHandle) Exec(query string, args ...interface{}) (sql.Result, error) {
t.d.trace(query, args...)
return t.h.Exec(query, args...)
}

66
Godeps/_workspace/src/github.com/jmoiron/modl/hooks.go generated vendored Normal file
View file

@ -0,0 +1,66 @@
package modl
import (
"reflect"
)
// PreInserter is an interface used to determine if a table type implements
// a PreInsert hook
type PreInserter interface {
PreInsert(SqlExecutor) error
}
// PostInserter is an interface used to determine if a table type implements
// a PostInsert hook
type PostInserter interface {
PostInsert(SqlExecutor) error
}
// PostGetter is an interface used to determine if a table type implements
// a PostGet hook
type PostGetter interface {
PostGet(SqlExecutor) error
}
// PreUpdater is an interface used to determine if a table type implements
// a PreUpdate hook
type PreUpdater interface {
PreUpdate(SqlExecutor) error
}
// PostUpdater is an interface used to determine if a table type implements
// a PostUpdate hook
type PostUpdater interface {
PostUpdate(SqlExecutor) error
}
// PreDeleter is an interface used to determine if a table type implements
// a PreDelete hook
type PreDeleter interface {
PreDelete(SqlExecutor) error
}
// PostDeleter is an interface used to determine if a table type implements
// a PostDelete hook
type PostDeleter interface {
PostDelete(SqlExecutor) error
}
// Determine which hooks are supported by the mapper struct i
func (t *TableMap) setupHooks(i interface{}) {
// These hooks must be implemented on a pointer, so if a value is passed in
// we have to get a pointer for a new value of that type in order for the
// type assertions to pass.
ptr := i
if reflect.ValueOf(i).Kind() == reflect.Struct {
ptr = reflect.New(reflect.ValueOf(i).Type()).Interface()
}
_, t.CanPreInsert = ptr.(PreInserter)
_, t.CanPostInsert = ptr.(PostInserter)
_, t.CanPostGet = ptr.(PostGetter)
_, t.CanPreUpdate = ptr.(PreUpdater)
_, t.CanPostUpdate = ptr.(PostUpdater)
_, t.CanPreDelete = ptr.(PreDeleter)
_, t.CanPostDelete = ptr.(PostDeleter)
}

375
Godeps/_workspace/src/github.com/jmoiron/modl/modl.go generated vendored Normal file
View file

@ -0,0 +1,375 @@
package modl
// Changes Copyright 2013 Jason Moiron. Original Gorp code
// Copyright 2012 James Cooper. All rights reserved.
//
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
//
// Source code and project home:
// https://github.com/jmoiron/modl
import (
"database/sql"
"fmt"
"reflect"
)
// NoKeysErr is a special error type returned when modl's CRUD helpers are
// used on tables which have not been set up with a primary key.
type NoKeysErr struct {
Table *TableMap
}
// Error returns the string representation of a NoKeysError.
func (n NoKeysErr) Error() string {
return fmt.Sprintf("Could not find keys for table %v", n.Table)
}
const versFieldConst = "[modl_ver_field]"
// OptimisticLockError is returned by Update() or Delete() if the
// struct being modified has a Version field and the value is not equal to
// the current value in the database
type OptimisticLockError struct {
// Table name where the lock error occurred
TableName string
// Primary key values of the row being updated/deleted
Keys []interface{}
// true if a row was found with those keys, indicating the
// LocalVersion is stale. false if no value was found with those
// keys, suggesting the row has been deleted since loaded, or
// was never inserted to begin with
RowExists bool
// Version value on the struct passed to Update/Delete. This value is
// out of sync with the database.
LocalVersion int64
}
// Error returns a description of the cause of the lock error
func (e OptimisticLockError) Error() string {
if e.RowExists {
return fmt.Sprintf("OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion)
}
return fmt.Sprintf("OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys)
}
// A bindPlan saves a query type (insert, get, updated, delete) so it doesn't
// have to be re-created every time it's executed.
type bindPlan struct {
query string
argFields []string
keyFields []string
versField string
autoIncrIdx int
}
func (plan bindPlan) createBindInstance(elem reflect.Value) bindInstance {
bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, versField: plan.versField}
if plan.versField != "" {
bi.existingVersion = elem.FieldByName(plan.versField).Int()
}
for i := 0; i < len(plan.argFields); i++ {
k := plan.argFields[i]
if k == versFieldConst {
newVer := bi.existingVersion + 1
bi.args = append(bi.args, newVer)
if bi.existingVersion == 0 {
elem.FieldByName(plan.versField).SetInt(int64(newVer))
}
} else {
val := elem.FieldByName(k).Interface()
bi.args = append(bi.args, val)
}
}
for i := 0; i < len(plan.keyFields); i++ {
k := plan.keyFields[i]
val := elem.FieldByName(k).Interface()
bi.keys = append(bi.keys, val)
}
return bi
}
type bindInstance struct {
query string
args []interface{}
keys []interface{}
existingVersion int64
versField string
autoIncrIdx int
}
// SqlExecutor exposes modl operations that can be run from Pre/Post
// hooks. This hides whether the current operation that triggered the
// hook is in a transaction.
//
// See the DbMap function docs for each of the functions below for more
// information.
type SqlExecutor interface {
Get(dest interface{}, keys ...interface{}) error
Insert(list ...interface{}) error
Update(list ...interface{}) (int64, error)
Delete(list ...interface{}) (int64, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Select(dest interface{}, query string, args ...interface{}) error
SelectOne(dest interface{}, query string, args ...interface{}) error
handle() handle
}
// Compile-time check that DbMap and Transaction implement the SqlExecutor
// interface.
var (
_ SqlExecutor = &DbMap{}
_ SqlExecutor = &Transaction{}
)
///////////////
func hookedget(m *DbMap, e SqlExecutor, dest interface{}, query string, args ...interface{}) error {
err := e.handle().Get(dest, query, args...)
if err != nil {
return err
}
table := m.TableFor(dest)
if table != nil && table.CanPostGet {
err = dest.(PostGetter).PostGet(e)
if err != nil {
return err
}
}
return nil
}
func hookedselect(m *DbMap, e SqlExecutor, dest interface{}, query string, args ...interface{}) error {
err := e.handle().Select(dest, query, args...)
if err != nil {
return err
}
// select can use arbitrary structs for join queries, so we needn't find a table
table := m.TableFor(dest)
if table != nil && table.CanPostGet {
var x interface{}
v := reflect.ValueOf(dest)
if v.Kind() == reflect.Ptr {
v = reflect.Indirect(v)
}
l := v.Len()
for i := 0; i < l; i++ {
x = v.Index(i).Interface()
err = x.(PostGetter).PostGet(e)
if err != nil {
return err
}
}
}
return nil
}
func get(m *DbMap, e SqlExecutor, dest interface{}, keys ...interface{}) error {
table := m.TableFor(dest)
if table == nil {
return fmt.Errorf("could not find table for %v", dest)
}
if len(table.Keys) < 1 {
return &NoKeysErr{table}
}
plan := table.bindGet()
err := e.handle().Get(dest, plan.query, keys...)
if err != nil {
return err
}
if table.CanPostGet {
err = dest.(PostGetter).PostGet(e)
if err != nil {
return err
}
}
return nil
}
func deletes(m *DbMap, e SqlExecutor, list ...interface{}) (int64, error) {
var err error
var table *TableMap
var elem reflect.Value
var count int64
for _, ptr := range list {
table, elem, err = tableForPointer(m, ptr, true)
if err != nil {
return -1, err
}
if table.CanPreDelete {
err = ptr.(PreDeleter).PreDelete(e)
if err != nil {
return -1, err
}
}
bi := table.bindDelete(elem)
res, err := e.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, e, table.TableName, bi.existingVersion, elem, bi.keys...)
}
count += rows
if table.CanPostDelete {
err = ptr.(PostDeleter).PostDelete(e)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func update(m *DbMap, e SqlExecutor, list ...interface{}) (int64, error) {
var err error
var table *TableMap
var elem reflect.Value
var count int64
for _, ptr := range list {
table, elem, err = tableForPointer(m, ptr, true)
if err != nil {
return -1, err
}
if table.CanPreUpdate {
err = ptr.(PreUpdater).PreUpdate(e)
if err != nil {
return -1, err
}
}
bi := table.bindUpdate(elem)
if err != nil {
return -1, err
}
res, err := e.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, e, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
if bi.versField != "" {
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
}
count += rows
if table.CanPostUpdate {
err = ptr.(PostUpdater).PostUpdate(e)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func insert(m *DbMap, e SqlExecutor, list ...interface{}) error {
var err error
var table *TableMap
var elem reflect.Value
for _, ptr := range list {
table, elem, err = tableForPointer(m, ptr, false)
if err != nil {
return err
}
if table.CanPreInsert {
err = ptr.(PreInserter).PreInsert(e)
if err != nil {
return err
}
}
bi := table.bindInsert(elem)
if bi.autoIncrIdx > -1 {
id, err := m.Dialect.InsertAutoIncr(e, bi.query, bi.args...)
if err != nil {
return err
}
f := elem.Field(bi.autoIncrIdx)
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else {
return fmt.Errorf("modl: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d", bi.query, bi.autoIncrIdx)
}
} else {
_, err := e.Exec(bi.query, bi.args...)
if err != nil {
return err
}
}
if table.CanPostInsert {
err = ptr.(PostInserter).PostInsert(e)
if err != nil {
return err
}
}
}
return nil
}
func lockError(m *DbMap, e SqlExecutor, tableName string, existingVer int64, elem reflect.Value, keys ...interface{}) (int64, error) {
dest := reflect.New(elem.Type()).Interface()
err := get(m, e, dest, keys...)
if err != nil {
return -1, err
}
ole := OptimisticLockError{tableName, keys, true, existingVer}
if dest == nil {
ole.RowExists = false
}
return -1, ole
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,375 @@
package modl
import (
"bytes"
"fmt"
"reflect"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
)
// TableMap represents a mapping between a Go struct and a database table
// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
type TableMap struct {
// Name of database table.
TableName string
Keys []*ColumnMap
Columns []*ColumnMap
gotype reflect.Type
version *ColumnMap
insertPlan bindPlan
updatePlan bindPlan
deletePlan bindPlan
getPlan bindPlan
dbmap *DbMap
mapper *reflectx.Mapper
// Cached capabilities for the struct mapped to this table
CanPreInsert bool
CanPostInsert bool
CanPostGet bool
CanPreUpdate bool
CanPostUpdate bool
CanPreDelete bool
CanPostDelete bool
}
// ResetSql removes cached insert/update/select/delete SQL strings
// associated with this TableMap. Call this if you've modified
// any column names or the table name itself.
func (t *TableMap) ResetSql() {
t.insertPlan = bindPlan{}
t.updatePlan = bindPlan{}
t.deletePlan = bindPlan{}
t.getPlan = bindPlan{}
}
// SetKeys lets you specify the fields on a struct that map to primary
// key columns on the table. If isAutoIncr is set, result.LastInsertId()
// will be used after INSERT to bind the generated id to the Go struct.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
t.Keys = make([]*ColumnMap, 0)
for _, name := range fieldNames {
// FIXME: sqlx.NameMapper is a deprecated API. modl should have its
// own API which sets sqlx's mapping funcs as necessary
colmap := t.ColMap(sqlx.NameMapper(name))
colmap.isPK = true
colmap.isAutoIncr = isAutoIncr
t.Keys = append(t.Keys, colmap)
}
t.ResetSql()
return t
}
// ColMap returns the ColumnMap pointer matching the given struct field
// name. It panics if the struct does not contain a field matching this
// name.
func (t *TableMap) ColMap(field string) *ColumnMap {
for _, col := range t.Columns {
if col.fieldName == field || col.ColumnName == field {
return col
}
}
panic(fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
t.TableName, t.gotype.Name(), field))
}
// SetVersionCol sets the column to use as the Version field. By default
// the "Version" field is used. Returns the column found, or panics
// if the struct does not contain a field matching this name.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
func (t *TableMap) SetVersionCol(field string) *ColumnMap {
c := t.ColMap(field)
t.version = c
t.ResetSql()
return c
}
func (t *TableMap) bindGet() bindPlan {
plan := t.getPlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString("select ")
x := 0
for _, col := range t.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
plan.argFields = append(plan.argFields, col.fieldName)
x++
}
}
s.WriteString(" from ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.TableName))
s.WriteString(" where ")
for x := range t.Keys {
col := t.Keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, col.fieldName)
}
s.WriteString(";")
plan.query = s.String()
t.getPlan = plan
}
return plan
}
func (t *TableMap) bindDelete(elem reflect.Value) bindInstance {
plan := t.deletePlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuoteField(t.TableName)))
for y := range t.Columns {
col := t.Columns[y]
if !col.Transient {
if col == t.version {
plan.versField = col.fieldName
}
}
}
s.WriteString(" where ")
for x := range t.Keys {
k := t.Keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, k.fieldName)
plan.argFields = append(plan.argFields, k.fieldName)
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields)))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(";")
plan.query = s.String()
t.deletePlan = plan
}
return plan.createBindInstance(elem)
}
func (t *TableMap) bindUpdate(elem reflect.Value) bindInstance {
plan := t.updatePlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuoteField(t.TableName)))
x := 0
for y := range t.Columns {
col := t.Columns[y]
if !col.isPK && !col.Transient {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
}
}
s.WriteString(" where ")
for y := range t.Keys {
col := t.Keys[y]
if y > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, col.fieldName)
plan.keyFields = append(plan.keyFields, col.fieldName)
x++
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(";")
plan.query = s.String()
t.updatePlan = plan
}
return plan.createBindInstance(elem)
}
func (t *TableMap) bindInsert(elem reflect.Value) bindInstance {
plan := t.insertPlan
if plan.query == "" {
plan.autoIncrIdx = -1
s := bytes.Buffer{}
s2 := bytes.Buffer{}
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuoteField(t.TableName)))
x := 0
first := true
for y := range t.Columns {
col := t.Columns[y]
if !col.Transient {
if !first {
s.WriteString(",")
s2.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
if col.isAutoIncr {
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
plan.autoIncrIdx = y
} else {
s2.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
}
first = false
}
}
s.WriteString(") values (")
s.WriteString(s2.String())
s.WriteString(")")
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
}
s.WriteString(";")
plan.query = s.String()
t.insertPlan = plan
}
return plan.createBindInstance(elem)
}
// ColumnMap represents a mapping between a Go struct field and a single
// column in a table.
// Unique and MaxSize only inform the CreateTables() function and are not
// used for validation by Insert/Update/Delete/Get.
type ColumnMap struct {
// Column name in db table
ColumnName string
// If true, this column is skipped in generated SQL statements
Transient bool
// If true, " unique" is added to create table statements.
Unique bool
// Passed to Dialect.ToSqlType() to assist in informing the
// correct column type to map to in CreateTables()
MaxSize int
// the table this column belongs to
table *TableMap
fieldName string
gotype reflect.Type
sqltype string
createSql string
isPK bool
isAutoIncr bool
}
// SetTransient allows you to mark the column as transient. If true
// this column will be skipped when SQL statements are generated
func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
c.Transient = b
return c
}
// SetUnique sets the unqiue clause for this column. If true, a unique clause
// will be added to create table statements for this column.
func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
c.Unique = b
return c
}
// SetSqlCreate overrides the default create statement used when this column
// is created by CreateTable. This will override all other options (like
// SetMaxSize, SetSqlType, etc). To unset, call with the empty string.
func (c *ColumnMap) SetSqlCreate(s string) *ColumnMap {
c.createSql = s
return c
}
// SetSqlType sets an override for the column's sql type. This is a string,
// such as 'varchar(32)' or 'text', which will be used by CreateTable and
// nothing else. It is the caller's responsibility to ensure this will map
// cleanly to the underlying struct field via rows.Scan
func (c *ColumnMap) SetSqlType(t string) *ColumnMap {
c.sqltype = t
return c
}
// SetMaxSize specifies the max length of values of this column. This is
// passed to the dialect.ToSqlType() function, which can use the value
// to alter the generated type for "create table" statements
func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
c.MaxSize = size
return c
}
// Return a table for a pointer; error if i is not a pointer or if the
// table is not found
func tableForPointer(m *DbMap, i interface{}, checkPk bool) (*TableMap, reflect.Value, error) {
v := reflect.ValueOf(i)
if v.Kind() != reflect.Ptr {
return nil, v, fmt.Errorf("value %v not a pointer", v)
}
v = v.Elem()
t := m.TableForType(v.Type())
if t == nil {
return nil, v, fmt.Errorf("could not find table for %v", t)
}
if checkPk && len(t.Keys) < 1 {
return t, v, &NoKeysErr{t}
}
return t, v, nil
}

69
Godeps/_workspace/src/github.com/jmoiron/modl/test-all generated vendored Normal file
View file

@ -0,0 +1,69 @@
#!/bin/bash
#
# To use this script, set the following environment variables:
#
# MODL_MYSQL_DSN - mysql connect DSN, like "modltest/modltest/modltest"
# MODL_POSTGRES_DSN - postgres connect DSN, eg:
# "username=modltest password=modltest dbname=modltest ssl-mode=disable"
# MODL_SQLITE_DSN - sqlite connect DSN, which is a path to a sqlite file.
# MODL_FAIL_ON_SKIP - optional, will fail if any DBs are skipped (for CI, mostly)
#
# In addition to this, you can create an `environ` file in this directory which
# will be sourced and ignored by git.
#
# Additional arguments to test-all will be used after go-test, so to run
# the benchmark suite on all dbs you can do:
# ./test-all -bench=. -benchmem
#
if [ -f "./environ" ]; then
. ./environ
fi
function exit_on_error {
if [ $1 != 0 ]; then
exit $1
fi
}
# set -e
if [ -n "$MODL_MYSQL_DSN" ]; then
export MODL_TEST_DSN="$MODL_MYSQL_DSN"
export MODL_TEST_DIALECT="mysql"
echo "Testing MySQL"
go test $@
exit_on_error $?
else
echo "Skipping MySQL, \$MODL_MYSQL_DSN=$MODL_MYSQL_DSN"
if [ -n "$MODL_FAIL_ON_SKIP" ]; then
exit -1
fi
fi
if [ -n "$MODL_POSTGRES_DSN" ]; then
export MODL_TEST_DSN="$MODL_POSTGRES_DSN"
export MODL_TEST_DIALECT="postgres"
echo "Testing PostgreSQL"
go test $@
exit_on_error $?
else
echo "Skipping PostgreSQL, \$MODL_POSTGRES_DSN=$MODL_POSTGRES_DSN"
if [ -n "$MODL_FAIL_ON_SKIP" ]; then
exit -1
fi
fi
if [ -n "$MODL_SQLITE_DSN" ]; then
export MODL_TEST_DSN="$MODL_SQLITE_DSN"
export MODL_TEST_DIALECT="sqlite"
echo "Testing SQLite"
go test $@
exit_on_error $?
else
echo "Skipping SQLite, \$MODL_SQLITE_DSN=$MODL_SQLITE_DSN"
if [ -n "$MODL_FAIL_ON_SKIP" ]; then
exit -1
fi
fi

25
Godeps/_workspace/src/github.com/jmoiron/modl/todo.md generated vendored Normal file
View file

@ -0,0 +1,25 @@
Future:
- some way to designate a column as a foreign key
Todo:
- benchmarks that can compare mainline gorp to this fork
- cache/store as much reflect stuff as possible
- add query builder
- update docs with new examples
- add better interfaces to control underlying types to TableMap
In Progress:
(Both of these need tests)
- alter schema creation to take advantage of ColMap.sqltype
- alter schema creation to be able to return output (so people can look at or inspect it)
Done:
- remove list & new struct support form in favor of filling pointers and slices
- replace reflect struct filling with structscan from sqlx
- use strings.ToLower on table & field names by default, aligning behavior w/ sqlx
- replace hook calling process with one that uses interfaces

View file

@ -0,0 +1,67 @@
package modl
import (
"database/sql"
"github.com/jmoiron/sqlx"
)
// Transaction represents a database transaction.
// Insert/Update/Delete/Get/Exec operations will be run in the context
// of that transaction. Transactions should be terminated with
// a call to Commit() or Rollback()
type Transaction struct {
dbmap *DbMap
Tx *sqlx.Tx
}
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
func (t *Transaction) Insert(list ...interface{}) error {
return insert(t.dbmap, t, list...)
}
// Update has the same behavior as DbMap.Update(), but runs in a transaction.
func (t *Transaction) Update(list ...interface{}) (int64, error) {
return update(t.dbmap, t, list...)
}
// Delete has the same behavior as DbMap.Delete(), but runs in a transaction.
func (t *Transaction) Delete(list ...interface{}) (int64, error) {
return deletes(t.dbmap, t, list...)
}
// Get has the Same behavior as DbMap.Get(), but runs in a transaction.
func (t *Transaction) Get(dest interface{}, keys ...interface{}) error {
return get(t.dbmap, t, dest, keys...)
}
// Select has the Same behavior as DbMap.Select(), but runs in a transaction.
func (t *Transaction) Select(dest interface{}, query string, args ...interface{}) error {
return hookedselect(t.dbmap, t, dest, query, args...)
}
func (t *Transaction) SelectOne(dest interface{}, query string, args ...interface{}) error {
return hookedget(t.dbmap, t, dest, query, args...)
}
// Exec has the same behavior as DbMap.Exec(), but runs in a transaction.
func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) {
t.dbmap.trace(query, args)
return t.Tx.Exec(query, args...)
}
// Commit commits the underlying database transaction.
func (t *Transaction) Commit() error {
t.dbmap.trace("commit;")
return t.Tx.Commit()
}
// Rollback rolls back the underlying database transaction.
func (t *Transaction) Rollback() error {
t.dbmap.trace("rollback;")
return t.Tx.Rollback()
}
func (t *Transaction) handle() handle {
return &tracingHandle{h: t.Tx, d: t.dbmap}
}

View file

@ -0,0 +1,24 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
tags
environ

23
Godeps/_workspace/src/github.com/jmoiron/sqlx/LICENSE generated vendored Normal file
View file

@ -0,0 +1,23 @@
Copyright (c) 2013, Jason Moiron
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.

258
Godeps/_workspace/src/github.com/jmoiron/sqlx/README.md generated vendored Normal file
View file

@ -0,0 +1,258 @@
#sqlx
[![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE)
sqlx is a library which provides a set of extensions on go's standard
`database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`,
et al. all leave the underlying interfaces untouched, so that their interfaces
are a superset on the standard ones. This makes it relatively painless to
integrate existing codebases using database/sql with sqlx.
Major additional concepts are:
* Marshal rows into structs (with embedded struct support), maps, and slices
* Named parameter support including prepared statements
* `Get` and `Select` to go quickly from query to struct/slice
* `LoadFile` for executing statements from a file
There is now some [fairly comprehensive documentation](http://jmoiron.github.io/sqlx/) for sqlx.
You can also read the usage below for a quick sample on how sqlx works, or check out the [API
documentation on godoc](http://godoc.org/github.com/jmoiron/sqlx).
## Recent Changes
The ability to use basic types as Select and Get destinations was added. This
is only valid when there is one column in the result set, and both functions
return an error if this isn't the case. This allows for much simpler patterns
of access for single column results:
```go
var count int
err := db.Get(&count, "SELECT count(*) FROM person;")
var names []string
err := db.Select(&names, "SELECT name FROM person;")
```
See the note on Scannability at the bottom of this README for some more info.
### Backwards Compatibility
There is no Go1-like promise of absolute stability, but I take the issue
seriously and will maintain the library in a compatible state unless vital
bugs prevent me from doing so. Since [#59](https://github.com/jmoiron/sqlx/issues/59) and [#60](https://github.com/jmoiron/sqlx/issues/60) necessitated
breaking behavior, a wider API cleanup was done at the time of fixing.
## install
go get github.com/jmoiron/sqlx
## issues
Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of
`Columns()` can have duplicate names on queries like:
```sql
SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id;
```
making a struct or map destination ambiguous. Use `AS` in your queries
to give rows distinct names, `rows.Scan` to scan them manually, or
`SliceScan` to get a slice of results.
## usage
Below is an example which shows some common use cases for sqlx. Check
[sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more
usage.
```go
package main
import (
_ "github.com/lib/pq"
"database/sql"
"github.com/jmoiron/sqlx"
"log"
)
var schema = `
CREATE TABLE person (
first_name text,
last_name text,
email text
);
CREATE TABLE place (
country text,
city text NULL,
telcode integer
)`
type Person struct {
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
Email string
}
type Place struct {
Country string
City sql.NullString
TelCode int
}
func main() {
// this connects & tries a simple 'SELECT 1', panics on error
// use sqlx.Open() for sql.Open() semantics
db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable")
if err != nil {
log.Fatalln(err)
}
// exec the schema or fail; multi-statement Exec behavior varies between
// database drivers; pq will exec them all, sqlite3 won't, ymmv
db.MustExec(schema)
tx := db.MustBegin()
tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net")
tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net")
tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1")
tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852")
tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65")
// Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person
tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"})
tx.Commit()
// Query the database, storing results in a []Person (wrapped in []interface{})
people := []Person{}
db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC")
jason, john := people[0], people[1]
fmt.Printf("%#v\n%#v", jason, john)
// Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"}
// Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"}
// You can also get a single result, a la QueryRow
jason = Person{}
err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason")
fmt.Printf("%#v\n", jason)
// Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"}
// if you have null fields and use SELECT *, you must use sql.Null* in your struct
places := []Place{}
err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC")
if err != nil {
fmt.Println(err)
return
}
usa, singsing, honkers := places[0], places[1], places[2]
fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers)
// Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1}
// Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65}
// Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852}
// Loop through rows using only one struct
place := Place{}
rows, err := db.Queryx("SELECT * FROM place")
for rows.Next() {
err := rows.StructScan(&place)
if err != nil {
log.Fatalln(err)
}
fmt.Printf("%#v\n", place)
}
// Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1}
// Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852}
// Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65}
// Named queries, using `:name` as the bindvar. Automatic bindvar support
// which takes into account the dbtype based on the driverName on sqlx.Open/Connect
_, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`,
map[string]interface{}{
"first": "Bin",
"last": "Smuth",
"email": "bensmith@allblacks.nz",
})
// Selects Mr. Smith from the database
rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"})
// Named queries can also use structs. Their bind names follow the same rules
// as the name -> db mapping, so struct fields are lowercased and the `db` tag
// is taken into consideration.
rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason)
}
```
## Scannability
Get and Select are able to take base types, so the following is now possible:
```go
var name string
db.Get(&name, "SELECT first_name FROM person WHERE id=$1", 10)
var ids []int64
db.Select(&ids, "SELECT id FROM person LIMIT 20;")
```
This can get complicated with destination types which are structs, like `sql.NullString`. Because of this, straightforward rules for *scannability* had to be developed. Iff something is "Scannable", then it is used directly in `rows.Scan`; if it's not, then the standard sqlx struct rules apply.
Something is scannable if any of the following are true:
* It is not a struct, ie. `reflect.ValueOf(v).Kind() != reflect.Struct`
* It implements the `sql.Scanner` interface
* It has no exported fields (eg. `time.Time`)
## embedded structs
Scan targets obey Go attribute rules directly, including nested embedded structs. Older versions of sqlx would attempt to also descend into non-embedded structs, but this is no longer supported.
Go makes *accessing* '[ambiguous selectors](http://play.golang.org/p/MGRxdjLaUc)' a compile time error, defining structs with ambiguous selectors is legal. Sqlx will decide which field to use on a struct based on a breadth first search of the struct and any structs it embeds, as specified by the order of the fields as accessible by `reflect`, which generally means in source-order. This means that sqlx chooses the outer-most, top-most matching name for targets, even when the selector might technically be ambiguous.
## scan safety
By default, scanning into structs requires the structs to have fields for all of the
columns in the query. This was done for a few reasons:
* A mistake in naming during development could lead you to believe that data is
being written to a field when actually it can't be found and it is being dropped
* This behavior mirrors the behavior of the Go compiler with respect to unused
variables
* Selecting more data than you need is wasteful (more data on the wire, more time
marshalling, etc)
Unlike Marshallers in the stdlib, the programmer scanning an sql result into a struct
will generally have a full understanding of what the underlying data model is *and*
full control over the SQL statement.
Despite this, there are use cases where it's convenient to be able to ignore unknown
columns. In most of these cases, you might be better off with `ScanSlice`, but where
you want to still use structs, there is now the `Unsafe` method. Its usage is most
simply shown in an example:
```go
db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable")
if err != nil {
log.Fatal(err)
}
type Person {
Name string
}
var p Person
// This fails, because there is no destination for location in Person
err = db.Get(&p, "SELECT name, location FROM person LIMIT 1")
udb := db.Unsafe()
// This succeeds and just sets `Name` in the p struct
err = udb.Get(&p, "SELECT name, location FROM person LIMIT 1")
```
The `Unsafe` method is implemented on `Tx`, `DB`, and `Stmt`. When you use an unsafe
`Tx` or `DB` to create a new `Tx` or `Stmt`, those inherit its lack of safety.

84
Godeps/_workspace/src/github.com/jmoiron/sqlx/bind.go generated vendored Normal file
View file

@ -0,0 +1,84 @@
package sqlx
import (
"bytes"
"strconv"
)
// Bindvar types supported by Rebind, BindMap and BindStruct.
const (
UNKNOWN = iota
QUESTION
DOLLAR
NAMED
)
// BindType returns the bindtype for a given database given a drivername.
func BindType(driverName string) int {
switch driverName {
case "postgres", "pgx":
return DOLLAR
case "mysql":
return QUESTION
case "sqlite3":
return QUESTION
case "oci8":
return NAMED
}
return UNKNOWN
}
// FIXME: this should be able to be tolerant of escaped ?'s in queries without
// losing much speed, and should be to avoid confusion.
// FIXME: this is now produces the wrong results for oracle's NAMED bindtype
// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
func Rebind(bindType int, query string) string {
if bindType != DOLLAR {
return query
}
qb := []byte(query)
// Add space enough for 10 params before we have to allocate
rqb := make([]byte, 0, len(qb)+10)
j := 1
for _, b := range qb {
if b == '?' {
rqb = append(rqb, '$')
for _, b := range strconv.Itoa(j) {
rqb = append(rqb, byte(b))
}
j++
} else {
rqb = append(rqb, b)
}
}
return string(rqb)
}
// Experimental implementation of Rebind which uses a bytes.Buffer. The code is
// much simpler and should be more resistant to odd unicode, but it is twice as
// slow. Kept here for benchmarking purposes and to possibly replace Rebind if
// problems arise with its somewhat naive handling of unicode.
func rebindBuff(bindType int, query string) string {
if bindType != DOLLAR {
return query
}
b := make([]byte, 0, len(query))
rqb := bytes.NewBuffer(b)
j := 1
for _, r := range query {
if r == '?' {
rqb.WriteRune('$')
rqb.WriteString(strconv.Itoa(j))
j++
} else {
rqb.WriteRune(r)
}
}
return rqb.String()
}

12
Godeps/_workspace/src/github.com/jmoiron/sqlx/doc.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Package sqlx provides general purpose extensions to database/sql.
//
// It is intended to seamlessly wrap database/sql and provide convenience
// methods which are useful in the development of database driven applications.
// None of the underlying database/sql methods are changed. Instead all extended
// behavior is implemented through new methods defined on wrapper types.
//
// Additions include scanning into structs, named query support, rebinding
// queries for different drivers, convenient shorthands for common error handling
// and more.
//
package sqlx

321
Godeps/_workspace/src/github.com/jmoiron/sqlx/named.go generated vendored Normal file
View file

@ -0,0 +1,321 @@
package sqlx
// Named Query Support
//
// * BindMap - bind query bindvars to map/struct args
// * NamedExec, NamedQuery - named query w/ struct or map
// * NamedStmt - a pre-compiled named query which is a prepared statement
//
// Internal Interfaces:
//
// * compileNamedQuery - rebind a named query, returning a query and list of names
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"unicode"
"github.com/jmoiron/sqlx/reflectx"
)
// NamedStmt is a prepared statement that executes named queries. Prepare it
// how you would execute a NamedQuery, but pass in a struct or map when executing.
type NamedStmt struct {
Params []string
QueryString string
Stmt *Stmt
}
// Close closes the named statement.
func (n *NamedStmt) Close() error {
return n.Stmt.Close()
}
// Exec executes a named statement using the struct passed.
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.Exec(args...)
}
// Query executes a named statement using the struct argument, returning rows.
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.Query(args...)
}
// QueryRow executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
func (n *NamedStmt) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowx(args...)
}
// MustExec execs a NamedStmt, panicing on error
func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
res, err := n.Exec(arg)
if err != nil {
panic(err)
}
return res
}
// Queryx using this NamedStmt
func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
r, err := n.Query(arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper}, err
}
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
return n.QueryRow(arg)
}
// Select using this NamedStmt
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
rows, err := n.Query(arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// Get using this NamedStmt
func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
r := n.QueryRowx(arg)
return r.scanAny(dest, false)
}
// A union interface of preparer and binder, required to be able to prepare
// named statements (as the bindtype must be determined).
type namedPreparer interface {
Preparer
binder
}
func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
bindType := BindType(p.DriverName())
q, args, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := Preparex(p, q)
if err != nil {
return nil, err
}
return &NamedStmt{
QueryString: q,
Params: args,
Stmt: stmt,
}, nil
}
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg, m)
}
// private interface to generate a list of interfaces from a given struct
// type, given a list of names to pull out of the struct. Used by public
// BindStruct interface.
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
// grab the indirected value of arg
v := reflect.ValueOf(arg)
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
v = v.Elem()
}
fields := m.TraversalsByName(v.Type(), names)
for i, t := range fields {
if len(t) == 0 {
return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg)
}
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
}
return arglist, nil
}
// like bindArgs, but for maps.
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
for _, name := range names {
val, ok := arg[name]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
}
arglist = append(arglist, val)
}
return arglist, nil
}
// bindStruct binds a named parameter query with fields from a struct argument.
// The rules for binding field names to parameter names follow the same
// conventions as for StructScan, including obeying the `db` struct tags.
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindArgs(names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
return bound, arglist, nil
}
// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindMapArgs(names, args)
return bound, arglist, err
}
// -- Compilation of Named Queries
// Allow digits and letters in bind params; additionally runes are
// checked against underscores, meaning that bind params can have be
// alphanumeric with underscores. Mind the difference between unicode
// digits and numbers, where '5' is a digit but '五' is not.
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))
inName := false
last := len(qs) - 1
currentVar := 1
name := make([]byte, 0, 10)
for i, b := range qs {
// a ':' while we're in a name is an error
if b == ':' {
// if this is the second ':' in a '::' escape sequence, append a ':'
if inName && i > 0 && qs[i-1] == ':' {
rebound = append(rebound, ':')
inName = false
continue
} else if inName {
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
return query, names, err
}
inName = true
name = []byte{}
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last {
// append the byte to the name if we are in a name and not on the last byte
name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done
} else if inName {
inName = false
// if this is the final byte of the string and it is part of the name, then
// make sure to add it to the name
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
name = append(name, b)
}
// add the string representation to the names list
names = append(names, string(name))
// add a proper bindvar for the bindType
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ':')
rebound = append(rebound, name...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
rebound = append(rebound, b)
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
rebound = append(rebound, b)
}
} else {
// this is a normal byte and should just go onto the rebound query
rebound = append(rebound, b)
}
}
return string(rebound), names, err
}
// Bind binds a struct or a map to a query with named parameters.
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(bindType, query, arg, mapper())
}
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMap(bindType, query, maparg)
}
return bindStruct(bindType, query, arg, m)
}
// NamedQuery binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Queryx(q, args...)
}
// NamedExec uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Exec(q, args...)
}

View file

@ -0,0 +1,227 @@
package sqlx
import (
"database/sql"
"testing"
)
func TestCompileQuery(t *testing.T) {
table := []struct {
Q, R, D, N string
V []string
}{
// basic test for named parameters, invalid char ',' terminating
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
V: []string{"name", "age", "first", "last"},
},
// This query tests a named parameter ending the string as well as numbers
{
Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT * FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`,
N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
* We could certainly iterate by Rune instead, though it's a great deal slower,
* it's probably the RightWay(tm)
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: []string{"name", "age", "first", "last"},
},
*/
}
for _, test := range table {
qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION)
if err != nil {
t.Error(err)
}
if qr != test.R {
t.Errorf("expected %s, got %s", test.R, qr)
}
if len(names) != len(test.V) {
t.Errorf("expected %#v, got %#v", test.V, names)
} else {
for i, name := range names {
if name != test.V[i] {
t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name)
}
}
}
qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR)
if qd != test.D {
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd)
}
qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
if qq != test.N {
t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
}
}
}
type Test struct {
t *testing.T
}
func (t Test) Error(err error, msg ...interface{}) {
if err != nil {
if len(msg) == 0 {
t.t.Error(err)
} else {
t.t.Error(msg...)
}
}
}
func (t Test) Errorf(err error, format string, args ...interface{}) {
if err != nil {
t.t.Errorf(format, args...)
}
}
func TestNamedQueries(t *testing.T) {
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
loadDefaultFixture(db, t)
test := Test{t}
var ns *NamedStmt
var err error
// Check that invalid preparations fail
ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name")
if err == nil {
t.Error("Expected an error with invalid prepared statement.")
}
ns, err = db.PrepareNamed("invalid sql")
if err == nil {
t.Error("Expected an error with invalid prepared statement.")
}
// Check closing works as anticipated
ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name")
test.Error(err)
err = ns.Close()
test.Error(err)
ns, err = db.PrepareNamed(`
SELECT first_name, last_name, email
FROM person WHERE first_name=:first_name AND email=:email`)
test.Error(err)
// test Queryx w/ uses Query
p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"}
rows, err := ns.Queryx(p)
test.Error(err)
for rows.Next() {
var p2 Person
rows.StructScan(&p2)
if p.FirstName != p2.FirstName {
t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName)
}
if p.LastName != p2.LastName {
t.Errorf("got %s, expected %s", p.LastName, p2.LastName)
}
if p.Email != p2.Email {
t.Errorf("got %s, expected %s", p.Email, p2.Email)
}
}
// test Select
people := make([]Person, 0, 5)
err = ns.Select(&people, p)
test.Error(err)
if len(people) != 1 {
t.Errorf("got %d results, expected %d", len(people), 1)
}
if p.FirstName != people[0].FirstName {
t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName)
}
if p.LastName != people[0].LastName {
t.Errorf("got %s, expected %s", p.LastName, people[0].LastName)
}
if p.Email != people[0].Email {
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
}
// test Exec
ns, err = db.PrepareNamed(`
INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email)`)
test.Error(err)
js := Person{
FirstName: "Julien",
LastName: "Savea",
Email: "jsavea@ab.co.nz",
}
_, err = ns.Exec(js)
test.Error(err)
// Make sure we can pull him out again
p2 := Person{}
db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email)
if p2.Email != js.Email {
t.Errorf("expected %s, got %s", js.Email, p2.Email)
}
// test Txn NamedStmts
tx := db.MustBegin()
txns := tx.NamedStmt(ns)
// We're going to add Steven in this txn
sl := Person{
FirstName: "Steven",
LastName: "Luatua",
Email: "sluatua@ab.co.nz",
}
_, err = txns.Exec(sl)
test.Error(err)
// then rollback...
tx.Rollback()
// looking for Steven after a rollback should fail
err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
if err != sql.ErrNoRows {
t.Errorf("expected no rows error, got %v", err)
}
// now do the same, but commit
tx = db.MustBegin()
txns = tx.NamedStmt(ns)
_, err = txns.Exec(sl)
test.Error(err)
tx.Commit()
// looking for Steven after a Commit should succeed
err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
test.Error(err)
if p2.Email != sl.Email {
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
}
})
}

View file

@ -0,0 +1,17 @@
# reflectx
The sqlx package has special reflect needs. In particular, it needs to:
* be able to map a name to a field
* understand embedded structs
* understand mapping names to fields by a particular tag
* user specified name -> field mapping functions
These behaviors mimic the behaviors by the standard library marshallers and also the
behavior of standard Go accessors.
The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is
addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct
tags in the ways that are vital to most marshalers, and they are slow.
This reflectx package extends reflect to achieve these goals.

View file

@ -0,0 +1,268 @@
// Package reflect implements extensions to the standard reflect lib suitable
// for implementing marshaling and unmarshaling packages. The main Mapper type
// allows for Go-compatible named atribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
//
package reflectx
import "sync"
import (
"reflect"
"runtime"
)
type fieldMap map[string][]int
// Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers, optionally obeying a field tag for name
// mapping and a function to provide a basic mapping of fields to names.
type Mapper struct {
cache map[reflect.Type]fieldMap
tagName string
tagMapFunc func(string) string
mapFunc func(string) string
mutex sync.Mutex
}
// NewMapper returns a new mapper which optionally obeys the field tag given
// by tagName. If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]fieldMap),
tagName: tagName,
}
}
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
// AND a mapper for tag values. This is useful for tags like json which can
// have values like "name,omitempty".
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]fieldMap),
tagName: tagName,
mapFunc: mapFunc,
tagMapFunc: tagMapFunc,
}
}
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
// a struct field name mapper func given by f. Tags will take precedence, but
// for any other field, the mapped name will be f(field.Name)
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]fieldMap),
tagName: tagName,
mapFunc: f,
}
}
// TypeMap returns a mapping of field strings to int slices representing
// the traversal down the struct to reach the field.
func (m *Mapper) TypeMap(t reflect.Type) fieldMap {
m.mutex.Lock()
mapping, ok := m.cache[t]
if !ok {
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
m.cache[t] = mapping
}
m.mutex.Unlock()
return mapping
}
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
nm := m.TypeMap(v.Type())
for tagName, indexes := range nm {
r[tagName] = FieldByIndexes(v, indexes)
}
return r
}
// FieldByName returns a field by the its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
nm := m.TypeMap(v.Type())
traversal, ok := nm[name]
if !ok {
return *new(reflect.Value)
}
return FieldByIndexes(v, traversal)
}
// FieldsByName returns a slice of values corresponding to the slice of names
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
// to a struct Kind. Returns zero Value for each name not found.
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
nm := m.TypeMap(v.Type())
vals := make([]reflect.Value, 0, len(names))
for _, name := range names {
traversal, ok := nm[name]
if !ok {
vals = append(vals, *new(reflect.Value))
} else {
vals = append(vals, FieldByIndexes(v, traversal))
}
}
return vals
}
// Traversals by name returns a slice of int slices which represent the struct
// traversals for each mapped name. Panics if t is not a struct or Indirectable
// to a struct. Returns empty int slice for each name not found.
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
t = Deref(t)
mustBe(t, reflect.Struct)
nm := m.TypeMap(t)
r := make([][]int, 0, len(names))
for _, name := range names {
traversal, ok := nm[name]
if !ok {
r = append(r, []int{})
} else {
r = append(r, traversal)
}
}
return r
}
// FieldByIndexes returns a value for a particular struct traversal.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
// if this is a pointer, it's possible it is nil
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
return v
}
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
// but is not concerned with allocating nil pointers because the value is
// going to be used for reading and not setting.
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
}
return v
}
// Deref is Indirect for reflect.Types
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
// -- helpers & utilities --
type Kinder interface {
Kind() reflect.Kind
}
// mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required.
func mustBe(v Kinder, expected reflect.Kind) {
k := v.Kind()
if k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k})
}
}
// methodName is returns the caller of the function calling methodName
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
type typeQueue struct {
t reflect.Type
p []int
}
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
for p, n := range is {
x[p] = n
}
x[len(x)-1] = i
return x
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) fieldMap {
queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), []int{}})
m := fieldMap{}
for len(queue) != 0 {
// pop the first item off of the queue
tq := queue[0]
queue = queue[1:]
// iterate through all of its fields
for fieldPos := 0; fieldPos < tq.t.NumField(); fieldPos++ {
f := tq.t.Field(fieldPos)
name := f.Tag.Get(tagName)
if len(name) == 0 {
if mapFunc != nil {
name = mapFunc(f.Name)
} else {
name = f.Name
}
} else if tagMapFunc != nil {
name = tagMapFunc(name)
}
// if the name is "-", disabled via a tag, skip it
if name == "-" {
continue
}
// skip unexported fields
if len(f.PkgPath) != 0 {
continue
}
// bfs search of anonymous embedded structs
if f.Anonymous {
queue = append(queue, typeQueue{Deref(f.Type), apnd(tq.p, fieldPos)})
continue
}
// if the name is shadowed by an earlier identical name in the search, skip it
if _, ok := m[name]; ok {
continue
}
// add it to the map at the current position
m[name] = apnd(tq.p, fieldPos)
}
}
return m
}

Some files were not shown because too many files have changed in this diff Show more