mirror of
https://github.com/MartinThoma/LaTeX-examples.git
synced 2025-04-25 22:38:04 +02:00
tikz/validation-curve: Update
This commit is contained in:
parent
9c5d73dbff
commit
7f1684d2c1
24 changed files with 1988 additions and 1715 deletions
66
tikz/validation-curve/edit_curve.py
Normal file → Executable file
66
tikz/validation-curve/edit_curve.py
Normal file → Executable file
|
@ -4,8 +4,10 @@ import csv
|
|||
import glob
|
||||
import numpy as np
|
||||
|
||||
history_files = glob.glob("*.csv")
|
||||
history_files = glob.glob("cifar100_*.csv")
|
||||
print(history_files)
|
||||
data = []
|
||||
loss = []
|
||||
|
||||
for filename in history_files:
|
||||
print(filename)
|
||||
|
@ -18,7 +20,7 @@ for filename in history_files:
|
|||
if i == 0:
|
||||
continue
|
||||
data.append([float(el[2])])
|
||||
print(data)
|
||||
loss.append([float(el[1])])
|
||||
else:
|
||||
for i, el in enumerate(datathis):
|
||||
if i == 0:
|
||||
|
@ -26,20 +28,60 @@ for filename in history_files:
|
|||
if i == len(data):
|
||||
break
|
||||
data[i - 1].append(float(el[2]))
|
||||
loss[i - 1].append(float(el[1]))
|
||||
|
||||
print("!" * 80)
|
||||
print(data)
|
||||
print("-" * 80)
|
||||
# crop to where all are trained
|
||||
print(len(data))
|
||||
orderings = []
|
||||
for i, el in enumerate(data):
|
||||
if len(el) != len(history_files):
|
||||
break
|
||||
data = data[:i]
|
||||
loss = loss[:i]
|
||||
print(len(data))
|
||||
|
||||
|
||||
# orderings
|
||||
def get_changes(ord1, ord2):
|
||||
"""Count how often the order changes between two elements."""
|
||||
changes = 0
|
||||
for i in range(10):
|
||||
for j in range(i + 1, 10):
|
||||
o1go2 = (ord1.index(i) > ord1.index(j) and
|
||||
ord2.index(i) > ord2.index(j))
|
||||
o1lo2 = (ord1.index(i) < ord1.index(j) and
|
||||
ord2.index(i) < ord2.index(j))
|
||||
if not (o1go2 or o1lo2):
|
||||
changes += 1
|
||||
|
||||
return changes
|
||||
|
||||
if len(history_files) > 1:
|
||||
orderings = []
|
||||
for row in data:
|
||||
ordering = zip(range(10), row)
|
||||
ordering = sorted(ordering, key=lambda n: n[1])
|
||||
ordering = [el[0] for el in ordering]
|
||||
orderings.append(ordering)
|
||||
get_changes(orderings[0], orderings[1])
|
||||
|
||||
change_list = []
|
||||
for ord1, ord2 in zip(orderings, orderings[1:]):
|
||||
changes = get_changes(ord1, ord2)
|
||||
change_list.append(changes)
|
||||
change_list = np.array(change_list)
|
||||
|
||||
print("change_list = {}".format(change_list.mean()))
|
||||
|
||||
# write
|
||||
max_range = 0
|
||||
with open('baseline_cifar_test_acc.csv', 'w') as fp:
|
||||
writer = csv.writer(fp, delimiter=',')
|
||||
writer.writerow(["epoch", "min_acc", "max_acc", "mean_acc"])
|
||||
writer.writerow(["epoch", "min_acc", "max_acc", "mean_acc", "mean_loss"])
|
||||
for epoch, row in enumerate(data):
|
||||
if len(row) < 10:
|
||||
print(row)
|
||||
print(len(row))
|
||||
if len(row) < len(history_files):
|
||||
break
|
||||
max_range = max(max_range, max(row) - min(row))
|
||||
print("max range={}, epoch={}".format(max_range, epoch))
|
||||
writer.writerow([epoch, min(row), max(row), np.array(row).mean()])
|
||||
print(max_range)
|
||||
writer.writerow([epoch, min(row), max(row), np.array(row).mean(),
|
||||
np.array(loss[epoch]).mean()])
|
||||
print("max_range={}".format(max_range))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue