methods.go (12485B)
1 // Copyright 2019 The Hugo Authors. All rights reserved. 2 // Some functions in this file (see comments) is based on the Go source code, 3 // copyright The Go Authors and governed by a BSD-style license. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Package codegen contains helpers for code generation. 17 package codegen 18 19 import ( 20 "fmt" 21 "go/ast" 22 "go/parser" 23 "go/token" 24 "os" 25 "path" 26 "path/filepath" 27 "reflect" 28 "regexp" 29 "sort" 30 "strings" 31 "sync" 32 ) 33 34 // Make room for insertions 35 const weightWidth = 1000 36 37 // NewInspector creates a new Inspector given a source root. 38 func NewInspector(root string) *Inspector { 39 return &Inspector{ProjectRootDir: root} 40 } 41 42 // Inspector provides methods to help code generation. It uses a combination 43 // of reflection and source code AST to do the heavy lifting. 44 type Inspector struct { 45 ProjectRootDir string 46 47 init sync.Once 48 49 // Determines method order. Go's reflect sorts lexicographically, so 50 // we must parse the source to preserve this order. 51 methodWeight map[string]map[string]int 52 } 53 54 // MethodsFromTypes create a method set from the include slice, excluding any 55 // method in exclude. 56 func (c *Inspector) MethodsFromTypes(include []reflect.Type, exclude []reflect.Type) Methods { 57 c.parseSource() 58 59 var methods Methods 60 61 excludes := make(map[string]bool) 62 63 if len(exclude) > 0 { 64 for _, m := range c.MethodsFromTypes(exclude, nil) { 65 excludes[m.Name] = true 66 } 67 } 68 69 // There may be overlapping interfaces in types. Do a simple check for now. 70 seen := make(map[string]bool) 71 72 nameAndPackage := func(t reflect.Type) (string, string) { 73 var name, pkg string 74 75 isPointer := t.Kind() == reflect.Ptr 76 77 if isPointer { 78 t = t.Elem() 79 } 80 81 pkgPrefix := "" 82 if pkgPath := t.PkgPath(); pkgPath != "" { 83 pkgPath = strings.TrimSuffix(pkgPath, "/") 84 _, shortPath := path.Split(pkgPath) 85 pkgPrefix = shortPath + "." 86 pkg = pkgPath 87 } 88 89 name = t.Name() 90 if name == "" { 91 // interface{} 92 name = t.String() 93 } 94 95 if isPointer { 96 pkgPrefix = "*" + pkgPrefix 97 } 98 99 name = pkgPrefix + name 100 101 return name, pkg 102 } 103 104 for _, t := range include { 105 for i := 0; i < t.NumMethod(); i++ { 106 107 m := t.Method(i) 108 if excludes[m.Name] || seen[m.Name] { 109 continue 110 } 111 112 seen[m.Name] = true 113 114 if m.PkgPath != "" { 115 // Not exported 116 continue 117 } 118 119 numIn := m.Type.NumIn() 120 121 ownerName, _ := nameAndPackage(t) 122 123 method := Method{Owner: t, OwnerName: ownerName, Name: m.Name} 124 125 for i := 0; i < numIn; i++ { 126 in := m.Type.In(i) 127 128 name, pkg := nameAndPackage(in) 129 130 if pkg != "" { 131 method.Imports = append(method.Imports, pkg) 132 } 133 134 method.In = append(method.In, name) 135 } 136 137 numOut := m.Type.NumOut() 138 139 if numOut > 0 { 140 for i := 0; i < numOut; i++ { 141 out := m.Type.Out(i) 142 name, pkg := nameAndPackage(out) 143 144 if pkg != "" { 145 method.Imports = append(method.Imports, pkg) 146 } 147 148 method.Out = append(method.Out, name) 149 } 150 } 151 152 methods = append(methods, method) 153 } 154 } 155 156 sort.SliceStable(methods, func(i, j int) bool { 157 mi, mj := methods[i], methods[j] 158 159 wi := c.methodWeight[mi.OwnerName][mi.Name] 160 wj := c.methodWeight[mj.OwnerName][mj.Name] 161 162 if wi == wj { 163 return mi.Name < mj.Name 164 } 165 166 return wi < wj 167 }) 168 169 return methods 170 } 171 172 func (c *Inspector) parseSource() { 173 c.init.Do(func() { 174 if !strings.Contains(c.ProjectRootDir, "hugo") { 175 panic("dir must be set to the Hugo root") 176 } 177 178 c.methodWeight = make(map[string]map[string]int) 179 dirExcludes := regexp.MustCompile("docs|examples") 180 fileExcludes := regexp.MustCompile("autogen") 181 var filenames []string 182 183 filepath.Walk(c.ProjectRootDir, func(path string, info os.FileInfo, err error) error { 184 if info.IsDir() { 185 if dirExcludes.MatchString(info.Name()) { 186 return filepath.SkipDir 187 } 188 } 189 190 if !strings.HasSuffix(path, ".go") || fileExcludes.MatchString(path) { 191 return nil 192 } 193 194 filenames = append(filenames, path) 195 196 return nil 197 }) 198 199 for _, filename := range filenames { 200 201 pkg := c.packageFromPath(filename) 202 203 fset := token.NewFileSet() 204 node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) 205 if err != nil { 206 panic(err) 207 } 208 209 ast.Inspect(node, func(n ast.Node) bool { 210 switch t := n.(type) { 211 case *ast.TypeSpec: 212 if t.Name.IsExported() { 213 switch it := t.Type.(type) { 214 case *ast.InterfaceType: 215 iface := pkg + "." + t.Name.Name 216 methodNames := collectMethodsRecursive(pkg, it.Methods.List) 217 weights := make(map[string]int) 218 weight := weightWidth 219 for _, name := range methodNames { 220 weights[name] = weight 221 weight += weightWidth 222 } 223 c.methodWeight[iface] = weights 224 } 225 } 226 } 227 return true 228 }) 229 230 } 231 232 // Complement 233 for _, v1 := range c.methodWeight { 234 for k2, w := range v1 { 235 if v, found := c.methodWeight[k2]; found { 236 for k3, v3 := range v { 237 v1[k3] = (v3 / weightWidth) + w 238 } 239 } 240 } 241 } 242 }) 243 } 244 245 func (c *Inspector) packageFromPath(p string) string { 246 p = filepath.ToSlash(p) 247 base := path.Base(p) 248 if !strings.Contains(base, ".") { 249 return base 250 } 251 return path.Base(strings.TrimSuffix(p, base)) 252 } 253 254 // Method holds enough information about it to recreate it. 255 type Method struct { 256 // The interface we extracted this method from. 257 Owner reflect.Type 258 259 // String version of the above, on the form PACKAGE.NAME, e.g. 260 // page.Page 261 OwnerName string 262 263 // Method name. 264 Name string 265 266 // Imports needed to satisfy the method signature. 267 Imports []string 268 269 // Argument types, including any package prefix, e.g. string, int, interface{}, 270 // net.Url 271 In []string 272 273 // Return types. 274 Out []string 275 } 276 277 // Declaration creates a method declaration (without any body) for the given receiver. 278 func (m Method) Declaration(receiver string) string { 279 return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStr()) 280 } 281 282 // DeclarationNamed creates a method declaration (without any body) for the given receiver 283 // with named return values. 284 func (m Method) DeclarationNamed(receiver string) string { 285 return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStrNamed()) 286 } 287 288 // Delegate creates a delegate call string. 289 func (m Method) Delegate(receiver, delegate string) string { 290 ret := "" 291 if len(m.Out) > 0 { 292 ret = "return " 293 } 294 return fmt.Sprintf("%s%s.%s.%s%s", ret, receiverShort(receiver), delegate, m.Name, m.inOutStr()) 295 } 296 297 func (m Method) String() string { 298 return m.Name + m.inStr() + " " + m.outStr() + "\n" 299 } 300 301 func (m Method) inOutStr() string { 302 if len(m.In) == 0 { 303 return "()" 304 } 305 306 args := make([]string, len(m.In)) 307 for i := 0; i < len(args); i++ { 308 args[i] = fmt.Sprintf("arg%d", i) 309 } 310 return "(" + strings.Join(args, ", ") + ")" 311 } 312 313 func (m Method) inStr() string { 314 if len(m.In) == 0 { 315 return "()" 316 } 317 318 args := make([]string, len(m.In)) 319 for i := 0; i < len(args); i++ { 320 args[i] = fmt.Sprintf("arg%d %s", i, m.In[i]) 321 } 322 return "(" + strings.Join(args, ", ") + ")" 323 } 324 325 func (m Method) outStr() string { 326 if len(m.Out) == 0 { 327 return "" 328 } 329 if len(m.Out) == 1 { 330 return m.Out[0] 331 } 332 333 return "(" + strings.Join(m.Out, ", ") + ")" 334 } 335 336 func (m Method) outStrNamed() string { 337 if len(m.Out) == 0 { 338 return "" 339 } 340 341 outs := make([]string, len(m.Out)) 342 for i := 0; i < len(outs); i++ { 343 outs[i] = fmt.Sprintf("o%d %s", i, m.Out[i]) 344 } 345 346 return "(" + strings.Join(outs, ", ") + ")" 347 } 348 349 // Methods represents a list of methods for one or more interfaces. 350 // The order matches the defined order in their source file(s). 351 type Methods []Method 352 353 // Imports returns a sorted list of package imports needed to satisfy the 354 // signatures of all methods. 355 func (m Methods) Imports() []string { 356 var pkgImports []string 357 for _, method := range m { 358 pkgImports = append(pkgImports, method.Imports...) 359 } 360 if len(pkgImports) > 0 { 361 pkgImports = uniqueNonEmptyStrings(pkgImports) 362 sort.Strings(pkgImports) 363 } 364 return pkgImports 365 } 366 367 // ToMarshalJSON creates a MarshalJSON method for these methods. Any method name 368 // matching any of the regexps in excludes will be ignored. 369 func (m Methods) ToMarshalJSON(receiver, pkgPath string, excludes ...string) (string, []string) { 370 var sb strings.Builder 371 372 r := receiverShort(receiver) 373 what := firstToUpper(trimAsterisk(receiver)) 374 pgkName := path.Base(pkgPath) 375 376 fmt.Fprintf(&sb, "func Marshal%sToJSON(%s %s) ([]byte, error) {\n", what, r, receiver) 377 378 var methods Methods 379 excludeRes := make([]*regexp.Regexp, len(excludes)) 380 381 for i, exclude := range excludes { 382 excludeRes[i] = regexp.MustCompile(exclude) 383 } 384 385 for _, method := range m { 386 // Exclude methods with arguments and incompatible return values 387 if len(method.In) > 0 || len(method.Out) == 0 || len(method.Out) > 2 { 388 continue 389 } 390 391 if len(method.Out) == 2 { 392 if method.Out[1] != "error" { 393 continue 394 } 395 } 396 397 for _, re := range excludeRes { 398 if re.MatchString(method.Name) { 399 continue 400 } 401 } 402 403 methods = append(methods, method) 404 } 405 406 for _, method := range methods { 407 varn := varName(method.Name) 408 if len(method.Out) == 1 { 409 fmt.Fprintf(&sb, "\t%s := %s.%s()\n", varn, r, method.Name) 410 } else { 411 fmt.Fprintf(&sb, "\t%s, err := %s.%s()\n", varn, r, method.Name) 412 fmt.Fprint(&sb, "\tif err != nil {\n\t\treturn nil, err\n\t}\n") 413 } 414 } 415 416 fmt.Fprint(&sb, "\n\ts := struct {\n") 417 418 for _, method := range methods { 419 fmt.Fprintf(&sb, "\t\t%s %s\n", method.Name, typeName(method.Out[0], pgkName)) 420 } 421 422 fmt.Fprint(&sb, "\n\t}{\n") 423 424 for _, method := range methods { 425 varn := varName(method.Name) 426 fmt.Fprintf(&sb, "\t\t%s: %s,\n", method.Name, varn) 427 } 428 429 fmt.Fprint(&sb, "\n\t}\n\n") 430 fmt.Fprint(&sb, "\treturn json.Marshal(&s)\n}") 431 432 pkgImports := append(methods.Imports(), "encoding/json") 433 434 if pkgPath != "" { 435 // Exclude self 436 for i, pkgImp := range pkgImports { 437 if pkgImp == pkgPath { 438 pkgImports = append(pkgImports[:i], pkgImports[i+1:]...) 439 } 440 } 441 } 442 443 return sb.String(), pkgImports 444 } 445 446 func collectMethodsRecursive(pkg string, f []*ast.Field) []string { 447 var methodNames []string 448 for _, m := range f { 449 if m.Names != nil { 450 methodNames = append(methodNames, m.Names[0].Name) 451 continue 452 } 453 454 if ident, ok := m.Type.(*ast.Ident); ok && ident.Obj != nil { 455 // Embedded interface 456 methodNames = append( 457 methodNames, 458 collectMethodsRecursive( 459 pkg, 460 ident.Obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List)...) 461 } else { 462 // Embedded, but in a different file/package. Return the 463 // package.Name and deal with that later. 464 name := packageName(m.Type) 465 if !strings.Contains(name, ".") { 466 // Assume current package 467 name = pkg + "." + name 468 } 469 methodNames = append(methodNames, name) 470 } 471 } 472 473 return methodNames 474 } 475 476 func firstToLower(name string) string { 477 return strings.ToLower(name[:1]) + name[1:] 478 } 479 480 func firstToUpper(name string) string { 481 return strings.ToUpper(name[:1]) + name[1:] 482 } 483 484 func packageName(e ast.Expr) string { 485 switch tp := e.(type) { 486 case *ast.Ident: 487 return tp.Name 488 case *ast.SelectorExpr: 489 return fmt.Sprintf("%s.%s", packageName(tp.X), packageName(tp.Sel)) 490 } 491 return "" 492 } 493 494 func receiverShort(receiver string) string { 495 return strings.ToLower(trimAsterisk(receiver))[:1] 496 } 497 498 func trimAsterisk(name string) string { 499 return strings.TrimPrefix(name, "*") 500 } 501 502 func typeName(name, pkg string) string { 503 return strings.TrimPrefix(name, pkg+".") 504 } 505 506 func uniqueNonEmptyStrings(s []string) []string { 507 var unique []string 508 set := map[string]any{} 509 for _, val := range s { 510 if val == "" { 511 continue 512 } 513 if _, ok := set[val]; !ok { 514 unique = append(unique, val) 515 set[val] = val 516 } 517 } 518 return unique 519 } 520 521 func varName(name string) string { 522 name = firstToLower(name) 523 524 // Adjust some reserved keywords, see https://golang.org/ref/spec#Keywords 525 switch name { 526 case "type": 527 name = "typ" 528 case "package": 529 name = "pkg" 530 // Not reserved, but syntax highlighters has it as a keyword. 531 case "len": 532 name = "length" 533 } 534 535 return name 536 }