Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in interpolate.dmrg_cross #32

Open
MazenAli opened this issue Dec 10, 2024 · 1 comment
Open

Bug in interpolate.dmrg_cross #32

MazenAli opened this issue Dec 10, 2024 · 1 comment

Comments

@MazenAli
Copy link

Describe the bug
DMRG cross fails due to shape mismatch.

To Reproduce
Too long to include everything here.

This is the call

cross_func = lambda idx: torch.tensor(self._cross_func(idx, eval_point)[i, j])
tensor_shape = [2]*num_sites
tt = tn.interpolate.dmrg_cross(cross_func, tensor_shape, verbose=True)

This is the printed output and traceback:

Sweep 1: 
	LR supercore 1,2
		number evaluations 8
		rank updated 2 -> 2, local error 1.566035e+02
	LR supercore 2,3
		number evaluations 16
		rank updated 2 -> 4, local error 8.495978e-01
	LR supercore 3,4
		number evaluations 32
		rank updated 2 -> 5, local error 8.368856e-01
	LR supercore 4,5
		number evaluations 40
		rank updated 2 -> 5, local error 1.144697e+00
	LR supercore 5,6
		number evaluations 20
		rank updated 2 -> 4, local error 4.816946e-16
	RL supercore 5,6
		number evaluations 20
		rank updated 4 -> 2, local error 3.952443e-16
	RL supercore 4,5
		number evaluations 40

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 1
----> 1 J = get_tensor_networks()
      2 print(J)

File /app/tn.py:318, in TN.get_tensor_networks(self)
    316     cross_func = lambda idx: torch.tensor(self._cross_func(idx, eval_point)[i, j])
    317     tensor_shape = [2]*num_sites
--> 318     tt = tn.interpolate.dmrg_cross(cross_func, tensor_shape, verbose=True)
    320     col.append(tt)
    322 row.append(col)

File /usr/local/lib/python3.12/site-packages/torchtt/interpolate.py:693, in dmrg_cross(function, N, eps, nswp, x_start, kick, dtype, device, eval_vect, rmax, verbose)
    690 if radd > 0:
    691     U = tn.cat(
    692         (U, tn.zeros((U.shape[0], radd), dtype=dtype, device=device)), 1)
--> 693     U = U @ Rtemp.T
    694     V = V.t()
    696 # print('kkt new',tn.linalg.norm(supercore-U@V))
    697 # compute err (dx)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x4 and 5x4)
@ion-g-ion
Copy link
Owner

ion-g-ion commented Dec 14, 2024

Hi,
Thanks for reporting it. Since I could not reproduce it (it seems that it appears only for some functions), I did some blind changes in order to avoid the multiplication error. It is pushed in the latest git version. I hope this solves it...

BR,
Ion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants