methods.go (12485B)
1 // Copyright 2019 The Hugo Authors. All rights reserved.
2 // Some functions in this file (see comments) is based on the Go source code,
3 // copyright The Go Authors and governed by a BSD-style license.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15
16 // Package codegen contains helpers for code generation.
17 package codegen
18
19 import (
20 "fmt"
21 "go/ast"
22 "go/parser"
23 "go/token"
24 "os"
25 "path"
26 "path/filepath"
27 "reflect"
28 "regexp"
29 "sort"
30 "strings"
31 "sync"
32 )
33
34 // Make room for insertions
35 const weightWidth = 1000
36
37 // NewInspector creates a new Inspector given a source root.
38 func NewInspector(root string) *Inspector {
39 return &Inspector{ProjectRootDir: root}
40 }
41
42 // Inspector provides methods to help code generation. It uses a combination
43 // of reflection and source code AST to do the heavy lifting.
44 type Inspector struct {
45 ProjectRootDir string
46
47 init sync.Once
48
49 // Determines method order. Go's reflect sorts lexicographically, so
50 // we must parse the source to preserve this order.
51 methodWeight map[string]map[string]int
52 }
53
54 // MethodsFromTypes create a method set from the include slice, excluding any
55 // method in exclude.
56 func (c *Inspector) MethodsFromTypes(include []reflect.Type, exclude []reflect.Type) Methods {
57 c.parseSource()
58
59 var methods Methods
60
61 excludes := make(map[string]bool)
62
63 if len(exclude) > 0 {
64 for _, m := range c.MethodsFromTypes(exclude, nil) {
65 excludes[m.Name] = true
66 }
67 }
68
69 // There may be overlapping interfaces in types. Do a simple check for now.
70 seen := make(map[string]bool)
71
72 nameAndPackage := func(t reflect.Type) (string, string) {
73 var name, pkg string
74
75 isPointer := t.Kind() == reflect.Ptr
76
77 if isPointer {
78 t = t.Elem()
79 }
80
81 pkgPrefix := ""
82 if pkgPath := t.PkgPath(); pkgPath != "" {
83 pkgPath = strings.TrimSuffix(pkgPath, "/")
84 _, shortPath := path.Split(pkgPath)
85 pkgPrefix = shortPath + "."
86 pkg = pkgPath
87 }
88
89 name = t.Name()
90 if name == "" {
91 // interface{}
92 name = t.String()
93 }
94
95 if isPointer {
96 pkgPrefix = "*" + pkgPrefix
97 }
98
99 name = pkgPrefix + name
100
101 return name, pkg
102 }
103
104 for _, t := range include {
105 for i := 0; i < t.NumMethod(); i++ {
106
107 m := t.Method(i)
108 if excludes[m.Name] || seen[m.Name] {
109 continue
110 }
111
112 seen[m.Name] = true
113
114 if m.PkgPath != "" {
115 // Not exported
116 continue
117 }
118
119 numIn := m.Type.NumIn()
120
121 ownerName, _ := nameAndPackage(t)
122
123 method := Method{Owner: t, OwnerName: ownerName, Name: m.Name}
124
125 for i := 0; i < numIn; i++ {
126 in := m.Type.In(i)
127
128 name, pkg := nameAndPackage(in)
129
130 if pkg != "" {
131 method.Imports = append(method.Imports, pkg)
132 }
133
134 method.In = append(method.In, name)
135 }
136
137 numOut := m.Type.NumOut()
138
139 if numOut > 0 {
140 for i := 0; i < numOut; i++ {
141 out := m.Type.Out(i)
142 name, pkg := nameAndPackage(out)
143
144 if pkg != "" {
145 method.Imports = append(method.Imports, pkg)
146 }
147
148 method.Out = append(method.Out, name)
149 }
150 }
151
152 methods = append(methods, method)
153 }
154 }
155
156 sort.SliceStable(methods, func(i, j int) bool {
157 mi, mj := methods[i], methods[j]
158
159 wi := c.methodWeight[mi.OwnerName][mi.Name]
160 wj := c.methodWeight[mj.OwnerName][mj.Name]
161
162 if wi == wj {
163 return mi.Name < mj.Name
164 }
165
166 return wi < wj
167 })
168
169 return methods
170 }
171
172 func (c *Inspector) parseSource() {
173 c.init.Do(func() {
174 if !strings.Contains(c.ProjectRootDir, "hugo") {
175 panic("dir must be set to the Hugo root")
176 }
177
178 c.methodWeight = make(map[string]map[string]int)
179 dirExcludes := regexp.MustCompile("docs|examples")
180 fileExcludes := regexp.MustCompile("autogen")
181 var filenames []string
182
183 filepath.Walk(c.ProjectRootDir, func(path string, info os.FileInfo, err error) error {
184 if info.IsDir() {
185 if dirExcludes.MatchString(info.Name()) {
186 return filepath.SkipDir
187 }
188 }
189
190 if !strings.HasSuffix(path, ".go") || fileExcludes.MatchString(path) {
191 return nil
192 }
193
194 filenames = append(filenames, path)
195
196 return nil
197 })
198
199 for _, filename := range filenames {
200
201 pkg := c.packageFromPath(filename)
202
203 fset := token.NewFileSet()
204 node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
205 if err != nil {
206 panic(err)
207 }
208
209 ast.Inspect(node, func(n ast.Node) bool {
210 switch t := n.(type) {
211 case *ast.TypeSpec:
212 if t.Name.IsExported() {
213 switch it := t.Type.(type) {
214 case *ast.InterfaceType:
215 iface := pkg + "." + t.Name.Name
216 methodNames := collectMethodsRecursive(pkg, it.Methods.List)
217 weights := make(map[string]int)
218 weight := weightWidth
219 for _, name := range methodNames {
220 weights[name] = weight
221 weight += weightWidth
222 }
223 c.methodWeight[iface] = weights
224 }
225 }
226 }
227 return true
228 })
229
230 }
231
232 // Complement
233 for _, v1 := range c.methodWeight {
234 for k2, w := range v1 {
235 if v, found := c.methodWeight[k2]; found {
236 for k3, v3 := range v {
237 v1[k3] = (v3 / weightWidth) + w
238 }
239 }
240 }
241 }
242 })
243 }
244
245 func (c *Inspector) packageFromPath(p string) string {
246 p = filepath.ToSlash(p)
247 base := path.Base(p)
248 if !strings.Contains(base, ".") {
249 return base
250 }
251 return path.Base(strings.TrimSuffix(p, base))
252 }
253
254 // Method holds enough information about it to recreate it.
255 type Method struct {
256 // The interface we extracted this method from.
257 Owner reflect.Type
258
259 // String version of the above, on the form PACKAGE.NAME, e.g.
260 // page.Page
261 OwnerName string
262
263 // Method name.
264 Name string
265
266 // Imports needed to satisfy the method signature.
267 Imports []string
268
269 // Argument types, including any package prefix, e.g. string, int, interface{},
270 // net.Url
271 In []string
272
273 // Return types.
274 Out []string
275 }
276
277 // Declaration creates a method declaration (without any body) for the given receiver.
278 func (m Method) Declaration(receiver string) string {
279 return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStr())
280 }
281
282 // DeclarationNamed creates a method declaration (without any body) for the given receiver
283 // with named return values.
284 func (m Method) DeclarationNamed(receiver string) string {
285 return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStrNamed())
286 }
287
288 // Delegate creates a delegate call string.
289 func (m Method) Delegate(receiver, delegate string) string {
290 ret := ""
291 if len(m.Out) > 0 {
292 ret = "return "
293 }
294 return fmt.Sprintf("%s%s.%s.%s%s", ret, receiverShort(receiver), delegate, m.Name, m.inOutStr())
295 }
296
297 func (m Method) String() string {
298 return m.Name + m.inStr() + " " + m.outStr() + "\n"
299 }
300
301 func (m Method) inOutStr() string {
302 if len(m.In) == 0 {
303 return "()"
304 }
305
306 args := make([]string, len(m.In))
307 for i := 0; i < len(args); i++ {
308 args[i] = fmt.Sprintf("arg%d", i)
309 }
310 return "(" + strings.Join(args, ", ") + ")"
311 }
312
313 func (m Method) inStr() string {
314 if len(m.In) == 0 {
315 return "()"
316 }
317
318 args := make([]string, len(m.In))
319 for i := 0; i < len(args); i++ {
320 args[i] = fmt.Sprintf("arg%d %s", i, m.In[i])
321 }
322 return "(" + strings.Join(args, ", ") + ")"
323 }
324
325 func (m Method) outStr() string {
326 if len(m.Out) == 0 {
327 return ""
328 }
329 if len(m.Out) == 1 {
330 return m.Out[0]
331 }
332
333 return "(" + strings.Join(m.Out, ", ") + ")"
334 }
335
336 func (m Method) outStrNamed() string {
337 if len(m.Out) == 0 {
338 return ""
339 }
340
341 outs := make([]string, len(m.Out))
342 for i := 0; i < len(outs); i++ {
343 outs[i] = fmt.Sprintf("o%d %s", i, m.Out[i])
344 }
345
346 return "(" + strings.Join(outs, ", ") + ")"
347 }
348
349 // Methods represents a list of methods for one or more interfaces.
350 // The order matches the defined order in their source file(s).
351 type Methods []Method
352
353 // Imports returns a sorted list of package imports needed to satisfy the
354 // signatures of all methods.
355 func (m Methods) Imports() []string {
356 var pkgImports []string
357 for _, method := range m {
358 pkgImports = append(pkgImports, method.Imports...)
359 }
360 if len(pkgImports) > 0 {
361 pkgImports = uniqueNonEmptyStrings(pkgImports)
362 sort.Strings(pkgImports)
363 }
364 return pkgImports
365 }
366
367 // ToMarshalJSON creates a MarshalJSON method for these methods. Any method name
368 // matching any of the regexps in excludes will be ignored.
369 func (m Methods) ToMarshalJSON(receiver, pkgPath string, excludes ...string) (string, []string) {
370 var sb strings.Builder
371
372 r := receiverShort(receiver)
373 what := firstToUpper(trimAsterisk(receiver))
374 pgkName := path.Base(pkgPath)
375
376 fmt.Fprintf(&sb, "func Marshal%sToJSON(%s %s) ([]byte, error) {\n", what, r, receiver)
377
378 var methods Methods
379 excludeRes := make([]*regexp.Regexp, len(excludes))
380
381 for i, exclude := range excludes {
382 excludeRes[i] = regexp.MustCompile(exclude)
383 }
384
385 for _, method := range m {
386 // Exclude methods with arguments and incompatible return values
387 if len(method.In) > 0 || len(method.Out) == 0 || len(method.Out) > 2 {
388 continue
389 }
390
391 if len(method.Out) == 2 {
392 if method.Out[1] != "error" {
393 continue
394 }
395 }
396
397 for _, re := range excludeRes {
398 if re.MatchString(method.Name) {
399 continue
400 }
401 }
402
403 methods = append(methods, method)
404 }
405
406 for _, method := range methods {
407 varn := varName(method.Name)
408 if len(method.Out) == 1 {
409 fmt.Fprintf(&sb, "\t%s := %s.%s()\n", varn, r, method.Name)
410 } else {
411 fmt.Fprintf(&sb, "\t%s, err := %s.%s()\n", varn, r, method.Name)
412 fmt.Fprint(&sb, "\tif err != nil {\n\t\treturn nil, err\n\t}\n")
413 }
414 }
415
416 fmt.Fprint(&sb, "\n\ts := struct {\n")
417
418 for _, method := range methods {
419 fmt.Fprintf(&sb, "\t\t%s %s\n", method.Name, typeName(method.Out[0], pgkName))
420 }
421
422 fmt.Fprint(&sb, "\n\t}{\n")
423
424 for _, method := range methods {
425 varn := varName(method.Name)
426 fmt.Fprintf(&sb, "\t\t%s: %s,\n", method.Name, varn)
427 }
428
429 fmt.Fprint(&sb, "\n\t}\n\n")
430 fmt.Fprint(&sb, "\treturn json.Marshal(&s)\n}")
431
432 pkgImports := append(methods.Imports(), "encoding/json")
433
434 if pkgPath != "" {
435 // Exclude self
436 for i, pkgImp := range pkgImports {
437 if pkgImp == pkgPath {
438 pkgImports = append(pkgImports[:i], pkgImports[i+1:]...)
439 }
440 }
441 }
442
443 return sb.String(), pkgImports
444 }
445
446 func collectMethodsRecursive(pkg string, f []*ast.Field) []string {
447 var methodNames []string
448 for _, m := range f {
449 if m.Names != nil {
450 methodNames = append(methodNames, m.Names[0].Name)
451 continue
452 }
453
454 if ident, ok := m.Type.(*ast.Ident); ok && ident.Obj != nil {
455 // Embedded interface
456 methodNames = append(
457 methodNames,
458 collectMethodsRecursive(
459 pkg,
460 ident.Obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List)...)
461 } else {
462 // Embedded, but in a different file/package. Return the
463 // package.Name and deal with that later.
464 name := packageName(m.Type)
465 if !strings.Contains(name, ".") {
466 // Assume current package
467 name = pkg + "." + name
468 }
469 methodNames = append(methodNames, name)
470 }
471 }
472
473 return methodNames
474 }
475
476 func firstToLower(name string) string {
477 return strings.ToLower(name[:1]) + name[1:]
478 }
479
480 func firstToUpper(name string) string {
481 return strings.ToUpper(name[:1]) + name[1:]
482 }
483
484 func packageName(e ast.Expr) string {
485 switch tp := e.(type) {
486 case *ast.Ident:
487 return tp.Name
488 case *ast.SelectorExpr:
489 return fmt.Sprintf("%s.%s", packageName(tp.X), packageName(tp.Sel))
490 }
491 return ""
492 }
493
494 func receiverShort(receiver string) string {
495 return strings.ToLower(trimAsterisk(receiver))[:1]
496 }
497
498 func trimAsterisk(name string) string {
499 return strings.TrimPrefix(name, "*")
500 }
501
502 func typeName(name, pkg string) string {
503 return strings.TrimPrefix(name, pkg+".")
504 }
505
506 func uniqueNonEmptyStrings(s []string) []string {
507 var unique []string
508 set := map[string]any{}
509 for _, val := range s {
510 if val == "" {
511 continue
512 }
513 if _, ok := set[val]; !ok {
514 unique = append(unique, val)
515 set[val] = val
516 }
517 }
518 return unique
519 }
520
521 func varName(name string) string {
522 name = firstToLower(name)
523
524 // Adjust some reserved keywords, see https://golang.org/ref/spec#Keywords
525 switch name {
526 case "type":
527 name = "typ"
528 case "package":
529 name = "pkg"
530 // Not reserved, but syntax highlighters has it as a keyword.
531 case "len":
532 name = "length"
533 }
534
535 return name
536 }