e34b61ba5769 — Sean Russell 11 years ago
Adds support for produce/consume mime type specification
4 files changed, 113 insertions(+), 26 deletions(-)

M api.go
M generator.go
M get.go
M post.go
M api.go +2 -2
@@ 106,9 106,9 @@ func (p DataTypeFields) typeOf() string 
 		case "byte":
 			return "[]byte"
 		case "date":
-			return "date.Time"
+			return "time.Time"
 		case "date-time":
-			return "date.Time"
+			return "time.Time"
 		}
 	case "boolean":
 	case "array":

          
M generator.go +75 -21
@@ 4,6 4,8 @@ import (
 	"bytes"
 	"fmt"
 	"strings"
+	"github.com/glenn-brown/skiplist"
+	"sync"
 )
 
 var swaggerToGo map[string]string

          
@@ 13,27 15,40 @@ func Generate(pack string, j []byte, out
 	if e != nil {
 		return e
 	}
-	out <- fmt.Sprintf("package %s", pack)
-	out <- ""
-	// FIXME: only import packages that are used
-	out <- "import ("
-	out <- "\t\"bytes\""
-	out <- "\t\"encoding/json\""
-	out <- "\t\"errors\""
-	out <- "\t\"fmt\""
-	out <- "\t\"io\""
-	out <- "\t\"io/ioutil\""
-	out <- "\t\"mime/multipart\""
-	out <- "\t\"net/http\""
-	out <- "\t\"os\""
-	out <- ")"
-	out <- ""
+	outWrapper := make(chan string)
+	wait := sync.WaitGroup{}
+	wait.Add(1)
+	go func() {
+		lines := make([]string, 0, 100)
+		imports := skiplist.New()
+		imports.Set("import \"fmt\"", true)
+		imports.Set("import \"net/http\"", true)
+		for line := range outWrapper {
+			if strings.HasPrefix(line, "import ") {
+				imports.Set(line, true)
+			} else {
+				lines = append(lines, line)
+			}
+		}
+		out <- fmt.Sprintf("package %s", pack)
+		out <- ""
+		for f := imports.Front(); f != nil; f = f.Next() {
+			out <- f.Key().(string)
+		}
+		out <- ""
+		for _, line := range lines {
+			out <- line
+		}
+		wait.Done()
+	}()
 	switch t.(type) {
 	case ApiDecl:
-		generateApi(t.(ApiDecl), out)
+		generateApi(t.(ApiDecl), outWrapper)
 	case ResourceListing:
-		generateResourceListing(t.(ResourceListing), out)
+		generateResourceListing(t.(ResourceListing), outWrapper)
 	}
+	close(outWrapper)
+	wait.Wait()
 	return nil
 }
 

          
@@ 64,6 79,11 @@ func generateModels(a ApiDecl, out chan 
 			if pv.Description != "" {
 				out <- fmt.Sprintf("\t// %s", pv.Description)
 			}
+			if strings.HasPrefix(pv.typeOf(), "time.") {
+				out <- "import \"time\""
+			} else if strings.HasPrefix(pv.typeOf(), "*os.") {
+				out <- "import \"os\""
+			}
 			out <- fmt.Sprintf("\t%s %s", upcase(pk), pv.typeOf())
 		}
 		out <- "}"

          
@@ 88,7 108,7 @@ func generateDocs(op Operation, out chan
 
 func generateFunc(api Api, op Operation, out chan string) {
 	fname := upcase(op.Nickname)
-	s := fmt.Sprintf("func %s(%s) %s {", fname, genArgs(op), genRv(op))
+	s := fmt.Sprintf("func %s(%s) %s {", fname, genArgs(op, out), genRv(op))
 	out <- s
 	switch op.Method {
 	case "DELETE":

          
@@ 106,19 126,21 @@ func generateFunc(api Api, op Operation,
 	out <- "}"
 }
 
-func genArrayToString(p Parameter) string {
+func genArrayToString(p Parameter, out chan string) string {
 	rv := bytes.Buffer{}
 	xs := fmt.Sprintf("%ss", p.Name)
 	rv.WriteString(fmt.Sprintf("\t%s := []string{}", xs))
 	rv.WriteString(fmt.Sprintf("\tfor _, v := range %s {", p.Name))
 	rv.WriteString(fmt.Sprintf("\t\t%s = append(%s, toString(v))", xs, xs))
 	rv.WriteString("\t}")
+	out <- "import \"strings\""
 	rv.WriteString(fmt.Sprintf("\t%sString := strings.Join(%s, \"%2C\")", p.Name, xs))
 	return string(rv.Bytes())
 }
 
 func genDeserialize(op Operation, out chan string) {
 	out <- "\tdefer resp.Body.Close()"
+	out <- "import \"io/ioutil\""
 	out <- "\ts, e := ioutil.ReadAll(resp.Body)"
 	out <- "\tif e != nil {"
 	out <- fmt.Sprintf("\t\treturn %s, e", op.zero())

          
@@ 127,15 149,17 @@ func genDeserialize(op Operation, out ch
 	case "string":
 		out <- "\treturn s, nil"
 	case "int":
+		out <- "import \"strconv\""
 		out <- "\treturn strconv.Atoi(s)"
 	case "int64":
+		out <- "import \"strconv\""
 		out <- "\tk, e := strconv.ParseInt(s, 10, 64)"
 		out <- "\tif e != nil {"
 		out <- fmt.Sprintf("\t\treturn %s, e", op.zero())
 		out <- "\t}"
 		out <- "\treturn k, nil"
 	default:
-		out <- "\te = json.Unmarshal(s, &rv)"
+		encoding(op.Produces, "\te = %s.Unmarshal(s, &rv)", out)
 		out <- "\tif e != nil {"
 		out <- fmt.Sprintf("\t\treturn %s, e", op.zero())
 		out <- "\t}"

          
@@ 143,12 167,42 @@ func genDeserialize(op Operation, out ch
 	}
 }
 
-func genArgs(op Operation) string {
+func encoding(items []string, format string, out chan string) {
+	hasJson := false
+	hasXml := false
+	for _, prod := range items {
+		if strings.HasSuffix(prod, "json") {
+			hasJson = true
+		} else if strings.HasSuffix(prod, "xml"){
+			hasXml = true
+		}
+	}
+	if len(items) == 0 {
+		hasJson = true
+	}
+	if !(hasJson || hasXml) {
+		panic("API must support either XML or JSON")
+	}
+	if hasJson {
+		out <- "import \"encoding/json\""
+		out <- fmt.Sprintf(format, "json")
+	} else {
+		out <- "import \"encoding/xml\""
+		out <- fmt.Sprintf(format, "xml")
+	}
+}
+
+func genArgs(op Operation, out chan string) string {
 	b := bytes.Buffer{}
 	np := len(op.Parameters) - 1
 	for i, p := range op.Parameters {
 		b.WriteString(p.Name)
 		b.WriteString(" ")
+		if strings.HasPrefix(p.typeOf(), "time.") {
+			out <- "import \"time\""
+		} else if strings.HasPrefix(p.typeOf(), "*os.") {
+			out <- "import \"os\""
+		}
 		b.WriteString(p.typeOf())
 		if i < np {
 			b.WriteString(", ")

          
M get.go +28 -1
@@ 18,11 18,13 @@ func genGetBase(api Api, op Operation, o
 	for _, p := range op.Parameters {
 		if p.Minimum != "" {
 			out <- fmt.Sprintf("\tif %s < %s(%s) {", p.Name, p.typeOf(), p.Minimum)
+			out <- "import \"errors\""
 			out <- fmt.Sprintf("\t\treturn %serrors.New(fmt.Sprintf(\"invalid value (%%d < %s)\", %s))", rv, p.Minimum, p.Name)
 			out <- "\t}"
 		}
 		if p.Maximum != "" {
 			out <- fmt.Sprintf("\tif %s > %s(%s) {", p.Name, p.typeOf(), p.Maximum)
+			out <- "import \"errors\""
 			out <- fmt.Sprintf("\t\treturn %serrors.New(fmt.Sprintf(\"invalid value (%%d > %s)\", %s))", rv, p.Maximum, p.Name)
 			out <- "\t}"
 		}

          
@@ 33,7 35,7 @@ func genGetBase(api Api, op Operation, o
 		} else {
 			v := p.Name
 			if p.Type == "array" {
-				v = genArrayToString(p)
+				v = genArrayToString(p, out)
 			}
 			params[p.Name] = v
 		}

          
@@ 59,17 61,42 @@ func genGetBase(api Api, op Operation, o
 	out <- url
 	out <- "\tclient := &http.Client{}"
 	out <- fmt.Sprintf("\treq, err := http.NewRequest(%#v, url, nil)", op.Method)
+	addProduce(op, out)
 	out <- "\tresp, err := client.Do(req)"
 	out <- "\tif err != nil {"
 	out <- fmt.Sprintf("\t\treturn %serr", zero)
 	out <- "\t}"
 	for _, c := range op.ResponseMessages {
 		out <- fmt.Sprintf("\tif resp.StatusCode == %d {", c.Code)
+		out <- "import \"errors\""
 		out <- fmt.Sprintf("\t\treturn %serrors.New(%#v)", rv, c.Message)
 		out <- "\t}"
 	}
 }
 
+func addProduce(op Operation, out chan string) {
+	hasJson := false
+	hasXml := false
+	for _, prod := range op.Produces {
+		if strings.HasSuffix(prod, "json") {
+			hasJson = true
+		} else if strings.HasSuffix(prod, "xml"){
+			hasXml = true
+		}
+	}
+	if len(op.Produces) == 0 {
+		hasJson = true
+	}
+	if !(hasJson || hasXml) {
+		panic("API must support either XML or JSON")
+	}
+	if hasJson {
+		out <- "\treq.Header.Add(\"Accept-Type\", \"application/json\")"
+	} else {
+		out <- "\treq.Header.Add(\"Accept-Type\", \"application/xml\")"
+	}
+}
+
 func genGet(api Api, op Operation, out chan string) {
 	out <- fmt.Sprintf("\trv := %s", op.zero())
 	genGetBase(api, op, out)

          
M post.go +8 -2
@@ 38,7 38,7 @@ func genPostBase(api Api, op Operation, 
 				file = p
 			} else {
 				hasBody = true
-				out <- fmt.Sprintf("\tb, e := json.Marshal(%s)", p.Name)
+				encoding(op.Consumes, fmt.Sprintf("\tb, e := %%s.Marshal(%s)", p.Name), out)
 				out <- "\tif e != nil {"
 				out <- fmt.Sprintf("\t\treturn %se", rv)
 				out <- "\t}"

          
@@ 46,7 46,7 @@ func genPostBase(api Api, op Operation, 
 		} else {
 			v := p.Name
 			if p.Type == "array" {
-				v = genArrayToString(p)
+				v = genArrayToString(p, out)
 			}
 			params[p.Name] = v
 		}

          
@@ 60,7 60,9 @@ func genPostBase(api Api, op Operation, 
 	out <- url
 	out <- "\tclient := &http.Client{}"
 	if hasBody {
+		out <- "import \"bytes\""
 		out <- fmt.Sprintf("\treq, err := http.NewRequest(%#v, url, bytes.NewReader(b))", op.Method)
+		addProduce(op, out)
 	} else {
 		postMultipartFile(op, file, params, out)
 	}

          
@@ 74,6 76,7 @@ func genPostBase(api Api, op Operation, 
 		out <- "\t}"
 	}
 	out <- "\tif resp.StatusCode >= 500 && resp.StatusCode < 600 {"
+	out <- "import \"errors\""
 	out <- fmt.Sprintf("\t\treturn %serrors.New(\"Internal server error\")", rv)
 	out <- "\t}"
 	if op.Type != "void" {

          
@@ 88,8 91,11 @@ func genPost(api Api, op Operation, out 
 }
 
 func postMultipartFile(op Operation, file Parameter, params map[string]string, out chan string) {
+	out <- "import \"bytes\""
     out <- "\tvar b bytes.Buffer"
+    out <- "import \"mime/multipart\""
     out <- "\tw := multipart.NewWriter(&b)"
+    out <- "import \"io\""
     out <- "\tvar fw io.Writer"
     out <- "\tvar err error"
     if file.Name != "" {