diff --git a/__pycache__/train.cpython-312.pyc b/__pycache__/train.cpython-312.pyc index d28880d..e4893d6 100644 Binary files a/__pycache__/train.cpython-312.pyc and b/__pycache__/train.cpython-312.pyc differ diff --git a/results.tsv b/results.tsv new file mode 100644 index 0000000..ddd1e13 --- /dev/null +++ b/results.tsv @@ -0,0 +1,501 @@ +step lambda train_loss train_ppl eval_ppl eval_bpb lr time_s best_ppl q_neg1 q_zero q_pos1 +1 0.0050 2.136068 8.47 - - 1.00e-06 5.5 inf - - - +2 0.0100 2.560659 12.94 - - 2.00e-06 5.7 inf - - - +3 0.0150 2.504529 12.24 - - 3.00e-06 5.8 inf - - - +4 0.0200 1.823999 6.20 - - 4.00e-06 6.0 inf - - - +5 0.0250 2.411383 11.15 - - 5.00e-06 6.2 inf - - - +6 0.0300 1.781534 5.94 - - 6.00e-06 6.4 inf - - - +7 0.0350 2.355754 10.55 - - 7.00e-06 6.6 inf - - - +8 0.0400 2.197158 9.00 - - 8.00e-06 6.8 inf - - - +9 0.0450 2.136650 8.47 - - 9.00e-06 7.0 inf - - - +10 0.0500 2.071529 7.94 - - 1.00e-05 7.2 inf - - - +11 0.0550 2.267195 9.65 - - 1.10e-05 7.4 inf - - - +12 0.0600 2.020964 7.55 - - 1.20e-05 7.5 inf - - - +13 0.0650 2.299464 9.97 - - 1.30e-05 7.7 inf - - - +14 0.0700 2.289145 9.87 - - 1.40e-05 7.9 inf - - - +15 0.0750 2.129694 8.41 - - 1.50e-05 8.1 inf - - - +16 0.0800 2.081169 8.01 - - 1.60e-05 8.3 inf - - - +17 0.0850 2.236226 9.36 - - 1.70e-05 8.5 inf - - - +18 0.0900 1.992637 7.33 - - 1.80e-05 8.7 inf - - - +19 0.0950 2.251162 9.50 - - 1.90e-05 8.9 inf - - - +20 0.1000 2.385473 10.86 - - 2.00e-05 9.0 inf - - - +21 0.1050 1.958657 7.09 - - 2.10e-05 9.2 inf - - - +22 0.1100 2.121532 8.34 - - 2.20e-05 9.4 inf - - - +23 0.1150 2.175685 8.81 - - 2.30e-05 9.6 inf - - - +24 0.1200 2.106013 8.22 - - 2.40e-05 9.8 inf - - - +25 0.1250 2.277615 9.75 40.64 5.3449 2.50e-05 29.0 40.64 0.3431 0.3135 0.3434 +26 0.1300 1.839106 6.29 - - 2.60e-05 29.2 40.64 - - - +27 0.1350 2.051520 7.78 - - 2.70e-05 29.4 40.64 - - - +28 0.1400 1.955719 7.07 - - 2.80e-05 29.6 40.64 - - - +29 0.1450 1.995286 7.35 - - 2.90e-05 29.7 40.64 - - - +30 0.1500 1.781965 5.94 - - 3.00e-05 29.9 40.64 - - - +31 0.1550 2.062846 7.87 - - 3.10e-05 30.1 40.64 - - - +32 0.1600 1.859248 6.42 - - 3.20e-05 30.3 40.64 - - - +33 0.1650 2.119028 8.32 - - 3.30e-05 30.5 40.64 - - - +34 0.1700 1.926346 6.86 - - 3.40e-05 30.7 40.64 - - - +35 0.1750 2.928543 18.70 - - 3.50e-05 30.9 40.64 - - - +36 0.1800 2.470379 11.83 - - 3.60e-05 31.1 40.64 - - - +37 0.1850 2.307961 10.05 - - 3.70e-05 31.3 40.64 - - - +38 0.1900 2.296540 9.94 - - 3.80e-05 31.5 40.64 - - - +39 0.1950 2.601529 13.48 - - 3.90e-05 31.6 40.64 - - - +40 0.2000 2.572134 13.09 - - 4.00e-05 31.8 40.64 - - - +41 0.2050 2.381441 10.82 - - 4.10e-05 32.0 40.64 - - - +42 0.2100 2.346776 10.45 - - 4.20e-05 32.2 40.64 - - - +43 0.2150 3.014177 20.37 - - 4.30e-05 32.4 40.64 - - - +44 0.2200 2.523110 12.47 - - 4.40e-05 32.6 40.64 - - - +45 0.2250 2.421204 11.26 - - 4.50e-05 32.8 40.64 - - - +46 0.2300 2.435839 11.43 - - 4.60e-05 33.0 40.64 - - - +47 0.2350 2.681941 14.61 - - 4.70e-05 33.2 40.64 - - - +48 0.2400 2.798432 16.42 - - 4.80e-05 33.3 40.64 - - - +49 0.2450 2.496705 12.14 - - 4.90e-05 33.5 40.64 - - - +50 0.2500 2.837209 17.07 58.16 5.8619 5.00e-05 52.8 40.64 0.3431 0.3135 0.3434 +51 0.2550 2.712536 15.07 - - 5.10e-05 52.9 40.64 - - - +52 0.2600 2.741771 15.51 - - 5.20e-05 53.1 40.64 - - - +53 0.2650 2.228770 9.29 - - 5.30e-05 53.3 40.64 - - - +54 0.2700 2.615676 13.68 - - 5.40e-05 53.5 40.64 - - - +55 0.2750 2.806169 16.55 - - 5.50e-05 53.7 40.64 - - - +56 0.2800 2.551644 12.83 - - 5.60e-05 53.9 40.64 - - - +57 0.2850 2.542200 12.71 - - 5.70e-05 54.1 40.64 - - - +58 0.2900 2.780680 16.13 - - 5.80e-05 54.3 40.64 - - - +59 0.2950 2.341232 10.39 - - 5.90e-05 54.5 40.64 - - - +60 0.3000 2.846083 17.22 - - 6.00e-05 54.6 40.64 - - - +61 0.3050 2.464588 11.76 - - 6.10e-05 54.8 40.64 - - - +62 0.3100 2.793186 16.33 - - 6.20e-05 55.0 40.64 - - - +63 0.3150 2.939746 18.91 - - 6.30e-05 55.2 40.64 - - - +64 0.3200 3.012881 20.35 - - 6.40e-05 55.4 40.64 - - - +65 0.3250 3.114154 22.51 - - 6.50e-05 55.6 40.64 - - - +66 0.3300 2.846965 17.24 - - 6.60e-05 55.8 40.64 - - - +67 0.3350 2.722409 15.22 - - 6.70e-05 56.0 40.64 - - - +68 0.3400 2.481803 11.96 - - 6.80e-05 56.2 40.64 - - - +69 0.3450 2.550641 12.82 - - 6.90e-05 56.4 40.64 - - - +70 0.3500 2.568753 13.05 - - 7.00e-05 56.5 40.64 - - - +71 0.3550 2.640688 14.02 - - 7.10e-05 56.7 40.64 - - - +72 0.3600 2.478850 11.93 - - 7.20e-05 56.9 40.64 - - - +73 0.3650 2.698298 14.85 - - 7.30e-05 57.1 40.64 - - - +74 0.3700 2.368236 10.68 - - 7.40e-05 57.3 40.64 - - - +75 0.3750 2.745409 15.57 113.02 6.8205 7.50e-05 76.5 40.64 0.3431 0.3135 0.3434 +76 0.3800 3.099635 22.19 - - 7.60e-05 76.7 40.64 - - - +77 0.3850 3.232097 25.33 - - 7.70e-05 76.9 40.64 - - - +78 0.3900 2.545948 12.76 - - 7.80e-05 77.1 40.64 - - - +79 0.3950 2.684703 14.65 - - 7.90e-05 77.3 40.64 - - - +80 0.4000 2.604475 13.52 - - 8.00e-05 77.5 40.64 - - - +81 0.4050 2.618869 13.72 - - 8.10e-05 77.7 40.64 - - - +82 0.4100 2.712800 15.07 - - 8.20e-05 77.8 40.64 - - - +83 0.4150 2.804403 16.52 - - 8.30e-05 78.0 40.64 - - - +84 0.4200 2.715544 15.11 - - 8.40e-05 78.2 40.64 - - - +85 0.4250 3.329880 27.93 - - 8.50e-05 78.4 40.64 - - - +86 0.4300 2.900910 18.19 - - 8.60e-05 78.6 40.64 - - - +87 0.4350 3.247538 25.73 - - 8.70e-05 78.8 40.64 - - - +88 0.4400 2.764128 15.87 - - 8.80e-05 79.0 40.64 - - - +89 0.4450 3.109798 22.42 - - 8.90e-05 79.2 40.64 - - - +90 0.4500 3.016528 20.42 - - 9.00e-05 79.4 40.64 - - - +91 0.4550 3.504213 33.26 - - 9.10e-05 79.6 40.64 - - - +92 0.4600 3.373166 29.17 - - 9.20e-05 79.7 40.64 - - - +93 0.4650 3.525333 33.97 - - 9.30e-05 79.9 40.64 - - - +94 0.4700 3.573746 35.65 - - 9.40e-05 80.1 40.64 - - - +95 0.4750 3.279140 26.55 - - 9.50e-05 80.3 40.64 - - - +96 0.4800 3.354543 28.63 - - 9.60e-05 80.5 40.64 - - - +97 0.4850 2.990998 19.91 - - 9.70e-05 80.7 40.64 - - - +98 0.4900 3.271327 26.35 - - 9.80e-05 80.9 40.64 - - - +99 0.4950 3.185179 24.17 - - 9.90e-05 81.1 40.64 - - - +100 0.5000 3.800461 44.72 618.35 9.2723 1.00e-04 100.3 40.64 0.3431 0.3135 0.3434 +101 0.5050 3.646878 38.35 - - 1.00e-04 100.5 40.64 - - - +102 0.5100 3.478473 32.41 - - 1.00e-04 100.7 40.64 - - - +103 0.5150 3.571715 35.58 - - 1.00e-04 100.9 40.64 - - - +104 0.5200 3.554683 34.98 - - 1.00e-04 101.1 40.64 - - - +105 0.5250 3.653823 38.62 - - 1.00e-04 101.3 40.64 - - - +106 0.5300 3.917732 50.29 - - 1.00e-04 101.5 40.64 - - - +107 0.5350 4.449574 85.59 - - 1.00e-04 101.6 40.64 - - - +108 0.5400 4.093189 59.93 - - 1.00e-04 101.8 40.64 - - - +109 0.5450 4.310959 74.51 - - 1.00e-04 102.0 40.64 - - - +110 0.5500 3.835424 46.31 - - 1.00e-04 102.2 40.64 - - - +111 0.5550 4.298162 73.56 - - 1.00e-04 102.4 40.64 - - - +112 0.5600 3.930627 50.94 - - 1.00e-04 102.6 40.64 - - - +113 0.5650 3.924658 50.64 - - 1.00e-04 102.8 40.64 - - - +114 0.5700 3.979460 53.49 - - 1.00e-04 103.0 40.64 - - - +115 0.5750 4.577632 97.28 - - 1.00e-04 103.2 40.64 - - - +116 0.5800 4.399222 81.39 - - 1.00e-04 103.4 40.64 - - - +117 0.5850 4.940031 139.77 - - 1.00e-04 103.5 40.64 - - - +118 0.5900 4.668941 106.58 - - 1.00e-04 103.7 40.64 - - - +119 0.5950 4.804703 122.08 - - 1.00e-04 103.9 40.64 - - - +120 0.6000 5.443753 231.31 - - 1.00e-04 104.1 40.64 - - - +121 0.6050 4.958951 142.44 - - 1.00e-04 104.3 40.64 - - - +122 0.6100 5.579468 264.93 - - 1.00e-04 104.5 40.64 - - - +123 0.6150 5.471995 237.93 - - 1.00e-04 104.7 40.64 - - - +124 0.6200 5.912117 369.49 - - 1.00e-04 104.9 40.64 - - - +125 0.6250 5.534120 253.18 27940.34 14.7701 1.00e-04 124.0 40.64 0.3431 0.3135 0.3434 +126 0.6300 6.140091 464.10 - - 1.00e-04 124.2 40.64 - - - +127 0.6350 6.698533 811.21 - - 1.00e-04 124.4 40.64 - - - +128 0.6400 7.445140 1711.53 - - 1.00e-04 124.5 40.64 - - - +129 0.6450 8.332374 4156.27 - - 1.00e-04 124.7 40.64 - - - +130 0.6500 8.018986 3038.09 - - 1.00e-04 124.9 40.64 - - - +131 0.6550 7.844746 2552.29 - - 1.00e-04 125.1 40.64 - - - +132 0.6600 7.028385 1128.21 - - 1.00e-04 125.3 40.64 - - - +133 0.6650 8.018407 3036.34 - - 1.00e-04 125.5 40.64 - - - +134 0.6700 8.925037 7517.86 - - 1.00e-04 125.7 40.64 - - - +135 0.6750 9.367007 11696.06 - - 1.00e-04 125.9 40.64 - - - +136 0.6800 9.680276 15998.91 - - 1.00e-04 126.1 40.64 - - - +137 0.6850 9.984258 21682.43 - - 1.00e-04 126.2 40.64 - - - +138 0.6900 10.559093 38526.19 - - 1.00e-04 126.4 40.64 - - - +139 0.6950 9.374221 11780.73 - - 1.00e-04 126.6 40.64 - - - +140 0.7000 11.076711 64647.89 - - 1.00e-04 126.8 40.64 - - - +141 0.7050 10.831450 50587.03 - - 1.00e-04 127.0 40.64 - - - +142 0.7100 11.326814 83018.08 - - 1.00e-04 127.2 40.64 - - - +143 0.7150 11.741574 125690.06 - - 1.00e-04 127.4 40.64 - - - +144 0.7200 12.815244 367781.29 - - 1.00e-04 127.6 40.64 - - - +145 0.7250 13.129809 503736.90 - - 1.00e-04 127.8 40.64 - - - +146 0.7300 12.875619 390670.17 - - 1.00e-04 127.9 40.64 - - - +147 0.7350 12.901546 400931.36 - - 1.00e-04 128.1 40.64 - - - +148 0.7400 14.442235 1871470.41 - - 1.00e-04 128.3 40.64 - - - +149 0.7450 13.666269 861361.28 - - 1.00e-04 128.5 40.64 - - - +150 0.7500 13.603951 809321.45 2067437.67 20.9794 1.00e-04 147.7 40.64 0.3431 0.3135 0.3434 +151 0.7550 12.629039 305296.53 - - 1.00e-04 147.9 40.64 - - - +152 0.7600 13.005286 444758.28 - - 1.00e-04 148.1 40.64 - - - +153 0.7650 14.194674 1461062.53 - - 1.00e-04 148.3 40.64 - - - +154 0.7700 15.230688 4117217.36 - - 1.00e-04 148.5 40.64 - - - +155 0.7750 13.619554 822047.59 - - 1.00e-04 148.7 40.64 - - - +156 0.7800 13.882224 1068989.20 - - 1.00e-04 148.8 40.64 - - - +157 0.7850 13.324673 612113.31 - - 1.00e-04 149.0 40.64 - - - +158 0.7900 15.784142 7160866.42 - - 1.00e-04 149.2 40.64 - - - +159 0.7950 14.135076 1376528.86 - - 1.00e-04 149.4 40.64 - - - +160 0.8000 13.560103 774601.06 - - 1.00e-04 149.6 40.64 - - - +161 0.8050 15.835391 7537424.06 - - 1.00e-04 149.8 40.64 - - - +162 0.8100 15.722898 6735472.24 - - 1.00e-04 150.0 40.64 - - - +163 0.8150 13.590790 798739.23 - - 1.00e-04 150.2 40.64 - - - +164 0.8200 14.479448 1942426.10 - - 1.00e-04 150.4 40.64 - - - +165 0.8250 16.230492 11189558.65 - - 1.00e-04 150.5 40.64 - - - +166 0.8300 15.013620 3313847.23 - - 1.00e-04 150.7 40.64 - - - +167 0.8350 15.745448 6889080.04 - - 1.00e-04 150.9 40.64 - - - +168 0.8400 14.910224 2988326.12 - - 1.00e-04 151.1 40.64 - - - +169 0.8450 15.106627 3636846.30 - - 1.00e-04 151.3 40.64 - - - +170 0.8500 16.314306 12167831.43 - - 1.00e-04 151.5 40.64 - - - +171 0.8550 15.396415 4859347.64 - - 1.00e-04 151.7 40.64 - - - +172 0.8600 14.927106 3039203.11 - - 1.00e-04 151.9 40.64 - - - +173 0.8650 14.441647 1870369.52 - - 1.00e-04 152.1 40.64 - - - +174 0.8700 15.264482 4258729.97 - - 1.00e-04 152.2 40.64 - - - +175 0.8750 15.554062 5689095.84 8106699.38 22.9507 1.00e-04 171.5 40.64 0.3431 0.3135 0.3434 +176 0.8800 14.270290 1575826.37 - - 1.00e-04 171.6 40.64 - - - +177 0.8850 16.421373 13542905.21 - - 1.00e-04 171.8 40.64 - - - +178 0.8900 13.723610 912195.79 - - 1.00e-04 172.0 40.64 - - - +179 0.8950 15.787707 7186446.14 - - 1.00e-04 172.2 40.64 - - - +180 0.9000 16.114063 9959757.31 - - 1.00e-04 172.4 40.64 - - - +181 0.9050 16.504068 14710445.46 - - 1.00e-04 172.6 40.64 - - - +182 0.9100 16.453072 13979067.56 - - 1.00e-04 172.8 40.64 - - - +183 0.9150 14.644923 2291972.16 - - 1.00e-04 173.0 40.64 - - - +184 0.9200 15.111359 3654093.64 - - 1.00e-04 173.1 40.64 - - - +185 0.9250 15.287539 4358064.21 - - 1.00e-04 173.3 40.64 - - - +186 0.9300 16.466921 14174014.09 - - 1.00e-04 173.5 40.64 - - - +187 0.9350 15.300754 4416038.54 - - 1.00e-04 173.7 40.64 - - - +188 0.9400 16.760338 19007378.99 - - 1.00e-04 173.9 40.64 - - - +189 0.9450 18.445984 102562598.45 - - 1.00e-04 174.1 40.64 - - - +190 0.9500 14.309163 1638289.21 - - 1.00e-04 174.3 40.64 - - - +191 0.9550 15.713504 6672491.17 - - 1.00e-04 174.5 40.64 - - - +192 0.9600 15.771399 7070201.22 - - 1.00e-04 174.7 40.64 - - - +193 0.9650 15.713089 6669723.68 - - 1.00e-04 174.8 40.64 - - - +194 0.9700 17.085546 26312284.87 - - 1.00e-04 175.0 40.64 - - - +195 0.9750 15.827868 7480928.76 - - 1.00e-04 175.2 40.64 - - - +196 0.9800 14.334038 1679552.22 - - 1.00e-04 175.4 40.64 - - - +197 0.9850 17.020346 24651435.40 - - 1.00e-04 175.6 40.64 - - - +198 0.9900 15.319046 4497561.96 - - 1.00e-04 175.8 40.64 - - - +199 0.9950 16.279392 11750334.22 - - 1.00e-04 176.0 40.64 - - - +200 1.0000 15.644218 6225836.97 18402015.78 24.1334 1.00e-04 195.1 40.64 0.3431 0.3135 0.3434 +201 1.0000 20.681583 959174334.90 - - 1.00e-04 195.3 40.64 - - - +202 1.0000 15.649086 6256215.44 - - 1.00e-04 195.5 40.64 - - - +203 1.0000 16.587341 15987877.05 - - 1.00e-04 195.6 40.64 - - - +204 1.0000 17.473787 38794432.72 - - 1.00e-04 195.8 40.64 - - - +205 1.0000 14.719927 2470489.42 - - 1.00e-04 196.0 40.64 - - - +206 1.0000 15.012506 3310158.02 - - 1.00e-04 196.2 40.64 - - - +207 1.0000 16.118345 10002496.41 - - 1.00e-04 196.4 40.64 - - - +208 1.0000 15.265300 4262216.11 - - 1.00e-04 196.6 40.64 - - - +209 1.0000 16.789154 19563067.84 - - 1.00e-04 196.8 40.64 - - - +210 1.0000 16.989750 23908626.86 - - 1.00e-04 197.0 40.64 - - - +211 1.0000 16.242119 11320421.17 - - 1.00e-04 197.1 40.64 - - - +212 1.0000 15.237806 4146629.20 - - 1.00e-04 197.3 40.64 - - - +213 1.0000 16.513100 14843901.57 - - 1.00e-04 197.5 40.64 - - - +214 1.0000 17.315149 33103491.14 - - 1.00e-04 197.7 40.64 - - - +215 1.0000 16.225832 11137540.48 - - 1.00e-04 197.9 40.64 - - - +216 1.0000 14.931105 3051380.47 - - 1.00e-04 198.1 40.64 - - - +217 1.0000 17.014717 24513072.44 - - 1.00e-04 198.3 40.64 - - - +218 1.0000 15.806069 7319622.94 - - 1.00e-04 198.5 40.64 - - - +219 1.0000 17.820438 54867864.28 - - 1.00e-04 198.6 40.64 - - - +220 1.0000 15.284276 4343869.11 - - 1.00e-04 198.8 40.64 - - - +221 1.0000 17.491499 39487666.76 - - 1.00e-04 199.0 40.64 - - - +222 1.0000 16.272469 11669259.71 - - 1.00e-04 199.2 40.64 - - - +223 1.0000 15.695415 6552875.66 - - 1.00e-04 199.4 40.64 - - - +224 1.0000 16.692852 17766976.16 - - 1.00e-04 199.6 40.64 - - - +225 1.0000 15.208409 4026504.99 19929751.22 24.2484 1.00e-04 218.7 40.64 0.3431 0.3135 0.3434 +226 1.0000 16.983114 23750502.42 - - 1.00e-04 218.9 40.64 - - - +227 1.0000 15.883364 7907828.09 - - 1.00e-04 219.1 40.64 - - - +228 1.0000 17.957699 62940401.08 - - 1.00e-04 219.3 40.64 - - - +229 1.0000 16.241711 11315801.42 - - 1.00e-04 219.5 40.64 - - - +230 1.0000 14.475777 1935307.28 - - 1.00e-04 219.6 40.64 - - - +231 1.0000 17.789164 53178438.96 - - 1.00e-04 219.8 40.64 - - - +232 1.0000 14.760505 2572798.37 - - 1.00e-04 220.0 40.64 - - - +233 1.0000 15.785724 7172204.92 - - 1.00e-04 220.2 40.64 - - - +234 1.0000 15.076777 3529890.68 - - 1.00e-04 220.4 40.64 - - - +235 1.0000 16.190544 10751374.11 - - 1.00e-04 220.6 40.64 - - - +236 1.0000 17.371979 35039226.88 - - 1.00e-04 220.8 40.64 - - - +237 1.0000 15.648698 6253787.59 - - 1.00e-04 221.0 40.64 - - - +238 1.0000 14.119370 1355077.91 - - 1.00e-04 221.2 40.64 - - - +239 1.0000 15.223616 4088201.28 - - 1.00e-04 221.3 40.64 - - - +240 1.0000 15.438815 5069816.17 - - 1.00e-04 221.5 40.64 - - - +241 1.0000 17.664293 46935894.84 - - 1.00e-04 221.7 40.64 - - - +242 1.0000 17.876865 58052909.59 - - 1.00e-04 221.9 40.64 - - - +243 1.0000 16.507032 14754112.19 - - 1.00e-04 222.1 40.64 - - - +244 1.0000 17.971159 63793316.22 - - 1.00e-04 222.3 40.64 - - - +245 1.0000 15.368195 4724132.86 - - 1.00e-04 222.5 40.64 - - - +246 1.0000 18.383377 96338354.14 - - 1.00e-04 222.7 40.64 - - - +247 1.0000 15.772984 7081409.66 - - 1.00e-04 222.8 40.64 - - - +248 1.0000 15.465032 5204486.62 - - 1.00e-04 223.0 40.64 - - - +249 1.0000 15.963969 8571636.51 - - 1.00e-04 223.2 40.64 - - - +250 1.0000 16.459703 14072082.74 21525920.50 24.3596 1.00e-04 242.4 40.64 0.3431 0.3135 0.3434 +251 1.0000 16.423126 13566664.75 - - 1.00e-04 242.5 40.64 - - - +252 1.0000 14.955647 3127196.48 - - 1.00e-04 242.7 40.64 - - - +253 1.0000 14.772740 2604471.58 - - 1.00e-04 242.9 40.64 - - - +254 1.0000 17.072338 25967027.14 - - 1.00e-04 243.1 40.64 - - - +255 1.0000 16.750446 18820294.11 - - 1.00e-04 243.3 40.64 - - - +256 1.0000 17.235100 30556860.74 - - 1.00e-04 243.5 40.64 - - - +257 1.0000 16.986397 23828592.65 - - 1.00e-04 243.7 40.64 - - - +258 1.0000 15.759155 6984159.91 - - 1.00e-04 243.8 40.64 - - - +259 1.0000 16.714687 18159190.38 - - 1.00e-04 244.0 40.64 - - - +260 1.0000 15.415783 4954381.67 - - 1.00e-04 244.2 40.64 - - - +261 1.0000 16.314724 12172915.11 - - 1.00e-04 244.4 40.64 - - - +262 1.0000 15.626440 6116129.67 - - 1.00e-04 244.6 40.64 - - - +263 1.0000 18.201193 80292988.26 - - 1.00e-04 244.8 40.64 - - - +264 1.0000 16.578726 15850725.26 - - 1.00e-04 245.0 40.64 - - - +265 1.0000 17.064611 25767161.65 - - 1.00e-04 245.2 40.64 - - - +266 1.0000 14.295428 1615941.43 - - 1.00e-04 245.3 40.64 - - - +267 1.0000 16.349220 12600162.59 - - 1.00e-04 245.5 40.64 - - - +268 1.0000 16.572262 15748596.53 - - 1.00e-04 245.7 40.64 - - - +269 1.0000 17.461777 38331276.26 - - 1.00e-04 245.9 40.64 - - - +270 1.0000 17.657043 46596847.99 - - 1.00e-04 246.1 40.64 - - - +271 1.0000 18.224865 82216367.62 - - 1.00e-04 246.3 40.64 - - - +272 1.0000 18.116886 73801243.49 - - 1.00e-04 246.5 40.64 - - - +273 1.0000 17.457096 38152280.83 - - 1.00e-04 246.7 40.64 - - - +274 1.0000 17.424744 36937714.21 - - 1.00e-04 246.8 40.64 - - - +275 1.0000 15.735705 6822287.44 23966721.82 24.5145 1.00e-04 266.0 40.64 0.3431 0.3135 0.3434 +276 1.0000 16.359583 12731412.22 - - 1.00e-04 266.1 40.64 - - - +277 1.0000 15.829643 7494224.76 - - 1.00e-04 266.3 40.64 - - - +278 1.0000 14.499822 1982405.70 - - 1.00e-04 266.5 40.64 - - - +279 1.0000 16.023691 9099146.51 - - 1.00e-04 266.7 40.64 - - - +280 1.0000 16.804184 19859319.68 - - 1.00e-04 266.9 40.64 - - - +281 1.0000 15.337939 4583343.27 - - 1.00e-04 267.1 40.64 - - - +282 1.0000 16.338114 12460993.27 - - 1.00e-04 267.2 40.64 - - - +283 1.0000 16.782156 19426642.19 - - 1.00e-04 267.4 40.64 - - - +284 1.0000 15.253704 4213078.07 - - 1.00e-04 267.6 40.64 - - - +285 1.0000 16.790190 19583339.61 - - 1.00e-04 267.8 40.64 - - - +286 1.0000 17.665545 46994658.76 - - 1.00e-04 268.0 40.64 - - - +287 1.0000 16.535795 15184644.00 - - 1.00e-04 268.2 40.64 - - - +288 1.0000 18.098822 72480030.46 - - 1.00e-04 268.4 40.64 - - - +289 1.0000 16.959953 23206739.92 - - 1.00e-04 268.6 40.64 - - - +290 1.0000 17.374523 35128494.31 - - 1.00e-04 268.7 40.64 - - - +291 1.0000 15.094248 3592100.97 - - 1.00e-04 268.9 40.64 - - - +292 1.0000 15.874532 7838293.84 - - 1.00e-04 269.1 40.64 - - - +293 1.0000 15.964324 8574677.98 - - 1.00e-04 269.3 40.64 - - - +294 1.0000 16.601570 16216991.84 - - 1.00e-04 269.5 40.64 - - - +295 1.0000 16.576714 15818861.66 - - 1.00e-04 269.7 40.64 - - - +296 1.0000 17.835569 55704382.47 - - 1.00e-04 269.9 40.64 - - - +297 1.0000 15.463197 5194945.83 - - 1.00e-04 270.1 40.64 - - - +298 1.0000 17.415686 36604643.20 - - 1.00e-04 270.2 40.64 - - - +299 1.0000 16.729563 18431334.93 - - 1.00e-04 270.4 40.64 - - - +300 1.0000 17.504755 40014602.99 23127587.13 24.4631 1.00e-04 289.6 40.64 0.3431 0.3135 0.3434 +301 1.0000 17.412655 36493870.64 - - 1.00e-04 289.8 40.64 - - - +302 1.0000 16.575960 15806948.18 - - 1.00e-04 290.0 40.64 - - - +303 1.0000 16.105282 9872679.26 - - 1.00e-04 290.1 40.64 - - - +304 1.0000 17.884157 58477766.54 - - 1.00e-04 290.3 40.64 - - - +305 1.0000 17.408779 36352703.90 - - 1.00e-04 290.5 40.64 - - - +306 1.0000 14.993399 3247508.57 - - 1.00e-04 290.7 40.64 - - - +307 1.0000 17.999483 65626038.83 - - 1.00e-04 290.9 40.64 - - - +308 1.0000 16.599461 16182817.70 - - 1.00e-04 291.1 40.64 - - - +309 1.0000 14.952442 3117188.93 - - 1.00e-04 291.3 40.64 - - - +310 1.0000 15.683398 6474605.56 - - 1.00e-04 291.5 40.64 - - - +311 1.0000 15.943405 8397168.73 - - 1.00e-04 291.7 40.64 - - - +312 1.0000 14.936008 3066380.56 - - 1.00e-04 291.8 40.64 - - - +313 1.0000 17.241177 30743114.40 - - 1.00e-04 292.0 40.64 - - - +314 1.0000 15.993195 8825842.12 - - 1.00e-04 292.2 40.64 - - - +315 1.0000 16.213741 11003690.81 - - 1.00e-04 292.4 40.64 - - - +316 1.0000 15.646375 6239275.97 - - 1.00e-04 292.6 40.64 - - - +317 1.0000 16.648079 16989038.81 - - 1.00e-04 292.8 40.64 - - - +318 1.0000 15.747152 6900830.53 - - 1.00e-04 293.0 40.64 - - - +319 1.0000 16.976950 23604541.59 - - 1.00e-04 293.2 40.64 - - - +320 1.0000 18.516691 110077058.23 - - 1.00e-04 293.3 40.64 - - - +321 1.0000 14.905923 2975500.71 - - 1.00e-04 293.5 40.64 - - - +322 1.0000 15.182674 3924205.26 - - 1.00e-04 293.7 40.64 - - - +323 1.0000 15.287921 4359731.16 - - 1.00e-04 293.9 40.64 - - - +324 1.0000 16.774338 19275352.56 - - 1.00e-04 294.1 40.64 - - - +325 1.0000 16.380091 12995201.26 21985239.37 24.3900 1.00e-04 313.3 40.64 0.3431 0.3135 0.3434 +326 1.0000 18.700535 132293675.08 - - 1.00e-04 313.4 40.64 - - - +327 1.0000 15.726874 6762304.90 - - 1.00e-04 313.6 40.64 - - - +328 1.0000 16.364964 12800099.87 - - 1.00e-04 313.8 40.64 - - - +329 1.0000 15.850266 7650384.43 - - 1.00e-04 314.0 40.64 - - - +330 1.0000 17.046324 25300223.21 - - 1.00e-04 314.2 40.64 - - - +331 1.0000 14.607670 2208159.30 - - 1.00e-04 314.4 40.64 - - - +332 1.0000 15.602778 5973110.88 - - 1.00e-04 314.6 40.64 - - - +333 1.0000 17.155127 28208293.03 - - 1.00e-04 314.8 40.64 - - - +334 1.0000 19.121439 201528018.66 - - 1.00e-04 315.0 40.64 - - - +335 1.0000 21.599886 2402764870.39 - - 1.00e-04 315.1 40.64 - - - +336 1.0000 16.656918 17139864.61 - - 1.00e-04 315.3 40.64 - - - +337 1.0000 18.170324 77852325.36 - - 1.00e-04 315.5 40.64 - - - +338 1.0000 15.976001 8675389.53 - - 1.00e-04 315.7 40.64 - - - +339 1.0000 17.651363 46332924.23 - - 1.00e-04 315.9 40.64 - - - +340 1.0000 15.614954 6046281.39 - - 1.00e-04 316.1 40.64 - - - +341 1.0000 17.012737 24464588.78 - - 1.00e-04 316.3 40.64 - - - +342 1.0000 17.504669 40011168.65 - - 1.00e-04 316.5 40.64 - - - +343 1.0000 15.771162 7068522.49 - - 1.00e-04 316.6 40.64 - - - +344 1.0000 16.822245 20221251.10 - - 1.00e-04 316.8 40.64 - - - +345 1.0000 16.596968 16142525.73 - - 1.00e-04 317.0 40.64 - - - +346 1.0000 18.967236 172729192.58 - - 1.00e-04 317.2 40.64 - - - +347 1.0000 17.785002 52957579.00 - - 1.00e-04 317.4 40.64 - - - +348 1.0000 16.835760 20496406.10 - - 1.00e-04 317.6 40.64 - - - +349 1.0000 17.962944 63271403.95 - - 1.00e-04 317.8 40.64 - - - +350 1.0000 15.125232 3705140.52 20564188.92 24.2936 1.00e-04 336.9 40.64 0.3431 0.3135 0.3434 +351 1.0000 16.323538 12280679.67 - - 1.00e-04 337.1 40.64 - - - +352 1.0000 17.803726 53958521.89 - - 1.00e-04 337.3 40.64 - - - +353 1.0000 13.705504 895828.63 - - 1.00e-04 337.5 40.64 - - - +354 1.0000 15.707844 6634831.17 - - 1.00e-04 337.7 40.64 - - - +355 1.0000 15.741318 6860684.31 - - 1.00e-04 337.9 40.64 - - - +356 1.0000 15.197801 3984014.89 - - 1.00e-04 338.0 40.64 - - - +357 1.0000 15.476267 5263289.98 - - 1.00e-04 338.2 40.64 - - - +358 1.0000 16.191271 10759189.97 - - 1.00e-04 338.4 40.64 - - - +359 1.0000 16.276575 11717278.30 - - 1.00e-04 338.6 40.64 - - - +360 1.0000 16.514671 14867249.41 - - 1.00e-04 338.8 40.64 - - - +361 1.0000 17.057405 25582151.82 - - 1.00e-04 339.0 40.64 - - - +362 1.0000 16.177185 10608700.86 - - 1.00e-04 339.2 40.64 - - - +363 1.0000 16.845821 20703666.76 - - 1.00e-04 339.4 40.64 - - - +364 1.0000 17.109182 26941606.45 - - 1.00e-04 339.5 40.64 - - - +365 1.0000 16.646099 16955436.72 - - 1.00e-04 339.7 40.64 - - - +366 1.0000 16.622593 16561527.22 - - 1.00e-04 339.9 40.64 - - - +367 1.0000 16.251612 11428396.18 - - 1.00e-04 340.1 40.64 - - - +368 1.0000 17.486971 39309269.00 - - 1.00e-04 340.3 40.64 - - - +369 1.0000 15.996777 8857512.98 - - 1.00e-04 340.5 40.64 - - - +370 1.0000 16.390179 13126959.89 - - 1.00e-04 340.7 40.64 - - - +371 1.0000 16.078348 9610320.78 - - 1.00e-04 340.9 40.64 - - - +372 1.0000 16.320705 12245944.93 - - 1.00e-04 341.0 40.64 - - - +373 1.0000 17.404583 36200481.48 - - 1.00e-04 341.2 40.64 - - - +374 1.0000 15.051805 3442833.89 - - 1.00e-04 341.4 40.64 - - - +375 1.0000 17.337799 33861833.15 10474727.44 23.3204 1.00e-04 360.6 40.64 0.3431 0.3135 0.3434 +376 1.0000 15.483030 5299008.68 - - 1.00e-04 360.7 40.64 - - - +377 1.0000 17.054203 25500357.60 - - 1.00e-04 360.9 40.64 - - - +378 1.0000 16.264286 11574165.28 - - 1.00e-04 361.1 40.64 - - - +379 1.0000 14.534522 2052403.40 - - 1.00e-04 361.3 40.64 - - - +380 1.0000 19.237947 226430239.93 - - 1.00e-04 361.5 40.64 - - - +381 1.0000 15.563834 5744963.93 - - 1.00e-04 361.7 40.64 - - - +382 1.0000 16.259228 11515767.63 - - 1.00e-04 361.9 40.64 - - - +383 1.0000 15.994675 8838914.94 - - 1.00e-04 362.0 40.64 - - - +384 1.0000 15.382977 4794483.44 - - 1.00e-04 362.2 40.64 - - - +385 1.0000 16.727409 18391687.66 - - 1.00e-04 362.4 40.64 - - - +386 1.0000 16.073967 9568308.33 - - 1.00e-04 362.6 40.64 - - - +387 1.0000 16.334675 12418214.14 - - 1.00e-04 362.8 40.64 - - - +388 1.0000 17.176657 28822206.96 - - 1.00e-04 363.0 40.64 - - - +389 1.0000 16.566597 15659635.70 - - 1.00e-04 363.2 40.64 - - - +390 1.0000 15.541000 5615270.76 - - 1.00e-04 363.4 40.64 - - - +391 1.0000 15.630643 6141888.90 - - 1.00e-04 363.5 40.64 - - - +392 1.0000 15.551783 5676143.55 - - 1.00e-04 363.7 40.64 - - - +393 1.0000 16.352301 12639035.57 - - 1.00e-04 363.9 40.64 - - - +394 1.0000 14.554330 2093462.33 - - 1.00e-04 364.1 40.64 - - - +395 1.0000 16.409468 13382623.87 - - 1.00e-04 364.3 40.64 - - - +396 1.0000 15.559231 5718578.42 - - 1.00e-04 364.5 40.64 - - - +397 1.0000 14.673759 2359023.94 - - 1.00e-04 364.7 40.64 - - - +398 1.0000 17.763540 51833139.84 - - 1.00e-04 364.9 40.64 - - - +399 1.0000 16.750401 18819432.61 - - 1.00e-04 365.0 40.64 - - - +400 1.0000 15.040730 3404911.56 8403641.47 23.0026 1.00e-04 384.2 40.64 0.3431 0.3135 0.3434 +401 1.0000 14.322229 1659836.04 - - 1.00e-04 384.4 40.64 - - - +402 1.0000 14.282641 1595410.15 - - 1.00e-04 384.5 40.64 - - - +403 1.0000 16.290316 11879391.24 - - 1.00e-04 384.7 40.64 - - - +404 1.0000 16.676601 17480585.27 - - 1.00e-04 384.9 40.64 - - - +405 1.0000 14.881516 2903755.63 - - 1.00e-04 385.1 40.64 - - - +406 1.0000 14.902369 2964943.53 - - 1.00e-04 385.3 40.64 - - - +407 1.0000 16.123323 10052414.78 - - 1.00e-04 385.5 40.64 - - - +408 1.0000 15.940021 8368795.80 - - 1.00e-04 385.7 40.64 - - - +409 1.0000 17.176556 28819293.49 - - 1.00e-04 385.9 40.64 - - - +410 1.0000 17.039173 25119955.45 - - 1.00e-04 386.1 40.64 - - - +411 1.0000 16.088869 9711964.09 - - 1.00e-04 386.2 40.64 - - - +412 1.0000 16.546625 15349986.27 - - 1.00e-04 386.4 40.64 - - - +413 1.0000 15.773336 7083908.85 - - 1.00e-04 386.6 40.64 - - - +414 1.0000 15.346972 4624928.70 - - 1.00e-04 386.8 40.64 - - - +415 1.0000 16.043835 9284293.75 - - 1.00e-04 387.0 40.64 - - - +416 1.0000 16.283541 11799181.54 - - 1.00e-04 387.2 40.64 - - - +417 1.0000 16.496347 14597304.22 - - 1.00e-04 387.4 40.64 - - - +418 1.0000 14.502753 1988225.83 - - 1.00e-04 387.6 40.64 - - - +419 1.0000 15.335470 4572040.67 - - 1.00e-04 387.8 40.64 - - - +420 1.0000 14.096808 1324848.19 - - 1.00e-04 387.9 40.64 - - - +421 1.0000 17.281000 31992118.62 - - 1.00e-04 388.1 40.64 - - - +422 1.0000 16.930304 22528766.77 - - 1.00e-04 388.3 40.64 - - - +423 1.0000 14.584811 2158256.36 - - 1.00e-04 388.5 40.64 - - - +424 1.0000 15.976306 8678037.46 - - 1.00e-04 388.7 40.64 - - - +425 1.0000 15.359017 4680976.84 8820417.12 23.0724 1.00e-04 407.8 40.64 0.3431 0.3135 0.3434 +426 1.0000 16.494467 14569877.66 - - 1.00e-04 408.0 40.64 - - - +427 1.0000 14.895820 2945589.91 - - 1.00e-04 408.2 40.64 - - - +428 1.0000 16.700575 17904719.13 - - 1.00e-04 408.4 40.64 - - - +429 1.0000 14.957518 3133050.30 - - 1.00e-04 408.6 40.64 - - - +430 1.0000 17.895708 59157153.78 - - 1.00e-04 408.8 40.64 - - - +431 1.0000 17.215714 29970181.67 - - 1.00e-04 409.0 40.64 - - - +432 1.0000 14.244315 1535420.97 - - 1.00e-04 409.1 40.64 - - - +433 1.0000 14.527661 2038370.60 - - 1.00e-04 409.3 40.64 - - - +434 1.0000 16.402826 13294039.16 - - 1.00e-04 409.5 40.64 - - - +435 1.0000 14.736773 2512459.06 - - 1.00e-04 409.7 40.64 - - - +436 1.0000 16.392395 13156085.95 - - 1.00e-04 409.9 40.64 - - - +437 1.0000 15.652086 6275013.90 - - 1.00e-04 410.1 40.64 - - - +438 1.0000 14.430334 1849330.23 - - 1.00e-04 410.3 40.64 - - - +439 1.0000 15.864594 7760780.80 - - 1.00e-04 410.5 40.64 - - - +440 1.0000 16.068602 9517108.28 - - 1.00e-04 410.7 40.64 - - - +441 1.0000 15.038841 3398488.21 - - 1.00e-04 410.8 40.64 - - - +442 1.0000 17.391508 35730245.55 - - 1.00e-04 411.0 40.64 - - - +443 1.0000 16.897987 21812360.57 - - 1.00e-04 411.2 40.64 - - - +444 1.0000 15.083410 3553381.57 - - 1.00e-04 411.4 40.64 - - - +445 1.0000 17.950001 62457742.31 - - 1.00e-04 411.6 40.64 - - - +446 1.0000 16.585436 15957442.09 - - 1.00e-04 411.8 40.64 - - - +447 1.0000 14.971326 3176612.31 - - 1.00e-04 412.0 40.64 - - - +448 1.0000 15.332577 4558830.83 - - 1.00e-04 412.2 40.64 - - - +449 1.0000 14.162621 1414972.34 - - 1.00e-04 412.3 40.64 - - - +450 1.0000 15.698256 6571518.82 8496263.42 23.0184 1.00e-04 431.5 40.64 0.3431 0.3135 0.3434 +451 1.0000 15.054428 3451874.93 - - 1.00e-04 431.7 40.64 - - - +452 1.0000 15.212586 4043359.25 - - 1.00e-04 431.9 40.64 - - - +453 1.0000 15.066353 3493284.06 - - 1.00e-04 432.0 40.64 - - - +454 1.0000 16.376699 12951205.72 - - 1.00e-04 432.2 40.64 - - - +455 1.0000 15.645795 6235659.27 - - 1.00e-04 432.4 40.64 - - - +456 1.0000 15.419443 4972548.93 - - 1.00e-04 432.6 40.64 - - - +457 1.0000 16.470572 14225853.24 - - 1.00e-04 432.8 40.64 - - - +458 1.0000 16.947273 22914334.98 - - 1.00e-04 433.0 40.64 - - - +459 1.0000 15.739437 6847793.92 - - 1.00e-04 433.2 40.64 - - - +460 1.0000 14.501537 1985809.75 - - 1.00e-04 433.4 40.64 - - - +461 1.0000 15.664286 6352034.16 - - 1.00e-04 433.6 40.64 - - - +462 1.0000 17.169739 28623504.37 - - 1.00e-04 433.7 40.64 - - - +463 1.0000 15.566946 5762869.18 - - 1.00e-04 433.9 40.64 - - - +464 1.0000 14.081024 1304100.61 - - 1.00e-04 434.1 40.64 - - - +465 1.0000 15.382178 4790657.89 - - 1.00e-04 434.3 40.64 - - - +466 1.0000 16.377834 12965912.03 - - 1.00e-04 434.5 40.64 - - - +467 1.0000 17.427671 37046018.18 - - 1.00e-04 434.7 40.64 - - - +468 1.0000 16.139488 10216230.31 - - 1.00e-04 434.9 40.64 - - - +469 1.0000 15.881677 7894498.43 - - 1.00e-04 435.1 40.64 - - - +470 1.0000 16.874821 21312849.45 - - 1.00e-04 435.2 40.64 - - - +471 1.0000 14.063340 1281241.63 - - 1.00e-04 435.4 40.64 - - - +472 1.0000 15.073185 3517232.33 - - 1.00e-04 435.6 40.64 - - - +473 1.0000 17.112825 27039934.69 - - 1.00e-04 435.8 40.64 - - - +474 1.0000 16.001474 8899221.69 - - 1.00e-04 436.0 40.64 - - - +475 1.0000 15.138662 3755238.43 8578213.56 23.0322 1.00e-04 455.2 40.64 0.3431 0.3135 0.3434 +476 1.0000 15.831899 7511146.61 - - 1.00e-04 455.3 40.64 - - - +477 1.0000 14.439591 1866529.55 - - 1.00e-04 455.5 40.64 - - - +478 1.0000 15.221462 4079407.23 - - 1.00e-04 455.7 40.64 - - - +479 1.0000 15.805831 7317878.02 - - 1.00e-04 455.9 40.64 - - - +480 1.0000 15.026462 3356675.48 - - 1.00e-04 456.1 40.64 - - - +481 1.0000 15.943274 8396063.68 - - 1.00e-04 456.3 40.64 - - - +482 1.0000 14.808249 2698615.64 - - 1.00e-04 456.5 40.64 - - - +483 1.0000 16.028025 9138663.19 - - 1.00e-04 456.7 40.64 - - - +484 1.0000 18.031034 67729640.92 - - 1.00e-04 456.8 40.64 - - - +485 1.0000 16.722122 18294704.28 - - 1.00e-04 457.0 40.64 - - - +486 1.0000 15.834826 7533162.63 - - 1.00e-04 457.2 40.64 - - - +487 1.0000 15.912129 8138605.45 - - 1.00e-04 457.4 40.64 - - - +488 1.0000 15.672761 6406098.36 - - 1.00e-04 457.6 40.64 - - - +489 1.0000 15.834946 7534067.89 - - 1.00e-04 457.8 40.64 - - - +490 1.0000 13.973710 1171399.87 - - 1.00e-04 458.0 40.64 - - - +491 1.0000 14.330355 1673377.66 - - 1.00e-04 458.2 40.64 - - - +492 1.0000 16.400848 13267770.59 - - 1.00e-04 458.4 40.64 - - - +493 1.0000 14.427358 1843834.03 - - 1.00e-04 458.5 40.64 - - - +494 1.0000 15.368487 4725516.18 - - 1.00e-04 458.7 40.64 - - - +495 1.0000 15.380669 4783431.05 - - 1.00e-04 458.9 40.64 - - - +496 1.0000 14.697004 2414503.70 - - 1.00e-04 459.1 40.64 - - - +497 1.0000 15.141884 3767355.48 - - 1.00e-04 459.3 40.64 - - - +498 1.0000 16.121389 10032991.68 - - 1.00e-04 459.5 40.64 - - - +499 1.0000 16.472588 14254562.45 - - 1.00e-04 459.7 40.64 - - - +500 1.0000 15.577523 5824147.71 10033380.70 23.2583 1.00e-04 478.9 40.64 0.3431 0.3135 0.3434 diff --git a/train.py b/train.py index b5b305e..3583a73 100644 --- a/train.py +++ b/train.py @@ -37,13 +37,14 @@ DEFAULTS = dict( steps=1000, batch_size=4, seq_len=256, - learning_rate=1e-4, + learning_rate=1e-5, warmup_steps=100, eval_every=50, # Quantization group_size=128, # Bonsai uses 128 - quant_warmup_steps=1000, # lambda warmup over N steps - activation_bits=8, # INT8 activations + quant_warmup_steps=2000, # lambda warmup over N steps + activation_bits=16, # 16 = no activation quant (use 8 for INT8) + threshold=0.0, # deadzone threshold (0 = no deadzone, 0.5 = Bonsai-style) # Data train_dataset="roneneldan/TinyStories", eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"), @@ -55,7 +56,7 @@ DEFAULTS = dict( # Ternary Quantization Primitives # --------------------------------------------------------------------------- -def ternary_quantize(w, group_size=128): +def ternary_quantize(w, group_size=128, threshold=0.5): """Quantize weights to {-1, 0, +1} with per-group scale. Groups are formed by flattening the weight tensor and taking consecutive @@ -64,6 +65,7 @@ def ternary_quantize(w, group_size=128): Args: w: weight tensor of any shape group_size: number of weights per quantization group + threshold: deadzone threshold (0 < t < 1). Weights with |w_norm| < t are zeroed. Returns: w_quant: ternary weights {-1, 0, +1} in original shape @@ -80,14 +82,21 @@ def ternary_quantize(w, group_size=128): w_groups = w_flat.reshape(-1, group_size) - # Scale: mean(|w|) per group (HF blog approach) + # Scale: mean(|w|) per group (balanced ternary distribution) abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6) scale = abs_mean - # Normalize, clamp to [-1, 1], round to nearest ternary + # Normalize, clamp to [-1, 1], apply threshold, round to nearest ternary w_norm = w_groups / scale w_clamped = w_norm.clamp(-1.0, 1.0) - w_quant = torch.round(w_clamped) # {-1, 0, +1} + if threshold > 0: + # Hard threshold: |w_norm| < threshold → 0, else ±1 + w_quant = torch.where(w_clamped.abs() < threshold, + torch.zeros_like(w_clamped), + torch.sign(w_clamped)) + else: + # Soft rounding: round to nearest ternary + w_quant = torch.round(w_clamped) # Reshape back to original (trim padding) w_quant = w_quant.reshape(-1)[:n].reshape(original_shape) @@ -153,14 +162,15 @@ class BitLinear(nn.Module): """ def __init__(self, in_features, out_features, bias=True, group_size=128, - activation_bits=8): + activation_bits=8, threshold=0.5): super().__init__() self.in_features = in_features self.out_features = out_features self.group_size = group_size self.activation_bits = activation_bits + self.threshold = threshold - # FP16 weights (learned in full precision, quantized at forward time) + # FP32 weights (learned in full precision, quantized at forward time) self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02) if bias: self.bias = nn.Parameter(torch.zeros(out_features, dtype=DEFAULTS['dtype'])) @@ -179,7 +189,8 @@ class BitLinear(nn.Module): return out # Quantize weights - w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size) + w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size, + threshold=self.threshold) # Dequantize for forward pass w_dequant = ternary_dequantize(w_quant, scale, self.group_size) @@ -195,8 +206,8 @@ class BitLinear(nn.Module): # Straight-through estimator with lambda warmup: # out = out_fp + lambda * (out_quant - out_fp).detach() - # This lets gradients flow through FP weights while gradually - # exposing the network to quantized forward pass + # When lambda=0: pure FP forward (no quantization) + # When lambda=1: quantized forward, gradients through FP (full STE) out_fp = F.linear(x, self.weight, self.bias) out = out_fp + lambda_ * (out_quant - out_fp).detach() @@ -212,7 +223,7 @@ class BitLinear(nn.Module): # --------------------------------------------------------------------------- def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8, - exclude_embeddings=True): + exclude_embeddings=True, threshold=0.5): """Replace all nn.Linear layers in model with BitLinear. Args: @@ -220,6 +231,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8, group_size: quantization group size activation_bits: activation quantization bits (16 = no quant) exclude_embeddings: don't replace lm_head/embedding (usually) + threshold: deadzone threshold for ternary quantization (0 < t < 1) """ count = 0 for name, module in model.named_modules(): @@ -235,6 +247,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8, bias=module.bias is not None, group_size=group_size, activation_bits=activation_bits, + threshold=threshold, ) # Initialize from FP weights (critical for warmup to work) @@ -410,6 +423,7 @@ def train(args): model, group_size=args.group_size, activation_bits=args.activation_bits, + threshold=args.threshold, ) print(f"Replaced {n_replaced} Linear layers with BitLinear") @@ -581,6 +595,7 @@ def main(): parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"], dest="group_size") parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps") parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits") + parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold") parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset") parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path") parser.add_argument("--log-file", default=DEFAULTS["log_file"])