You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ImportError: cannot import name 'soft_unicode' from 'markupsafe'
ImportError: Pandas requires version '3.0.0' or newer of 'jinja2'
These can be solved by adding 2 lines to tpu_requirements.txt
markupsafe==2.0.1
jinja2~=3.0.0
DeprecationWarning: concurrency_count has been deprecated. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via max_threads in launch().
I was able to solve this by deleting concurrency_count=1 in serving.py, line 403.
According to Gradio v4.0.0 changelog, concurrency_count is removed and can be replaced with concurrency_limit. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.
2. Structure error
However, when I solve deprecation errors above, this error appears:
Error Log
I1107 06:16:48.996244 140573565926464 mesh_utils.py:260] Reordering mesh to physical ring order on single-tray TPU v2/v3.
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name user_fn already exists, using user_fn_1
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_1
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_2
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
Traceback (most recent call last):
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
mlxu.run(main)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
server.run()
File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
self.loglikelihood(pre_compile_data, pre_compile_data)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 775, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
logits = hf_model.module.apply(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/flax/linen/module.py", line 1511, in apply
return apply(
File "$HOME/.local/lib/python3.8/site-packages/flax/core/scope.py", line 930, in wrapper
raise errors.ApplyScopeInvalidVariablesStructureError(variables)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
mlxu.run(main)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
server.run()
File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
self.loglikelihood(pre_compile_data, pre_compile_data)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
logits = hf_model.module.apply(
It seems like something went wrong with "params" loading at function load_trainstate_checkpoint in checkpoint.py, but I couldn't figure where.
Is there someone who knows what's wrong?
The text was updated successfully, but these errors were encountered:
When I try to serve LLaMA with
v3_8
TPU as suggested in example script, there were some errors.Environment
v3-8
tpu-vm-base
Command
1. Deprecation warning
These can be solved by adding 2 lines to
tpu_requirements.txt
I was able to solve this by deleting
concurrency_count=1
inserving.py
, line 403.According to Gradio v4.0.0 changelog,
concurrency_count
is removed and can be replaced withconcurrency_limit
. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.2. Structure error
However, when I solve deprecation errors above, this error appears:
Error Log
It seems like something went wrong with
"params"
loading at functionload_trainstate_checkpoint
incheckpoint.py
, but I couldn't figure where.Is there someone who knows what's wrong?
The text was updated successfully, but these errors were encountered: