# HG changeset patch # User Sean Russell # Date 1397532124 14400 # Mon Apr 14 23:22:04 2014 -0400 # Node ID e34b61ba5769175ab108fc030e5a445c0d793fea # Parent 5d1797efb606fab776ce8863cd21167414542464 Adds support for produce/consume mime type specification diff --git a/api.go b/api.go --- a/api.go +++ b/api.go @@ -106,9 +106,9 @@ case "byte": return "[]byte" case "date": - return "date.Time" + return "time.Time" case "date-time": - return "date.Time" + return "time.Time" } case "boolean": case "array": diff --git a/generator.go b/generator.go --- a/generator.go +++ b/generator.go @@ -4,6 +4,8 @@ "bytes" "fmt" "strings" + "github.com/glenn-brown/skiplist" + "sync" ) var swaggerToGo map[string]string @@ -13,27 +15,40 @@ if e != nil { return e } - out <- fmt.Sprintf("package %s", pack) - out <- "" - // FIXME: only import packages that are used - out <- "import (" - out <- "\t\"bytes\"" - out <- "\t\"encoding/json\"" - out <- "\t\"errors\"" - out <- "\t\"fmt\"" - out <- "\t\"io\"" - out <- "\t\"io/ioutil\"" - out <- "\t\"mime/multipart\"" - out <- "\t\"net/http\"" - out <- "\t\"os\"" - out <- ")" - out <- "" + outWrapper := make(chan string) + wait := sync.WaitGroup{} + wait.Add(1) + go func() { + lines := make([]string, 0, 100) + imports := skiplist.New() + imports.Set("import \"fmt\"", true) + imports.Set("import \"net/http\"", true) + for line := range outWrapper { + if strings.HasPrefix(line, "import ") { + imports.Set(line, true) + } else { + lines = append(lines, line) + } + } + out <- fmt.Sprintf("package %s", pack) + out <- "" + for f := imports.Front(); f != nil; f = f.Next() { + out <- f.Key().(string) + } + out <- "" + for _, line := range lines { + out <- line + } + wait.Done() + }() switch t.(type) { case ApiDecl: - generateApi(t.(ApiDecl), out) + generateApi(t.(ApiDecl), outWrapper) case ResourceListing: - generateResourceListing(t.(ResourceListing), out) + generateResourceListing(t.(ResourceListing), outWrapper) } + close(outWrapper) + wait.Wait() return nil } @@ -64,6 +79,11 @@ if pv.Description != "" { out <- fmt.Sprintf("\t// %s", pv.Description) } + if strings.HasPrefix(pv.typeOf(), "time.") { + out <- "import \"time\"" + } else if strings.HasPrefix(pv.typeOf(), "*os.") { + out <- "import \"os\"" + } out <- fmt.Sprintf("\t%s %s", upcase(pk), pv.typeOf()) } out <- "}" @@ -88,7 +108,7 @@ func generateFunc(api Api, op Operation, out chan string) { fname := upcase(op.Nickname) - s := fmt.Sprintf("func %s(%s) %s {", fname, genArgs(op), genRv(op)) + s := fmt.Sprintf("func %s(%s) %s {", fname, genArgs(op, out), genRv(op)) out <- s switch op.Method { case "DELETE": @@ -106,19 +126,21 @@ out <- "}" } -func genArrayToString(p Parameter) string { +func genArrayToString(p Parameter, out chan string) string { rv := bytes.Buffer{} xs := fmt.Sprintf("%ss", p.Name) rv.WriteString(fmt.Sprintf("\t%s := []string{}", xs)) rv.WriteString(fmt.Sprintf("\tfor _, v := range %s {", p.Name)) rv.WriteString(fmt.Sprintf("\t\t%s = append(%s, toString(v))", xs, xs)) rv.WriteString("\t}") + out <- "import \"strings\"" rv.WriteString(fmt.Sprintf("\t%sString := strings.Join(%s, \"%2C\")", p.Name, xs)) return string(rv.Bytes()) } func genDeserialize(op Operation, out chan string) { out <- "\tdefer resp.Body.Close()" + out <- "import \"io/ioutil\"" out <- "\ts, e := ioutil.ReadAll(resp.Body)" out <- "\tif e != nil {" out <- fmt.Sprintf("\t\treturn %s, e", op.zero()) @@ -127,15 +149,17 @@ case "string": out <- "\treturn s, nil" case "int": + out <- "import \"strconv\"" out <- "\treturn strconv.Atoi(s)" case "int64": + out <- "import \"strconv\"" out <- "\tk, e := strconv.ParseInt(s, 10, 64)" out <- "\tif e != nil {" out <- fmt.Sprintf("\t\treturn %s, e", op.zero()) out <- "\t}" out <- "\treturn k, nil" default: - out <- "\te = json.Unmarshal(s, &rv)" + encoding(op.Produces, "\te = %s.Unmarshal(s, &rv)", out) out <- "\tif e != nil {" out <- fmt.Sprintf("\t\treturn %s, e", op.zero()) out <- "\t}" @@ -143,12 +167,42 @@ } } -func genArgs(op Operation) string { +func encoding(items []string, format string, out chan string) { + hasJson := false + hasXml := false + for _, prod := range items { + if strings.HasSuffix(prod, "json") { + hasJson = true + } else if strings.HasSuffix(prod, "xml"){ + hasXml = true + } + } + if len(items) == 0 { + hasJson = true + } + if !(hasJson || hasXml) { + panic("API must support either XML or JSON") + } + if hasJson { + out <- "import \"encoding/json\"" + out <- fmt.Sprintf(format, "json") + } else { + out <- "import \"encoding/xml\"" + out <- fmt.Sprintf(format, "xml") + } +} + +func genArgs(op Operation, out chan string) string { b := bytes.Buffer{} np := len(op.Parameters) - 1 for i, p := range op.Parameters { b.WriteString(p.Name) b.WriteString(" ") + if strings.HasPrefix(p.typeOf(), "time.") { + out <- "import \"time\"" + } else if strings.HasPrefix(p.typeOf(), "*os.") { + out <- "import \"os\"" + } b.WriteString(p.typeOf()) if i < np { b.WriteString(", ") diff --git a/get.go b/get.go --- a/get.go +++ b/get.go @@ -18,11 +18,13 @@ for _, p := range op.Parameters { if p.Minimum != "" { out <- fmt.Sprintf("\tif %s < %s(%s) {", p.Name, p.typeOf(), p.Minimum) + out <- "import \"errors\"" out <- fmt.Sprintf("\t\treturn %serrors.New(fmt.Sprintf(\"invalid value (%%d < %s)\", %s))", rv, p.Minimum, p.Name) out <- "\t}" } if p.Maximum != "" { out <- fmt.Sprintf("\tif %s > %s(%s) {", p.Name, p.typeOf(), p.Maximum) + out <- "import \"errors\"" out <- fmt.Sprintf("\t\treturn %serrors.New(fmt.Sprintf(\"invalid value (%%d > %s)\", %s))", rv, p.Maximum, p.Name) out <- "\t}" } @@ -33,7 +35,7 @@ } else { v := p.Name if p.Type == "array" { - v = genArrayToString(p) + v = genArrayToString(p, out) } params[p.Name] = v } @@ -59,17 +61,42 @@ out <- url out <- "\tclient := &http.Client{}" out <- fmt.Sprintf("\treq, err := http.NewRequest(%#v, url, nil)", op.Method) + addProduce(op, out) out <- "\tresp, err := client.Do(req)" out <- "\tif err != nil {" out <- fmt.Sprintf("\t\treturn %serr", zero) out <- "\t}" for _, c := range op.ResponseMessages { out <- fmt.Sprintf("\tif resp.StatusCode == %d {", c.Code) + out <- "import \"errors\"" out <- fmt.Sprintf("\t\treturn %serrors.New(%#v)", rv, c.Message) out <- "\t}" } } +func addProduce(op Operation, out chan string) { + hasJson := false + hasXml := false + for _, prod := range op.Produces { + if strings.HasSuffix(prod, "json") { + hasJson = true + } else if strings.HasSuffix(prod, "xml"){ + hasXml = true + } + } + if len(op.Produces) == 0 { + hasJson = true + } + if !(hasJson || hasXml) { + panic("API must support either XML or JSON") + } + if hasJson { + out <- "\treq.Header.Add(\"Accept-Type\", \"application/json\")" + } else { + out <- "\treq.Header.Add(\"Accept-Type\", \"application/xml\")" + } +} + func genGet(api Api, op Operation, out chan string) { out <- fmt.Sprintf("\trv := %s", op.zero()) genGetBase(api, op, out) diff --git a/post.go b/post.go --- a/post.go +++ b/post.go @@ -38,7 +38,7 @@ file = p } else { hasBody = true - out <- fmt.Sprintf("\tb, e := json.Marshal(%s)", p.Name) + encoding(op.Consumes, fmt.Sprintf("\tb, e := %%s.Marshal(%s)", p.Name), out) out <- "\tif e != nil {" out <- fmt.Sprintf("\t\treturn %se", rv) out <- "\t}" @@ -46,7 +46,7 @@ } else { v := p.Name if p.Type == "array" { - v = genArrayToString(p) + v = genArrayToString(p, out) } params[p.Name] = v } @@ -60,7 +60,9 @@ out <- url out <- "\tclient := &http.Client{}" if hasBody { + out <- "import \"bytes\"" out <- fmt.Sprintf("\treq, err := http.NewRequest(%#v, url, bytes.NewReader(b))", op.Method) + addProduce(op, out) } else { postMultipartFile(op, file, params, out) } @@ -74,6 +76,7 @@ out <- "\t}" } out <- "\tif resp.StatusCode >= 500 && resp.StatusCode < 600 {" + out <- "import \"errors\"" out <- fmt.Sprintf("\t\treturn %serrors.New(\"Internal server error\")", rv) out <- "\t}" if op.Type != "void" { @@ -88,8 +91,11 @@ } func postMultipartFile(op Operation, file Parameter, params map[string]string, out chan string) { + out <- "import \"bytes\"" out <- "\tvar b bytes.Buffer" + out <- "import \"mime/multipart\"" out <- "\tw := multipart.NewWriter(&b)" + out <- "import \"io\"" out <- "\tvar fw io.Writer" out <- "\tvar err error" if file.Name != "" {