1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/token"
15 "log"
16 "os"
17 "strings"
18
19 "golang.org/x/tools/go/ast/astutil"
20
21 internalastutil "runtime/_mkmalloc/astutil"
22 )
23
24 var stdout = flag.Bool("stdout", false, "write sizeclasses source to stdout instead of sizeclasses.go")
25
26 func makeSizeToSizeClass(classes []class) []uint8 {
27 sc := uint8(0)
28 ret := make([]uint8, smallScanNoHeaderMax+1)
29 for i := range ret {
30 if i > classes[sc].size {
31 sc++
32 }
33 ret[i] = sc
34 }
35 return ret
36 }
37
38 func main() {
39 log.SetFlags(0)
40 log.SetPrefix("mkmalloc: ")
41
42 classes := makeClasses()
43 sizeToSizeClass := makeSizeToSizeClass(classes)
44
45 if *stdout {
46 if _, err := os.Stdout.Write(mustFormat(generateSizeClasses(classes))); err != nil {
47 log.Fatal(err)
48 }
49 return
50 }
51
52 sizeclasesesfile := "../../internal/runtime/gc/sizeclasses.go"
53 if err := os.WriteFile(sizeclasesesfile, mustFormat(generateSizeClasses(classes)), 0666); err != nil {
54 log.Fatal(err)
55 }
56
57 outfile := "../malloc_generated.go"
58 if err := os.WriteFile(outfile, mustFormat(inline(specializedMallocConfig(classes, sizeToSizeClass))), 0666); err != nil {
59 log.Fatal(err)
60 }
61
62 tablefile := "../malloc_tables_generated.go"
63 if err := os.WriteFile(tablefile, mustFormat(generateTable(sizeToSizeClass)), 0666); err != nil {
64 log.Fatal(err)
65 }
66 }
67
68
69 func withLineNumbers(b []byte) []byte {
70 var buf bytes.Buffer
71 i := 1
72 for line := range bytes.Lines(b) {
73 fmt.Fprintf(&buf, "%d: %s", i, line)
74 i++
75 }
76 return buf.Bytes()
77 }
78
79
80 func mustFormat(b []byte) []byte {
81 formatted, err := format.Source(b)
82 if err != nil {
83 log.Fatalf("error formatting source: %v\nsource:\n%s\n", err, withLineNumbers(b))
84 }
85 return formatted
86 }
87
88
89
90 type generatorConfig struct {
91 file string
92 specs []spec
93 }
94
95
96
97
98 type spec struct {
99 name string
100 templateFunc string
101 ops []op
102 }
103
104
105 type replacementKind int
106
107 const (
108 inlineFunc = replacementKind(iota)
109 subBasicLit
110 foldCondition
111 )
112
113
114
115
116 type op struct {
117 kind replacementKind
118 from string
119 to string
120 }
121
122 func smallScanNoHeaderSCFuncName(sc, scMax uint8) string {
123 if sc == 0 || sc > scMax {
124 return "mallocPanic"
125 }
126 return fmt.Sprintf("mallocgcSmallScanNoHeaderSC%d", sc)
127 }
128
129 func tinyFuncName(size uintptr) string {
130 if size == 0 || size > smallScanNoHeaderMax {
131 return "mallocPanic"
132 }
133 return fmt.Sprintf("mallocgcTinySize%d", size)
134 }
135
136 func smallNoScanSCFuncName(sc, scMax uint8) string {
137 if sc < 2 || sc > scMax {
138 return "mallocPanic"
139 }
140 return fmt.Sprintf("mallocgcSmallNoScanSC%d", sc)
141 }
142
143
144
145 func specializedMallocConfig(classes []class, sizeToSizeClass []uint8) generatorConfig {
146 config := generatorConfig{file: "../malloc_stubs.go"}
147
148
149
150
151 scMax := sizeToSizeClass[smallScanNoHeaderMax]
152
153 str := fmt.Sprint
154
155
156 {
157 const noscan = 0
158 for sc := uint8(0); sc <= scMax; sc++ {
159 if sc == 0 {
160 continue
161 }
162 name := smallScanNoHeaderSCFuncName(sc, scMax)
163 elemsize := classes[sc].size
164 config.specs = append(config.specs, spec{
165 templateFunc: "mallocStub",
166 name: name,
167 ops: []op{
168 {inlineFunc, "inlinedMalloc", "smallScanNoHeaderStub"},
169 {inlineFunc, "heapSetTypeNoHeaderStub", "heapSetTypeNoHeaderStub"},
170 {inlineFunc, "nextFreeFastStub", "nextFreeFastStub"},
171 {inlineFunc, "writeHeapBitsSmallStub", "writeHeapBitsSmallStub"},
172 {subBasicLit, "elemsize_", str(elemsize)},
173 {subBasicLit, "sizeclass_", str(sc)},
174 {subBasicLit, "noscanint_", str(noscan)},
175 {foldCondition, "isTiny_", str(false)},
176 },
177 })
178 }
179 }
180
181
182 {
183 const noscan = 1
184
185
186 tinySizeClass := sizeToSizeClass[tinySize]
187 for s := range uintptr(16) {
188 if s == 0 {
189 continue
190 }
191 name := tinyFuncName(s)
192 elemsize := classes[tinySizeClass].size
193 config.specs = append(config.specs, spec{
194 templateFunc: "mallocStub",
195 name: name,
196 ops: []op{
197 {inlineFunc, "inlinedMalloc", "tinyStub"},
198 {inlineFunc, "nextFreeFastTiny", "nextFreeFastTiny"},
199 {subBasicLit, "elemsize_", str(elemsize)},
200 {subBasicLit, "sizeclass_", str(tinySizeClass)},
201 {subBasicLit, "size_", str(s)},
202 {subBasicLit, "noscanint_", str(noscan)},
203 {foldCondition, "isTiny_", str(true)},
204 },
205 })
206 }
207
208
209 for sc := uint8(tinySizeClass); sc <= scMax; sc++ {
210 name := smallNoScanSCFuncName(sc, scMax)
211 elemsize := classes[sc].size
212 config.specs = append(config.specs, spec{
213 templateFunc: "mallocStub",
214 name: name,
215 ops: []op{
216 {inlineFunc, "inlinedMalloc", "smallNoScanStub"},
217 {inlineFunc, "nextFreeFastStub", "nextFreeFastStub"},
218 {subBasicLit, "elemsize_", str(elemsize)},
219 {subBasicLit, "sizeclass_", str(sc)},
220 {subBasicLit, "noscanint_", str(noscan)},
221 {foldCondition, "isTiny_", str(false)},
222 },
223 })
224 }
225 }
226
227 return config
228 }
229
230
231 func inline(config generatorConfig) []byte {
232 var out bytes.Buffer
233
234
235 fset := token.NewFileSet()
236 f, err := parser.ParseFile(fset, config.file, nil, 0)
237 if err != nil {
238 log.Fatalf("parsing %s: %v", config.file, err)
239 }
240
241
242
243
244
245
246 funcDecls := map[string]*ast.FuncDecl{}
247 importDecls := []*ast.GenDecl{}
248 for _, decl := range f.Decls {
249 switch decl := decl.(type) {
250 case *ast.FuncDecl:
251 funcDecls[decl.Name.Name] = decl
252 case *ast.GenDecl:
253 if decl.Tok.String() == "import" {
254 importDecls = append(importDecls, decl)
255 continue
256 }
257 }
258 }
259
260
261 out.WriteString("// Code generated by mkmalloc.go; DO NOT EDIT.\n")
262 out.WriteString("// See overview in malloc_stubs.go.\n\n")
263 out.WriteString("package " + f.Name.Name + "\n\n")
264 for _, importDecl := range importDecls {
265 out.Write(mustFormatNode(fset, importDecl))
266 out.WriteString("\n\n")
267 }
268
269
270 for _, spec := range config.specs {
271
272 containingFuncCopy := internalastutil.CloneNode(funcDecls[spec.templateFunc])
273 if containingFuncCopy == nil {
274 log.Fatal("did not find", spec.templateFunc)
275 }
276 containingFuncCopy.Name.Name = spec.name
277
278
279 stamped := ast.Node(containingFuncCopy)
280 for _, repl := range spec.ops {
281 switch repl.kind {
282 case inlineFunc:
283 if toDecl, ok := funcDecls[repl.to]; ok {
284 stamped = inlineFunction(stamped, repl.from, toDecl)
285 }
286 case subBasicLit:
287 stamped = substituteWithBasicLit(stamped, repl.from, repl.to)
288 case foldCondition:
289 stamped = foldIfCondition(stamped, repl.from, repl.to)
290 default:
291 log.Fatalf("unknown op kind %v", repl.kind)
292 }
293 }
294
295 out.Write(mustFormatNode(fset, stamped))
296 out.WriteString("\n\n")
297 }
298
299 return out.Bytes()
300 }
301
302
303
304 func substituteWithBasicLit(node ast.Node, from, to string) ast.Node {
305
306 toExpr, err := parser.ParseExpr(to)
307 if err != nil {
308 log.Fatalf("parsing expr %q: %v", to, err)
309 }
310 if _, ok := toExpr.(*ast.BasicLit); !ok {
311 log.Fatalf("op 'to' expr %q is not a basic literal", to)
312 }
313 return astutil.Apply(node, func(cursor *astutil.Cursor) bool {
314 if isIdentWithName(cursor.Node(), from) {
315 cursor.Replace(toExpr)
316 }
317 return true
318 }, nil)
319 }
320
321
322
323
324 func foldIfCondition(node ast.Node, from, to string) ast.Node {
325 var isTrue bool
326 switch to {
327 case "true":
328 isTrue = true
329 case "false":
330 isTrue = false
331 default:
332 log.Fatalf("op 'to' expr %q is not true or false", to)
333 }
334 return astutil.Apply(node, func(cursor *astutil.Cursor) bool {
335 var foldIfTrue bool
336 ifexpr, ok := cursor.Node().(*ast.IfStmt)
337 if !ok {
338 return true
339 }
340 if isIdentWithName(ifexpr.Cond, from) {
341 foldIfTrue = true
342 } else if unaryexpr, ok := ifexpr.Cond.(*ast.UnaryExpr); ok && unaryexpr.Op == token.NOT && isIdentWithName(unaryexpr.X, from) {
343 foldIfTrue = false
344 } else {
345
346 return true
347 }
348 if foldIfTrue == isTrue {
349 for _, stmt := range ifexpr.Body.List {
350 cursor.InsertBefore(stmt)
351 }
352 }
353 cursor.Delete()
354 return true
355 }, nil)
356 }
357
358
359
360
361
362
363 func inlineFunction(node ast.Node, from string, toDecl *ast.FuncDecl) ast.Node {
364 return astutil.Apply(node, func(cursor *astutil.Cursor) bool {
365 switch node := cursor.Node().(type) {
366 case *ast.AssignStmt:
367
368
369 if len(node.Rhs) == 1 && isCallTo(node.Rhs[0], from) {
370 args := node.Rhs[0].(*ast.CallExpr).Args
371 if !argsMatchParameters(args, toDecl.Type.Params) {
372 log.Fatalf("applying op: arguments to %v don't match parameter names of %v: %v", from, toDecl.Name, debugPrint(args...))
373 }
374 replaceAssignment(cursor, node, toDecl)
375 }
376 return false
377 case *ast.CallExpr:
378
379 if isCallTo(node, from) {
380 if _, ok := cursor.Parent().(*ast.AssignStmt); !ok {
381 log.Fatalf("applying op: all calls to function %q being replaced must appear in an assignment statement, appears in %T", from, cursor.Parent())
382 }
383 }
384 }
385 return true
386 }, nil)
387 }
388
389
390
391 func argsMatchParameters(args []ast.Expr, params *ast.FieldList) bool {
392 var paramIdents []*ast.Ident
393 for _, f := range params.List {
394 paramIdents = append(paramIdents, f.Names...)
395 }
396
397 if len(args) != len(paramIdents) {
398 return false
399 }
400
401 for i := range args {
402 if !isIdentWithName(args[i], paramIdents[i].Name) {
403 return false
404 }
405 }
406
407 return true
408 }
409
410
411 func isIdentWithName(expr ast.Node, name string) bool {
412 ident, ok := expr.(*ast.Ident)
413 if !ok {
414 return false
415 }
416 return ident.Name == name
417 }
418
419
420 func isCallTo(expr ast.Expr, name string) bool {
421 callexpr, ok := expr.(*ast.CallExpr)
422 if !ok {
423 return false
424 }
425 return isIdentWithName(callexpr.Fun, name)
426 }
427
428
429
430
431 func replaceAssignment(cursor *astutil.Cursor, assign *ast.AssignStmt, funcdecl *ast.FuncDecl) {
432 if !hasTerminatingReturn(funcdecl.Body) {
433 log.Fatal("function being inlined must have a return at the end")
434 }
435
436 body := internalastutil.CloneNode(funcdecl.Body)
437 if hasTerminatingAndNonterminatingReturn(funcdecl.Body) {
438
439
440
441
442 body = addContinues(cursor, assign, body, everythingFollowingInParent(cursor)).(*ast.BlockStmt)
443 }
444
445 if len(body.List) < 1 {
446 log.Fatal("replacing with empty bodied function")
447 }
448
449
450
451
452
453
454 beforeReturn, ret := body.List[:len(body.List)-1], body.List[len(body.List)-1]
455 returnStmt, ok := ret.(*ast.ReturnStmt)
456 if !ok {
457 log.Fatal("last stmt in function we're replacing with should be a return")
458 }
459 results := returnStmt.Results
460
461
462 for _, stmt := range beforeReturn {
463 cursor.InsertBefore(stmt)
464 }
465
466
467 replaceWithAssignment(cursor, assign.Lhs, results, assign.Tok)
468 }
469
470
471 func hasTerminatingReturn(block *ast.BlockStmt) bool {
472 _, ok := block.List[len(block.List)-1].(*ast.ReturnStmt)
473 return ok
474 }
475
476
477
478 func hasTerminatingAndNonterminatingReturn(block *ast.BlockStmt) bool {
479 if !hasTerminatingReturn(block) {
480 return false
481 }
482 var ret bool
483 for i := range block.List[:len(block.List)-1] {
484 ast.Inspect(block.List[i], func(node ast.Node) bool {
485 _, ok := node.(*ast.ReturnStmt)
486 if ok {
487 ret = true
488 return false
489 }
490 return true
491 })
492 }
493 return ret
494 }
495
496
497
498 func everythingFollowingInParent(cursor *astutil.Cursor) *ast.BlockStmt {
499 parent := cursor.Parent()
500 block, ok := parent.(*ast.BlockStmt)
501 if !ok {
502 log.Fatal("internal error: in everythingFollowingInParent, cursor doesn't point to element in block list")
503 }
504
505 blockcopy := internalastutil.CloneNode(block)
506 blockcopy.List = blockcopy.List[cursor.Index()+1:]
507
508 if _, ok := blockcopy.List[len(blockcopy.List)-1].(*ast.ReturnStmt); !ok {
509 log.Printf("%s", mustFormatNode(token.NewFileSet(), blockcopy))
510 log.Fatal("internal error: parent doesn't end in a return")
511 }
512 return blockcopy
513 }
514
515
516
517
518
519 func addContinues(cursor *astutil.Cursor, assignNode *ast.AssignStmt, toBlock *ast.BlockStmt, continueBlock *ast.BlockStmt) ast.Node {
520 if !hasTerminatingReturn(continueBlock) {
521 log.Fatal("the block being continued to in addContinues must end in a return")
522 }
523 applyFunc := func(cursor *astutil.Cursor) bool {
524 ret, ok := cursor.Node().(*ast.ReturnStmt)
525 if !ok {
526 return true
527 }
528
529 if cursor.Parent() == toBlock && cursor.Index() == len(toBlock.List)-1 {
530 return false
531 }
532
533
534
535
536 replaceWithAssignment(cursor, assignNode.Lhs, ret.Results, assignNode.Tok)
537 cursor.InsertAfter(internalastutil.CloneNode(continueBlock))
538
539 return false
540 }
541 return astutil.Apply(toBlock, applyFunc, nil)
542 }
543
544
545 func debugPrint(nodes ...ast.Expr) string {
546 var b strings.Builder
547 for i, node := range nodes {
548 b.Write(mustFormatNode(token.NewFileSet(), node))
549 if i != len(nodes)-1 {
550 b.WriteString(", ")
551 }
552 }
553 return b.String()
554 }
555
556
557 func mustFormatNode(fset *token.FileSet, node any) []byte {
558 var buf bytes.Buffer
559 format.Node(&buf, fset, node)
560 return buf.Bytes()
561 }
562
563
564
565
566
567 func mustMatchExprs(lhs []ast.Expr, rhs []ast.Expr) ([]ast.Expr, []ast.Expr) {
568 if len(lhs) != len(rhs) {
569 log.Fatal("exprs don't match", debugPrint(lhs...), debugPrint(rhs...))
570 }
571
572 var newLhs, newRhs []ast.Expr
573 for i := range lhs {
574 lhsIdent, ok1 := lhs[i].(*ast.Ident)
575 rhsIdent, ok2 := rhs[i].(*ast.Ident)
576 if ok1 && ok2 && lhsIdent.Name == rhsIdent.Name {
577 continue
578 }
579 newLhs = append(newLhs, lhs[i])
580 newRhs = append(newRhs, rhs[i])
581 }
582
583 return newLhs, newRhs
584 }
585
586
587
588
589 func replaceWithAssignment(cursor *astutil.Cursor, lhs, rhs []ast.Expr, tok token.Token) {
590 newLhs, newRhs := mustMatchExprs(lhs, rhs)
591 if len(newLhs) == 0 {
592 cursor.Delete()
593 return
594 }
595 if len(newRhs) == 1 {
596 if lit, ok := newRhs[0].(*ast.BasicLit); ok {
597 constDecl := &ast.DeclStmt{
598 Decl: &ast.GenDecl{
599 Tok: token.CONST,
600 Specs: []ast.Spec{
601 &ast.ValueSpec{
602 Names: []*ast.Ident{newLhs[0].(*ast.Ident)},
603 Values: []ast.Expr{lit},
604 },
605 },
606 },
607 }
608 cursor.Replace(constDecl)
609 return
610 }
611 }
612 newAssignment := &ast.AssignStmt{
613 Lhs: newLhs,
614 Rhs: newRhs,
615 Tok: tok,
616 }
617 cursor.Replace(newAssignment)
618 }
619
620
621 func generateTable(sizeToSizeClass []uint8) []byte {
622 scMax := sizeToSizeClass[smallScanNoHeaderMax]
623
624 var b bytes.Buffer
625 fmt.Fprintln(&b, `// Code generated by mkmalloc.go; DO NOT EDIT.
626 //go:build !plan9
627
628 package runtime
629
630 import "unsafe"
631
632 var mallocScanTable = [513]func(size uintptr, typ *_type, needzero bool) unsafe.Pointer{`)
633
634 for i := range uintptr(smallScanNoHeaderMax + 1) {
635 fmt.Fprintf(&b, "%s,\n", smallScanNoHeaderSCFuncName(sizeToSizeClass[i], scMax))
636 }
637
638 fmt.Fprintln(&b, `
639 }
640
641 var mallocNoScanTable = [513]func(size uintptr, typ *_type, needzero bool) unsafe.Pointer{`)
642 for i := range uintptr(smallScanNoHeaderMax + 1) {
643 if i < 16 {
644 fmt.Fprintf(&b, "%s,\n", tinyFuncName(i))
645 } else {
646 fmt.Fprintf(&b, "%s,\n", smallNoScanSCFuncName(sizeToSizeClass[i], scMax))
647 }
648 }
649
650 fmt.Fprintln(&b, `
651 }`)
652
653 return b.Bytes()
654 }
655
View as plain text